import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import BlenderbotForConditionalGeneration, AutoModelForSeq2SeqLM, BlenderbotConfig
import os

class Model(nn.Module):
    '''
    original model. trained with MLE
    '''
    def __init__(self, args):
        super(Model, self).__init__()

        self.tau = args.tau

        self.bot_model = AutoModelForSeq2SeqLM.from_pretrained(args.bert_model)

        # output_config_file = os.path.join(args.bert_model, 'config.json')
        # config = BlenderbotConfig.from_json_file(output_config_file)
        # self.bot_model = BlenderbotForConditionalGeneration(config)


    def hidden2vocab(self, hiddens):
        hiddens = self.bot_model.lm_head(hiddens) + self.bot_model.final_logits_bias # [b, t, v]
        return hiddens

    def forward(self, model_inputs):

        input_ids = model_inputs['input_ids']
        attention_mask = model_inputs['attention_mask']
        decoder_input_ids = model_inputs['decoder_input_ids']
        decoder_attention_mask = model_inputs['decoder_attention_mask']
        labels = model_inputs['labels']

        batch_size = input_ids.size(0)
        encoder = self.bot_model.get_encoder()
        decoder = self.bot_model.get_decoder()
        encoder_outputs = encoder(input_ids=input_ids,
                            attention_mask=attention_mask,
                            return_dict=True
                            )
        hidden_states = encoder_outputs['last_hidden_state'] # [b, t, d]

        decoder_outputs = decoder(
            input_ids=decoder_input_ids,
            attention_mask=decoder_attention_mask,
            encoder_hidden_states=hidden_states,
            encoder_attention_mask=attention_mask,
        )

        sequence_output = decoder_outputs[0] # [b, t, d]

        lm_logits = self.hidden2vocab(sequence_output)

        vocab_size = lm_logits.size(-1)
        loss_fct = nn.CrossEntropyLoss()

        nll = loss_fct(lm_logits.view(-1,vocab_size), labels.view(-1))
        loss = nll
        
        return loss
