# Parameter-Efficient Fine-Tuning (PEFT) methods
from peft import (
    LoraConfig,
    get_peft_model,
)
import peft
from peft.tuners.lora import *  # for PRILoRA patch
from peft.utils import _get_submodules
import os

print(f"PRILoRA: peft version: {peft.__version__}")

def prilora_after_model_loaded(model_args, model, training_args, data_args=None):
    if model_args.apply_lora:

        if peft.__version__ != "0.5.0.dev0":
            raise Exception("peft version changed")

        if model_args.prilora_exp is not None:
            print(f"doing exp: {model_args.prilora_exp}")

        if 'prilora' in model_args.prilora_exp:  # LORA with pruning of A according to importance
            import torch
            import torch.nn as nn
            import math
            import torch.nn.functional as F

            def _find_and_replace(self, adapter_name):
                lora_config = self.peft_config[adapter_name]
                self._check_quantization_dependency()
                is_target_modules_in_base_model = False
                key_list = [key for key, _ in self.model.named_modules()]
                total_ranks = 0
                for key in key_list:
                    if not self._check_target_module_exists(lora_config, key):
                        continue

                    is_target_modules_in_base_model = True
                    parent, target, target_name = _get_submodules(self.model, key)

                    if isinstance(target, LoraLayer) and isinstance(target, torch.nn.Conv2d):
                        target.update_layer_conv2d(
                            adapter_name,
                            lora_config.r,
                            lora_config.lora_alpha,
                            lora_config.lora_dropout,
                            lora_config.init_lora_weights,
                        )
                    elif isinstance(target, LoraLayer) and isinstance(target, torch.nn.Embedding):
                        target.update_layer_embedding(
                            adapter_name,
                            lora_config.r,
                            lora_config.lora_alpha,
                            lora_config.lora_dropout,
                            lora_config.init_lora_weights,
                        )

                    elif isinstance(target, LoraLayer):
                        target.update_layer(
                            adapter_name,
                            lora_config.r,
                            lora_config.lora_alpha,
                            lora_config.lora_dropout,
                            lora_config.init_lora_weights,
                        )
                    else:

                        if True:  # PRILoRA patch

                            if 'deberta' in model_args.model_name_or_path:
                                layer_number = int(key.split(".")[3]) + 1
                            else:
                                raise Exception("wrong model name")

                            new_rank = lora_config.r

                            if 'deberta' in model_args.model_name_or_path:
                                num_layers = 12
                                start_rank = model_args.lin_start
                                end_rank = model_args.lin_stop

                                if not hasattr(self, 'print_1'):
                                    print(f"doing gradual increase: {start_rank}->{end_rank}")
                                    self.print_1 = True

                                new_rank = int(
                                    start_rank + (float(layer_number) / num_layers) * (end_rank - start_rank))
                            else:
                                raise Exception("wrong model name")

                            lora_config.r = new_rank
                            total_ranks += new_rank
                            print(f'PRILoRA patch, layer {key} new rank {new_rank} total rank = {total_ranks}')

                        lora_config.path = key
                        lora_config.lora_model = self

                        new_module = self._create_new_module(lora_config, adapter_name, target)
                        new_module.key = key  # patch by PRILoRA
                        self._replace_module(parent, target_name, new_module, target)

                if not is_target_modules_in_base_model:
                    raise ValueError(
                        f"Target modules {lora_config.target_modules} not found in the base model. "
                        f"Please check the target modules and try again."
                    )

            peft.LoraModel._find_and_replace = _find_and_replace

            def _create_new_module(self, lora_config, adapter_name, target):
                bias = hasattr(target, "bias") and target.bias is not None
                kwargs = {
                    "r": lora_config.r,
                    "lora_alpha": lora_config.lora_alpha,
                    "lora_dropout": lora_config.lora_dropout,
                    "fan_in_fan_out": lora_config.fan_in_fan_out,
                    "init_lora_weights": lora_config.init_lora_weights,
                    "path": lora_config.path  # patch by PRILoRA
                }
                loaded_in_4bit = getattr(self.model, "is_loaded_in_4bit", False)
                loaded_in_8bit = getattr(self.model, "is_loaded_in_8bit", False)

                if loaded_in_8bit and isinstance(target, bnb.nn.Linear8bitLt):
                    eightbit_kwargs = kwargs.copy()
                    eightbit_kwargs.update(
                        {
                            "has_fp16_weights": target.state.has_fp16_weights,
                            "memory_efficient_backward": target.state.memory_efficient_backward,
                            "threshold": target.state.threshold,
                            "index": target.index,
                        }
                    )
                    new_module = Linear8bitLt(
                        adapter_name, target.in_features, target.out_features, bias=bias, **eightbit_kwargs
                    )
                elif loaded_in_4bit and is_bnb_4bit_available() and isinstance(target, bnb.nn.Linear4bit):
                    fourbit_kwargs = kwargs.copy()
                    fourbit_kwargs.update(
                        {
                            "compute_dtype": target.compute_dtype,
                            "compress_statistics": target.weight.compress_statistics,
                            "quant_type": target.weight.quant_type,
                        }
                    )
                    new_module = Linear4bit(adapter_name, target.in_features, target.out_features, bias=bias,
                                            **fourbit_kwargs)
                elif isinstance(target, torch.nn.Embedding):
                    embedding_kwargs = kwargs.copy()
                    embedding_kwargs.pop("fan_in_fan_out", None)
                    in_features, out_features = target.num_embeddings, target.embedding_dim
                    new_module = Embedding(adapter_name, in_features, out_features, **embedding_kwargs)
                elif isinstance(target, torch.nn.Conv2d):
                    out_channels, in_channels = target.weight.size()[:2]
                    kernel_size = target.weight.size()[2:]
                    stride = target.stride
                    padding = target.padding
                    new_module = Conv2d(adapter_name, in_channels, out_channels, kernel_size, stride, padding, **kwargs)
                else:
                    if isinstance(target, torch.nn.Linear):
                        in_features, out_features = target.in_features, target.out_features
                        if kwargs["fan_in_fan_out"]:
                            warnings.warn(
                                "fan_in_fan_out is set to True but the target module is `torch.nn.Linear`. "
                                "Setting fan_in_fan_out to False."
                            )
                            kwargs["fan_in_fan_out"] = lora_config.fan_in_fan_out = False
                    elif isinstance(target, Conv1D):
                        in_features, out_features = (
                            target.weight.ds_shape if hasattr(target.weight, "ds_shape") else target.weight.shape
                        )
                        kwargs["is_target_conv_1d_layer"] = True
                        if not kwargs["fan_in_fan_out"]:
                            warnings.warn(
                                "fan_in_fan_out is set to False but the target module is `Conv1D`. "
                                "Setting fan_in_fan_out to True."
                            )
                            kwargs["fan_in_fan_out"] = lora_config.fan_in_fan_out = True
                    else:
                        raise ValueError(
                            f"Target module {target} is not supported. "
                            f"Currently, only `torch.nn.Linear` and `Conv1D` are supported."
                        )
                    new_module = Linear(adapter_name, in_features, out_features, bias=bias, **kwargs)

                return new_module

            peft.LoraModel._create_new_module = _create_new_module

            def new_init(self, in_features: int, out_features: int, **kwargs):
                # LoraLayer.init
                self.r = {}
                self.lora_alpha = {}
                self.scaling = {}
                self.lora_dropout = nn.ModuleDict({})

                self.register_buffer('ema', None)

                self.lora_A = nn.ModuleDict({})
                self.lora_B = nn.ModuleDict({})

                # For Embedding layer
                self.lora_embedding_A = nn.ParameterDict({})
                self.lora_embedding_B = nn.ParameterDict({})
                # Mark the weight as unmerged
                self.merged = False
                self.disable_adapters = False
                self.in_features = in_features
                self.out_features = out_features
                self.kwargs = kwargs

            peft.tuners.lora.LoraLayer.__init__ = new_init

            def new_update_layer_function(self, adapter_name, r, lora_alpha, lora_dropout, init_lora_weights):

                self.r[adapter_name] = r
                self.lora_alpha[adapter_name] = lora_alpha
                if lora_dropout > 0.0:
                    lora_dropout_layer = nn.Dropout(p=lora_dropout)
                else:
                    lora_dropout_layer = nn.Identity()

                self.lora_dropout.update(nn.ModuleDict({adapter_name: lora_dropout_layer}))
                # Actual trainable parameters
                if r > 0:

                    self.lora_A.update(
                        nn.ModuleDict({adapter_name: nn.Linear(self.in_features, r, bias=False)}))
                    self.lora_B.update(nn.ModuleDict({adapter_name: nn.Linear(r, self.out_features, bias=False)}))
                    self.scaling[adapter_name] = lora_alpha / r

                if init_lora_weights:
                    self.reset_lora_parameters(adapter_name)

                self.to(self.weight.device)

            peft.tuners.lora.LoraLayer.update_layer = new_update_layer_function

            def new_reset_lora_parameters(self, adapter_name):
                if adapter_name in self.lora_A.keys():
                    # initialize A the same way as the default for nn.Linear and B to zero

                    nn.init.kaiming_uniform_(self.lora_A[adapter_name].weight, a=math.sqrt(5))
                    nn.init.zeros_(self.lora_B[adapter_name].weight)

                if adapter_name in self.lora_embedding_A.keys():
                    # initialize a the same way as the default for nn.linear and b to zero
                    nn.init.zeros_(self.lora_embedding_A[adapter_name])
                    nn.init.normal_(self.lora_embedding_B[adapter_name])

            peft.tuners.lora.LoraLayer.reset_lora_parameters = new_reset_lora_parameters

            def linear__init__(
                    self,
                    adapter_name: str,
                    in_features: int,
                    out_features: int,
                    r: int = 0,
                    lora_alpha: int = 1,
                    lora_dropout: float = 0.0,
                    fan_in_fan_out: bool = False,
                    # Set this to True if the layer to replace stores weight like (fan_in, fan_out)
                    is_target_conv_1d_layer: bool = False,
                    **kwargs,
            ):
                init_lora_weights = kwargs.pop("init_lora_weights", True)
                path = kwargs.pop("path", True)  # patch by PRILoRA
                self.path = path
                nn.Linear.__init__(self, in_features, out_features, **kwargs)
                LoraLayer.__init__(self, in_features=in_features, out_features=out_features)
                # Freezing the pre-trained weight matrix
                self.weight.requires_grad = False

                self.fan_in_fan_out = fan_in_fan_out
                if fan_in_fan_out:
                    self.weight.data = self.weight.data.T

                nn.Linear.reset_parameters(self)
                self.update_layer(adapter_name, r, lora_alpha, lora_dropout, init_lora_weights)
                self.active_adapter = adapter_name
                self.is_target_conv_1d_layer = is_target_conv_1d_layer

                # PRILoRA
                self.decay_rate = 0.9
                self.forward_counter = 0
                self.forward_counter_after_a = 0
                self.num_layer_pruned = 0

                # PRILoRA patch
                if model_args.prilora_prune is not False:
                    if not hasattr(self.__class__, 'prilora_prune'):
                        print(
                            f"********** prilora_prune: {model_args.prilora_prune}")
                        self.__class__.prilora_prune = True
                    self.register_forward_pre_hook(self.prune_weights)

            peft.tuners.lora.Linear.__init__ = linear__init__

            def new_forward(self, x: torch.Tensor):
                # Linear.forward
                previous_dtype = x.dtype
                if self.active_adapter not in self.lora_A.keys():
                    return F.linear(x, transpose(self.weight, self.fan_in_fan_out), bias=self.bias)
                if self.disable_adapters:
                    if self.r[self.active_adapter] > 0 and self.merged:
                        self.unmerge()
                    result = F.linear(x, transpose(self.weight, self.fan_in_fan_out), bias=self.bias)
                elif self.r[self.active_adapter] > 0 and not self.merged:

                    if True:  # PRILoRA
                        x_norm = x.detach().norm(p=2, dim=(0, 1))
                        if self.ema is None:
                            self.ema = x_norm
                            self.forward_counter = 1

                        else:
                            # Update the exponential moving average
                            self.ema = self.decay_rate * self.ema + (1 - self.decay_rate) * x_norm
                            self.forward_counter += 1

                    # The regular LINEAR part
                    result = F.linear(x, transpose(self.weight, self.fan_in_fan_out), bias=self.bias)

                    x_after_A = self.lora_A[self.active_adapter](self.lora_dropout[self.active_adapter](x))

                    result += (
                            self.lora_B[self.active_adapter](
                                x_after_A
                            )

                            * self.scaling[self.active_adapter]
                    )

                else:
                    result = F.linear(x, transpose(self.weight, self.fan_in_fan_out), bias=self.bias)

                result = result.to(previous_dtype)

                return result

            peft.tuners.lora.Linear.forward = new_forward

            def prune_weights(self, module, input):

                if self.forward_counter > 20 and self.forward_counter % 40 == 0:

                        def do_prune(weight_matrix,
                                     avg_input_vector,
                                     dim,
                                     sparsity,
                                     ):

                            with torch.no_grad():

                                S = torch.abs(weight_matrix) * avg_input_vector

                                sorted_weights, unused = torch.sort(S, dim=dim)

                                if dim == 1:
                                    threshold_values = sorted_weights[:, int(weight_matrix.shape[dim] * sparsity)]
                                    mask = S < threshold_values[:, None]

                                elif dim == 0:
                                    threshold_values = sorted_weights[int(weight_matrix.shape[dim] * sparsity), :]
                                    mask = S < threshold_values[None, :]

                                weight_matrix[mask] = 0

                        self.num_layer_pruned += 1

                        if not hasattr(self.__class__, 'dbgmes2'):
                            print(f"\n\n*** PRILoRA: {model_args.prilora_prune} *******")
                            self.__class__.dbgmes2 = True

                        do_prune(weight_matrix=self.lora_A[self.active_adapter].weight,
                                avg_input_vector=self.ema,
                                dim=1,  # 1 for select best in a row
                                sparsity=model_args.prilora_prune,
                                )

            peft.tuners.lora.Linear.prune_weights = prune_weights

        print("Model's state_dict:")
        for param_tensor in model.state_dict():
            print(param_tensor, "\t", model.state_dict()[param_tensor].size())

        print(f"lr: {training_args.learning_rate}")
        print(f"batch size: {training_args.per_device_train_batch_size}")
        print(f"num_train_epochs: {training_args.num_train_epochs}")
        print(f"load_best_model_at_end: {training_args.load_best_model_at_end}")

        training_args.run_name = f"process {os.getpid()}"

        print(f"lora_dropout = {model_args.lora_dropout}")
        if 'deberta' in model_args.model_name_or_path:
            target_modules = ["query_proj",
                              "value_proj",
                              "key_proj",
                              "attention.output.dense",
                              "intermediate.dense",
                              "output.dense"
                              ]
        else:
            raise Exception("wrong model")

        if hasattr(model_args, 'lora_alpha'):
            if model_args.lora_alpha!=8 and model_args.lora_alpha!=16 and model_args.lora_alpha!=32:
                raise Exception('Remove at your own risk')

        print(f"PRILoRA: target lora modules: {target_modules}")
        print(f"PRILoRA: Lora alpha: {model_args.lora_alpha}")

        task_type = "SEQ_CLS"

        lora_config = LoraConfig(
            r=8,
            lora_alpha=model_args.lora_alpha,
            target_modules=target_modules,
            lora_dropout=model_args.lora_dropout,
            bias="none",
            task_type=task_type
        )

        print(f"freeze all weights")
        for name, parameter in model.named_parameters():
            parameter.requires_grad = False

        model = get_peft_model(model, lora_config)

        # for name, param in model.named_parameters():
        #     if param.requires_grad:
        #         print(f"PRILoRA, trainable: Name: {name}, Size: {param.size()}")

        model.print_trainable_parameters()

        return model
