# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import math
import torch
import numpy as np

from fairseq import metrics, utils
from fairseq.criterions import FairseqCriterion, register_criterion
from fairseq.criterions.label_smoothed_cross_entropy import (
    LabelSmoothedCrossEntropyCriterion,
    label_smoothed_nll_loss
)

@register_criterion('multi_loss_label_smoothed_cross_entropy')
class MultiLossLabelSmoothedCrossEntropyCriterion(LabelSmoothedCrossEntropyCriterion):
    def __init__(self, task, sentence_avg, label_smoothing: list, task_tokens: list, 
                task_loss_weights: list, task_token_position: int, 
                loss_k: None, loss_t0: None):
        super().__init__(task, sentence_avg, label_smoothing[0])
        self.sentence_avg = sentence_avg
        self.eps = label_smoothing
        self.task_tokens = task_tokens
        self.task_loss_weights = task_loss_weights
        self.task_token_position = task_token_position
        self.loss_k = loss_k
        self.loss_t0 = loss_t0

        assert len(label_smoothing) == len(task_tokens), \
            "label_smoothing and task_tokens must have same length"
        if self.loss_k is None:
            assert len(task_loss_weights) == len(task_tokens), \
                "task_loss_weights and task_tokens must have same length"
        else:
            assert self.loss_t0 is not None, "loss_k and loss_t0 must be specified simultaneously"

    @staticmethod
    def add_args(parser):
        """Add criterion-specific arguments to the parser."""
        # fmt: off
        parser.add_argument('--label-smoothing', default=[0.], type=float, nargs='+',
                            help='epsilon for label smoothing, 0 means no label smoothing')
        parser.add_argument('--task-tokens', default=['summ'], type=str, nargs='+')
        parser.add_argument('--task-loss-weights', default=[1.], type=float, nargs='+')
        parser.add_argument('--task-token-position', type=int, default=1,
                            help="the index of task token in the target token")
        parser.add_argument('--loss-k', type=float, default=None,
                            help="k for dynamic weight, control decay speed")
        parser.add_argument('--loss-t0', type=int, default=None,
                    help="t0 for dynamic weight, control the starting step of decay")
        # fmt: on

    def forward(self, model, sample, reduce=True, update_num=None):
        """Compute the loss for the given sample.

        Returns a tuple with three elements:
        1) the loss
        2) the sample size, which is used as the denominator for the gradient
        3) logging outputs to display while training
        """
        net_output = model(**sample['net_input'])
        loss, nll_loss, losses, nll_losses, loss_weights = self.compute_loss(
            model, net_output, sample, reduce=reduce, update_num=update_num)
        sample_size = sample['target'].size(0) if self.sentence_avg else sample['ntokens']
        logging_output = {
            'loss': loss.data,
            'nll_loss': nll_loss.data,
            'ntokens': sample['ntokens'],
            'nsentences': sample['target'].size(0),
            'sample_size': sample_size
        }
        for (task_token, l, nl, w) in zip(self.task_tokens, losses, nll_losses, loss_weights):
            logging_output["{}_loss".format(task_token)] = l.data
            logging_output["{}_nll_loss".format(task_token)] = nl.data
            logging_output["{}_w".format(task_token)] = w

        return loss, sample_size, logging_output

    def compute_loss(self, model, net_output, sample, reduce=True, update_num=None):
        raw_lprobs = model.get_normalized_probs(net_output, log_probs=True)
        raw_target = model.get_targets(sample, net_output)
        losses, nll_losses = [], []
        loss_sum, nll_loss_sum = None, None
        if self.loss_k is None or update_num is not None:
            loss_weights = self.task_loss_weights
        else:
            if update_num is None:
                loss_weights = [1, 1] 
            else:
                interval = max(0, update_num - self.loss_t0)
                weight2 = 1 / (1 + np.exp(-self.loss_k * interval))
                weight1 = 1.0 - weight2
                loss_weights = [weight1, weight2]
        for (eps, task_token, weight) in zip(self.eps, self.task_tokens, loss_weights):
            sample_task_tokens = raw_target[:, self.task_token_position]
            task_token_idx = self.task.src_dict.index(
                '[{}]'.format(task_token))
            sample_idxs = (sample_task_tokens == task_token_idx)
            sample_lprobs = raw_lprobs[sample_idxs].view(-1, raw_lprobs.size(-1))
            sample_target = raw_target[sample_idxs].view(-1, 1)
            sample_loss, sample_nll_loss = label_smoothed_nll_loss(
                sample_lprobs, sample_target, eps, ignore_index=self.padding_idx, reduce=reduce
            )
            losses.append(sample_loss)
            nll_losses.append(sample_nll_loss)
            if loss_sum is None:
                loss_sum = sample_loss * weight
                nll_loss_sum = sample_nll_loss * weight 
            else:
                loss_sum += sample_loss * weight
                nll_loss_sum += sample_nll_loss * weight 
        # loss = torch.sum(torch.tensor(losses)) / np.sum(self.task_loss_weights)
        # nll_loss = torch.sum(torch.tensor(nll_losses)) / np.sum(self.task_loss_weights)
        return loss_sum, nll_loss_sum, losses, nll_losses, loss_weights

    def reduce_metrics(self, logging_outputs):
        """Aggregate logging outputs from data parallel training."""
        loss_sum = sum(log.get('loss', 0) for log in logging_outputs)
        nll_loss_sum = sum(log.get('nll_loss', 0) for log in logging_outputs)
        ntokens = sum(log.get('ntokens', 0) for log in logging_outputs)
        sample_size = sum(log.get('sample_size', 0) for log in logging_outputs)

        metrics.log_scalar('loss', loss_sum / sample_size / math.log(2), sample_size, round=3)
        metrics.log_scalar('nll_loss', nll_loss_sum / ntokens / math.log(2), ntokens, round=3)
        metrics.log_derived('ppl', lambda meters: utils.get_perplexity(meters['nll_loss'].avg))

        for task_token in self.task_tokens:
            l_key = "{}_loss".format(task_token)
            nl_key = "{}_nll_loss".format(task_token)
            w_key = "{}_w".format(task_token)
            l_sum = sum(log.get(l_key, 0) for log in logging_outputs)
            nl_sum = sum(log.get(nl_key, 0) for log in logging_outputs)
            # w = sum(log.get(w_key, 0) for log in logging_outputs) / len(logging_outputs)
            w = logging_outputs[0].get(w_key, 0)

            metrics.log_scalar(l_key, l_sum / sample_size / math.log(2), sample_size, round=3)
            metrics.log_scalar(nl_key, nl_sum / ntokens / math.log(2), ntokens, round=3)
            metrics.log_scalar(w_key, w, 0, round=2)

    @staticmethod
    def logging_outputs_can_be_summed() -> bool:
        """
        Whether the logging outputs returned by `forward` can be summed
        across workers prior to calling `reduce_metrics`. Setting this
        to True will improves distributed training speed.
        """
        return True
