import torch
import torch.nn.functional as F

import math
from fairseq import metrics, utils
from fairseq.criterions import register_criterion
from fairseq.criterions.label_smoothed_cross_entropy import (
    LabelSmoothedCrossEntropyCriterion,
    label_smoothed_nll_loss
)

@register_criterion('cross_entropy_and_ranking_document')
class CrossEntropyAndRankingDocument(LabelSmoothedCrossEntropyCriterion):

    def __init__(self, task, ranking_head_name, sentence_avg, 
                label_smoothing, ranking_loss_weight, 
                ranking_loss_margin, ranking_loss_temperature,
                ranking_loss_reduction):
        super().__init__(task, sentence_avg, label_smoothing)
        self.ranking_head_name = ranking_head_name
        self.sentence_avg = sentence_avg
        self.eps = label_smoothing
        self.ranking_loss_weight = ranking_loss_weight
        self.ranking_loss_margin = ranking_loss_margin
        self.ranking_loss_temperature = ranking_loss_temperature
        self.ranking_loss_reduction = ranking_loss_reduction

    @staticmethod
    def add_args(parser):
        """Add criterion-specific arguments to the parser."""
        # fmt: off
        parser.add_argument('--label-smoothing', default=0., type=float, metavar='D',
                            help='epsilon for label smoothing, 0 means no label smoothing')
        parser.add_argument('--ranking-head-name', type=str, default=None, help='use ranking during training')
        parser.add_argument('--ranking-loss-weight', type=int, default=-1, help='ranking loss during training')
        parser.add_argument('--ranking-loss-margin', type=float, default=1.0, help='ranking margin during training')
        parser.add_argument('--ranking-loss-temperature', type=float, default=1.0, help='ranking temperature during training')
        parser.add_argument('--ranking-loss-reduction', type=str, default='mean', 
                            help="ranking loss reduction method, one of ['mean', 'sum']")
        # fmt: on

    def forward(self, model, sample, reduce=True):
        """Compute the loss for the given sample.

        Returns a tuple with three elements:
        1) the loss
        2) the sample size, which is used as the denominator for the gradient
        3) logging outputs to display while training
        """

        assert (
            hasattr(model, 'classification_heads')
            and self.ranking_head_name in model.classification_heads
        ), 'model must provide sentence ranking head for --criterion=cross_entropy_and_ranking'

        weight = self.ranking_loss_weight if self.ranking_loss_weight != -1 else sample['ntokens']
        temperature = self.ranking_loss_temperature

        summ, extra = model(**sample['net_input'], classification_head_name=self.ranking_head_name)
        loss, nll_loss = self.compute_loss(model, tuple([summ]), sample['target'], reduce=reduce)
        rank_loss = self.compute_rank_loss(extra['inner_states'][-1], tuple(extra['candidate_decoder_outs']), 
                        sample['label'], temperature=temperature, weight=weight, reduce=reduce)
        sample_size = sample['target'].size(0) if self.sentence_avg else sample['ntokens']
        loss = loss + rank_loss
        logging_output = {
            'loss': loss.data,
            'rank_loss': rank_loss.data,
            'nll_loss': nll_loss.data,
            'ntokens': sample['ntokens'],
            'nsentences': sample['target'].size(0),
            'sample_size': sample_size,
        }
        return loss, sample_size, logging_output

    def compute_loss(self, model, net_output, target, reduce=True):
        lprobs = model.get_normalized_probs(net_output, log_probs=True)
        lprobs = lprobs.view(-1, lprobs.size(-1))
        target = target.view(-1, 1)  # sample["target"]
        loss, nll_loss = label_smoothed_nll_loss(
            lprobs, target, self.eps, ignore_index=self.padding_idx, reduce=reduce,
        )
        return loss, nll_loss

    def compute_rank_loss(self, anchor, candidate_decoder_outs, target, temperature=1.0, weight=1, reduce=True):
        # lprobs = model.get_normalized_probs(net_output, log_probs=True)
        candidate_decoder_outs = list(candidate_decoder_outs) # [bs, tgt_len, hidden_dim]
        num = len(candidate_decoder_outs)
        tgt_len, bs, hidden_dim = anchor.size() # [tgt_len, bs, hidden_dim]
        anchor = anchor.transpose(1, 0)
        anchor = anchor.unsqueeze(3).repeat(1, 1, 1, num)
        
        candidate_decoder_outs = torch.stack(candidate_decoder_outs, dim=3) # [bs, tgt_len, hidden_dim, number]   

        anchor_square_sum = torch.sum(anchor.to(torch.float32).pow(2), dim=2, keepdim=True)
        anchor_norm = torch.sqrt(anchor_square_sum)
        anchor_isinf = torch.isinf(anchor_norm)
        # print("anchor_inf_number: ", torch.sum(anchor_isinf.to(torch.int64)))
        # print("anchor_square_sum inf values: ", square_sum[anchor_isinf])
        # print("anchor_norm inf values: ", anchor_norm[anchor_isinf])
        # print("anchor.pow(2) inf values: ", anchor.pow(2)[anchor_isinf.repeat(1, 1, hidden_dim, 1)])
        # print()

        candidate_square_sum = torch.sum(
            candidate_decoder_outs.to(torch.float32).pow(2), dim=2, keepdim=True)
        candidate_norm = torch.sqrt(candidate_square_sum)
        candidate_isinf = torch.isinf(candidate_norm)
        # print("candidate_inf_number: ", torch.sum(candidate_isinf.to(torch.int64)))
        # print("candidate_square_sum inf values: ", candidate_square_sum[candidate_isinf])
        # print("candidate_square_sum inf values: ", candidate_square_sum[candidate_isinf])
        # print("candidate.pow(2) inf values: ", candidate_decoder_outs.pow(2)[
        #     candidate_isinf.repeat(1, 1, hidden_dim, 1)])
        # print()

        new_anchor = anchor / anchor_norm
        new_candidate_decoder_outs = candidate_decoder_outs / candidate_norm
        similarity = torch.sum(new_anchor * new_candidate_decoder_outs, dim=2) # [bs, tgt_len, number]
        similarity = similarity.view(-1, similarity.size(-1)) * temperature
        target = target.repeat(1, tgt_len, 1)
        target = target.view(-1)  # sample["label"]
        loss = F.cross_entropy(similarity, target, reduction=self.ranking_loss_reduction) * weight
        if self.ranking_loss_reduction == 'sum':
            loss /= num
        return loss

    @staticmethod
    def reduce_metrics(logging_outputs) -> None:
        """Aggregate logging outputs from data parallel training."""
        loss_sum = sum(log.get('loss', 0) for log in logging_outputs)
        rank_loss_sum = sum(log.get('rank_loss', 0) for log in logging_outputs)
        nll_loss_sum = sum(log.get('nll_loss', 0) for log in logging_outputs)
        ntokens = sum(log.get('ntokens', 0) for log in logging_outputs)
        sample_size = sum(log.get('sample_size', 0) for log in logging_outputs)

        metrics.log_scalar('loss', loss_sum / sample_size / math.log(2), sample_size, round=3)
        metrics.log_scalar('nll_loss', nll_loss_sum / ntokens / math.log(2), sample_size, round=3)
        metrics.log_scalar('rank_loss_sum', rank_loss_sum / sample_size / math.log(2), ntokens, round=3)
        metrics.log_derived('ppl', lambda meters: utils.get_perplexity(meters['nll_loss'].avg))
