# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
"""
BART: Denoising Sequence-to-Sequence Pre-training for
Natural Language Generation, Translation, and Comprehension
"""

import logging

import torch
import torch.nn as nn

from fairseq import utils
from fairseq.models import (
    register_model,
    register_model_architecture,
)
from examples.summarization.models.bartSummAbs import BARTSummAbsModel, mbart_summ_large_architecture

logger = logging.getLogger(__name__)

@register_model('bart_summ_abs_lang')
class BARTSummAbsLangModel(BARTSummAbsModel):
    def __init__(self, args, encoder, decoder):
        super().__init__(args, encoder, decoder)            

    @staticmethod
    def add_args(parser):
        super(BARTSummAbsLangModel, BARTSummAbsLangModel).add_args(parser)

    @property
    def supported_targets(self):
        return {'self'}

    def forward(
        self, src_tokens, src_lengths, prev_output_tokens=None,
        features_only=False, classification_head_name=None,
        **kwargs
    ):
        if classification_head_name is not None:
            features_only = True
        encoder_out, encoder_extra = self.encoder(
            src_tokens, src_lengths
        )

        doc_state = None
        if self.args.doc_state in ["encoder", "adapter"]:
            doc_state = encoder_extra["{}_doc_state".format(self.args.doc_state)]
        elif self.args.doc_state == "fused":
            doc_state = encoder_extra["doc_state"]

        if getattr(self.args, "only_encoder_for_cls", False):
            x = encoder_out.encoder_out
            x = x.permute(1,0,2)
            if classification_head_name is not None:
                sentence_representation = x[
                    src_tokens.eq(self.encoder.encoder.dictionary.eos()), :
                ].view(x.size(0), -1, x.size(-1))[:, -1, :]
                x = self.classification_heads[classification_head_name](
                    sentence_representation
                )
            return x, encoder_out

        x, extra = self.decoder(
            prev_output_tokens,
            encoder_out=encoder_out,
            features_only=features_only,
            **kwargs,
        )

        extra['encoder_out'] = encoder_out.encoder_out
        extra['encoder_doc_out'] = doc_state
        if self.args.use_lang_classifier:
            lang_cls_input = encoder_extra['lang_cls_input']
            lang_cls_input_hat = lang_cls_input.detach()
            extra['lang_cls_out'] = self.lang_classifer(lang_cls_input_hat)
            self.lang_classifer.requires_grad_(False)
            extra['lang_cls_out_adv'] = self.lang_classifer(lang_cls_input)
            self.lang_classifer.requires_grad_(True)

        if classification_head_name is not None:
            sentence_representation = x[
                src_tokens.eq(self.encoder.encoder.dictionary.eos()), :
            ].view(x.size(0), -1, x.size(-1))[:, -1, :]
            x = self.classification_heads[classification_head_name](
                sentence_representation
            )
        return x, extra

@register_model_architecture('bart_summ_abs_lang', 'mbart_summ_abs_lang_large')
def mbart_summ_large_lang_architecture(args):
    mbart_summ_large_architecture(args)
