import logging, copy
import numpy as np
import torch
import torch.nn as nn
from collections import namedtuple
from torch.nn import CrossEntropyLoss, NLLLoss
from transformers import MBartForConditionalGeneration, MT5ForConditionalGeneration, XLMRobertaModel
from transformers.modeling_outputs import BaseModelOutput, Seq2SeqLMOutput
from transformers.models.mbart.modeling_mbart import shift_tokens_right
import ipdb

logger = logging.getLogger(__name__)

class GenerativeModel(nn.Module):
    def __init__(self, config, tokenizer):
        super().__init__()
        self.config = config
        self.tokenizer = tokenizer
        logger.info(f'Loading pre-trained model {config.model_name}')
        if config.model_name == "facebook/mbart-large-50":
            self.model = MBartForConditionalGeneration.from_pretrained(config.model_name, cache_dir=config.cache_dir)
        elif config.model_name.startswith("google/mt5-"):
            self.model = MT5ForConditionalGeneration.from_pretrained(config.model_name, cache_dir=config.cache_dir)
        elif config.model_name.startswith("copy+google/mt5-"):
            model_name = config.model_name.split('copy+', 1)[1]
            self.model = MT5Copy.from_pretrained(model_name, cache_dir=config.cache_dir, output_attentions=True)
        elif config.model_name.startswith("copy+facebook/mbart-large-50"):
            model_name = config.model_name.split('copy+', 1)[1]
            self.model = MBartCopy.from_pretrained(model_name, cache_dir=config.cache_dir, output_attentions=True)
        elif config.model_name.startswith("xlm-roberta-"):
            self.model = XLMRDocoderModel(config, copy=False)
        elif config.model_name.startswith("copy+xlm-roberta-"):
            self.model = XLMRDocoderModel(config, copy=True)
        self.model.resize_token_embeddings(len(self.tokenizer))
        
        # initialization
        # for special_token in special_tokens:
        #     if not special_token.startswith("<"):
        #         continue
        #     m_token = special_token.split("--", 1)[1].split("--", 1)[0]
        #     m_ids = torch.LongTensor(self.tokenizer(m_token, add_special_tokens=False)["input_ids"])
        #     self.model.shared.weight.data[self.tokenizer.convert_tokens_to_ids(special_token)] = copy.deepcopy(self.model.shared(m_ids).mean(dim=0).detach()).data
            
    def forward(self, batch):
        outputs = self.model(input_ids=batch.enc_idxs, 
                             attention_mask=batch.enc_attn, 
                             decoder_input_ids=batch.dec_idxs, 
                             decoder_attention_mask=batch.dec_attn, 
                             labels=batch.lbl_idxs, 
                             return_dict=True)
        
        loss = outputs['loss']
        
        return loss
        
    def predict(self, batch, out_start_code, num_beams=1, max_length=50):
        self.eval()
        with torch.no_grad():
            if out_start_code == self.model.config.decoder_start_token_id:
                out_start_code = None
                
            if num_beams == 1:
                self.model._cache_input_ids = batch.enc_idxs
            else:
                expanded_return_idx = (
                    torch.arange(batch.enc_idxs.shape[0]).view(-1, 1).repeat(1, num_beams).view(-1).to(batch.enc_idxs.device)
                )
                input_ids = batch.enc_idxs.index_select(0, expanded_return_idx)
                self.model._cache_input_ids = input_ids
                
            outputs = self.model.generate(input_ids=batch.enc_idxs, 
                                          attention_mask=batch.enc_attn, 
                                          num_beams=num_beams, 
                                          max_length=max_length, 
                                          forced_bos_token_id=out_start_code)

#             if True:
#                 for bid in range(len(batch.enc_idxs)):
#                 for bid in range(len(batch.enc_idxs[:1])):
#                     print("------")
#                     print("---Input---")
#                     print(self.tokenizer.decode(batch.enc_idxs[bid], skip_special_tokens=True, clean_up_tokenization_spaces=True))
#                     print("---Output---")
#                     print(self.tokenizer.decode(outputs[bid], skip_special_tokens=True, clean_up_tokenization_spaces=True))
#                     print("---Gold---")
#                     print(batch)
#                     print(self.tokenizer.decode(batch.raw_lbl_idxs[bid], skip_special_tokens=True, clean_up_tokenization_spaces=True))
            
        final_output = []
        for bid in range(len(batch.enc_idxs)):
            output_sentence = self.tokenizer.decode(outputs[bid], skip_special_tokens=True, clean_up_tokenization_spaces=True)
            final_output.append(output_sentence)
        self.train()

        return final_output

class Prefix_fn_cls():
    def __init__(self, tokenizer, special_tokens, input_enc_idxs):
        self.tokenizer=tokenizer
        self.input_enc_idxs=input_enc_idxs
        self.special_ids = [element for l in self.tokenizer(special_tokens, add_special_tokens=False)['input_ids'] for element in l]
    def get(self, batch_id, previous_token):
        # get input
        inputs = list(set(self.input_enc_idxs[batch_id].tolist()))+self.special_ids
        return inputs

class XLMRDocoderModel(nn.Module):
    def __init__(self, config, copy):
        super().__init__()
        
        self.copy = copy
        
        if not self.copy:
            self.encoder = XLMRobertaModel.from_pretrained(config.model_name, cache_dir=config.cache_dir)
        else:
            model_name = config.model_name.split('copy+', 1)[1]
            self.encoder = XLMRobertaModel.from_pretrained(model_name, cache_dir=config.cache_dir)
            
        self.d_model = self.encoder.config.hidden_size
        # dim_feedforward = self.encoder.config.intermediate_size
        self.dim_feedforward = 1024
        self.n_layer = 8
        decoder_layer = MyTransformerDecoderLayer(d_model=self.d_model, nhead=8, dim_feedforward=self.dim_feedforward, dropout=0.1)
        decoder_norm = nn.LayerNorm(self.d_model)
        self.decoder = MyTransformerDecoder(decoder_layer, self.n_layer, decoder_norm)
        self._reset_parameters(self.decoder)
        
        if self.copy:
            self.linear_copy = nn.Linear(self.d_model, 1)
            self._reset_parameters(self.linear_copy)
        
        self.loss = nn.CrossEntropyLoss()

        self.config = namedtuple('config', field_names=['decoder_start_token_id'])
        self.config.decoder_start_token_id = self.encoder.config.bos_token_id
        
        
    def _reset_parameters(self, m):
        for p in m.parameters():
            if p.dim() > 1:
                nn.init.xavier_uniform_(p)
                
    def resize_token_embeddings(self, size):
        self.encoder.resize_token_embeddings(size)
        self.lm_head = nn.Linear(self.d_model, self.encoder.config.vocab_size)
        self._reset_parameters(self.lm_head)
        
    def generate_square_subsequent_mask(self, sz):
        mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)
        mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
        return mask
    
    def decode(self, dec_input_embeds, hidden_states, input_ids, dec_mask=None, dec_pad_mask=None):
        outputs, wt = self.decoder(dec_input_embeds, hidden_states, tgt_mask=dec_mask, tgt_key_padding_mask=dec_pad_mask)
        outputs = outputs.transpose(0, 1)
        
        if not self.copy:
            logits = self.lm_head(outputs)
        else:
            p_copy = torch.sigmoid(self.linear_copy(outputs))
            g_logits = self.lm_head(outputs)
            g_probs = torch.softmax(g_logits, dim=-1) * (1 - p_copy)
            
            c_idxs = input_ids.unsqueeze(1).repeat(1, wt.size(1), 1)
            gc_probs = torch.scatter_add(g_probs, 2, c_idxs, wt*p_copy)
            
            eps = 1e-7
            logits = torch.log(gc_probs+eps)
            
        return logits

    def forward(self, input_ids, attention_mask, decoder_input_ids, decoder_attention_mask, labels, return_dict):
        hidden_outputs = self.encoder(input_ids=input_ids, attention_mask=attention_mask, return_dict=True)
        hidden_states = hidden_outputs["last_hidden_state"]
        hidden_states = hidden_states.transpose(0, 1)
        
        dec_input_embeds = self.encoder.embeddings(input_ids=decoder_input_ids)
        dec_input_embeds = dec_input_embeds.transpose(0, 1)
        
        dec_mask = self.generate_square_subsequent_mask(decoder_attention_mask.size(1)).cuda()
        
        dec_pad_mask = torch.zeros(decoder_attention_mask.size(), dtype=bool).cuda()
        dec_pad_mask = dec_pad_mask.masked_fill(decoder_attention_mask == 0, True)
        
        logits = self.decode(dec_input_embeds, hidden_states, input_ids, dec_mask, dec_pad_mask)
        
        labels = labels.contiguous().view(-1)
        logits = logits.contiguous().view(-1, logits.size(-1))
        loss = self.loss(logits, labels)
        
        return_dict = {
            "loss": loss
        }
        
        return return_dict
    
    def generate(self, input_ids, attention_mask, num_beams=1, max_length=50, forced_bos_token_id=None):
        with torch.no_grad():
            # encoder inputs
            
            hidden_outputs = self.encoder(input_ids=input_ids, attention_mask=attention_mask, return_dict=True)
            hidden_states = hidden_outputs["last_hidden_state"]
            hidden_states = hidden_states.transpose(0, 1)

            batch_size = input_ids.size(0)

            gen_idxs = torch.zeros((batch_size, max_length), dtype=torch.long).cuda()

            for i in range(1, max_length):
                
                gen_embs = self.encoder.embeddings(input_ids=gen_idxs[:, :i]).transpose(0, 1)
                gen_mask = self.generate_square_subsequent_mask(i).cuda()
                
                outputs = self.decode(gen_embs, hidden_states, input_ids, gen_mask)
                outputs = outputs[:, -1, :]
                
                values, idx = torch.max(outputs, 1)
                gen_idxs[:, i] = idx

        gen_idxs = gen_idxs.cpu().numpy()

        for i in range(batch_size):
            eos_pos = np.where(gen_idxs[i]==self.encoder.config.eos_token_id)[0]
            if len(eos_pos) > 0:
                gen_idxs[i, eos_pos[0]+1:] = self.encoder.config.pad_token_id

        return gen_idxs
    
def _get_clones(module, N):
    return ModuleList([copy.deepcopy(module) for i in range(N)])

def _get_activation_fn(activation):
    if activation == "relu":
        return nn.functional.relu

    elif activation == "gelu":
        return nn.functional.gelu
    raise RuntimeError("activation should be relu/gelu, not {}".format(activation))

class MyTransformerDecoder(nn.Module):

    __constants__ = ['norm']

    def __init__(self, decoder_layer, num_layers, norm=None):
        super(MyTransformerDecoder, self).__init__()
        self.layers = _get_clones(decoder_layer, num_layers)
        self.num_layers = num_layers
        self.norm = norm

    def forward(self, tgt, memory, tgt_mask=None,
                memory_mask=None, tgt_key_padding_mask=None,
                memory_key_padding_mask=None):

        output = tgt

        for mod in self.layers:
            output, wt = mod(output, memory, tgt_mask=tgt_mask,
                         memory_mask=memory_mask,
                         tgt_key_padding_mask=tgt_key_padding_mask,
                         memory_key_padding_mask=memory_key_padding_mask)

        if self.norm is not None:
            output = self.norm(output)

        return output, wt

from torch.nn.modules.activation import MultiheadAttention
from torch.nn.modules.container import ModuleList
class MyTransformerDecoderLayer(nn.Module):

    def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, activation="relu"):

        super(MyTransformerDecoderLayer, self).__init__()
        self.self_attn = MultiheadAttention(d_model, nhead, dropout=dropout)
        self.multihead_attn = MultiheadAttention(d_model, nhead, dropout=dropout)

        # Implementation of Feedforward model
        self.linear1 = nn.Linear(d_model, dim_feedforward)
        self.dropout = nn.Dropout(dropout)
        self.linear2 = nn.Linear(dim_feedforward, d_model)

        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.norm3 = nn.LayerNorm(d_model)
        self.dropout1 = nn.Dropout(dropout)
        self.dropout2 = nn.Dropout(dropout)
        self.dropout3 = nn.Dropout(dropout)

        self.activation = _get_activation_fn(activation)

    def __setstate__(self, state):

        if 'activation' not in state:
            state['activation'] = nn.functional.relu
        super(MyTransformerDecoderLayer, self).__setstate__(state)

    def forward(self, tgt, memory, tgt_mask=None, memory_mask=None,
                tgt_key_padding_mask=None, memory_key_padding_mask=None):

        tgt2 = self.self_attn(tgt, tgt, tgt, attn_mask=tgt_mask, key_padding_mask=tgt_key_padding_mask)[0]
        tgt = tgt + self.dropout1(tgt2)
        tgt = self.norm1(tgt)
        tgt2, wt = self.multihead_attn(tgt, memory, memory, attn_mask=memory_mask, key_padding_mask=memory_key_padding_mask)    
        tgt = tgt + self.dropout2(tgt2)
        tgt = self.norm2(tgt)
        tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt))))
        tgt = tgt + self.dropout3(tgt2)
        tgt = self.norm3(tgt)

        return tgt, wt

class MT5Copy(MT5ForConditionalGeneration):
    
    def __init__(self, config):
        super().__init__(config)
        self.linear_copy = nn.Linear(self.model_dim, 1)
        #self.linear_copy = nn.Linear(2*self.model_dim, 1)
        #self.copy_attention = nn.MultiheadAttention(self.model_dim, 1)

    def forward(
        self,
        input_ids=None,
        attention_mask=None,
        decoder_input_ids=None,
        decoder_attention_mask=None,
        head_mask=None,
        decoder_head_mask=None,
        cross_attn_head_mask=None,
        encoder_outputs=None,
        past_key_values=None,
        inputs_embeds=None,
        decoder_inputs_embeds=None,
        labels=None,
        use_cache=None,
        output_attentions=None,
        output_hidden_states=None,
        return_dict=None,
    ):
        use_cache = use_cache if use_cache is not None else self.config.use_cache
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        # FutureWarning: head_mask was separated into two input args - head_mask, decoder_head_mask
        if head_mask is not None and decoder_head_mask is None:
            if self.config.num_layers == self.config.num_decoder_layers:
                warnings.warn(__HEAD_MASK_WARNING_MSG, FutureWarning)
                decoder_head_mask = head_mask

        # Encode if needed (training, first prediction pass)
        if encoder_outputs is None:
            # Convert encoder inputs in embeddings if needed
            encoder_outputs = self.encoder(
                input_ids=input_ids,
                attention_mask=attention_mask,
                inputs_embeds=inputs_embeds,
                head_mask=head_mask,
                output_attentions=output_attentions,
                output_hidden_states=output_hidden_states,
                return_dict=return_dict,
            )
        elif return_dict and not isinstance(encoder_outputs, BaseModelOutput):
            encoder_outputs = BaseModelOutput(
                last_hidden_state=encoder_outputs[0],
                hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None,
                attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None,
            )

        hidden_states = encoder_outputs[0]

        if self.model_parallel:
            torch.cuda.set_device(self.decoder.first_device)

        if labels is not None and decoder_input_ids is None and decoder_inputs_embeds is None:
            # get decoder inputs from shifting lm labels to the right
            decoder_input_ids = self._shift_right(labels)

        # If decoding with past key value states, only the last tokens
        # should be given as an input
        if past_key_values is not None:
            assert labels is None, "Decoder should not use cached key value states when training."
            if decoder_input_ids is not None:
                decoder_input_ids = decoder_input_ids[:, -1:]
            if decoder_inputs_embeds is not None:
                decoder_inputs_embeds = decoder_inputs_embeds[:, -1:]

        # Set device for model parallelism
        if self.model_parallel:
            torch.cuda.set_device(self.decoder.first_device)
            hidden_states = hidden_states.to(self.decoder.first_device)
            if decoder_input_ids is not None:
                decoder_input_ids = decoder_input_ids.to(self.decoder.first_device)
            if attention_mask is not None:
                attention_mask = attention_mask.to(self.decoder.first_device)
            if decoder_attention_mask is not None:
                decoder_attention_mask = decoder_attention_mask.to(self.decoder.first_device)

        # Decode
        decoder_outputs = self.decoder(
            input_ids=decoder_input_ids,
            attention_mask=decoder_attention_mask,
            inputs_embeds=decoder_inputs_embeds,
            past_key_values=past_key_values,
            encoder_hidden_states=hidden_states,
            encoder_attention_mask=attention_mask,
            head_mask=decoder_head_mask,
            cross_attn_head_mask=cross_attn_head_mask,
            use_cache=use_cache,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )

        sequence_output = decoder_outputs[0]

        # Set device for model parallelism
        if self.model_parallel:
            torch.cuda.set_device(self.encoder.first_device)
            self.lm_head = self.lm_head.to(self.encoder.first_device)
            sequence_output = sequence_output.to(self.lm_head.weight.device)

        if self.config.tie_word_embeddings:
            # Rescale output before projecting on vocab
            # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/transformer.py#L586
            sequence_output = sequence_output * (self.model_dim ** -0.5)
        
        if input_ids is None:
            input_ids = self._cache_input_ids # batch x sequence_length
        try:
            assert input_ids.size(0) == hidden_states.size(0)
        except:
            ipdb.set_trace()
            
        lm_logits = self.lm_head(sequence_output) 
        
        # Copy distribution
        cross_attentions = decoder_outputs['cross_attentions'][-1] # batch x head x decoder_length x encoder_length
        cross_attentions = torch.mean(cross_attentions, dim=1) # batch x decoder_length x encoder_length
        #our_attn_mask = (attention_mask.unsqueeze(1).expand(-1, sequence_output.size(1), -1) == 0)
        #attn_output, attn_output_weights = self.copy_attention(sequence_output.transpose(1,0), hidden_states.transpose(1,0), hidden_states.transpose(1,0), attn_mask=our_attn_mask)

        # Probability of copying
        p_copy = torch.sigmoid(self.linear_copy(sequence_output))
        #p_copy = torch.sigmoid(self.linear_copy(torch.cat([sequence_output, attn_output.transpose(1,0)], dim=-1)))
        
        # Merge distribution
        original_word_pro = torch.softmax(lm_logits, dim=-1) * (1 - p_copy) #[batch, sequence_length, vocab_size]
        copy_words = input_ids.unsqueeze(1).repeat(1, cross_attentions.size(1), 1) #(batch, target_length, encoder_length)
        lm_logits = torch.scatter_add(original_word_pro, 2, copy_words, cross_attentions*p_copy)
        #copy_words = input_ids.unsqueeze(1).repeat(1, attn_output_weights.size(1), 1) #(batch, target_length, encoder_length)
        #lm_logits = torch.scatter_add(original_word_pro, 2, copy_words, attn_output_weights*p_copy)
        
        eps = 1e-7
        lm_logits = torch.log(lm_logits+eps)

        loss = None
        if labels is not None:
            #loss_fct = CrossEntropyLoss(ignore_index=-100)
            loss_fct = NLLLoss(ignore_index=-100)
            loss = loss_fct(lm_logits.view(-1, lm_logits.size(-1)), labels.view(-1))
            # TODO(thom): Add z_loss https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L666

        if not return_dict:
            output = (lm_logits,) + decoder_outputs[1:] + encoder_outputs
            return ((loss,) + output) if loss is not None else output

        return Seq2SeqLMOutput(
            loss=loss,
            logits=lm_logits,
            past_key_values=decoder_outputs.past_key_values,
            decoder_hidden_states=decoder_outputs.hidden_states,
            decoder_attentions=decoder_outputs.attentions,
            cross_attentions=decoder_outputs.cross_attentions,
            encoder_last_hidden_state=encoder_outputs.last_hidden_state,
            encoder_hidden_states=encoder_outputs.hidden_states,
            encoder_attentions=encoder_outputs.attentions,
        )

class MBartCopy(MBartForConditionalGeneration):
    
    def __init__(self, config):
        super().__init__(config)
        self.linear_copy = nn.Linear(config.d_model, 1)
    
    def forward(
        self,
        input_ids=None,
        attention_mask=None,
        decoder_input_ids=None,
        decoder_attention_mask=None,
        head_mask=None,
        decoder_head_mask=None,
        cross_attn_head_mask=None,
        encoder_outputs=None,
        past_key_values=None,
        inputs_embeds=None,
        decoder_inputs_embeds=None,
        labels=None,
        use_cache=None,
        output_attentions=None,
        output_hidden_states=None,
        return_dict=None,
    ):
        r"""
        labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
            Labels for computing the masked language modeling loss. Indices should either be in ``[0, ...,
            config.vocab_size]`` or -100 (see ``input_ids`` docstring). Tokens with indices set to ``-100`` are ignored
            (masked), the loss is only computed for the tokens with labels in ``[0, ..., config.vocab_size]``.

        Returns:

        """
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        if labels is not None:
            if decoder_input_ids is None:
                decoder_input_ids = shift_tokens_right(labels, self.config.pad_token_id)

        outputs = self.model(
            input_ids,
            attention_mask=attention_mask,
            decoder_input_ids=decoder_input_ids,
            encoder_outputs=encoder_outputs,
            decoder_attention_mask=decoder_attention_mask,
            head_mask=head_mask,
            decoder_head_mask=decoder_head_mask,
            cross_attn_head_mask=cross_attn_head_mask,
            past_key_values=past_key_values,
            inputs_embeds=inputs_embeds,
            decoder_inputs_embeds=decoder_inputs_embeds,
            use_cache=use_cache,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )
        lm_logits = self.lm_head(outputs[0]) + self.final_logits_bias
        
        if input_ids is None:
            input_ids = self._cache_input_ids # batch x sequence_length
        try:
            assert input_ids.size(0) == outputs['encoder_last_hidden_state'].size(0)
        except:
            ipdb.set_trace()

        # Copy distribution
        cross_attentions = outputs['cross_attentions'][-1] 
        cross_attentions = torch.mean(cross_attentions, dim=1) # batch x decoder_length x encoder_length

        # Probability of copying
        p_copy = torch.sigmoid(self.linear_copy(outputs['last_hidden_state']))
        
        # Merge distribution
        original_word_pro = torch.softmax(lm_logits, dim=-1) * (1 - p_copy) #[batch, sequence_length, vocab_size]
        copy_words = input_ids.unsqueeze(1).repeat(1, cross_attentions.size(1), 1) #(batch, target_length, encoder_length)
        lm_logits = torch.scatter_add(original_word_pro, 2, copy_words, cross_attentions*p_copy)
        
        eps = 1e-7
        lm_logits = torch.log(lm_logits+eps)

        masked_lm_loss = None
        if labels is not None:
            #loss_fct = CrossEntropyLoss()
            loss_fct = NLLLoss(ignore_index=-100)
            masked_lm_loss = loss_fct(lm_logits.view(-1, self.config.vocab_size), labels.view(-1))

        if not return_dict:
            output = (lm_logits,) + outputs[1:]
            return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output

        return Seq2SeqLMOutput(
            loss=masked_lm_loss,
            logits=lm_logits,
            past_key_values=outputs.past_key_values,
            decoder_hidden_states=outputs.decoder_hidden_states,
            decoder_attentions=outputs.decoder_attentions,
            cross_attentions=outputs.cross_attentions,
            encoder_last_hidden_state=outputs.encoder_last_hidden_state,
            encoder_hidden_states=outputs.encoder_hidden_states,
            encoder_attentions=outputs.encoder_attentions,
        )
