# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import math

import torch
import torch.nn.functional as F

from fairseq import metrics, utils, global_var_manager
from fairseq.criterions import FairseqCriterion, register_criterion


def label_smoothed_nll_loss(lprobs, target, epsilon, ignore_index=None, reduce=True):
    if target.dim() == lprobs.dim() - 1:
        target = target.unsqueeze(-1)
    nll_loss = -lprobs.gather(dim=-1, index=target)
    #print("max_probs: ", torch.softmax(lprobs, -1).topk(8, dim=-1))
    #print("nll_loss: ", nll_loss.squeeze(-1))
    #print("lprobs: ", lprobs)
    #print("AT label-smoothed loss, please input: ")
    #a = input()
    smooth_loss = -lprobs.sum(dim=-1, keepdim=True)
    if ignore_index is not None:
        pad_mask = target.eq(ignore_index)
        nll_loss.masked_fill_(pad_mask, 0.)
        smooth_loss.masked_fill_(pad_mask, 0.)
    else:
        nll_loss = nll_loss.squeeze(-1)
        smooth_loss = smooth_loss.squeeze(-1)
    if reduce:
        nll_loss = nll_loss.sum()
        smooth_loss = smooth_loss.sum()
    eps_i = epsilon / lprobs.size(-1)
    loss = (1. - epsilon) * nll_loss + eps_i * smooth_loss
    return loss, nll_loss


def at_nat_kl_loss(lprobs, target, nat_output, label_pred_match, label_pred_sim,
                   ignore_index=None, reduce=True, kl_strategy="uniform"):
    """
    Args:
        lprobs: (B, L, V)
        nat_output: (B, L, V)
        label_pred_match: (B, L)

    Returns:

    """
    bsz, len_x, vocab_size = lprobs.size()
    if ignore_index is not None:
        non_pad_mask = target.ne(ignore_index)
        target_mask = label_pred_match & non_pad_mask
    else:
        target_mask = label_pred_match
    # target_mask -> (B, L)
    target_mask = target_mask.float()
    '''
    if kl_strategy == "linear":
        factor = (torch.arange(0, len_x, dtype=torch.float) / (len_x - 1)).unsqueeze_(0).to(target_mask)
    elif kl_strategy == "square":
        # torch 1.4 does not support torch.square() operation
        factor = ((torch.arange(0, len_x, dtype=torch.float) / (len_x - 1))*
                  (torch.arange(0, len_x, dtype=torch.float) / (len_x - 1))).unsqueeze_(0).to(target_mask)
        # if torch supports torch.square operation
        #factor = torch.square(torch.arange(0, len_x, dtype=torch.float) / (len_x - 1)).unsqueeze_(0).to(target_mask)
    elif kl_strategy == "sqrt":
        factor = torch.sqrt(torch.arange(0, len_x, dtype=torch.float) / (len_x - 1)).unsqueeze_(0).to(target_mask)
    else: # kl_strategy == "uniform":
        factor = torch.ones(len_x, dtype=torch.float).unsqueeze_(0).to(target_mask)
    target_mask = target_mask * factor
    #print("lprobs.size(), nat_output.size(): ", lprobs.size(), nat_output.size())
    #print("target_mask, lprobs, nat_output: ", target_mask, lprobs, nat_output)
    target_mask = target_mask.unsqueeze(-1).float()
    log_nat_output = torch.log(nat_output)
    at_nat_kl_loss = target_mask * nat_output * (log_nat_output - lprobs)
    #print("previous at_nat_kl_loss: ", at_nat_kl_loss)
    if reduce:
        at_nat_kl_loss = at_nat_kl_loss.sum()
    #print("weighted_masked at_nat_kl_loss: ", at_nat_kl_loss)
    #print("unmasked at_nat_kl_loss: ", (nat_output * (log_nat_output - lprobs)).sum())
    #print("weighted_masked at_nat_kl_loss -1: ", (target_mask * (nat_output * (log_nat_output - lprobs))).sum(-1))
    #print("unmasked at_nat_kl_loss -1: ", (nat_output * (log_nat_output - lprobs)).sum(-1))

    return at_nat_kl_loss
    '''
    if kl_strategy == "linear":
        factor = (torch.arange(0, len_x, dtype=torch.float) / (len_x - 1)).unsqueeze_(0).to(target_mask)
    else:  # kl_strategy == "uniform":
        factor = torch.ones(len_x, dtype=torch.float).unsqueeze_(0).to(target_mask)
    factor[:, 0] = 0
    factor[:, -1] = 0
    #print("target_mask: ", target_mask)
    #print("label_pred_sim: ", label_pred_sim.size(), label_pred_sim)
    #print("factor: ", factor.size(), factor)
    kl_scaling_factor = (factor * F.threshold(label_pred_sim, 0.5, 0.0)).unsqueeze(-1).float()
    #print("kl_scaling_factor: ", kl_scaling_factor.size(), kl_scaling_factor)
    #target_mask = target_mask.unsqueeze(-1).float()#
    log_nat_output = torch.log(nat_output)
    at_nat_kl_loss = kl_scaling_factor * nat_output * (log_nat_output - lprobs)
    #print("emb_sim scaled at_nat_kl_loss before reduction: ", at_nat_kl_loss)
    if reduce:
        at_nat_kl_loss = at_nat_kl_loss.sum()
    #print("emb_sim scaled at_nat_kl_loss after reduction: ", at_nat_kl_loss)
    #print("original at_nat_kl_loss: ", (nat_output * (log_nat_output - lprobs)).sum())
    #print("weighted  at_nat_kl_loss -1: ", (kl_scaling_factor * (nat_output * (log_nat_output - lprobs))).sum(-1))
    #print("masked at_nat_kl_loss -1: ", (target_mask * nat_output * (log_nat_output - lprobs)).sum(-1))

    return at_nat_kl_loss


def at_cmlm_kl_loss(lprobs, target, nat_probs, target_mask, reduce=True, kl_strategy="uniform", label_pred_sim=None):
    bsz, len_x, vocab_size = lprobs.size()
    with torch.no_grad():
        # target_mask -> (B, L)
        target_mask = target_mask.float()
        # global_var_manager.set_value("count_at_wrong", global_var_manager.get_value("count_at_wrong") + float(target_mask.sum().data))
        # print("target_mask: ", target_mask.squeeze(-1))
        if kl_strategy == "linear":
            factor = (torch.arange(0, len_x, dtype=torch.float) / (len_x - 1)).unsqueeze_(0).to(target_mask)
        elif kl_strategy == "at_emb_sim":
            # label_pred_sim -> (B, L)
            at_label_pred_sim, nat_label_pred_sim = label_pred_sim
            factor = (at_label_pred_sim + 1) / 2
        elif kl_strategy == "higher_golden_probs" or kl_strategy == "higher_emb_sim":
            at_label_pred_sim, nat_label_pred_sim = label_pred_sim
            factor = torch.gt(nat_label_pred_sim, at_label_pred_sim).float()
            # global_var_manager.set_value("count_nat_better_than_at", global_var_manager.get_value("count_nat_better_than_at") + float((target_mask * factor).sum()))
        elif kl_strategy == "emb_sim_diff":
            at_label_pred_sim, nat_label_pred_sim = label_pred_sim
            factor = F.threshold(nat_label_pred_sim - at_label_pred_sim, 0, 0)
        elif kl_strategy == "nat_higher_at_nat_higher_threshold":
            at_label_pred_sim, nat_label_pred_sim = label_pred_sim
            factor1 = torch.gt(nat_label_pred_sim, at_label_pred_sim).float()
            factor2 = torch.gt(nat_label_pred_sim, 0.7).float()
            # global_var_manager.set_value("count_nat_better_than_at", global_var_manager.get_value("count_nat_better_than_at") + float((target_mask * factor1).sum()))
            # global_var_manager.set_value("count_nat_higher_than_threshold", global_var_manager.get_value("count_nat_higher_than_threshold") + float((target_mask * factor2).sum()))
            factor = factor1 * factor2
            # global_var_manager.set_value("count_nat_higher&better", global_var_manager.get_value("count_nat_higher&better") + float((target_mask * factor).sum()))
        elif kl_strategy == "diff_on_nat_argmax":
            nat_argmax_probs, at_prob_at_nat_argmax = label_pred_sim
            factor = F.threshold(nat_argmax_probs - at_prob_at_nat_argmax, 0., 0.)
        elif kl_strategy == "uniform":
            factor = torch.ones(len_x, dtype=torch.float).unsqueeze_(0).to(target_mask)
        else:
            raise NotImplementedError

        # kl_scaling_factor -> (B, L, 1)
        kl_scaling_factor = (target_mask * factor).unsqueeze(-1).float()
        nat_lprobs = torch.log(nat_probs)

    at_cmlm_kl_loss = kl_scaling_factor * nat_probs * (nat_lprobs - lprobs)
    #print("original at_cmlm_kl_loss: ", (nat_probs * (nat_lprobs - lprobs)).sum(-1))
    #print("weighted  at_cmlm_kl_loss -1: ", (kl_scaling_factor * nat_probs * (nat_lprobs - lprobs)).sum(-1))
    if reduce:
        at_cmlm_kl_loss = at_cmlm_kl_loss.sum()
    return at_cmlm_kl_loss


def at_cmlm_hidden_mse_loss(lprobs, target, nat_probs, at_output_hiddens, nat_output_hiddens, target_mask, reduce=True, kl_strategy="uniform", label_pred_sim=None):
    #vocab_size = lprobs.size(-1)
    at_hiddens = at_output_hiddens[-1].transpose(0, 1)
    nat_hiddens = nat_output_hiddens[-1].transpose(0, 1)
    bsz, len_x, hidden_dim = at_hiddens.size()
    with torch.no_grad():
        # target_mask -> (B, L)
        target_mask = target_mask.float()
        # global_var_manager.set_value("count_at_wrong", global_var_manager.get_value("count_at_wrong") + float(target_mask.sum().data))
        # print("target_mask: ", target_mask.squeeze(-1))
        if kl_strategy == "linear":
            factor = (torch.arange(0, len_x, dtype=torch.float) / (len_x - 1)).unsqueeze_(0).to(target_mask)
        elif kl_strategy == "at_emb_sim":
            # label_pred_sim -> (B, L)
            at_label_pred_sim, nat_label_pred_sim = label_pred_sim
            factor = (at_label_pred_sim + 1) / 2
        elif kl_strategy == "higher_golden_probs" or kl_strategy == "higher_emb_sim":
            at_label_pred_sim, nat_label_pred_sim = label_pred_sim
            factor = torch.gt(nat_label_pred_sim, at_label_pred_sim).float()
            # global_var_manager.set_value("count_nat_better_than_at", global_var_manager.get_value("count_nat_better_than_at") + float((target_mask * factor).sum()))
        elif kl_strategy == "emb_sim_diff":
            at_label_pred_sim, nat_label_pred_sim = label_pred_sim
            factor = F.threshold(nat_label_pred_sim - at_label_pred_sim, 0, 0)
        elif kl_strategy == "nat_higher_at_nat_higher_threshold":
            at_label_pred_sim, nat_label_pred_sim = label_pred_sim
            factor1 = torch.gt(nat_label_pred_sim, at_label_pred_sim).float()
            factor2 = torch.gt(nat_label_pred_sim, 0.7).float()
            # global_var_manager.set_value("count_nat_better_than_at", global_var_manager.get_value("count_nat_better_than_at") + float((target_mask * factor1).sum()))
            # global_var_manager.set_value("count_nat_higher_than_threshold", global_var_manager.get_value("count_nat_higher_than_threshold") + float((target_mask * factor2).sum()))
            factor = factor1 * factor2
            # global_var_manager.set_value("count_nat_higher&better", global_var_manager.get_value("count_nat_higher&better") + float((target_mask * factor).sum()))
        elif kl_strategy == "uniform":
            factor = torch.ones(len_x, dtype=torch.float).unsqueeze_(0).to(target_mask)
        else:
            raise NotImplementedError

    # kl_scaling_factor -> (B, L, 1)
    scaling_factor = (target_mask*factor).unsqueeze(-1).float()
    at_cmlm_mse_loss = scaling_factor * F.mse_loss(at_hiddens, nat_hiddens) / hidden_dim
    #print("original at_cmlm_kl_loss: ", (nat_output * (log_nat_output - lprobs)).sum(-1))
    #print("weighted  at_cmlm_kl_loss -1: ", (kl_scaling_factor * nat_output * (log_nat_output - lprobs)).sum(-1))
    if reduce:
        at_cmlm_mse_loss = at_cmlm_mse_loss.sum()
    return at_cmlm_mse_loss


@register_criterion('my_at_criterion')
class MyATCriterion(FairseqCriterion):

    def __init__(self, task, sentence_avg, label_smoothing):
        super().__init__(task)
        self.sentence_avg = sentence_avg
        self.eps = label_smoothing

    @staticmethod
    def add_args(parser):
        """Add criterion-specific arguments to the parser."""
        # fmt: off
        parser.add_argument('--label-smoothing', default=0., type=float, metavar='D',
                            help='epsilon for label smoothing, 0 means no label smoothing')
        # fmt: on

    def forward(self, model, sample, reduce=True, mode="at",
                at_net_output=None, nat_output=None, nat_output_hidden=None, label_pred_match=None, kl_strategy='uniform', kl_loss_coef=1.0, label_pred_sim=None, target_mask=None):
        """Compute the loss for the given sample.

        Returns a tuple with three elements:
        1) the loss
        2) the sample size, which is used as the denominator for the gradient
        3) logging outputs to display while training
        """
        if mode == "at" or mode == "at-return_output":
            net_output = model(**sample['net_input'])
            loss, nll_loss = self.compute_loss(model, net_output, sample, reduce=reduce)
            sample_size = sample['target'].size(0) if self.sentence_avg else sample['ntokens']
            logging_output = {
                'loss': loss.data,
                'nll_loss': nll_loss.data,
                'ntokens': sample['ntokens'],
                'nsentences': sample['target'].size(0),
                'sample_size': sample_size,
            }
            if mode == "at-return_output":
                return loss, sample_size, logging_output, net_output,
        elif mode == "at-nat_smoothed":
            net_output = model(**sample['net_input'])
            assert nat_output is not None and label_pred_match is not None
            label_smoothed_nll_loss, nll_loss = self.compute_loss(model, net_output, sample, reduce=reduce)
            kl_loss = self.compute_at_nat_kl_loss(model, net_output, sample, nat_output=nat_output,
                                                            label_pred_match=label_pred_match, reduce=reduce,
                                                            kl_strategy=kl_strategy, label_pred_sim=label_pred_sim)
            #print("label_smoothed_nll_loss, kl_coefficient, kl_loss: ", label_smoothed_nll_loss, kl_loss_coef, kl_loss)
            loss = label_smoothed_nll_loss + kl_loss_coef * kl_loss
            sample_size = sample['target'].size(0) if self.sentence_avg else sample['ntokens']
            logging_output = {
                'loss': loss.data,
                'nll_loss': nll_loss.data,
                'ntokens': sample['ntokens'],
                'nsentences': sample['target'].size(0),
                'sample_size': sample_size,
            }
        elif mode == "at-cmlm_kl":
            assert at_net_output is not None and nat_output is not None and target_mask is not None
            kl_loss = self.compute_at_cmlm_kl_loss(model, at_net_output, sample, nat_output=nat_output,
                                                            target_mask=target_mask, reduce=reduce,
                                                            kl_strategy=kl_strategy, label_pred_sim=label_pred_sim)
            return kl_loss
        elif mode == "at-cmlm_mse":
            assert at_net_output is not None and nat_output is not None and target_mask is not None and nat_output_hidden is not None
            kl_loss = self.compute_at_cmlm_hidden_loss(model, at_net_output, sample, nat_output=nat_output, nat_output_hidden=nat_output_hidden,
                                                            target_mask=target_mask, reduce=reduce,
                                                            kl_strategy=kl_strategy, label_pred_sim=label_pred_sim)
            return kl_loss
        else:
            raise ValueError("AT criterion mode error!!!")

        return loss, sample_size, logging_output

    def compute_loss(self, model, net_output, sample, reduce=True):
        lprobs = model.get_normalized_probs(net_output, log_probs=True)
        lprobs = lprobs.view(-1, lprobs.size(-1))
        target = model.get_targets(sample, net_output).view(-1, 1)
        loss, nll_loss = label_smoothed_nll_loss(
            lprobs, target, self.eps, ignore_index=self.padding_idx, reduce=reduce,
        )
        return loss, nll_loss

    def compute_at_nat_kl_loss(self, model, net_output, sample, nat_output, label_pred_match, label_pred_sim,
                               reduce=True, kl_strategy="uniform"):
        #TODO
        lprobs = model.get_normalized_probs(net_output, log_probs=True)
        target = model.get_targets(sample, net_output)
        kl_loss = at_nat_kl_loss(
            lprobs, target, nat_output, label_pred_match, label_pred_sim, ignore_index=self.padding_idx, reduce=reduce, kl_strategy=kl_strategy
        )
        return kl_loss

    def compute_at_cmlm_kl_loss(self, model, at_net_output, sample, nat_output, target_mask, reduce=True,
                                kl_strategy="uniform", label_pred_sim=None):
        #lprobs = model.get_normalized_probs(net_output, log_probs=True)
        lprobs = F.log_softmax(at_net_output["logits"], dim=-1)
        target = sample["target"]
        kl_loss = at_cmlm_kl_loss(
            lprobs, target, nat_output, target_mask, reduce=reduce, kl_strategy=kl_strategy, label_pred_sim=label_pred_sim
        )
        return kl_loss

    def compute_at_cmlm_hidden_loss(self, model, at_net_output, sample, nat_output, nat_output_hidden, target_mask, reduce=True,
                                kl_strategy="uniform", label_pred_sim=None):
        #lprobs = model.get_normalized_probs(net_output, log_probs=True)
        #at_output_hiddens = net_output[1]["inner_states"]
        target = sample["target"]
        lprobs = F.log_softmax(at_net_output["logits"], dim=-1)
        at_output_hiddens = at_net_output['output_hiddens']
        mse_loss = at_cmlm_hidden_mse_loss(
            lprobs, target, nat_output, at_output_hiddens, nat_output_hidden, target_mask, reduce=reduce, kl_strategy=kl_strategy, label_pred_sim=label_pred_sim
        )
        return mse_loss

    @staticmethod
    def reduce_metrics(logging_outputs) -> None:
        """Aggregate logging outputs from data parallel training."""
        loss_sum = sum(log.get('loss', 0) for log in logging_outputs)
        nll_loss_sum = sum(log.get('nll_loss', 0) for log in logging_outputs)
        ntokens = sum(log.get('ntokens', 0) for log in logging_outputs)
        sample_size = sum(log.get('sample_size', 0) for log in logging_outputs)

        metrics.log_scalar('loss', loss_sum / sample_size / math.log(2), sample_size, round=3)
        metrics.log_scalar('nll_loss', nll_loss_sum / ntokens / math.log(2), ntokens, round=3)
        metrics.log_derived('ppl', lambda meters: utils.get_perplexity(meters['nll_loss'].avg))

    @staticmethod
    def logging_outputs_can_be_summed() -> bool:
        """
        Whether the logging outputs returned by `forward` can be summed
        across workers prior to calling `reduce_metrics`. Setting this
        to True will improves distributed training speed.
        """
        return True
