import copy
import ipdb
import torch
from torch import Tensor, device, nn
from torch.nn import CrossEntropyLoss
from transformers import AutoTokenizer, AutoModelWithLMHead, AutoModelForMaskedLM, AutoModelForCausalLM, AutoModel, BartForConditionalGeneration
from worldformer2.tokenization.custom_tokenizers import get_tokenizer
import logging
from worldformer2.tools.logging_util import basic_logging
basic_logging()

class BartModel(nn.Module):
    def __init__(self, config, **kwargs):
        super().__init__()
        self.config = config
        self.register_components()
    
    def register_components(self):
        """
        bart
        """
        #ipdb.set_trace()
    
        self.register_bart()
        self.register_aggregator()
        if self.config.tie_embeddings:
            self.tie_embeddings()
        #ipdb.set_trace()

        logging.info(f"Total params: {self.count_parameters(self)}")


    def tie_embeddings(self):
        #ipdb.set_trace()
        self.graph_encoder_decoder.model.shared.weight = self.text_encoder_decoder.model.shared.weight
        self.graph_encoder_decoder.lm_head.weight = self.text_encoder_decoder.lm_head.weight
        #self.graph_decoder.embed_tokens.weight = self.text_encoder.embed_tokens.weight
        #self.graph_encoder_decoder.lm_head.weight = self.text_encoder.embed_tokens.weight

        #self.text_encoder_decoder.model.encoder.embed_tokens.weight == self.text_encoder_decoder.model.decoder.embed_tokens.weight 

        #self.action_decoder.lm_head.weight = self.graph_encoder.embeddings.word_embeddings.weight
        #self.action_decoder.transformer.wte.weight = self.graph_encoder.embeddings.word_embeddings.weight

        #self.text_encoder.embeddings.word_embeddings.weight = self.graph_encoder.embeddings.word_embeddings.weight
        #self.graph_encoder.embeddings.word_embeddings.weight = self.text_encoder.embeddings.word_embeddings.weight        


    def register_bart(self):
        logging.info("Setting up bart encoder decoder")
        self.text_encoder_decoder = BartForConditionalGeneration.from_pretrained("facebook/bart-base")
        self.text_encoder_decoder.resize_token_embeddings(self.config.input_text_n_vocab)
        self.text_encoder = self.text_encoder_decoder.get_encoder()
        self.action_decoder = self.text_encoder_decoder.get_decoder()


        self.graph_encoder_decoder = BartForConditionalGeneration.from_pretrained("facebook/bart-base")
        self.graph_encoder_decoder.resize_token_embeddings(self.config.input_text_n_vocab)
        self.graph_encoder = self.graph_encoder_decoder.get_encoder()
        self.graph_decoder = self.graph_encoder_decoder.get_decoder()

    def register_aggregator(self):
        # TODO option for pretrained/from scratch
        tmp_model = AutoModel.from_pretrained("distilbert-base-cased")
        new_config = copy.deepcopy(tmp_model.config)

        # Hparams from paper
        new_config.max_position_embeddings = 2048
        new_config.num_hidden_layers = 2
        new_config.n_heads = 2
        new_config.hidden_dim = 4096

        self.aggregator = AutoModel.from_config(new_config)



    def forward(self,
                    text_encoder_input_ids=None,
                    text_encoder_attention_mask=None,
                    graph_encoder_input_ids=None,
                    graph_encoder_attention_mask=None,
                    action_decoder_input_ids=None,
                    action_decoder_attention_mask=None,
                    graph_decoder_input_ids=None,
                    graph_decoder_attention_mask=None,
                    action_state_token_ids=None,
                    graph_state_token_ids=None
                   ):

        """
        output = self.encoder_decoder(input_ids=text_encoder_input_ids,
                                      attention_mask=text_encoder_attention_mask,
                                      decoder_input_ids=action_decoder_input_ids,
                                      decoder_attention_mask=action_decoder_attention_mask,
                                     )
        """
        return_dict = False
        output_hidden_states = True
        head_mask = None
        inputs_embeds = None
        output_attentions = False

        action_logits = None
        graph_logits = None

        if self.config.task_mode in ['full_multitask', 'action_only', 'action_multi', 'graph_multi']:

            text_encoder_outputs = self.text_encoder(input_ids=text_encoder_input_ids,
                                                     attention_mask=text_encoder_attention_mask,
                                                     head_mask=head_mask,
                                                     inputs_embeds=inputs_embeds,
                                                     output_attentions=output_attentions,
                                                     output_hidden_states=output_hidden_states,
                                                     return_dict=return_dict,
                                                    )
        if self.config.task_mode in ['full_multitask', 'graph_only', 'action_multi', 'graph_multi']:

            graph_encoder_outputs = self.graph_encoder(input_ids=graph_encoder_input_ids,
                                                     attention_mask=graph_encoder_attention_mask,
                                                     head_mask=head_mask,
                                                     inputs_embeds=inputs_embeds,
                                                     output_attentions=output_attentions,
                                                     output_hidden_states=output_hidden_states,
                                                     return_dict=return_dict,
                                                    )
        
        if self.config.task_mode in ['full_multitask', 'action_multi', 'graph_multi']:
            text_last_hidden = text_encoder_outputs[0]
            graph_last_hidden = graph_encoder_outputs[0]

            # Aggregator
            #aggregator_inputs_template = text_last_hidden.new(text_last_hidden.shape[0], 2048, text_last_hidden.shape[2] )
            #aggregator_inputs_template[:,:text_last_hidden.shape[1]] = text_last_hidden 
            #aggregator_inputs_template[:,1024:1024+graph_last_hidden.shape[1]] = graph_last_hidden 
            #aggregator_inputs = aggregator_inputs_template
            aggregator_inputs = torch.cat([text_last_hidden, graph_last_hidden], 1)
            #position_ids = torch.arange(aggregator_inputs.size(1), dtype=torch.long, device=aggregator_inputs.device)
            #position_ids = position_ids.unsqueeze(0).expand(aggregator_inputs.size()[:2])
            #position_embeddings = self.aggregator.embeddings.position_embeddings(position_ids)
            #aggregator_inputs = aggregator_inputs + position_embeddings
            aggregator_input_mask = torch.cat([text_encoder_attention_mask, graph_encoder_attention_mask], 1)
            aggregator_outputs = self.aggregator(inputs_embeds=aggregator_inputs, attention_mask=aggregator_input_mask)
            aggregator_outputs_sum = torch.sum(aggregator_input_mask.unsqueeze(2) * aggregator_outputs['last_hidden_state'], 1)
            state_vec = aggregator_outputs_sum / (aggregator_input_mask.sum(1)).unsqueeze(1)

        
        if self.config.task_mode in ['full_multitask', 'action_only', 'action_multi']:

            decoder_input = text_encoder_outputs[0]

            decoder_head_mask = None
            cross_attn_head_mask = None
            past_key_values = None
            decoder_inputs_embeds = None
            use_cache = False
            action_decoder_outputs = self.action_decoder(input_ids=action_decoder_input_ids,
                                                         attention_mask=action_decoder_attention_mask,
                                                         encoder_hidden_states=decoder_input,
                                                         encoder_attention_mask=text_encoder_attention_mask,
                                                         head_mask=decoder_head_mask,
                                                         cross_attn_head_mask=cross_attn_head_mask,
                                                         past_key_values=past_key_values,
                                                         inputs_embeds=decoder_inputs_embeds,
                                                         use_cache=use_cache,
                                                         output_attentions=output_attentions,
                                                         output_hidden_states=output_hidden_states,
                                                         return_dict=return_dict,
                                                         state_vector=state_vec,
                                                        )

            action_logits = self.text_encoder_decoder.lm_head(action_decoder_outputs[0]) + self.text_encoder_decoder.final_logits_bias

        if self.config.task_mode in ['full_multitask', 'graph_only', 'graph_multi']:

            decoder_input = graph_encoder_outputs[0]

            decoder_head_mask = None
            cross_attn_head_mask = None
            past_key_values = None
            decoder_inputs_embeds = None
            use_cache = False
            graph_decoder_outputs = self.graph_decoder(input_ids=graph_decoder_input_ids,
                                                         attention_mask=graph_decoder_attention_mask,
                                                         encoder_hidden_states=decoder_input,
                                                         encoder_attention_mask=graph_encoder_attention_mask,
                                                         head_mask=decoder_head_mask,
                                                         cross_attn_head_mask=cross_attn_head_mask,
                                                         past_key_values=past_key_values,
                                                         inputs_embeds=decoder_inputs_embeds,
                                                         use_cache=use_cache,
                                                         output_attentions=output_attentions,
                                                         output_hidden_states=output_hidden_states,
                                                         return_dict=return_dict,
                                                         state_vector=state_vec,
                                                        )

            graph_logits = self.graph_encoder_decoder.lm_head(graph_decoder_outputs[0]) + self.graph_encoder_decoder.final_logits_bias

        #action_logits = output['logits']
        #ipdb.set_trace()


        output = {'action_logits': action_logits,
                  'graph_logits': graph_logits,
                 }

        return output

    # utils
    def count_parameters(self, model):
        return sum(p.numel() for p in model.parameters() if p.requires_grad)



"""
class Aggregator(nn.Module):
    def __init__(self, config, **kwargs):
        super().__init__()
        self.config = config
        ipdb.set_trace()
"""

if __name__=="__main__":

    # tests
    tmp = Worldformer(1)


