import torch
import torch.nn.functional as F
from torch import nn

from transformers import T5Tokenizer, T5Config
from .modeling_t5 import T5ForConditionalGeneration, T5PromptForConditionalGeneration

# feedback with t5
class LanguageT5Model_MT(nn.Module):
    def __init__(self,
                 config=None,
                 text_encoder=None,
                 tokenizer=None,
                 use_prompt=False,
                 only_prompt=False,
                 fix_prompt=False,
                 fix_prompt_pre_round=False,
                 prompt_config=None,
                 fix_word_embeddings=False,
                 multitask_train_prompt=False
                 ):
        super().__init__()

        self.tasks_pre_round = config['tasks']['pre_round']
        self.tasks_curr_round = config['tasks']['curr_round']

        self.tokenizer = tokenizer
        # config & model init
        t5_config = T5Config.from_pretrained(text_encoder)
        if not use_prompt:
            self.text_encoder = T5ForConditionalGeneration.from_pretrained(text_encoder, config=t5_config)
        else:
            t5_config.update({'prompt_config': prompt_config})
            if multitask_train_prompt:
                t5_config.update({'multitask_train_prompt': multitask_train_prompt})
            self.text_encoder = T5PromptForConditionalGeneration.from_pretrained(text_encoder, config=t5_config, tasks=self.tasks_pre_round + self.tasks_curr_round)
            if only_prompt:
                self.text_encoder.tune_only_prompt()
            if fix_prompt:
                self.text_encoder.fix_prompt()
            if fix_prompt_pre_round:
                self.text_encoder.fix_prompt_by_task(self.tasks_pre_round)
        # fix word embedding 
        if fix_word_embeddings:
            for name, para in self.text_encoder.shared.named_parameters():
                print('[FIX PARA]', name)
                para.requires_grad = False
            for name, para in self.text_encoder.encoder.embed_tokens.named_parameters():
                print('[FIX PARA]', name)
                para.requires_grad = False
            for name, para in self.text_encoder.decoder.embed_tokens.named_parameters():
                print('[FIX PARA]', name)
                para.requires_grad = False
        # count trainable parameters
        num_params = 0
        for name, para in self.named_parameters():
            if para.requires_grad == True:
                num_params += para.numel()
        print('[INFO] Finetune para num:', num_params / 1e6)

    def forward(self, input_ids, attention_mask, targets=None, tasks=None, train=True):
        if train:
            self.train()
            # get loss
            text_output = self.text_encoder(input_ids=input_ids, 
                                            attention_mask=attention_mask, 
                                            labels=targets,
                                            return_dict=True,
                                            tasks=tasks
                                            )
            loss = text_output.loss
            return loss
        else:
            self.eval()
            # get embeddings
            encoder_outputs = self.text_encoder.get_encoder_output(
                                            input_ids=input_ids,
                                            attention_mask=attention_mask,
                                            tasks=tasks
                                        )
            # get output
            output_sequences = self.text_encoder.generate(
                                            input_ids=input_ids,
                                            attention_mask=attention_mask,
                                            do_sample=False,  # disable sampling to test if batching affects output
                                            max_new_tokens=20,
                                            encoder_outputs=encoder_outputs
                                        )
            prediction = self.tokenizer.batch_decode(output_sequences, skip_special_tokens=True)
            return prediction

