

from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union

import numpy
import torch
from torch import nn
import torch.nn.functional as F

from torch.nn.modules.linear import Linear

from transformers import BertModel, BertConfig, BertTokenizer
from transformers.models.bert.modeling_bert import BertOnlyMLMHead

from transformers import RobertaModel, RobertaConfig 
from transformers.models.roberta.modeling_roberta import RobertaLMHead

#from BERT_emberder import PretrainedBertEmbedder
#from transformers.modeling_outputs import BaseModelOutputWithPoolingAndCrossAttentions, ModelOutput

try:
    from .utils import select_offsets_emb, get_text_field_mask, sequence_cross_entropy_with_logits
    from .metrics import CategoricalAccuracy
except:
    from utils import select_offsets_emb, get_text_field_mask, sequence_cross_entropy_with_logits
    from metrics import CategoricalAccuracy 


"""

class BertLMPredictionHead(nn.Module):
    def __init__(self, config, bert_model_embedding_weights):
        super(BertLMPredictionHead, self).__init__()
        self.transform = BertPredictionHeadTransform(config)
            ## dense -> activation -> layer_norm
        # The output weights are the same as the input embeddings, but there is
        # an output-only bias for each token.
        self.decoder = nn.Linear(bert_model_embedding_weights.size(1),
                                 bert_model_embedding_weights.size(0),
                                 bias=False)
        self.decoder.weight = bert_model_embedding_weights
        self.bias = nn.Parameter(torch.zeros(bert_model_embedding_weights.size(0)))

    def forward(self, hidden_states):
        hidden_states = self.transform(hidden_states)
        hidden_states = self.decoder(hidden_states) + self.bias
        return hidden_states

"""



class BertBEM(nn.Module):
    def __init__(self, 
                sent_encoder_folder = 'bert-base-uncased',
                def_encoder_folder = 'bert-base-uncased',
                cos_temp = 0.05,  
                    ## cos sim is very close, so, to make it more sharp, default value in CSE
                mlm_weight = 0.5, #0.5, #0.1,    ## 0.5 is better 
                gloss_weight = 1.0,
                use_roberta=False,
                hs_baseline = True,
                single_encoder = True ,  
                ):
        super().__init__()
        print(sent_encoder_folder) 
        self.hs_baseline = hs_baseline # True False  # default false,  
            ## hs -- gloss match, in hs_baseline setting, self.single_encoder = True originally 
        self.single_encoder = single_encoder  

        if not use_roberta: 
            if not self.single_encoder:
                self.mlm_encoder = BertModel.from_pretrained(sent_encoder_folder) 
            self.sent_encoder = BertModel.from_pretrained(sent_encoder_folder)  
            self.def_encoder = BertModel.from_pretrained(def_encoder_folder)  
            self.cls = BertOnlyMLMHead(self.sent_encoder.config)
        else:
            if not self.single_encoder:
                self.mlm_encoder = RobertaModel.from_pretrained(sent_encoder_folder)  
            self.sent_encoder = RobertaModel.from_pretrained(sent_encoder_folder) 
            self.def_encoder = RobertaModel.from_pretrained(def_encoder_folder)  
            self.cls = RobertaLMHead(self.sent_encoder.config)                    


        self.cos_temp = cos_temp
        self.mlm_weight = mlm_weight
        self.gloss_weight = gloss_weight
        self.cos_nn = nn.CosineSimilarity(dim=-1)
        self.ce_loss = nn.CrossEntropyLoss()
        self.metrics = {"label_accuracy": CategoricalAccuracy(), }
        
        if use_roberta:
            self.pad_ids = 1
            self.mask_ids = 50264
        else:
            self.pad_ids = 0
            self.mask_ids = 103
        self.align_weight = 0.2
        self.use_cos = True # False  # default True 

        self.emb_size=self.sent_encoder.config.hidden_size
            ##1024 #768 

    def cos_sim(self, x, y):
        ## x, y: (**, D, **)
        ## out: (**, **)
        if not self.use_cos:
            return torch.sum( x * y, axis=1 ).view([-1, 1])
        return self.cos_nn( x, y ) / self.cos_temp

    def Matrix_cos_sim(self, target_word_embs, target_def_embs, device=None ):
        bs, emb_num = target_word_embs.size() 
        ### matrix level cos sim
        eps = 1e-10
        if device is None and torch.cuda.is_available():
            device = target_word_embs.get_device() 
        eps_M = torch.ones(bs, bs, device=device) * eps 
        
        Matrix_dot = torch.mm( target_word_embs, target_def_embs.T ) 
        if not self.use_cos:
            return Matrix_dot
        Matrix_norm1 = torch.norm( target_word_embs, dim=1).view(-1, 1)
        Matrix_norm2 = torch.norm( target_def_embs, dim=1 ).view(1, -1)
        Norm = torch.mm( Matrix_norm1, Matrix_norm2 ) + eps_M
        Matrix_cos = Matrix_dot / Norm / self.cos_temp
        return Matrix_cos 

    def forward(self, 
                mlm_input_ids,
                target_word_offsets,
                def_input_ids,
                clean_input_ids,

                mlm_labels=None,
                mlm_offsets=None,
                **kwargs, 
                ): 
        
        ## compare to ori_forward:  use dyn mean method   
        ##: LM loss for sent encoder
        clean_sent_embs = self.sent_encoder( mlm_input_ids,  output_hidden_states=True, return_dict=True,)
        
        clean_mlm_embs = select_offsets_emb( clean_sent_embs.last_hidden_state, mlm_offsets )  
        clean_mlm_logits = self.cls( clean_mlm_embs )
        if mlm_labels is not None: 
            weights = (mlm_labels != self.pad_ids ).float()
            masked_lm_loss1 = sequence_cross_entropy_with_logits(logits=clean_mlm_logits, targets=mlm_labels, weights=weights, average='token') 
        
        if self.single_encoder:
            if mlm_labels is not None: 
                masked_lm_loss = masked_lm_loss1
        else:
            ##: LM loss, for mlm encoder
            sent_embs = self.mlm_encoder( mlm_input_ids,  output_hidden_states=True, return_dict=True,)
            mlm_embs = select_offsets_emb( sent_embs.last_hidden_state, mlm_offsets )  
            mlm_logits = self.cls( mlm_embs )
            if mlm_labels is not None: 
                weights = (mlm_labels != self.pad_ids ).float()
                masked_lm_loss2 = sequence_cross_entropy_with_logits(logits=mlm_logits, targets=mlm_labels, weights=weights, average='token') 
                masked_lm_loss = ( masked_lm_loss1 + masked_lm_loss2 ) /2

        ## 2: Contrasive loss
        ### first get gloss embs
        bs, num_sent, seq_len = def_input_ids.size()
        def_input_ids = def_input_ids.view(-1, seq_len) 
        def_input_mask = (def_input_ids != self.pad_ids).float() 
        def_input_mask_expand = torch.unsqueeze( def_input_mask, -1 ).expand( -1, -1, self.emb_size )
        def_embs = self.def_encoder( def_input_ids, output_hidden_states=False, return_dict=True,) 
        
        ### choose1: dynamic mean pooling
        def_embs_reweight = torch.sum( def_embs.last_hidden_state * def_input_mask_expand, axis=1 ) 
        weight_count = torch.sum( def_input_mask, axis=1 ) 
        def_embs = (def_embs_reweight.T / weight_count ).T
       
        def_embs = def_embs.view(bs, num_sent, self.emb_size) 
        ##clean_sent_embs = self.sent_encoder( sent_input_ids,  output_hidden_states=True, return_dict=True,)
        clean_sent_embs = self.sent_encoder( clean_input_ids,  output_hidden_states=True, return_dict=True,) 
            ### put target word also in MLM task 
            ### the input for sent embs had better be all clean 
        clean_sent_embs = clean_sent_embs.last_hidden_state 

        range_vector = torch.arange(0, clean_input_ids.size(0), dtype=torch.long).unsqueeze(1)
        mask_add = torch.tensor( [[self.mask_ids]] * range_vector.size()[0] )
        new_sent_input_ids = clean_input_ids.detach().clone()
        if torch.cuda.is_available():
            device = clean_sent_embs.get_device() 
            range_vector = range_vector.cuda( device )
            mask_add = mask_add.cuda( device )
            new_sent_input_ids = new_sent_input_ids.cuda( device )

        ## 3: get word embs: hs - hc
        ## hs: clean_sent_embs;     hc: sent_with_mask_embs 
        target_word_offsets = target_word_offsets.view(-1, 1) 
        if self.hs_baseline:
            sent_embs = clean_sent_embs 
        else:
            ##======
            # recalculate target_word emb by word_emb - mask_emb
            new_sent_input_ids[range_vector, target_word_offsets ] = mask_add
            if self.single_encoder:
                sent_with_mask_embs = self.sent_encoder( new_sent_input_ids,  output_hidden_states=True, return_dict=True,) 
            else:
                sent_with_mask_embs = self.mlm_encoder( new_sent_input_ids,  output_hidden_states=True, return_dict=True,) 
                #import pdb; pdb.set_trace() 
            sent_with_mask_embs = sent_with_mask_embs.last_hidden_state  
            sent_embs = clean_sent_embs - sent_with_mask_embs 
            ###=======

        """
        ##====== force alignment between def emb space and bert word space
        ## not helpful
        ## type 1: emb_def * bert_ls 
        correct_def_embs = def_embs[:,0,:]
        correct_word_label_ids = select_offsets_emb( sent_input_ids, target_word_offsets )      ## [16, 1] 
        correct_def_word_logits = self.cls( correct_def_embs )      ## [16, 30522]
        weights = (correct_word_label_ids != self.pad_ids ).float()
        def_word_align_loss = sequence_cross_entropy_with_logits(logits=correct_def_word_logits, targets=correct_word_label_ids, weights=weights, average='token') 
        ##=======  finish
        """

        target_word_embs = select_offsets_emb( sent_embs, target_word_offsets )   ## bs, 1, 768
        target_word_embs = target_word_embs.view(bs, self.emb_size ) 
        target_def_embs = def_embs[:,0,:]   # bs, 768
        neg_def_embs =def_embs[:,1,:]       # bs, 768
       
        neg_def_logit = self.cos_sim( target_word_embs, neg_def_embs )  ## bs 
        neg_def_logit = neg_def_logit.view(-1, 1)
        target_def_logit = self.Matrix_cos_sim(target_word_embs, target_def_embs ) 

        all_def_logits = torch.cat( [target_def_logit, neg_def_logit], 1 ) 
        if torch.cuda.is_available():
            device = all_def_logits.get_device() 
        else:
            device=None
        labels = torch.arange(0,bs, device=device )
        contrasive_loss = self.ce_loss( all_def_logits, labels)
 
        output = { "logits": all_def_logits }
        total_loss = self.gloss_weight * contrasive_loss + masked_lm_loss * self.mlm_weight
        
        num_classes = all_def_logits.size(-1)
        if (labels >= num_classes).any():
            print(labels , labels.shape )
            print(all_def_logits, all_def_logits.shape )
    
        self.metrics["label_accuracy"]( all_def_logits, labels )
        output['loss'] = total_loss     ##+ self.align_weight * def_word_align_loss 
        output['mlm_loss'] = masked_lm_loss 
        output['wsd_loss'] = contrasive_loss  
        #print( self.gloss_weight, self.mlm_weight, total_loss )
        return output

    def infer_forward(self, 
                input_ids,
                    ## (bs, seq_len)
                return_emb=False,
                 **kwargs, 
                ):  
        sent_embs = self.sent_encoder( input_ids,  output_hidden_states=False, return_dict=True,)
        sent_embs = sent_embs.last_hidden_state
            ## (bs, seq_len, 768 )
        if return_emb: return sent_embs
        mlm_logits = self.cls( sent_embs ) 
        return mlm_logits


    def get_metrics(self, reset: bool = False) :
        metrics_to_return = { "label_accuracy": self.metrics["label_accuracy"].get_metric(reset) , }
        return metrics_to_return

if __name__ == "__main__":
    pass