# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import math
import torch

from fairseq import metrics, utils
from fairseq.criterions import register_criterion 
from fairseq.criterions.label_smoothed_cross_entropy import LabelSmoothedCrossEntropyCriterion

def getMaskedPos(source, target, raw_target):
    """
    1 means the position is replaced by "<mask>"
    """
    MASK = 250026
    mask_pos = torch.nonzero(source == MASK, as_tuple=True)[0]
    mask_start_token = source[mask_pos - 1]
    mask_end_token = source[mask_pos + 1]
    mask = [0 for _ in range(target.size(0))]
    pt, mask_i = 0, 0
    while pt < target.size(0) and mask_i < mask_pos.size(0):
        if target[pt] == mask_start_token[mask_i]:
            # next token is the first masked token
            pt += 1
            while target[pt] != mask_end_token[mask_i]:
                mask[pt] = 1
                pt += 1
            mask_i += 1
        else:
            pt += 1
    mask = [0 for _ in range(raw_target.size(0) - target.size(0))] + mask
    mask = torch.tensor(mask).to(source)
    mask = mask.bool()
    # print("source: ", source)
    # print("masked_raw_target: ", raw_target * mask)
    # print("raw_target: ", raw_target)
    return mask


def label_smoothed_nll_loss_with_mask(
    lprobs, target, epsilon, mask,
    ignore_index=None, reduce=True
):
    if target.dim() == lprobs.dim() - 1:
        target = target.unsqueeze(-1)
    nll_loss = -lprobs.gather(dim=-1, index=target)
    smooth_loss = -lprobs.sum(dim=-1, keepdim=True)
    if ignore_index is not None:
        pad_mask = target.eq(ignore_index)
        nll_loss.masked_fill_(pad_mask, 0.)
        smooth_loss.masked_fill_(pad_mask, 0.)
    else:
        nll_loss = nll_loss.squeeze(-1)
        smooth_loss = smooth_loss.squeeze(-1)
    # The expanded size of the tensor (1) must match the existing size (69) at non-singleton dimension 1.  Target sizes: [2030, 1].  Tensor sizes: [29, 69]
    masked_nll_loss = nll_loss * mask.to(nll_loss)
    masked_smooth_loss = smooth_loss * mask.to(smooth_loss)
    if reduce:
        nll_loss = nll_loss.sum()
        smooth_loss = smooth_loss.sum()
        masked_nll_loss = masked_nll_loss.sum()
        masked_smooth_loss = masked_smooth_loss.sum()
    eps_i = epsilon / lprobs.size(-1)
    loss = (1. - epsilon) * nll_loss + eps_i * smooth_loss
    masked_loss = (1. - epsilon) * masked_nll_loss + eps_i * masked_smooth_loss
    return loss, nll_loss, masked_loss, masked_nll_loss


@register_criterion('masked_label_smoothed_cross_entropy')
class MaskedLabelSmoothedCrossEntropyCriterion(LabelSmoothedCrossEntropyCriterion):

    def __init__(self, task, sentence_avg, label_smoothing):
        super().__init__(task, sentence_avg, label_smoothing)

    def forward(self, model, sample, reduce=True):
        """Compute the loss for the given sample.

        Returns a tuple with three elements:
        1) the loss
        2) the sample size, which is used as the denominator for the gradient
        3) logging outputs to display while training
        """
        net_output = model(**sample['net_input'])
        loss, nll_loss, masked_loss, masked_nll_loss, num_masked_token = self.compute_loss(model, net_output, sample, reduce=reduce)
        sample_size = sample['target'].size(0) if self.sentence_avg else sample['ntokens']
        logging_output = {
            'loss': loss.data,
            'nll_loss': nll_loss.data,
            'masked_loss': masked_loss.data,
            'masked_nll_loss': masked_nll_loss.data,
            'ntokens': sample['ntokens'],
            'nsentences': sample['target'].size(0),
            'sample_size': sample_size,
            'n_masked_tokens': num_masked_token
        }
        return loss, sample_size, logging_output

    def compute_loss(self, model, net_output, sample, reduce=True):
        lprobs = model.get_normalized_probs(net_output, log_probs=True)
        lprobs = lprobs.view(-1, lprobs.size(-1))
        target = model.get_targets(sample, net_output).view(-1, 1)  # sample["target"]
        # compute mask
        sources = sample['net_input']['src_tokens']
        targets = model.get_targets(sample, net_output)
        masks = []
        for src, tgt in zip(sources, targets):
            src_start = (src == 0).nonzero(as_tuple=True)[0]
            tgt_start = (tgt == 0).nonzero(as_tuple=True)[0]
            selected_src = src[src_start:]
            selected_tgt = tgt[tgt_start:]
            mask_pos = getMaskedPos(selected_src, selected_tgt, tgt)
            masks.append(mask_pos)
        # mask = torch.stack(masks, axis=0).view(-1, 1)
        mask = torch.cat(masks).view(-1, 1)
        num_masked_token = mask.view(-1).sum()

        # print("batch_source: ", sources.view(-1)[:100])
        # print("batch_target: ", target.view(-1)[:100])
        # print("batch_raw_target: ", (target * mask).view(-1)[:100])

        loss, nll_loss, masked_loss, masked_nll_loss = label_smoothed_nll_loss_with_mask(
            lprobs, target, self.eps, mask,
            ignore_index=self.padding_idx, reduce=reduce,
        )
        return loss, nll_loss, masked_loss, masked_nll_loss, num_masked_token

    @staticmethod
    def reduce_metrics(logging_outputs) -> None:
        """Aggregate logging outputs from data parallel training."""
        loss_sum = sum(log.get('loss', 0) for log in logging_outputs)
        nll_loss_sum = sum(log.get('nll_loss', 0) for log in logging_outputs)
        masked_loss_sum = sum(log.get('masked_loss', 0) for log in logging_outputs)
        masked_nll_loss_sum = sum(log.get('masked_nll_loss', 0) for log in logging_outputs)
        ntokens = sum(log.get('ntokens', 0) for log in logging_outputs)
        n_masked_tokens = sum(log.get('n_masked_tokens', 0) for log in logging_outputs)
        sample_size = sum(log.get('sample_size', 0) for log in logging_outputs)

        metrics.log_scalar('loss', loss_sum / sample_size / math.log(2), sample_size, round=3)
        metrics.log_scalar('nll_loss', nll_loss_sum / ntokens / math.log(2), ntokens, round=3)
        metrics.log_scalar('masked_loss', masked_loss_sum / n_masked_tokens / math.log(2), n_masked_tokens, round=3)
        metrics.log_scalar('masked_nll_loss', masked_nll_loss_sum / n_masked_tokens / math.log(2), n_masked_tokens, round=3)
        metrics.log_derived('ppl', lambda meters: utils.get_perplexity(meters['nll_loss'].avg))
        metrics.log_derived('masked_ppl', lambda meters: utils.get_perplexity(meters['masked_nll_loss'].avg))
