import logging
import math
import os
import sys

import torch
from torch import nn
from torch.nn import CrossEntropyLoss

from transformers.modeling_bert import *
import numpy as np

class SpellBert(BertPreTrainedModel):
    def __init__(self, config):
        super(SpellBert, self).__init__(config)

        self.vocab_size = config.vocab_size
        self.bert = BertModel(config)
        self.dropout = nn.Dropout(config.hidden_dropout_prob)
        self.classifier = nn.Linear(config.hidden_size, config.vocab_size)

        self.init_weights()

    def tie_cls_weight(self):
        self.classifier.weight = self.bert.embeddings.word_embeddings.weight

    @staticmethod
    def build_batch(batch, tokenizer):
        return batch

    def forward(self, batch):
        input_ids = batch['src_idx']
        attention_mask = batch['masks']
        loss_mask = batch['loss_masks']
        label_ids = batch['tgt_idx'] if 'tgt_idx' in batch else None

        outputs = self.bert(input_ids, attention_mask=attention_mask)

        sequence_output = outputs[0]

        sequence_output = self.dropout(sequence_output)
        logits = self.classifier(sequence_output)

        outputs = (logits,) + outputs[2:]  # add hidden states and attention if they are here
        if label_ids is not None:
            loss_fct = CrossEntropyLoss()
            # Only keep active parts of the loss
            active_loss = loss_mask.view(-1) == 1
            active_logits = logits.view(-1, self.vocab_size)[active_loss]
            active_labels = label_ids.view(-1)[active_loss]
            loss = loss_fct(active_logits, active_labels)
            outputs = (loss,) + outputs

        return outputs

class ECOPOBert(BertPreTrainedModel):
    def __init__(self, config):
        super(ECOPOBert, self).__init__(config)

        self.vocab_size = config.vocab_size
        self.bert = BertModel(config)
        self.dropout = nn.Dropout(config.hidden_dropout_prob)
        self.classifier = nn.Linear(config.hidden_size, config.vocab_size)

        self.init_weights()

    def tie_cls_weight(self):
        self.classifier.weight = self.bert.embeddings.word_embeddings.weight

    @staticmethod
    def build_batch(batch, tokenizer):
        return batch

    def forward(self, batch):
        input_ids = batch['src_idx']
        attention_mask = batch['masks']
        loss_mask = batch['loss_masks']
        label_ids = batch['tgt_idx'] if 'tgt_idx' in batch else None
        bsz, max_length = input_ids.shape

        outputs = self.bert(input_ids, attention_mask=attention_mask)

        sequence_output = outputs[0]

        sequence_output = self.dropout(sequence_output)
        
        logits = self.classifier(sequence_output)

        outputs = (logits,) + outputs[2:]  # add hidden states and attention if they are here
        
        if label_ids is not None:
            # Only keep active parts of the CrossEntropy loss 
            loss_fct = CrossEntropyLoss()
            active_loss = loss_mask.view(-1) == 1
            active_logits = logits.view(-1, self.vocab_size)[active_loss]
            active_labels = label_ids.view(-1)[active_loss]
            loss1 = loss_fct(active_logits, active_labels)

            # Contrastive Probability ptimization Loss
            wrong_positions = torch.nonzero(~torch.eq(input_ids, label_ids))
            batch_loss_list = []
            for idx in range(0, len(wrong_positions)):
                bsz_position = wrong_positions[idx]
                mini_input_id = input_ids[bsz_position[0]]
                mini_label_id = label_ids[bsz_position[0]]
                
                vocab_logits = logits[bsz_position[0]][bsz_position[1]]
                normalization_vocab_logits = F.softmax(vocab_logits, dim = 0)
                pos_logits = normalization_vocab_logits[mini_label_id[bsz_position[1]]]
                pos_logits_index = mini_label_id[bsz_position[1]]
                topK_logits = torch.topk(normalization_vocab_logits, 5)[0]
                topK_logits_index = torch.topk(normalization_vocab_logits, 5)[1].tolist()

                if topK_logits_index[0] != pos_logits_index.item():
                    mini_batch_loss_list = []
                    mini_batch_logits = [pos_logits]
                    for mini_idx in range(0, 5):
                        if topK_logits_index[mini_idx] != pos_logits_index.item():
                            mini_batch_logits.append(topK_logits[mini_idx])
                    mini_batch_logits = F.softmax(torch.tensor(mini_batch_logits, requires_grad=True), dim=0)
                    for mini_idx in range(1, len(mini_batch_logits)):
                        mini_batch_loss_list.append(mini_batch_logits[mini_idx] - mini_batch_logits[0])
                    mini_batch_loss = torch.stack(mini_batch_loss_list).mean()
                
                    batch_loss_list.append(mini_batch_loss)
                
            loss2 = torch.stack(batch_loss_list).mean()
            
            loss = loss1 + loss2

            outputs = (loss,) + outputs

        return outputs 