import torch
import torch.nn.functional as F
from torch import nn

from transformers import T5Tokenizer, T5Config
from .modeling_t5 import T5ForConditionalGeneration, T5PromptForConditionalGeneration

# adaption with t5
class LanguageT5Model(nn.Module):
    def __init__(self,
                 text_encoder=None,
                 tokenizer=None,
                 use_prompt=False,
                 only_prompt=False,
                 fix_prompt=False,
                 prompt_config=None,
                 fix_word_embeddings=False,
                 use_bitfit=False
                 ):
        super().__init__()

        self.tokenizer = tokenizer
        # config & model init
        t5_config = T5Config.from_pretrained(text_encoder)
        if not use_prompt:
            self.text_encoder = T5ForConditionalGeneration.from_pretrained(text_encoder, config=t5_config)
            if use_bitfit:
                self.text_encoder.tune_only_bias()
        else:
            t5_config.update({'prompt_config': prompt_config})
            self.text_encoder = T5PromptForConditionalGeneration.from_pretrained(text_encoder, config=t5_config)
            if only_prompt:
                self.text_encoder.tune_only_prompt()
            if fix_prompt:
                self.text_encoder.fix_prompt()
        # fix word embedding                
        if fix_word_embeddings:
            for name, para in self.text_encoder.shared.named_parameters():
                print('[FIX PARA]', name)
                para.requires_grad = False
            for name, para in self.text_encoder.encoder.embed_tokens.named_parameters():
                print('[FIX PARA]', name)
                para.requires_grad = False
            for name, para in self.text_encoder.decoder.embed_tokens.named_parameters():
                print('[FIX PARA]', name)
                para.requires_grad = False
        # count trainable parameters
        num_params = 0
        for name, para in self.named_parameters():
            if para.requires_grad == True:
                num_params += para.numel()
        print('[INFO] Finetune para num:', num_params / 1e6)

    def forward(self, input_ids, attention_mask, targets=None, train=True):
        if train:
            self.train()
            # get loss
            text_output = self.text_encoder(input_ids=input_ids, 
                                            attention_mask=attention_mask, 
                                            labels=targets,
                                            return_dict=True,
                                            )
            loss = text_output.loss
            return loss
        else:
            self.eval()
            encoder_outputs = self.text_encoder.get_encoder_output(
                                            input_ids=input_ids,
                                            attention_mask=attention_mask
                                        )
            # generate output
            output_sequences = self.text_encoder.generate(
                                            input_ids=input_ids,
                                            attention_mask=attention_mask,
                                            max_new_tokens=512,
                                            do_sample=False,  # disable sampling to test if batching affects output
                                            encoder_outputs=encoder_outputs
                                        )
            prediction = self.tokenizer.batch_decode(output_sequences, skip_special_tokens=True)
            return prediction

if __name__ == "__main__":
    import warnings
    warnings.filterwarnings('ignore')
    
    seed = 42
    torch.manual_seed(seed)

    text_encoder = 't5-small'
    tokenizer = T5Tokenizer.from_pretrained(text_encoder)
    model = LanguageT5Model(text_encoder,
                          tokenizer=tokenizer,
                          use_prompt=True, 
                          only_prompt=True, 
                          prompt_config='../../configs/config_prompt_p20.json')
    model.eval()

    input_sequences = ['translate English to German: The house is wonderful.', 
                       'translate English to German: I like to work in NYC.']
    encoding = tokenizer(
            input_sequences,
            padding="longest",
            max_length=512,
            truncation=True,
            return_tensors="pt",
    )
    target_encoding = tokenizer(
            ['output', 'output a'],
            padding="longest",
            max_length=128,
            truncation=True,
            return_tensors="pt",
        )
    labels = target_encoding.input_ids
    input_ids, attention_mask = encoding.input_ids, encoding.attention_mask

    output = model(input_ids, attention_mask, labels, train=True)
    print(output)

    output = model(input_ids, attention_mask, train=False)
    print(output)
