import math

from fairseq import metrics, utils
from fairseq.criterions import FairseqCriterion, register_criterion
from fairseq.criterions.my_at_criterion import MyATCriterion
from fairseq.criterions.cross_entropy import CrossEntropyCriterion
from fairseq.criterions.label_smoothed_cross_entropy import LabelSmoothedCrossEntropyCriterion
from fairseq.criterions.nat_loss import LabelSmoothedDualImitationCriterion


@register_criterion('my_criterion')
class MyCriterion(FairseqCriterion):

    def __init__(self, task, sentence_avg, label_smoothing, mode):
        super().__init__(task)
        self.sentence_avg = sentence_avg
        self.label_smoothing = label_smoothing
        self.mode = mode
        #self.at_criterion = CrossEntropyCriterion(task, sentence_avg)
        #self.at_criterion = LabelSmoothedCrossEntropyCriterion(task, sentence_avg, label_smoothing)
        self.at_criterion = MyATCriterion(task, sentence_avg, label_smoothing)
        self.nat_criterion = LabelSmoothedDualImitationCriterion(task, label_smoothing)

    @staticmethod
    def add_args(parser):
        """Add criterion-specific arguments to the parser."""
        # fmt: off
        CrossEntropyCriterion.add_args(parser)
        #LabelSmoothedCrossEntropyCriterion.add_args(parser)
        LabelSmoothedDualImitationCriterion.add_args(parser)
        # fmt: on

    @property
    def criterion(self):
        if self.mode == 'at' or self.mode == 'joint-at' or self.mode == 'joint-at-only'\
                or self.mode == 'joint2-at' or self.mode == 'joint2-at-only':
            return self.at_criterion
        elif self.mode == 'nat' or self.mode == 'joint-nat' or self.mode == 'joint-nat-only'\
                or self.mode == 'joint2-at' or self.mode == 'joint2-nat-only':
            return self.nat_criterion
        else:
            raise ValueError("Mode Error!!!")

    def forward(self, model, sample, reduce=True, mode=None, nat_reduction='sum',
                nat_output=None, label_pred_match=None, kl_strategy='uniform', kl_loss_coef=1.0, label_pred_sim=None, target_mask=None):
        if mode is None:
            loss, sample_size, logging_output = self.criterion(model, sample, reduce)
        elif mode == 'at':
            loss, sample_size, logging_output = self.at_criterion(model.at_transformer, sample, reduce)
        elif mode == 'nat':
            loss, sample_size, logging_output = self.nat_criterion(model.nat_transformer, sample, reduce, nat_reduction)
        elif mode == 'at-nat_smoothed':
            assert nat_output is not None and label_pred_match is not None
            loss, sample_size, logging_output = self.at_criterion(model.at_transformer, sample, reduce, mode=mode,
                                                                  nat_output=nat_output, label_pred_match=label_pred_match,
                                                                  kl_strategy=kl_strategy, kl_loss_coef=kl_loss_coef, label_pred_sim=label_pred_sim)
        else:
            raise ValueError("Mode Error!!!")

        return loss, sample_size, logging_output

    @staticmethod
    def reduce_metrics(logging_outputs) -> None:
        """Aggregate logging outputs from data parallel training."""
        loss_sum = sum(log.get("loss", 0) for log in logging_outputs)
        ntokens = sum(log.get("ntokens", 0) for log in logging_outputs)
        sample_size = sum(log.get("sample_size", 0) for log in logging_outputs)

        metrics.log_scalar(
            "loss", loss_sum / sample_size / math.log(2), sample_size, round=3
        )
        if sample_size != ntokens:
            metrics.log_scalar(
                "nll_loss", loss_sum / ntokens / math.log(2), ntokens, round=3
            )
            metrics.log_derived(
                "ppl", lambda meters: utils.get_perplexity(meters["nll_loss"].avg)
            )
        else:
            metrics.log_derived(
                "ppl", lambda meters: utils.get_perplexity(meters["loss"].avg)
            )

    @staticmethod
    def logging_outputs_can_be_summed() -> bool:
        """
        Whether the logging outputs returned by `forward` can be summed
        across workers prior to calling `reduce_metrics`. Setting this
        to True will improves distributed training speed.
        """
        return True