from transformers import BertTokenizer, BertModel, BertForMaskedLM, RobertaForMaskedLM
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union
import pdb
import torch
from torch import nn
import numpy as np
import json 
from tqdm import tqdm 
from model_exp2 import BertBEM

from models.tokenization import Tokenization as BertTokenization 
import sys 
GR_model_code_path="../GR_models/models"
    ### the path of folder where the model is defined

sys.path.append(GR_model_code_path)
from Roberta_tokenization import Tokenization as RobertaTokenization 
 


def text2model_input(use_roberta=False):
    if not use_roberta:
        tokenizer = BertTokenization(lower_case=True )
    else:
        tokenizer = RobertaTokenization() 
    def batch2tensor(dict_data):
        return {k: torch.tensor([v],dtype=torch.long) for k,v in dict_data.items() }
    def input2ids(word_lst, target_word_idx, candidates=[] ):
        if not use_roberta:
            out = tokenizer(word_lst, target_word_idx, candidates, word_answer=None )
            tensor_out = batch2tensor(out)
        else:   
            out = tokenizer.bert_tokenization(word_lst, target_word_idx, )
            out['offsets'] = out['target_word_offset']
            tensor_out = batch2tensor(out)
        return tensor_out 
    return tokenizer, input2ids


class BERTwsd(nn.Module):
       
    def __init__(self, 
                ckpt_fp=None, 
                temp=0.1, 
                use_roberta=False,
                add_hc = True,
                hs_baseline = False,
                single_encoder = False ,  
                 ):
        """
        higher temps, the more import role emb plays
            for unsupervised bert, temp weight = 0.1;
            for wsd_bert, temp weight = 1/15~=0.07;
            for unsupervised roberta, temp weight = 0.25
        """
        ## beta set to 0 defaultly, 
        super().__init__()
       
        
        if not use_roberta:
            sent_encoder_folder = 'bert-base-uncased'
            def_encoder_folder = 'bert-base-uncased'
        else:
            sent_encoder_folder = 'roberta-base'
            def_encoder_folder = sent_encoder_folder

        self.encoder = BertBEM(
                                sent_encoder_folder=sent_encoder_folder,
                                def_encoder_folder=def_encoder_folder, 
                                use_roberta=use_roberta,
                                hs_baseline = hs_baseline,
                                single_encoder = single_encoder,  
                            )  

        if ckpt_fp is not None and not ckpt_fp in ['bert-base-uncased', 'roberta-base']:
            self.encoder.load_state_dict(torch.load(ckpt_fp, map_location=torch.device('cpu')))
        elif ckpt_fp == 'roberta-base':
            ## load model from roberta-base 
            pt_model = RobertaForMaskedLM.from_pretrained( ckpt_fp )
            params_dict = pt_model.state_dict()
            model_dict = self.encoder.state_dict()
            state_dict = dict()
            for k in params_dict:
                key = k.replace("roberta.", "sent_encoder.").replace( "lm_head", "cls").strip()
                if not  key in model_dict:
                    print(key)
                    import pdb; pdb.set_trace()     
                state_dict[key] = params_dict[k]
            model_dict.update( state_dict ) 
            self.encoder.load_state_dict( model_dict )
        self.temp = temp     # adjust or not? 0.05 seems better
        self.sim = nn.CosineSimilarity(dim=1, eps=1e-6) 

    def forward(self, input_ids=None, offsets=None, word_ids=None, **params ): 
        sent_embs = self.encoder.sent_encoder( input_ids,  output_hidden_states=False, return_dict=True,)
        sent_embs = sent_embs.last_hidden_state
            ## (bs, seq_len, 768 )
        
        mlm_logits = self.encoder.cls( sent_embs )
            ## word dot similarity
        mlm_logits = mlm_logits[0]

        ## embedding layer : word_embedding connection 
        emb_weights = self.encoder.sent_encoder.embeddings.word_embeddings.weight    # [vocab_size, 768]
        #emb_weights = self.bert_encoder.embeddings.word_embeddings.weight    # [vocab_size, 768]
         
        target_word_emb = emb_weights[word_ids]
        #target_word_emb = torch.mean(target_word_emb, dim=0 )
        target_word_emb = target_word_emb[0] 
        #emb_dot_logits =  torch.matmul( target_word_emb, emb_weights.T ) 
        emb_dot_logits = self.sim(target_word_emb.view([1, self.encoder.emb_size]), emb_weights )
        emb_dot_logits /= self.temp 
        return mlm_logits , emb_dot_logits


def sents2candidates(use_roberta,  use_emb = True, ckpt=None,
                     hs_baseline = False,
                    single_encoder = False ): 
    #tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
    tokenizer, indexer = text2model_input(use_roberta)
    if use_roberta:
        ## temp = 15 default value 
        model = BERTwsd(use_roberta=use_roberta, temp=1/15, ckpt_fp=ckpt,  hs_baseline = hs_baseline,
                        single_encoder = single_encoder)    
        ## the higher temp is, the smaller effect of emb
    else:
        model = BERTwsd(use_roberta=use_roberta, temp=1/15, ckpt_fp=ckpt,
                         hs_baseline = hs_baseline,
                        single_encoder = single_encoder
                        )     
        #model = BERTwsd(use_roberta=use_roberta, temp=1/10)   
    use_emb = use_emb

    def proba_filter(probas, threshold):
        return [ele for ele in probas if ele>=threshold]
    def entropy(proba_lst):
        entropy_value = - np.sum( np.log(proba_lst) * np.array(proba_lst) )
        return entropy_value
    def proba_list_to_words(probas, top_k, target_idx, inputs ):
        probas = probas / torch.sum(probas, axis=-1)
        top_k_porbas, word_pred_ids = torch.topk(probas, top_k)
        top_k_porbas, word_pred_ids = top_k_porbas.tolist(), word_pred_ids.tolist()
        probas = probas.tolist()

        i = target_idx 
        word_in =  tokenizer.decode([inputs['input_ids'][0][i]])
         
        entropy_value = entropy( probas)
        word_pred_probas = proba_filter(top_k_porbas, threshold=0)
        word_pred_ids = word_pred_ids[:len(word_pred_probas)]
        if not use_roberta: 
            word_pred_text = " ".join( [tokenizer.decode([ids]) for ids in word_pred_ids ] )        
        else: 
            word_pred_text = [ tokenizer.decode( ids ).strip() for ids in word_pred_ids ]
       
        word_pred_probas = [float(p) for p in word_pred_probas]
        #print("[word] %s [value %.3f] \t [related] %s"%(word_in, entropy_value, word_pred_text), word_pred_probas)
        return word_pred_text, word_pred_probas

    def gen_emb(sent, target_word_idx=-1, 
                            top_k=20,  #we use 50 in exps for fair comparison, but 20 is ok  
                            filter_proba=0.05,    ## not used 
                            word="" ):
        sent = sent.strip()
        words = sent.strip().split(' ')
        inputs = indexer(words, target_word_idx )
        offsets = inputs['offsets'].tolist()[0]
        word_ids = inputs['input_ids'][0,inputs['offsets'][0]]

        inputs['word_ids'] =  word_ids
        len_ids = len(inputs['input_ids'].tolist()[0])
        logits, emb_dot_logits = model(**inputs)
        logits = logits.detach()
        target_idx = offsets[0]
        probas_context = logits[target_idx] 
        probas_context =torch.nn.functional.softmax(probas_context, dim=-1).detach() 
        proba_emb_dot = torch.nn.functional.softmax(emb_dot_logits, dim=0 ).detach()

        if use_emb:
            target_probas = probas_context*proba_emb_dot
        else:
            target_probas = probas_context         
        word_pred_text, word_pred_probas = proba_list_to_words(target_probas, top_k, target_idx, inputs)
        return word_pred_text, word_pred_probas
    return  gen_emb


def load_pickle(data_path):
    import pickle
    return pickle.load(open(data_path, "rb"))
def write_data(data, fp):
    import os 
    if os.path.exists(fp): os.remove(fp)
    with open(fp,"w") as f:
        f.write("\n".join(data))
    return 


def main(in_fp, out_fp, 
        use_MASK=False, 
        double_MASK=False, 
        use_roberta=False,
        use_emb = True ,
        ckpt=None,
        top_k=20,
        single_encoder=False ,
        hs_baseline=False,

        remove_punct=False,
        force_lower=True,
         ):
    
    sent_checker = sents2candidates(use_roberta, use_emb, ckpt=ckpt, single_encoder=single_encoder, hs_baseline=hs_baseline )   
    data = load_pickle(in_fp)
    out_data = []
    idxs = [k for k in data ]
    for  idx in tqdm( (sorted(idxs)) ):
        word, pos, pred_, word_, next_ = data[idx][:5]
        candidates = data[idx][5] if len(data[idx])==6 else None

        if pred_ is None or len(pred_.strip()) == 0: pred_ = ""
        if next_ is None: next_ = ""

        if remove_punct:
            pred_, word_, next_ = [ remove_punc( text ) for text in [pred_, word_, next_] ] 
        if force_lower:
            pred_, word_, next_ = pred_.lower(), word_.lower(), next_.lower()
        word_idx = len( pred_.strip().split(" ") ) if len(pred_) > 0 else 0

        ori_sent_lst = [pred_, word_, next_]
        if not use_roberta: 
            pred_, word_, next_ = pred_.lower(), word_.lower(), next_.lower()
       
        if use_MASK:  word_ = "[MASK]"
        sent = (pred_.strip() + " " + word_.strip() + " " + next_.strip()).strip()

        if double_MASK:
            word_idx += len(sent.strip().split(" ")) + 1 
            sent = sent + " [SEP] " + (pred_.strip() + " " + "[MASK]" + " " + next_.strip()).strip()

        out_str, out_probas = sent_checker( sent, target_word_idx=word_idx, word =data[idx][3].strip(),    top_k=top_k,    )
        info = {"idx":idx, "word": word, "pos": pos, "sent_lst":ori_sent_lst, "sent":sent, 
                "candidates": out_str, 
                "candidate_probas": out_probas, 
                "answer":candidates}
        out_data.append( json.dumps(info) )

    write_data( out_data, out_fp)
    return  

 
### for LS07/14 data set 
in_fp ="/XXX/lexsub_test.pk"    ## LS07 or 14 dataset, 
out_fp =  "/XXX/test07._model_setting_.json";  
    ## output file for Model predictions 
baseline_ckpt="XXX/XXX.ckpt"
    ### the trained GR model 

use_emb=False 
    ### if you set use_emb=False, the code will output the direct GR model output 
    ### if you set True, the (+emb) post-process will be used
    ### more details about (+emb) please see the paper "https://aclanthology.org/2020.coling-main.107/" 

use_roberta=True 
    ## use RoBERTa backbone or BERT
top_k=20 ##50 
    ## how many candidates you want to output, we use 50 in our experiments 

single_encoder=False;hs_baseline=False;
    ### params about the used model structure,
    ### when you use Seperate Context encoder structure, set single_encoder=False; hs_baseline=False ;
    ### when you use MultiTask GR model, set single_encoder=True;hs_baseline=True;

force_lower=True 
    ### whether to lower-case the input text, since the trainnig data set in GR-model is all lower-case, we set force_lower=True 
 

main(in_fp, out_fp, use_roberta=use_roberta,  use_emb=use_emb, ckpt=baseline_ckpt, top_k=top_k,   
                    single_encoder=single_encoder ,  
                    hs_baseline=hs_baseline,
                    force_lower=True ,)



