# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import math
import logging

import torch
import torch.nn.functional as F
from torch import nn
from fairseq import metrics, modules, utils
from fairseq.criterions import FairseqCriterion, register_criterion


logger = logging.getLogger(__name__)


@register_criterion("sparse_masked_lm")
class SparseMaskedLmLoss(FairseqCriterion):
    """
    Implementation for the loss used in sparse masked language model (MLM) training.
    """

    def __init__(self, task, sparsity_weight, one4all, diagonal_weight, lang2group, kd_weight, kd_temperature, cos_weight, tpu=False):
        super().__init__(task)
        self.sparsity_weight = sparsity_weight
        self.one4all = eval(one4all) if one4all is not None else [0.]
        self.diagonal_weight = diagonal_weight
        self.kd_weight = kd_weight
        self.kd_temperature = kd_temperature
        self.cos_weight = cos_weight
        self.tpu = tpu

        if lang2group is not None:
            lang2group = eval(lang2group)
            assert len(set(lang2group.keys())) == len(lang2group.keys()), "One language should belong to only one group"
            assert len(lang2group.keys()) == len(task.lang2id.keys()), "Should specify group for each languages"
            group2id = {}
            for id, group in enumerate(sorted(set(lang2group.values()))):
                group2id[group] = id
            logger.info("Group to id mapping: {}".format(group2id))
            lid2gid = {}
            for lang, group in lang2group.items():
                lid2gid[task.lang2id[lang]] = group2id[group]
            logger.info("Language id to group id mapping: {}".format(lid2gid))
            self.group = torch.tensor([lid2gid[lid] for lid in sorted(task.lang2id.values())])
        else:
            self.group = None

    def _group_mask(self, lang_id):
        """
        Args:
            lang_id (LongTensor): language id of shape `(batch,)`.

        Returns:
            LongTensor: a tensor of shape `(batch, batch)` with the value 0 or 1.
        """
        group_id = F.embedding(lang_id, self.group.to(lang_id.device))
        group_mask = []
        for i in range(group_id.numel()):
            # eliminate penalty terms that are in the same group
            group_mask.append(group_id != group_id[i])
        group_mask = torch.stack(group_mask, dim=0).long()
        return group_mask

    @staticmethod
    def add_args(parser):
        """Add criterion-specific arguments to the parser."""
        # fmt: off
        parser.add_argument(
            "--sparsity-weight",
            type=float,
            metavar="D",
            default=0.,
            help="the weight of the sparsity penalty term",
        )
        parser.add_argument(
            "--one4all",
            type=str,
            metavar="LIST",
            default=None,
            help="a stringify list that specifies the sparsity ratios we would like to train on, e.g., \"[0.1,]\"",
        )
        parser.add_argument(
            "--diagonal-weight",
            type=float,
            metavar="D",
            default=0.,
            help="the weight of the diagonal penalty term",
        )
        parser.add_argument(
            "--lang2group",
            type=str,
            metavar="DICT",
            default=None,
            help="a stringify dictionary that maps a language to a its family, e.g., \"{'gu': 'Indo-European'}\"",
        )
        parser.add_argument(
            "--kd-weight",
            type=float,
            metavar="D",
            default=0.,
            help="the weight of the knowledge distillation term",
        )
        parser.add_argument(
            "--kd-temperature",
            type=float,
            metavar="D",
            default=1.,
            help="the temperature of the knowledge distillation term",
        )
        parser.add_argument(
            "--cos-weight",
            type=float,
            metavar="D",
            default=0.,
            help="the weight of the cosine embedding loss term",
        )
        # fmt: on

    def sparsity_samples(self, bsz):
        if len(self.one4all) == 1:
            # [target]
            return torch.ones(bsz) * self.one4all[0]
        elif len(self.one4all) == 2:
            # TODO: fixed sparsity samples for evaluation
            def round(tensor, decimal=0):
                return torch.round(tensor * 10 ** decimal) / (10 ** decimal)
            # [lower_bound, upper_bound]
            samples = round(torch.FloatTensor(bsz).uniform_(self.one4all[0], self.one4all[1]), decimal=1)
            # TODO: sandwich rule (for gradient accumulation and distributed training)
            return samples
        else:
            NotImplementedError("Unable to parse --one4all with more than 2 elements")

    def forward(self, model, sample, reduce=True):
        """Compute the loss for the given sample.

        Returns a tuple with three elements:
        1) the loss
        2) the sample size, which is used as the denominator for the gradient
        3) logging outputs to display while training
        """
        masked_tokens = sample["target"].ne(self.padding_idx)
        sample_size = masked_tokens.int().sum()

        # Rare: when all tokens are masked, project all tokens.
        # We use torch.where to avoid device-to-host transfers,
        # except on CPU where torch.where is not well supported
        # (see github.com/pytorch/pytorch/issues/26247).
        if self.tpu:
            masked_tokens = None  # always project all tokens on TPU
        elif masked_tokens.device == torch.device("cpu"):
            if not masked_tokens.any():
                masked_tokens = None
        else:
            masked_tokens = torch.where(
                masked_tokens.any(),
                masked_tokens,
                masked_tokens.new([True]),
            )

        bsz = sample["net_input"]["src_tokens"].size(0)
        target_sparsity = self.sparsity_samples(bsz).to(masked_tokens.device)
        logits, extra = model(**sample["net_input"], masked_tokens=masked_tokens, lang_id=sample["lang_id"], target_sparsity=target_sparsity, return_all_hiddens=self.cos_weight > 0 and self.task.args.teacher_file is not None)
        targets = model.get_targets(sample, [logits])
        if masked_tokens is not None:
            targets = targets[masked_tokens]

        nll_loss = modules.cross_entropy(
            logits.view(-1, logits.size(-1)),
            targets.view(-1),
            reduction="sum",
            ignore_index=self.padding_idx,
        )
        logging_output = {
            "nll_loss": nll_loss if self.tpu else nll_loss.data,
            "ntokens": sample["ntokens"],
            "nsentences": sample["nsentences"],
            "sample_size": sample_size,
        }

        loss = nll_loss

        if self.sparsity_weight > 0:
            sparsity_names = ["rank_sparsity", "head_sparsity", "hidden_sparsity"]
            # we multiply 64 because 1 head contains 64 neurons
            # 4 because one head relates to 4 weights, q, k, v and o
            # 2 because one neuron relates to 2 weights, w1 and w2
            sparsity_weights = [1, 64 * 4, 2]
            active_sum = torch.zeros(bsz, device=nll_loss.device)
            overall_sum = torch.zeros(bsz, device=nll_loss.device)
            for name, weight in zip(sparsity_names, sparsity_weights):
                sparsity = extra.get(name, None)
                if sparsity is not None:
                    # cast to FP32 to avoid overflow
                    active_sum += torch.sum(sparsity.float().view(sparsity.size(0), -1), dim=-1) * weight
                    overall_sum += float(sparsity.size(1) * sparsity.size(2)) * weight
            penalty = torch.abs(active_sum / overall_sum)
            # we scale up the penalty term because the loss will later be normalized
            # we also have penalty for each language as long as it does not reach the targeted value
            loss = loss + (self.sparsity_weight * sample_size * (torch.abs(penalty - target_sparsity).mean())).type_as(nll_loss)
            # loss = loss + (self.sparsity_weight * sample_size * ((penalty - target_sparsity) ** 2.).sum()).type_as(nll_loss)
            logging_output["penalty"] = penalty.mean() * sample_size if self.tpu else penalty.mean().data * sample_size
            logging_output["target"] = target_sparsity.mean() * sample_size if self.tpu else target_sparsity.mean().data * sample_size

        if self.diagonal_weight > 0:
            rank_mask, head_masks, hidden_masks = extra["rank_mask"], extra["head_masks"], extra["hidden_masks"]

            # "Can We Gain More from Orthogonality Regularizations in Training Deep CNNs?" (Bansal et al., 2018)
            def diag_pen(mask):
                mask = mask.transpose(0, 1).float()  # b x l x h -> l x b x h
                dotprod = torch.matmul(mask, mask.transpose(1, 2))
                dotprod = dotprod.div(float(mask.size(-1)))
                # we have unnecessary sparsity penalty on the diagonal entries, so we exclude the diagonal entries
                select = torch.ones(dotprod.size(), dtype=dotprod.dtype, device=dotprod.device) - \
                         torch.eye(dotprod.size(-1), dtype=dotprod.dtype, device=dotprod.device).unsqueeze(0)
                # if group information is provided, then mask out entries between two languages within the same group
                if self.group is not None:
                    lang_id = sample["lang_id"]
                    select *= self._group_mask(lang_id).type_as(select)
                pen = dotprod[select.bool()]  # to give a correct mean
                pen = torch.abs(pen)
                pen = pen.mean().type_as(nll_loss)
                pen = torch.zeros_like(pen) if torch.isnan(pen).any() else pen
                return pen

            # we impose the diagonal constraint to the masks per layer (except for head, otherwise overcomplete)
            diagonal = diag_pen(head_masks.view(head_masks.size(0), -1).unsqueeze(1)) + diag_pen(hidden_masks)
            if rank_mask is not None:
                diagonal += diag_pen(rank_mask)
                diagonal /= 3.
            else:
                diagonal /= 2.
            loss = loss + self.diagonal_weight * sample_size * diagonal
            logging_output["diagonal"] = diagonal * sample_size if self.tpu else diagonal.data * sample_size

        if self.task.args.teacher_file is not None:
            soft_targets, t_extra = model.get_soft_targets(**sample["net_input"], masked_tokens=masked_tokens, return_all_hiddens=self.cos_weight > 0)

            if self.kd_weight > 0:
                # "Distilling the Knowledge in a Neural Network" (Hinton et al., 2014)
                def soft_cross_entropy(predicts, targets):
                    return F.kl_div(F.log_softmax(predicts, dim=-1), F.softmax(targets, dim=-1), reduction='sum')

                kd = soft_cross_entropy(logits.view(-1, logits.size(-1)) / self.kd_temperature,
                                        soft_targets.view(-1, soft_targets.size(-1)) / self.kd_temperature)
                loss = loss + self.kd_weight * self.kd_temperature * self.kd_temperature * kd
                logging_output["kd"] = kd.float() if self.tpu else kd.data.float()  # avoid inf in FP16 with a large batch

            if self.cos_weight > 0:
                # "DistilBERT, a distilled version of BERT: smaller, faster, cheaper and lighter" (Sanh et al., 2019)
                dim = extra['inner_states'][0].size(-1)
                s_embed = extra['inner_states'][0]
                s_embed = torch.masked_select(s_embed, masked_tokens.transpose(0, 1).unsqueeze(-1).expand_as(s_embed))
                s_embed = s_embed.view(-1, dim)
                t_embed = t_extra['inner_states'][0]
                t_embed = torch.masked_select(t_embed, masked_tokens.transpose(0, 1).unsqueeze(-1).expand_as(t_embed))
                t_embed = t_embed.view(-1, dim)
                cos = nn.CosineEmbeddingLoss(reduction='sum')(s_embed, t_embed, s_embed.new(s_embed.size(0)).fill_(1))
                loss = loss + self.cos_weight * cos
                logging_output["cos"] = cos.float() if self.tpu else cos.data.float()

        logging_output["loss"] = loss if self.tpu else loss.data
        return loss, sample_size, logging_output

    @staticmethod
    def reduce_metrics(logging_outputs) -> None:
        """Aggregate logging outputs from data parallel training."""
        loss_sum = sum(log.get("loss", 0) for log in logging_outputs)
        nll_loss_sum = sum(log.get("nll_loss", 0) for log in logging_outputs)
        sample_size = sum(log.get("sample_size", 0) for log in logging_outputs)

        metrics.log_scalar(
            "loss", loss_sum / sample_size / math.log(2), sample_size, round=3
        )
        metrics.log_scalar(
            "nll_loss", nll_loss_sum / sample_size / math.log(2), sample_size, round=3
        )
        metrics.log_derived(
            "ppl", lambda meters: utils.get_perplexity(meters["nll_loss"].avg)
        )

        if sum(["penalty" in log.keys() for log in logging_outputs]):
            penalty_sum = sum(log.get("penalty", 0) for log in logging_outputs)
            target_sum = sum(log.get("target", 0) for log in logging_outputs)
            metrics.log_scalar(
                "penalty", penalty_sum / sample_size, sample_size, round=3
            )
            metrics.log_scalar(
                "target", target_sum / sample_size, sample_size, round=3
            )
        if sum(["diagonal" in log.keys() for log in logging_outputs]):
            diagonal_sum = sum(log.get("diagonal", 0) for log in logging_outputs)
            metrics.log_scalar(
                "diagonal", diagonal_sum / sample_size, sample_size, round=3
            )
        if sum(["kd" in log.keys() for log in logging_outputs]):
            kd_sum = sum(log.get("kd", 0) for log in logging_outputs)
            metrics.log_scalar(
                "kd", kd_sum / sample_size, sample_size, round=3
            )
        if sum(["cos" in log.keys() for log in logging_outputs]):
            cos_sum = sum(log.get("cos", 0) for log in logging_outputs)
            metrics.log_scalar(
                "cos", cos_sum / sample_size, sample_size, round=3
            )

    @staticmethod
    def logging_outputs_can_be_summed() -> bool:
        """
        Whether the logging outputs returned by `forward` can be summed
        across workers prior to calling `reduce_metrics`. Setting this
        to True will improves distributed training speed.
        """
        return True
