import math
import torch
import re

from torch.nn import CrossEntropyLoss
from fairseq.criterions import register_criterion
from fairseq.criterions.label_smoothed_cross_entropy import (
    LabelSmoothedCrossEntropyCriterion,
    label_smoothed_nll_loss
)
from fairseq import logging, metrics, utils

@register_criterion('label_smoothed_cross_entropy_w_lang_classifier')
class LabelSmoothedCrossEntropyWLangClassiferCriterion(LabelSmoothedCrossEntropyCriterion):
    def __init__(self, task, sentence_avg, label_smoothing):
        super().__init__(task, sentence_avg, label_smoothing)
        self.sentence_avg = sentence_avg
        self.eps = label_smoothing

        self.lang_loss = CrossEntropyLoss(reduction='sum')
        self.summ_langs = ["[{}]".format(lang) for lang in task.args.summ_langs]
        self.langs = ["[{}]".format(lang) for lang in task.args.summ_langs + task.args.unsupervised_langs]

    @staticmethod
    def add_args(parser):
        """Add criterion-specific arguments to the parser."""
        # fmt: off
        parser.add_argument('--label-smoothing', default=0., type=float, metavar='D',
                            help='epsilon for label smoothing, 0 means no label smoothing')
        # fmt: on

    def forward(self, model, sample, reduce=True):
        """Compute the loss for the given sample.

        Returns a tuple with three elements:
        1) the loss
        2) the sample size, which is used as the denominator for the gradient
        3) logging outputs to display while training
        """
        net_output = model(**sample['net_input'])
        loss, nll_loss, lang_loss, summ_idxs, lang_ncorrect, lang_n = self.compute_loss(model, net_output, sample, reduce=reduce)
        summ_ntokens = 0 
        for i in range(sample['target'].size(0)):
            if i in summ_idxs:
                tokens = sample['target'][i]
                summ_ntokens += torch.nonzero(~tokens.eq(self.padding_idx), as_tuple=False).numel()
        sample_size = sample['target'].size(0) if self.sentence_avg else sample['ntokens']
        logging_output = {
            'loss': loss.data,
            'nll_loss': nll_loss.data,
            'lang_loss': lang_loss.data,
            'ntokens': sample['ntokens'],
            'summ_ntokens': max(summ_ntokens, 1),
            'nsentences': sample['target'].size(0),
            'sample_size': sample_size
        }
        for k, v in lang_ncorrect.items():
            logging_output["lang_ncorrect_{}".format(k)] = v
            logging_output["lang_n_{}".format(k)] = lang_n[k]
        return loss, sample_size, logging_output

    def compute_loss(self, model, net_output, sample, reduce=True):
        lprobs = model.get_normalized_probs(net_output, log_probs=True)
        target = model.get_targets(sample, net_output)  # sample["target"]

        src_tokens = sample['net_input']['src_tokens']
        lang_tokens = src_tokens[:, -1].view(-1)

        # selected instances with summ langs
        candidate_langs = self.summ_langs if model.training else self.langs
        candidate_lang_ids = [model.encoder.dictionary.index(lang) for lang in candidate_langs]
        selected_idx = [i for i in range(lang_tokens.size(0)) if lang_tokens[i] in candidate_lang_ids]
        if len(selected_idx) != 0:
            selected_idx = torch.tensor(selected_idx, dtype=torch.long).to(device=lprobs.device)
            selected_lprobs = torch.index_select(lprobs, dim=0, index=selected_idx)
            selected_target = torch.index_select(target, dim=0, index=selected_idx)
            selected_lprobs = selected_lprobs.view(-1, lprobs.size(-1))
            selected_target = selected_target.view(-1, 1)

            loss, nll_loss = label_smoothed_nll_loss(
                selected_lprobs, selected_target, self.eps, ignore_index=self.padding_idx, reduce=reduce,
            )
        else:
            loss = torch.tensor(0.0).to(device=src_tokens.device)
            nll_loss = torch.tensor(0.0).to(device=src_tokens.device)

        # create lang labels
        lang_labels = torch.zeros(
            src_tokens.size(0), dtype=torch.long
        ).to(src_tokens)
        lang_ids = [model.encoder.dictionary.index(lang) for lang in self.langs]
        lang_indices = dict()
        for (li, lang_id) in enumerate(lang_ids):
            if li != 0:
                lang_labels = torch.where(
                    torch.eq(lang_tokens, lang_id),
                    torch.tensor(li, dtype=torch.long).to(lang_labels),
                    lang_labels
                )
            lang_indices[self.langs[li]] = torch.nonzero(torch.eq(lang_tokens, lang_id)).squeeze(-1)
        lang_classifier_outs = net_output[1]['lang_cls_out']
        lang_loss = self.lang_loss(
            lang_classifier_outs.view(-1, len(self.langs)), 
            lang_labels
        )
        lang_preds = lang_classifier_outs.argmax(dim=-1)
        correct = (lang_preds == lang_labels)
        lang_ncorrect = dict()
        lang_n = dict()
        for (li, lang) in enumerate(self.langs):
            lang_ncorrect[lang] = torch.index_select(correct, dim=0, index=lang_indices[lang]).sum()
            lang_n[lang] = torch.index_select(correct, dim=0, index=lang_indices[lang]).size(0)
        loss = loss + lang_loss
        return loss, nll_loss, lang_loss, selected_idx, lang_ncorrect, lang_n

    def reduce_metrics(logging_outputs) -> None:
        """Aggregate logging outputs from data parallel training."""
        loss_sum = sum(log.get('loss', 0) for log in logging_outputs)
        nll_loss_sum = sum(log.get('nll_loss', 0) for log in logging_outputs)
        lang_loss_sum = sum(log.get('lang_loss', 0) for log in logging_outputs)
        ntokens = sum(log.get('ntokens', 0) for log in logging_outputs)
        summ_ntokens = sum(log.get('summ_ntokens', 0) for log in logging_outputs)
        sample_size = sum(log.get('sample_size', 0) for log in logging_outputs)
        nsentences = sum(log.get('nsentences', 0) for log in logging_outputs)

        metrics.log_scalar('loss', loss_sum / sample_size / math.log(2), sample_size, round=3)
        metrics.log_scalar('nll_loss', nll_loss_sum / summ_ntokens / math.log(2), ntokens, round=3)
        metrics.log_scalar('lang_loss', lang_loss_sum / nsentences / math.log(2), nsentences, round=3)
        metrics.log_derived('ppl', lambda meters: utils.get_perplexity(meters['nll_loss'].avg))

        if len(logging_outputs) > 0:
            for key in logging_outputs[0]:
                if key.startswith("lang_ncorrect_"):
                    lang = key[len("lang_ncorrect_"):]
                    n_key = "lang_n_{}".format(lang)
                    lang_ncorrect = sum(log.get(key, 0) for log in logging_outputs)
                    lang_n = sum(log.get(n_key, 0) for log in logging_outputs)
                    # showed_lang = re.sub("[\[\]]", "_", lang)
                    showed_lang = lang[1:lang.index("_")]
                    metrics.log_scalar('lang_acc_{}'.format(showed_lang), 100.0 * lang_ncorrect / (lang_n + 1e-3), lang_n, round=1)