# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
"""
BART: Denoising Sequence-to-Sequence Pre-training for
Natural Language Generation, Translation, and Comprehension
"""

import logging

import torch.nn as nn

from fairseq import utils
from fairseq.models import (
    register_model,
    register_model_architecture,
)

from fairseq.data import data_utils

from fairseq.models.fairseq_model import BaseFairseqModel
from fairseq.models.bart.model import BARTModel, mbart_large_architecture
from fairseq.models.transformer import base_architecture
from fairseq.models.bart import BARTClassificationHead
from fairseq.modules.transformer_sentence_encoder import init_bert_params

from fairseq.models.fairseq_encoder import EncoderOut

logger = logging.getLogger(__name__)

DEFAULT_MAX_SOURCE_POSITIONS = 1024
DEFAULT_MAX_TARGET_POSITIONS = 1024


@register_model('rank_summ_gen')
class RankSummGen(BARTModel):

    def __init__(self, args, encoder, decoder):

        super().__init__(args, encoder, decoder)

    @staticmethod
    def add_args(parser):
        super(RankSummGen, RankSummGen).add_args(parser)

    def forward(
        self, src_tokens, src_lengths, margin, prev_output_tokens=None,
        features_only=False, classification_head_name=None, **kwargs
    ):

        # BART forward
        if classification_head_name is not None:
            features_only = True

        encoder_out = self.encoder(
            src_tokens,
            src_lengths=src_lengths,
            **kwargs,
        )

        if getattr(self.args, "only_encoder_for_cls", False):
            x = encoder_out.encoder_out
            x = x.permute(1,0,2)
            if classification_head_name is not None:
                sentence_representation = x[
                    src_tokens.eq(self.encoder.dictionary.eos()), :
                ].view(x.size(0), -1, x.size(-1))[:, -1, :]
                x = self.classification_heads[classification_head_name](
                    sentence_representation
                )
            return x, encoder_out

        x, extra = self.decoder(
            prev_output_tokens,
            encoder_out=encoder_out,
            features_only=False,
            **kwargs,
        )

        extra['encoder_out'] = encoder_out.encoder_out
        
        # x: [bs, tgt_len, hidden]
        # extra: {"attn": [attn], "inner_states": [inner_state], 'encoder_out': encoder_out}
        # inner_state: [tgt_len, bs, hidden]
        

        seq_len, bs, hidden_sizes = extra['encoder_out'].size()
        num = margin.size(1) # [bs, num, seqlen]
        # print(extra['encoder_out'].size(), margin.size()

        if classification_head_name is not None:
            candidate_decoder_outs = []
            for candidatei in range(num):
                document_states = encoder_out.encoder_out.clone()
                document_states = document_states.permute(1, 0, 2) # [bs, seqlen, hidden_sizes]
                document_states = document_states.contiguous().view(-1, seq_len, hidden_sizes)

                candidate_margin = margin[:, candidatei].contiguous().view(-1, seq_len)   # [bs, seqlen]
                margin_states = document_states.masked_fill_(candidate_margin.unsqueeze(-1), 0)   # [bs, seqlen, hidden_sizes]

                margin_states = margin_states.contiguous().view(bs, seq_len, hidden_sizes) # [bs, seqlen, hidden_state]

                candidate_encoder_out = EncoderOut(
                    encoder_out=margin_states,  # T x B x C
                    encoder_padding_mask=encoder_out.encoder_padding_mask.clone(),  # B x T
                    encoder_embedding=encoder_out.encoder_embedding.clone(),  # B x T x C
                    encoder_states=(encoder_out.encoder_states.clone()
                    if encoder_out.encoder_states is not None else None),  # List[T x B x C]
                    src_tokens=None,
                    src_lengths=None,
                )

                candidate_decoder_out, _ = self.decoder(
                    prev_output_tokens,
                    encoder_out=candidate_encoder_out,
                    features_only=features_only,
                    **kwargs,
                )

                candidate_decoder_outs.append(candidate_decoder_out)
            extra['candidate_decoder_outs'] = candidate_decoder_outs
        
        # print("extra.keys(): ", extra.keys())
        # print("classification_head_name: ", classification_head_name)
        # print("num: ", num)
        return x, extra


@register_model_architecture('rank_summ_gen', 'rank_summ_gen_large')
def rank_summ_large_architecture(args):
    mbart_large_architecture(args)
