# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import math
import torch
import torch.nn.functional as F

from fairseq import utils
from fairseq.sequence_generator import SequenceGenerator

from . import FairseqCriterion, register_criterion


@register_criterion('f1')
class F1Criterion(FairseqCriterion):

    def __init__(self, args, task):
        super().__init__(args, task)
        self.seq_gen = SequenceGenerator(
            task.target_dictionary, beam_size=6, diverse_beam_groups=2, diverse_beam_strength=0.,
        max_len_b=10)  # todo: add beam size to args, think about other args
        self._pad_id = getattr(args, 'pad_id', 1)
        self._f_beta = getattr(args, 'f_beta', 1.0)

    def forward(self, model, sample, reduce=False):
        """Compute the loss for the given sample.

        Returns a tuple with three elements:
        1) the loss
        2) the sample size, which is used as the denominator for the gradient
        3) logging outputs to display while training
        """
        assert 'net_input' in sample and 'target' in sample
        gen_output = self.seq_gen.generate([model], sample)

        tokens_out = torch.stack([el[0]['tokens'] for el in gen_output])
        target = sample['target']
        assert tokens_out.size(0) == target.size(0)
        batch_f1_scores = self._measure_f_score(target, tokens_out)
        if reduce:
             batch_f1_scores = batch_f1_scores.mean()

        logging_output = {
            'f_score': utils.item(batch_f1_scores) if reduce else batch_f1_scores,
            'ntokens': sample['ntokens'],
            'nsentences': sample['nsentences'],
            'sample_size': sample['nsentences'],
        }
        return batch_f1_scores, sample['nsentences'], logging_output


    @staticmethod
    def aggregate_logging_outputs(logging_outputs):
        """Aggregate logging outputs from data parallel training."""
        f_score_agg = torch.cat(list(log.get('f_score', 0) for log in logging_outputs))
        sample_size = sum(log.get('sample_size', 0) for log in logging_outputs)
        agg_output = {
            'f_score': f_score_agg,
            'bsz': sample_size
        }
        return agg_output

    def _measure_f_score(self, target, tokens_out):
        """calculate f1 on one batch"""
        f1 = torch.FloatTensor(list(map(self._measure_one, zip(target, tokens_out))))
        return f1

    def _measure_one(self, in_tuple):
        """Measures f1 on one case"""
        tgt, out = in_tuple
        hits, recall_denom, prec_denom = self._get_hits_and_denoms(out, tgt)

        prec = hits/prec_denom.to(torch.float)
        recall = hits/recall_denom.to(torch.float)
        f_score_beta_up = (1.0+self._f_beta**2)*(prec*recall)
        f_score_beta_down = (self._f_beta**2*prec) + recall
        f_score = f_score_beta_up/(f_score_beta_down + 1e-5)
        return f_score

    def _get_hits_and_denoms(self, out, tgt):
        """Calculate how many times the prediction is correct,
         and return lengths of unpadded output and target"""
        pad_tgts = tgt.eq(self._pad_id).sum()
        pad_outs = out.eq(self._pad_id).sum()
        if tgt.size(0) - pad_tgts > out.size(0) - pad_outs:
            # target has more sensible tokens
            tgt_narr = torch.narrow(tgt, dim=0, start=0, length=out.size(0) - pad_outs)
            hits = tgt_narr.eq(out).sum()
        else:
            # `tgt` is smaller, either pad `tgt`, or cut `out`
            out_narr = torch.narrow(out, dim=0, start=0, length=tgt.size(0) - pad_tgts)
            # correct `tgt` to remove pads
            tgt_narr = torch.narrow(tgt, dim=0, start=0, length=tgt.size(0) - pad_tgts)

            hits = out_narr.eq(tgt_narr).sum()
        recall_denom = tgt.size(0) - pad_tgts
        prec_denom = out.size(0) - pad_outs
        return hits.to(torch.float), recall_denom, prec_denom

