import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from transformers.modeling_outputs import BaseModelOutput

class Similarity(nn.Module):
    """
    Dot product or cosine similarity
    """

    def __init__(self, temp=1.0):
        super().__init__()
        self.temp = temp
        self.cos = nn.CosineSimilarity(dim=-1)

    def forward(self, x, y):
        return self.cos(x, y) / self.temp

class MLPLayer(nn.Module):
    """
    Head for getting sentence representations over RoBERTa/BERT's CLS representation.
    """

    def __init__(self, hidden_size, activation='relu'):
        super().__init__()
        self.dense = nn.Linear(hidden_size, hidden_size)
        if activation.lower() == 'relu':
            self.activation = nn.ReLU()
        elif activation.lower() == 'gelu':
            self.activation = nn.GELU()
        else:
            print('error activation')
            exit()

        self.reset_parameters()


    def reset_parameters(self):
        nn.init.xavier_uniform_(self.dense.weight)
        nn.init.constant_(self.dense.bias, 0.)

    def forward(self, features):
        x = self.dense(features)
        x = self.activation(x)
        return x


class FFN(nn.Module):
    def __init__(self, hidden_size, activation='relu'):
        super().__init__()
        self.mlp1 =  MLPLayer(hidden_size, activation=activation)
        self.ln1 = nn.LayerNorm(hidden_size)

        self.mlp2 =  MLPLayer(hidden_size, activation=activation)
        self.ln2 = nn.LayerNorm(hidden_size)


    def forward(self, x):

        fx = self.mlp1(x)
        x = x + self.ln1(fx)

        fx = self.mlp2(x)
        x = x + self.ln2(fx)

        return x



def get_symm_kl(noised_logits, input_logits):
    
    return (
        F.kl_div(
            F.log_softmax(noised_logits, dim=-1, dtype=torch.float32),
            F.softmax(input_logits, dim=-1, dtype=torch.float32),
            None,
            None,
            "sum",
        )
        + F.kl_div(
            F.log_softmax(input_logits, dim=-1, dtype=torch.float32),
            F.softmax(noised_logits, dim=-1, dtype=torch.float32),
            None,
            None,
            "sum",
        )
    ) / noised_logits.size(0)


def get_symm_kl_masked(noised_logits, input_logits, mask):
    
    return (
        F.kl_div(
            F.log_softmax(noised_logits, dim=-1, dtype=torch.float32),
            F.softmax(input_logits, dim=-1, dtype=torch.float32) * mask.unsqueeze(-1).float(),
            None,
            None,
            "sum",
        )
        + F.kl_div(
            F.log_softmax(input_logits, dim=-1, dtype=torch.float32),
            F.softmax(noised_logits, dim=-1, dtype=torch.float32) * mask.unsqueeze(-1).float(),
            None,
            None,
            "sum",
        )
    ) / torch.sum(mask).float()


def detach( x):
    x_new = x.detach()
    x_new.requires_grad = True
    return x_new


def norm(x):
    norm_x = x / (torch.norm(x, dim=-1).unsqueeze(-1) + 1e-12 )
    return norm_x

def one_hot(indice, num_classes):
    """
    one_hot
    """
    I = torch.eye(num_classes).to(indice.device)
    T = I[indice]
    return T



class Hidden2Discrete(nn.Module):
    def __init__(self, input_size, k_size, has_bias=True):
        super(Hidden2Discrete, self).__init__()
        self.k_size = k_size
        latent_size = self.k_size

        self.p_h = nn.Linear(input_size, latent_size, bias=has_bias)
        self.reset_parameters()

    def reset_parameters(self):
        nn.init.xavier_uniform_(self.p_h.weight)
        nn.init.constant_(self.p_h.bias, 0.)

    def forward(self, inputs):
        """
        :param inputs: batch_size x input_size
        :return:
        """

        logits = self.p_h(inputs)
        log_qy = F.log_softmax(logits, dim=1)
        return logits, log_qy


class Hidden2Gaussian(nn.Module):
    # # https://github.com/snakeztc/NeuralDialog-LaRL/blob/master/latent_dialog/nn_lib.py#L53
    def __init__(self, input_size, output_size, has_bias=True, mu_factor=1.0, logvar_factor=1.0):
        super(Hidden2Gaussian, self).__init__()

        self.mu = nn.Linear(input_size, output_size, bias=has_bias)
        self.logvar = nn.Linear(input_size, output_size, bias=has_bias)

        self.tanh = nn.Tanh()
        self.sigmoid = nn.Sigmoid()

        self.mu_factor = mu_factor
        self.logvar_factor = logvar_factor

        self.reset_parameters()


    def reset_parameters(self):
        nn.init.xavier_uniform_(self.mu.weight)
        nn.init.constant_(self.mu.bias, 0.)

        nn.init.xavier_uniform_(self.logvar.weight)
        nn.init.constant_(self.logvar.bias, 0.)

    def forward(self, inputs):

        mu = self.mu(inputs)

        logvar = self.logvar(inputs)
        return mu, logvar


def gaussian_logprob( mu, logvar, sample_z):
    var = torch.exp(logvar)
    constant = float(-0.5 * np.log(2*np.pi))
    logprob = constant - 0.5 * logvar - torch.pow((mu-sample_z), 2) / (2.0*var)
    return logprob


def replaceMatrixEye(sim_matrix, pos_sim):
    pos_sim = pos_sim.unsqueeze(-1)
    batch_size = sim_matrix.size(0)
    identity = torch.eye(batch_size, device=pos_sim.device)
    pos_sim = identity * pos_sim

    neg_sim = sim_matrix.masked_fill(identity == 1, 0)
    new_sim_matrix = pos_sim + neg_sim

    return new_sim_matrix
    


class LabelSmoothing(nn.Module):
    # https://github.com/NVIDIA/DeepLearningExamples/blob/master/PyTorch/Classification/ConvNets/image_classification/smoothing.py
    """
    NLL loss with label smoothing.
    """
    def __init__(self, smoothing=0.0):
        """
        Constructor for the LabelSmoothing module.
        :param smoothing: label smoothing factor
        """
        super(LabelSmoothing, self).__init__()
        self.confidence = 1.0 - smoothing
        self.smoothing = smoothing

    def forward(self, x, target):
        logprobs = torch.nn.functional.log_softmax(x, dim=-1)

        nll_loss = -logprobs.gather(dim=-1, index=target.unsqueeze(1))
        nll_loss = nll_loss.squeeze(1)
        smooth_loss = -logprobs.mean(dim=-1)
        loss = self.confidence * nll_loss + self.smoothing * smooth_loss
        return loss.mean()


def sample(input_ids, attention_mask, perturbed_enc, model, max_sample_len, sample_seq_nums=1, do_sample=True, min_length=5):
    # https://huggingface.co/transformers/_modules/transformers/generation_utils.html#GenerationMixin.generate

    encoder_outputs = BaseModelOutput(
        last_hidden_state=perturbed_enc
    )

    model_kwargs = {
        'attention_mask': attention_mask,
        'encoder_outputs': encoder_outputs
    }
    '''
    {
        attention_mask: [b, t],
        encoder_outputs: {
            last_hidden_state: [b, t, d],
            hidden_states: None,
            attentions: None
        }
    }
    '''
    ###

    pad_token_id = 0
    decoder_start_token_id = 1
    bos_token_id = 1
    eos_token_id = 2

    input_ids = model._prepare_decoder_input_ids_for_generation(
        input_ids, decoder_start_token_id=decoder_start_token_id, bos_token_id=bos_token_id
    )


    cur_len = input_ids.shape[-1]
    stopping_criteria = model._get_stopping_criteria(
        max_length=max_sample_len, 
        max_time=None,
        max_new_tokens=None,
        start_length=cur_len
        )
    # stopping_criteria = None

    # # get distribution pre_processing samplers
    logits_processor = model._get_logits_processor(
        repetition_penalty=1.0,
        no_repeat_ngram_size=0,
        encoder_no_repeat_ngram_size=0,
        encoder_input_ids=input_ids,
        bad_words_ids=None,
        min_length=min_length,
        max_length=max_sample_len,
        eos_token_id=eos_token_id,
        forced_bos_token_id=None,
        forced_eos_token_id=None,
        prefix_allowed_tokens_fn=None,
        num_beams=1,
        num_beam_groups=1,
        diversity_penalty=0.0,
        remove_invalid_values=None,
    )
    # logits_processor = None


    if do_sample:
        # get probability distribution warper
        logits_warper = model._get_logits_warper(
            top_k=1000, top_p=1.0, temperature=0.8, num_beams=1
        )

        # logits_warper = model._get_logits_warper(
        #     top_k=100, top_p=1.0, temperature=0.8, num_beams=1
        # )

        # it may prepare for future multiple sample 
        input_ids, model_kwargs = model._expand_inputs_for_generation(
            input_ids,
            expand_size=sample_seq_nums,
            is_encoder_decoder=True,
            **model_kwargs,
        )

        output = model.sample(
                        input_ids,
                        logits_processor=logits_processor,
                        logits_warper=logits_warper,
                        stopping_criteria=stopping_criteria,
                        pad_token_id=pad_token_id,
                        eos_token_id=eos_token_id,
                        output_scores=True,
                        return_dict_in_generate=True,
                        synced_gpus=False,
                        output_hidden_states=True,
                        **model_kwargs,
                    )
    else:
        output = model.greedy_search(
            input_ids,
            logits_processor=logits_processor,
            stopping_criteria=stopping_criteria,
            pad_token_id=pad_token_id,
            eos_token_id=eos_token_id,
            output_scores=True,
            return_dict_in_generate=True,
            synced_gpus=False,
            output_hidden_states=True,
            **model_kwargs,
        )

    next_reply_ids = output.sequences # [b, t`+1] index0 is the bos

    output_scores = torch.stack(output.scores) # [t`, b, v]
    lsf = nn.LogSoftmax(dim=-1)
    output_scores = lsf(output_scores)

    word_scores = torch.gather(output_scores.transpose(0, 1), dim=-1, index=next_reply_ids[:,1:].unsqueeze(-1)).squeeze(-1)  # [b, t`]

    last_hiddens = [hs[-1]  for hs in output.decoder_hidden_states]
    dec_hiddens = torch.cat(last_hiddens, dim=1) # [b, t`, d]
    return next_reply_ids, word_scores, dec_hiddens