import torch
from torch import nn
from transformers import PretrainedConfig
from src.models.bart.hypernet_bart import BartModel, BartConfig
from src.models.modeling_t5_lora import T5ForConditionalGeneration
from src.utils.utils import load_lora_dict_by_task

class HyperLoRAModelForPretrain(nn.Module):
    def __init__(
            self,
            config,
            model_args=None,
            lora_args=None,
            encoder=None,
            pretrain_task_names=None,
        ):
        super().__init__()
        self.config = config
        self.model_args = model_args
        self.lora_args = lora_args
        new_config = BartConfig.from_pretrained(model_args.hyper_model_name_or_path)
        # update config
        for key in config.to_dict():
            if key not in new_config.to_dict():
                setattr(new_config, key, config.to_dict()[key])
        self.hyperlora = BartModel.from_pretrained(model_args.hyper_model_name_or_path, config=new_config, torch_dtype=torch.bfloat16)

        self.beta = model_args.loss_beta

        # define downstream plm
        self.down_plm = T5ForConditionalGeneration.from_pretrained(model_args.model_name_or_path, adapter_config=lora_args, torch_dtype=torch.bfloat16)

        self.down_plm.apply_lora()
        # to solve the error
        self.generation_config = self.down_plm.generation_config

        self.pretrain_task_names = pretrain_task_names
        self.id2task = {i: task for i, task in enumerate(pretrain_task_names)}
        self.task2id = {task: i for i, task in enumerate(pretrain_task_names)}

        self.freeze_plm()

    def forward(
        self,
        input_ids,
        attention_mask=None,
        labels=None,
        task_ids=None,
        demo_input_ids=None,
        demo_attention_mask=None
    ):
        # assert one batch contains one task
        assert len(set(task_ids.tolist())) == 1

        # generate lora parameters
        generate_lora_weights = self.hyperlora(input_ids=demo_input_ids[0].unsqueeze(0),
                                               attention_mask=demo_attention_mask[0].unsqueeze(0)
                                               )

        # apply to plm
        output = self.down_plm(input_ids=input_ids,
                               attention_mask=attention_mask,
                               labels=labels,
                               generate_adapters=generate_lora_weights)

        # compute loss
        lm_loss = output.loss

        if not self.model_args.finetune:
            lora_l2_loss = self.compute_l2_loss(task_ids, generate_lora_weights)
            loss = lm_loss + self.beta * lora_l2_loss
        else:
            lora_l2_loss = lm_loss
            loss = lm_loss

        return {'lm_loss': lm_loss, 'lora_l2_loss': self.beta * lora_l2_loss, 'loss': loss, 'logits': output.logits}

    def print_mean_params(self):
        # encoder
        hyper_mean, down_mean = 0, 0
        for name, param in self.hyperlora.named_parameters():
            hyper_mean += param.data.mean().item()

        for name, param in self.down_plm.named_parameters():
            down_mean += param.data.mean().item()
        print('hyper_mean: {}, down_mean: {}'.format(hyper_mean, down_mean))

    def freeze_plm(self):
        if not self.model_args.finetune:
            # freeze encoder and plm_model, train hyperlora decoder
            for name, param in self.hyperlora.named_parameters():
                if 'encoder' in name:
                    param.requires_grad = False

            for name, param in self.down_plm.named_parameters():
                param.requires_grad = False
        else:
            # freeze hyperlora encoder, decoder, plm_model, train lora weights
            for name, param in self.hyperlora.named_parameters():
                param.requires_grad = False

            for name, param in self.down_plm.named_parameters():
                param.requires_grad = False

    def compute_l2_loss(self, task_ids, generate_lora_weights):
        """_summary_

        Args:
            task_ids (_type_): _description_
            generate_lora_weights (_type_): {'encoder': {'lora_q_a': []xl, 'lora_q_b': []xl},
                                             'decoder': [{'decoder_attn': {}, 'cross_attn': {}}]}

        Returns:
            _type_: loss
        """
        task_name = self.id2task[task_ids[0].item()]
        golden_lora_dict = load_lora_dict_by_task(self.model_args, task_name)
        l2_loss, l2_loss_fct = 0, nn.MSELoss()

        # (bsz, l, h, r), (bsz, l, h, r)
        device = task_ids.device
        lora2dict = {
            'encoder': {'q.lora_A': ['encoder', 'lora_qa', 0],
                        'q.lora_B': ['encoder', 'lora_qb', 0],
                        'v.lora_A': ['encoder', 'lora_va', 0],
                        'v.lora_B': ['encoder', 'lora_vb', 0]},
            'EncDecAttention': {'q.lora_A': ['decoder', 'cross_attn', 'lora_qa', 0],
                                'q.lora_B': ['decoder', 'cross_attn', 'lora_qb', 0],
                                'v.lora_A': ['decoder', 'cross_attn', 'lora_va', 0],
                                'v.lora_B': ['decoder', 'cross_attn', 'lora_vb', 0]},
            'decoder': {'q.lora_A': ['decoder', 'decoder_attn', 'lora_qa', 0],
                        'q.lora_B': ['decoder', 'decoder_attn', 'lora_qb', 0],
                        'v.lora_A': ['decoder', 'decoder_attn', 'lora_va', 0],
                        'v.lora_B': ['decoder', 'decoder_attn', 'lora_vb', 0]},
        }
        for module_name in golden_lora_dict:
            for key in lora2dict:
                if key in module_name:
                    for kv in lora2dict[key]:
                        if kv in module_name:
                            if len(lora2dict[key][kv]) == 3:
                                l2_loss += l2_loss_fct(generate_lora_weights[lora2dict[key][kv][0]][lora2dict[key][kv][1]][0, lora2dict[key][kv][2], :, :],
                                                       golden_lora_dict[module_name].to(device).to(torch.bfloat16))
                                lora2dict[key][kv][2] += 1
                            elif len(lora2dict[key][kv]) == 4:
                                l2_loss += l2_loss_fct(generate_lora_weights[lora2dict[key][kv][0]][lora2dict[key][kv][1]][lora2dict[key][kv][2]][0, lora2dict[key][kv][3], :, :],
                                                       golden_lora_dict[module_name].to(device).to(torch.bfloat16))
                                lora2dict[key][kv][3] += 1
                    break

        return l2_loss

    @torch.no_grad()
    def generate(
        self,
        input_ids,
        attention_mask=None,
        labels=None,
        task_ids=None,
        demo_input_ids=None,
        demo_attention_mask=None,
        **kwargs,
    ):
        # assert one batch contains one task
        assert len(set(task_ids.tolist())) == 1

        # generate lora parameters
        generate_lora_weights = self.hyperlora(input_ids=demo_input_ids[0].unsqueeze(0),
                                               attention_mask=demo_attention_mask[0].unsqueeze(0))

        # apply to plm
        generate_output = self.down_plm.generate(input_ids=input_ids,
                                        attention_mask=attention_mask,
                                        generate_adapters=generate_lora_weights,
                                        **kwargs)
        return generate_output

    @torch.no_grad()
    def generate_lora_weights(
        self,
        demo_input_ids=None,
        demo_attention_mask=None
    ):
        # # generate lora parameters
        generate_lora_weights = self.hyperlora(input_ids=demo_input_ids[0].unsqueeze(0),
                                               attention_mask=demo_attention_mask[0].unsqueeze(0)
                                               )
        return generate_lora_weights

class HyperLoRAModelForFinetune(nn.Module):
    def __init__(
            self,
            config,
            model_args=None,
            lora_args=None,
            generate_adapters=None
        ):
        super().__init__()
        self.config = config
        self.model_args = model_args
        self.lora_args = lora_args

        # define downstream plm
        self.down_plm = T5ForConditionalGeneration.from_pretrained(model_args.model_name_or_path,
                                                                   adapter_config=lora_args,
                                                                   torch_dtype=torch.bfloat16)
        # setup lora weights
        self.down_plm.apply_ft_lora(generate_adapters)
        # to solve the error
        self.generation_config = self.down_plm.generation_config

        self.freeze_plm()

    def forward(
        self,
        input_ids,
        attention_mask=None,
        labels=None,
    ):
        # apply to plm
        output = self.down_plm(input_ids=input_ids,
                               attention_mask=attention_mask,
                               labels=labels)

        # compute loss
        lm_loss = output.loss
        lora_l2_loss = lm_loss
        loss = lm_loss

        return {'lm_loss': lm_loss, 'lora_l2_loss': lora_l2_loss, 'loss': loss, 'logits': output.logits}

    def freeze_plm(self):
        # freeze hyperlora encoder, decoder, plm_model, train lora weights
        for name, param in self.down_plm.named_parameters():
            if 'lora' in name:
                param.requires_grad = True
            else:
                param.requires_grad = False

    @torch.no_grad()
    def generate(
        self,
        input_ids,
        attention_mask=None,
        labels=None,
        **kwargs,
    ):
        # apply to plm
        generate_output = self.down_plm.generate(input_ids=input_ids,
                                        attention_mask=attention_mask,
                                        **kwargs)
        return generate_output



# config
class HyperLoRAConfig(PretrainedConfig):
    def __init__(
        self,
        num_downplm_layers=48,
        plm_hidden_size=1024,
        lora_rank=16,
        **kwargs,
    ):
        self.num_downplm_layers = num_downplm_layers
        self.plm_hidden_size = plm_hidden_size
        self.lora_rank = lora_rank

        super().__init__(**kwargs)

