import torch
from data_loader import MAX_NUM_VECTORS, get_labelWords
import logging
import numpy as np
import torch.nn.functional as F
from torch import nn

logger = logging.getLogger(__name__)


class MLM_model_inferencer:
    def __init__(self, model, tokenizer, common_vocab, alpha):
        self.model = model
        self.tokenizer = tokenizer
        self.model.eval()
        # self.add_new_token()
        self.k = 5
        self.vocab = common_vocab
        self.alpha = alpha
        self.predict2labelWords = get_labelWords(tokenizer)

    def get_new_token(self, vid):
        assert (vid > 0 and vid <= MAX_NUM_VECTORS)
        return '[V%d]' % (vid)

    # def add_new_token(self):
    #     new_tokens = [self.get_new_token(i + 1) for i in range(MAX_NUM_VECTORS)]
    #     self.tokenizer.add_tokens(new_tokens)
    #     ebd = self.model.resize_token_embeddings(len(self.tokenizer))
    #     logger.info('# vocab after adding new tokens: %d' % len(self.tokenizer))


    def assign_embedding(self, new_token, token_emb):
        token_id = self.tokenizer.convert_tokens_to_ids([new_token])[0]
        self.model.embeddings.word_embeddings.weight[token_id] = token_emb

    def mask_filling(self, token_probs, attention_mask, label_ids, masked_indices,
                     dummy_token_probs, dummy_att_mask, dummy_masked_indices,
                     predict_list, mode, device,
                     filter_indices=None,
                     index_list=None):
        length = token_probs.size(1)
        batch_size = token_probs.size(0)

        dummy_len = dummy_token_probs.size(1)
        dummy_batch_size = dummy_token_probs.size(0)

        seq_ids = torch.argmax(dummy_token_probs, dim=-1)
        seqs = self.tokenizer.decode(list(seq_ids[0, :].cpu().numpy()))
        # print(seqs)

        # Prepare Labels
        mlm_labels = torch.ones([batch_size, length]).long()* (-100)
        dummy_mlm_labels = torch.ones([dummy_batch_size, dummy_len]).long()* (-100)
        # info_labels = torch.ones([batch_size, length]).long() * (-100)
        # predict = predict_list[0]
        # label_word_ids = list(self.predict2labelWords[predict])
        for i in range(batch_size):
            mlm_labels[i, masked_indices[i]] = label_ids[i]
            dummy_mlm_labels[i, dummy_masked_indices[i]] = label_ids[i]
            # info_labels[i, dummy_masked_indices[i]] = np.random.choice(label_word_ids, 1)[0]
        mlm_labels = mlm_labels.to(device)
        dummy_mlm_labels = dummy_mlm_labels.to(device)
        # info_labels = info_labels.to(device)

        # if np.random.random() > 0.8:
        #     decode_ids = torch.argmax(token_probs, dim=-1)
        #     for i in range(1):
        #         seq = self.tokenizer.decode(list(decode_ids[i].cpu().numpy()))
        #         mask = self.tokenizer.decode(decode_ids[i, masked_indices[i]])
        #         print("Check: %s, %s" % (seq, mask))
        # decode_ids = torch.argmax(dummy_token_probs, dim=-1)
        # for i in range(1):
        #     seq = self.tokenizer.decode(list(decode_ids[i].cpu().numpy()))
        #     mask = self.tokenizer.decode(decode_ids[i, masked_indices[i]])
        #     print("Check Dummy: %s, %s" % (seq, mask))

        if mode=="training":
            input_ids = torch.cat([token_probs, dummy_token_probs], dim=0)
            attention_mask = torch.cat([attention_mask, dummy_att_mask], dim=0)
            outputs = self.model(input_ids=input_ids,
                                 attention_mask=attention_mask,
                                 return_dict=False)
            logits = outputs[0]

            mlm_prediction_scores = logits[:batch_size, :, :]
            loss_fct = nn.CrossEntropyLoss()  # -100 index = padding token
            pred_loss = loss_fct(mlm_prediction_scores.view(-1, self.model.config.vocab_size), mlm_labels.view(-1))

            info_logits = logits[batch_size:, :, :]
            info_loss = -loss_fct(info_logits.view(-1, self.model.config.vocab_size), dummy_mlm_labels.view(-1))
            loss = pred_loss+ self.alpha* info_loss
        elif mode=="eval":
            with torch.no_grad():
                loss, logits = self.model(input_ids=token_probs,
                                          attention_mask=attention_mask,
                                          labels=mlm_labels,
                                          return_dict=False)
            log_probs = F.log_softmax(logits, dim=-1).cpu()
        else:
            with torch.no_grad():
                loss, logits = self.model(input_ids=dummy_token_probs,
                                          attention_mask=dummy_att_mask,
                                          labels=dummy_mlm_labels,
                                          return_dict=False)
                log_probs = F.log_softmax(logits, dim=-1).cpu()

        if mode=="training":
            return [loss, pred_loss, info_loss]
        elif mode=="eval":
            tot = log_probs.shape[0]
            cor = 0
            preds = []
            topk = []
            common_vocab_loss = []

            vocab_to_common_vocab = None
            if index_list is not None:
                vocab_to_common_vocab = {}
                for cid, idx in enumerate(index_list):
                    vocab_to_common_vocab[idx] = cid

            # During testing, return accuracy and top-k predictions
            for i in range(log_probs.shape[0]):
                masked_index = masked_indices[i]
                log_prob = log_probs[i, masked_index]
                mlm_label = label_ids[i][0].item()
                if filter_indices is not None:
                    log_prob = log_prob.index_select(dim=0, index=filter_indices)
                    pred_common_vocab = torch.argmax(log_prob)
                    pred = index_list[pred_common_vocab]

                    # get top-k predictions
                    topk_preds = []
                    topk_log_prob, topk_ids = torch.topk(log_prob, self.k)
                    for log_prob_i, idx in zip(topk_log_prob, topk_ids):
                        ori_idx = index_list[idx]
                        token = self.vocab[ori_idx]
                        topk_preds.append({'token': token, 'log_prob': log_prob_i.item()})
                    topk.append(topk_preds)

                    # compute entropy on common vocab
                    common_logits = logits[i][masked_index].cpu().index_select(dim=0, index=filter_indices)
                    common_log_prob = -F.log_softmax(common_logits, dim=-1)
                    if mlm_label == 14213:
                        print("Check")
                    common_label_id = vocab_to_common_vocab[mlm_label]
                    common_vocab_loss.append(common_log_prob[common_label_id].item())
                else:
                    pred = torch.argmax(log_prob)
                    topk.append([])
                if pred == mlm_labels[i,masked_index]:
                    cor += 1
                    preds.append(1)
                else:
                    preds.append(0)
            return loss, log_probs, cor, tot, preds, topk, common_vocab_loss

        else:
            tot = log_probs.shape[0]
            cor = 0
            preds = []
            topk = []
            common_vocab_loss = []

            vocab_to_common_vocab = None
            if index_list is not None:
                vocab_to_common_vocab = {}
                for cid, idx in enumerate(index_list):
                    vocab_to_common_vocab[idx] = cid

            # During testing, return accuracy and top-k predictions
            for i in range(log_probs.shape[0]):
                masked_index = dummy_masked_indices[i]
                log_prob = log_probs[i, masked_index]
                mlm_label = label_ids[i][0].item()
                if filter_indices is not None:
                    log_prob = log_prob.index_select(dim=0, index=filter_indices)
                    pred_common_vocab = torch.argmax(log_prob)
                    pred = index_list[pred_common_vocab]

                    # get top-k predictions
                    topk_preds = []
                    topk_log_prob, topk_ids = torch.topk(log_prob, self.k)
                    for log_prob_i, idx in zip(topk_log_prob, topk_ids):
                        ori_idx = index_list[idx]
                        token = self.vocab[ori_idx]
                        topk_preds.append({'token': token, 'log_prob': log_prob_i.item()})
                    topk.append(topk_preds)

                    # compute entropy on common vocab
                    common_logits = logits[i][masked_index].cpu().index_select(dim=0, index=filter_indices)
                    common_log_prob = -F.log_softmax(common_logits, dim=-1)
                    common_label_id = vocab_to_common_vocab[mlm_label]
                    common_vocab_loss.append(common_log_prob[common_label_id].item())
                else:
                    pred = torch.argmax(log_prob)
                    topk.append([])
                if pred == mlm_labels[i,masked_index]:
                    cor += 1
                    preds.append(1)
                else:
                    preds.append(0)
            return loss, log_probs, cor, tot, preds, topk, common_vocab_loss
