import torch
import torch.nn.functional as F
from torch import nn

from transformers import Qwen2Tokenizer, Qwen2Config
from .modeling_qwen1_5 import Qwen2Model, Qwen2PromptModel

# feedback with qwen1_5
class LanguageQwenModel_MT(nn.Module):
    def __init__(self,
                 config,
                 text_encoder=None,
                 use_prompt=False,
                 only_prompt=False,
                 fix_prompt=False,
                 fix_prompt_pre_round=False,
                 prompt_config=None,
                 fix_word_embeddings=False,
                 multitask_train_prompt=False
                 ):
        super().__init__()

        self.tasks_pre_round = config['tasks']['pre_round']
        self.tasks_curr_round = config['tasks']['curr_round']
        self.task_info = config['task_info']
        self.total_task = self.tasks_pre_round + self.tasks_curr_round

        # config & model init
        qwen1_5_config = Qwen2Config.from_pretrained(text_encoder)
        if not use_prompt:
            self.text_encoder = Qwen2Model.from_pretrained(text_encoder, config=qwen1_5_config)
        else:
            qwen1_5_config.update({'prompt_config': prompt_config})
            if multitask_train_prompt:
                qwen1_5_config.update({'multitask_train_prompt': multitask_train_prompt})
            self.text_encoder = Qwen2PromptModel.from_pretrained(text_encoder, config=qwen1_5_config, tasks=self.total_task)
            if only_prompt:
                self.text_encoder.tune_only_prompt()
            if fix_prompt:
                self.text_encoder.fix_prompt()
            if fix_prompt_pre_round:
                self.text_encoder.fix_prompt_by_task(self.tasks_pre_round)

        # cls heads for route
        text_width = self.text_encoder.config.hidden_size
        self.heads = [nn.Linear(text_width, self.task_info[t]['label_num']) for t in self.total_task]
        self.heads = nn.ModuleList(self.heads)

        self.task2head_map = {t: self.heads[i] for i, t in enumerate(self.total_task)}
        self.task2head_index = {t: i for i, t in enumerate(self.total_task)}

        # fix word embedding 
        if fix_word_embeddings:
            for name, para in self.text_encoder.embed_tokens.named_parameters():
                print('[FIX PARA]', name)
                para.requires_grad = False

        # count trainable parameters
        num_params = 0
        for name, para in self.named_parameters():
            if para.requires_grad == True:
                num_params += para.numel()
        print('[INFO] Finetune para num:', num_params / 1e6)

    def forward(self, text, tasks, targets=None, train=True):
        if train:
            self.train()
            # get output embedding
            text_output = self.text_encoder(text.input_ids,
                                            attention_mask=text.attention_mask,
                                            return_dict=True,
                                            tasks=tasks
                                            )
            text_embeds = text_output.last_hidden_state
            sequence_lengths = torch.eq(text.input_ids, self.text_encoder.config.pad_token_id).int().argmax(-1) - 1
            sequence_lengths = sequence_lengths % text.input_ids.shape[-1]
            sequence_lengths = sequence_lengths.to(text_embeds.device)
            # 取最后一个text embed，跟qwen源码保持一致
            batch_size = len(text_embeds)
            text_embeds = text_embeds[torch.arange(batch_size, device=text_embeds.device), sequence_lengths]

            # route for forward
            text_embeds_batch = self.reform_batch(text_embeds, tasks)
            targets_batch = self.reform_batch(targets, tasks)

            loss = 0.0
            for key in text_embeds_batch:
                if len(text_embeds_batch[key]) == 0:
                    continue
                # get cls result
                prediction = self.task2head_map[key](text_embeds_batch[key])
                if key == 'stsb':
                    prediction = prediction.clip(0., 5.)
                    target = targets_batch[key].to(torch.float)
                    if len(target.shape) == 1:
                        target = target.view(-1, 1)
                    loss += F.mse_loss(prediction, target)
                else:
                    target = targets_batch[key].to(torch.long)
                    loss += F.cross_entropy(prediction, target)
            return loss
        else:
            self.eval()
            text_output = self.text_encoder(text.input_ids,
                                            attention_mask=text.attention_mask,
                                            return_dict=True,
                                            tasks=tasks
                                            )
            text_embeds = text_output.last_hidden_state
            sequence_lengths = torch.eq(text.input_ids, self.text_encoder.config.pad_token_id).int().argmax(-1) - 1
            sequence_lengths = sequence_lengths % text.input_ids.shape[-1]
            sequence_lengths = sequence_lengths.to(text_embeds.device)
            # 取最后一个text embed，跟qwen源码保持一致
            batch_size = len(text_embeds)
            text_embeds = text_embeds[torch.arange(batch_size, device=text_embeds.device), sequence_lengths]

            text_embeds_batch = self.reform_batch(text_embeds, tasks)
            targets_batch = self.reform_batch(targets, tasks)

            results = []
            # batch_size = text.input_ids.shape[0]
            for key in text_embeds_batch:
                if len(text_embeds_batch[key]) == 0:
                    continue
                prediction = self.task2head_map[key](text_embeds_batch[key])
                if key == 'stsb':
                    prediction = prediction.clip(0., 5.)
                    pred_class = prediction
                else:
                    _, pred_class = prediction.max(1)
                targets = targets_batch[key]
                for i in range(len(targets)):
                    results.append({"prediction": pred_class[i].item(), "target": targets[i].item(), "task": key})
            return results
    
    # route tesnfor according to task name
    def reform_batch(self, input_tensor, tasks):
        text_batch = {t: [] for t in tasks}
        for input_t, task in zip(input_tensor, tasks):
            text_batch[task].append(input_t)
        for key in text_batch:
            if len(text_batch[key]) != 0:
                text_batch[key] = torch.stack(text_batch[key])
        return text_batch

