# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
"""
BART: Denoising Sequence-to-Sequence Pre-training for
Natural Language Generation, Translation, and Comprehension
"""

import torch
import logging
import torch.nn as nn

from fairseq.models import (
    register_model,
    register_model_architecture,
)

from fairseq.models.bart.model import BARTModel, mbart_large_architecture, BARTClassificationHead
from examples.summarization.modules.utils import GradReverse

logger = logging.getLogger(__name__)

DEFAULT_MAX_SOURCE_POSITIONS = 1024
DEFAULT_MAX_TARGET_POSITIONS = 1024


@register_model('ext_summ')
class ExtSumm(BARTModel):
    def __init__(self, args, encoder, decoder):
        super().__init__(args, encoder, decoder)
        if self.args.use_lang_classifier:
            num_langs = len(args.summ_langs) + len(args.denoise_langs)
            self.lang_classifer = nn.Linear(
                args.encoder_embed_dim,
                num_langs,
                bias=False,
            )

    @staticmethod
    def add_args(parser):
        super(ExtSumm, ExtSumm).add_args(parser)
        parser.add_argument('--reverse-lambda', type=float, default=-1.0)
        parser.add_argument(
            '--encoder-output-layer', nargs='+', default=None, type=int,
            help="hidden states of which layers are selected as the encoder output. " \
            "By default, the last layer will be selected." \
            "If multiple layers are given, output of the given layers will be concatnated"
        )

    def register_classification_head(self, name, num_classes=None, inner_dim=None, **kwargs):
        """Register a classification head."""
        logger.info("Registering classification head: {0}".format(name))
        if name in self.classification_heads:
            prev_num_classes = self.classification_heads[name].out_proj.out_features
            prev_inner_dim = self.classification_heads[name].dense.out_features
            if num_classes != prev_num_classes or inner_dim != prev_inner_dim:
                logger.warning(
                    're-registering head "{}" with num_classes {} (prev: {}) '
                    'and inner_dim {} (prev: {})'.format(
                        name, num_classes, prev_num_classes, inner_dim, prev_inner_dim
                    )
                )
        
        classification_head_input_dim = self.args.encoder_embed_dim
        if self.args.encoder_output_layer is not None:
            classification_head_input_dim *= len(self.args.encoder_output_layer)

        self.classification_heads[name] = BARTClassificationHead(
            classification_head_input_dim,
            inner_dim or self.args.encoder_embed_dim,
            num_classes,
            self.args.pooler_activation_fn,
            self.args.pooler_dropout,
        )

    def forward(
        self, src_tokens, src_lengths, 
        margin=None, prev_output_tokens=None,
        features_only=False, classification_head_name=None, **kwargs
    ):
        assert classification_head_name is not None, "classification_head_name must be needed for extSumm.py"

        return_all_hiddens = (self.args.encoder_output_layer is not None)
        encoder_out = self.encoder(
            src_tokens,
            src_lengths=src_lengths,
            return_all_hiddens=return_all_hiddens,
            **kwargs
        )

        extra = {}
        extra['encoder_out'] = encoder_out.encoder_out

        seq_len, bs, _ = extra['encoder_out'].size()
        num = margin.size(1)

        if classification_head_name is not None:
            if self.args.encoder_output_layer is None:
                document_states = extra['encoder_out']
            else:
                encoder_states = encoder_out.encoder_states
                layer_states = [encoder_states[layer] for layer in self.args.encoder_output_layer]
                document_states = torch.cat(layer_states, dim=-1)
            hidden_sizes = document_states.size(-1)
            document_states = document_states.permute(1, 0, 2) # [bs, seqlen, hidden_sizes]
            document_states = document_states.unsqueeze(1).repeat(1, num, 1, 1) # [bs, num, seqlen, hidden_sizes]
            document_states = document_states.contiguous().view(-1, seq_len, hidden_sizes)

            margin = margin.contiguous().view(-1, seq_len) # [bs * num, seqlen]
            margin_states = document_states.masked_fill_(margin.unsqueeze(-1), 0) # [bs * num, seqlen, hidden_sizes]
            margin_states = margin_states.sum(axis=1) # [bs * num, hidden_state]

            margin_states = margin_states.contiguous().view(bs, -1, hidden_sizes) # [bs, num, hidden_state]
            x = self.classification_heads[classification_head_name](margin_states).squeeze(-1) # [bs, num, 2]
            extra['margin_states'] = margin_states
            
            if self.args.use_lang_classifier:
                encoder_last_hidden = encoder_out.encoder_out
                encoder_last_hidden = encoder_last_hidden.transpose(0, 1) # [T x B x C] -> [B x T x C]

                mask = torch.eq(src_tokens, self.encoder.dictionary.pad()).unsqueeze(-1)
                encoder_last_hidden = encoder_last_hidden.masked_fill(
                    mask, 0
                ).mean(dim=-2) # [B x T x C] -> [B x T]

                reversed_encoder_last_hidden = GradReverse.apply(
                    encoder_last_hidden, self.args.reverse_lambda
                )

                lang_cls_out = self.lang_classifer(
                    reversed_encoder_last_hidden,
                )
                extra['lang_cls_out'] = lang_cls_out
            return x, extra


@register_model_architecture('ext_summ', 'ext_summ_large')
def ext_summ_large_architecture(args):
    mbart_large_architecture(args)
