# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
"""
BART: Denoising Sequence-to-Sequence Pre-training for
Natural Language Generation, Translation, and Comprehension
"""

import logging

from fairseq.models import (
    register_model,
    register_model_architecture,
)
from examples.summarization.models.bartSummAbsWAdapter import BARTSummAbsWAdapterModel, mbart_summ_ads_w_adapter_large_architecture
from examples.summarization.modules.intergratedAdapter import intergratedAdapter

logger = logging.getLogger(__name__)

@register_model('bart_summ_abs_lang_w_adapter')
class BARTSummAbsLangWAdapterModel(BARTSummAbsWAdapterModel):
    def __init__(self, args, encoder, decoder):
        super().__init__(args, encoder, decoder)
        # print("son: ", self.args)
        # if hasattr(self.args, "iadapter_config"):
        double_input = getattr(self.args, "lang_adapter_double_input", False)
        self.iadapter = intergratedAdapter(
            self.args,
            double_input=double_input
        )
        if args.freeze_iadapter:
            self.iadapter.requires_grad_(False)

    # @staticmethod
    # def add_args(parser):
    #     super(BARTSummAbsLangWAdapterModel, BARTSummAbsLangWAdapterModel).add_args(parser)

    def forward(
        self, src_tokens, src_lengths, prev_output_tokens=None,
        features_only=False, classification_head_name=None,
        **kwargs
    ):
        if classification_head_name is not None:
            features_only = True
        encoder_out_tuple, encoder_extra = self.encoder(
            src_tokens, src_lengths
        )
        encoder_out_tensor = encoder_out_tuple.encoder_out
        
        # rewrite
        adapter_out = self.iadapter(
            src_tokens,
            encoder_out_tuple,
            encoder_out_tensor,
            encoder_extra['subtract'].unsqueeze(0).expand(encoder_out_tensor.size(0), -1, -1)
        )
        doc_state = self.encoder.select_doc_state(
            src_tokens,
            adapter_out.encoder_out
        )
        encoder_out = adapter_out
        if getattr(self.args, "only_encoder_for_cls", False):
            x = encoder_out.encoder_out
            x = x.permute(1,0,2)
            if classification_head_name is not None:
                sentence_representation = x[
                    src_tokens.eq(self.encoder.encoder.dictionary.eos()), :
                ].view(x.size(0), -1, x.size(-1))[:, -1, :]
                x = self.classification_heads[classification_head_name](
                    sentence_representation
                )
            return x, encoder_out
        if not features_only:
            x, extra = self.decoder(
                prev_output_tokens,
                encoder_out=encoder_out,
                features_only=features_only,
                **kwargs,
            )
        else:
            x = None
            extra = dict()

        extra['encoder_out'] = encoder_out.encoder_out
        extra['encoder_doc_out'] = doc_state
        if self.args.use_lang_classifier:
            lang_cls_input = encoder_extra['lang_cls_input']
            lang_cls_input_hat = lang_cls_input.detach()
            extra['lang_cls_out'] = self.lang_classifer(lang_cls_input_hat)
            self.lang_classifer.requires_grad_(False)
            extra['lang_cls_out_adv'] = self.lang_classifer(lang_cls_input)
            self.lang_classifer.requires_grad_(True)

        if classification_head_name is not None:
            sentence_representation = x[
                src_tokens.eq(self.encoder.encoder.dictionary.eos()), :
            ].view(x.size(0), -1, x.size(-1))[:, -1, :]
            x = self.classification_heads[classification_head_name](
                sentence_representation
            )
        return x, extra

@register_model_architecture('bart_summ_abs_lang_w_adapter', 'mbart_summ_abs_lang_w_adapter_large')
def mbart_summ_ads_lang_w_adapter_large_architecture(args):
    mbart_summ_ads_w_adapter_large_architecture(args)
