

from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union

import numpy
import torch
from torch import nn
import torch.nn.functional as F

from torch.nn.modules.linear import Linear

from transformers import BertModel, BertConfig, BertTokenizer

from transformers.models.bert.modeling_bert import BertOnlyMLMHead

#from BERT_emberder import PretrainedBertEmbedder
#from transformers.modeling_outputs import BaseModelOutputWithPoolingAndCrossAttentions, ModelOutput

try:
    from .utils import select_offsets_emb, get_text_field_mask, sequence_cross_entropy_with_logits
    from .metrics import CategoricalAccuracy
except:
    from utils import select_offsets_emb, get_text_field_mask, sequence_cross_entropy_with_logits
    from metrics import CategoricalAccuracy 



class BertBEM(nn.Module):
    def __init__(self, 
                sent_encoder_folder = 'bert-base-uncased',
                def_encoder_folder = 'bert-base-uncased',
                cos_temp = 0.05,  
                    ## cos sim is very close, so, to make it more sharp, default value in CSE
                mlm_weight = 0.5, #0.5, #0.1,    ## 0.5 is better 
                gloss_weight = 1.0
                ):
        super().__init__()
        print(sent_encoder_folder) 
        self.sent_encoder = BertModel.from_pretrained(sent_encoder_folder) 
        self.def_encoder = BertModel.from_pretrained(def_encoder_folder)  
        
        self.cos_temp = cos_temp
        self.mlm_weight = mlm_weight
        self.gloss_weight = gloss_weight
        self.cos_nn = nn.CosineSimilarity(dim=-1)
        
        self.cls = BertOnlyMLMHead(self.sent_encoder.config)
        self.ce_loss = nn.CrossEntropyLoss()
        self.metrics = {"label_accuracy": CategoricalAccuracy(), }
        self.pad_ids = 0


    def cos_sim(self, x, y):
        ## x, y: (**, D, **)
        ## out: (**, **)
        return self.cos_nn( x, y ) / self.cos_temp

    def Matrix_cos_sim(self, target_word_embs, target_def_embs, device=None ):
        bs, emb_num = target_word_embs.size() 
        ### matrix level cos sim
        eps = 1e-10
        if device is None and torch.cuda.is_available():
            device = target_word_embs.get_device() 
        eps_M = torch.ones(bs, bs, device=device) * eps 
        
        Matrix_dot = torch.mm( target_word_embs, target_def_embs.T ) 
        Matrix_norm1 = torch.norm( target_word_embs, dim=1).view(-1, 1)
        Matrix_norm2 = torch.norm( target_def_embs, dim=1 ).view(1, -1)
        Norm = torch.mm( Matrix_norm1, Matrix_norm2 ) + eps_M
        Matrix_cos = Matrix_dot / Norm / self.cos_temp
        return Matrix_cos 

    def ori_forward(self, 
                sent_input_ids,
                    ## (bs, seq_len)
                target_word_offsets,
                    ## (bs)
                def_input_ids,
                    ## (batch_size, num_sent, seq_len)
                    ## num_sent_id = 0: answer; num_sent_id = 1: hard_example;
                mlm_labels=None,
                mlm_offsets=None, 
                ## def_labels=None,  def_label isn't needed, because the answer is always place at the 1st position
                **kwargs, 
                ): 
        #import pdb; pdb.set_device() 
        sent_embs = self.sent_encoder( sent_input_ids,  output_hidden_states=False, return_dict=True,)
        sent_embs = sent_embs.last_hidden_state
            ## (bs, seq_len, 768 )

        ## 1: LM loss
        mlm_embs = select_offsets_emb( sent_embs, mlm_offsets )  
        mlm_logits = self.cls( mlm_embs )
        
        if mlm_labels is not None: 
            weights = (mlm_labels!=0).float()
            masked_lm_loss = sequence_cross_entropy_with_logits(logits=mlm_logits, targets=mlm_labels, weights=weights, average='token') 
        
        ## 2: Contrasive loss
        bs, num_sent, seq_len = def_input_ids.size()
        def_input_ids = def_input_ids.view(-1, seq_len) 
        def_embs = self.def_encoder( def_input_ids, output_hidden_states=False, return_dict=True,) 
        def_embs = def_embs.last_hidden_state   ## bs*num_sent, seq, 768
        
        ## mean pooling, exp in sentence BERT shows it's better
        def_embs = torch.mean( def_embs, dim=1 )  ## bs*num_sent, 768
        def_embs = def_embs.view(bs, num_sent, 768) 

        target_word_offsets = target_word_offsets.view(-1, 1) 
        target_word_embs = select_offsets_emb( sent_embs, target_word_offsets ) 
            ## bs, 1, 768
        target_word_embs = target_word_embs.view(bs, 768 ) 

        target_def_embs = def_embs[:,0,:]   # bs, 768
        neg_def_embs =def_embs[:,1,:]       # bs, 768
       
        neg_def_logit = self.cos_sim( target_word_embs, neg_def_embs )  ## bs 
        neg_def_logit = neg_def_logit.view(-1, 1)
        target_def_logit = self.Matrix_cos_sim(target_word_embs, target_def_embs ) 

        all_def_logits = torch.cat( [target_def_logit, neg_def_logit], 1 ) 
        if torch.cuda.is_available():
            device = all_def_logits.get_device() 
        else:
            device=None
        labels = torch.arange(0,bs, device=device )
        contrasive_loss = self.ce_loss( all_def_logits, labels)
 
        output = { "logits": all_def_logits }
        total_loss = contrasive_loss + masked_lm_loss * self.mlm_weight
        
        num_classes = all_def_logits.size(-1)
        if (labels >= num_classes).any():
            print(labels , labels.shape )
            print(all_def_logits, all_def_logits.shape )
        

        self.metrics["label_accuracy"]( all_def_logits, labels )
        output['loss'] = total_loss 
        output['mlm_loss'] = masked_lm_loss 
        output['wsd_loss'] = contrasive_loss  
        #print(masked_lm_loss, contrasive_loss, total_loss)
        return output

    def forward(self, 
                sent_input_ids,
                target_word_offsets,
                def_input_ids,
                mlm_labels=None,
                mlm_offsets=None, 
                **kwargs, 
                ): 
        ## compare to ori_forward:  use dyn mean method

        sent_embs = self.sent_encoder( sent_input_ids,  output_hidden_states=True, return_dict=True,)
        ## 1: LM loss, on top layer
        mlm_embs = select_offsets_emb( sent_embs.last_hidden_state, mlm_offsets )  
        mlm_logits = self.cls( mlm_embs )
        if mlm_labels is not None: 
            weights = (mlm_labels != self.pad_ids ).float()
            masked_lm_loss = sequence_cross_entropy_with_logits(logits=mlm_logits, targets=mlm_labels, weights=weights, average='token') 
        
        ## 2: Contrasive loss
        bs, num_sent, seq_len = def_input_ids.size()
        def_input_ids = def_input_ids.view(-1, seq_len) 
        def_input_mask = (def_input_ids != self.pad_ids).float() 
        def_input_mask_expand = torch.unsqueeze( def_input_mask, -1 ).expand( -1, -1, 768 )
        def_embs = self.def_encoder( def_input_ids, output_hidden_states=False, return_dict=True,) 
        
        ### choose1: dynamic mean pooling
        def_embs_reweight = torch.sum( def_embs.last_hidden_state * def_input_mask_expand, axis=1 ) 
        weight_count = torch.sum( def_input_mask, axis=1 ) 
        def_embs = (def_embs_reweight.T / weight_count ).T
       
        def_embs = def_embs.view(bs, num_sent, 768) 
        sent_embs = sent_embs.last_hidden_state 

        target_word_offsets = target_word_offsets.view(-1, 1) 
        target_word_embs = select_offsets_emb( sent_embs, target_word_offsets ) 
            ## bs, 1, 768
        target_word_embs = target_word_embs.view(bs, 768 ) 

        target_def_embs = def_embs[:,0,:]   # bs, 768
        neg_def_embs =def_embs[:,1,:]       # bs, 768
       
        neg_def_logit = self.cos_sim( target_word_embs, neg_def_embs )  ## bs 
        neg_def_logit = neg_def_logit.view(-1, 1)
        target_def_logit = self.Matrix_cos_sim(target_word_embs, target_def_embs ) 

        all_def_logits = torch.cat( [target_def_logit, neg_def_logit], 1 ) 
        if torch.cuda.is_available():
            device = all_def_logits.get_device() 
        else:
            device=None
        labels = torch.arange(0,bs, device=device )
        contrasive_loss = self.ce_loss( all_def_logits, labels)
 
        output = { "logits": all_def_logits }
        total_loss = self.gloss_weight * contrasive_loss + masked_lm_loss * self.mlm_weight
        
        num_classes = all_def_logits.size(-1)
        if (labels >= num_classes).any():
            print(labels , labels.shape )
            print(all_def_logits, all_def_logits.shape )
    
        self.metrics["label_accuracy"]( all_def_logits, labels )
        output['loss'] = total_loss  #total_loss   # total_loss # masked_lm_loss 
         
        output['mlm_loss'] = masked_lm_loss 
        output['wsd_loss'] = contrasive_loss  
        return output

    def infer_forward(self, 
                input_ids,
                    ## (bs, seq_len)
                return_emb=False,
                 **kwargs, 
                ):  
        sent_embs = self.sent_encoder( input_ids,  output_hidden_states=False, return_dict=True,)
        sent_embs = sent_embs.last_hidden_state
            ## (bs, seq_len, 768 )
        if return_emb: return sent_embs
        mlm_logits = self.cls( sent_embs ) 
        return mlm_logits


    def get_metrics(self, reset: bool = False) :
        metrics_to_return = { "label_accuracy": self.metrics["label_accuracy"].get_metric(reset) , }
        return metrics_to_return

if __name__ == "__main__":

    pass