# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
"""
BART: Denoising Sequence-to-Sequence Pre-training for
Natural Language Generation, Translation, and Comprehension
"""

import logging

import torch

from fairseq.models import (
    register_model,
    register_model_architecture,
)
from examples.summarization.modules.utils import GradReverse
from examples.summarization.modules.intergratedAdapter import intergratedAdapter
from examples.summarization.models.bartSummAbs import mbart_summ_large_architecture, BARTSummAbsModel

from examples.summarization.modules.TransformerDecoderwAdapter import TransformerDecoderWAdapter
from fairseq.models.transformer import TransformerDecoder


logger = logging.getLogger(__name__)

@register_model('bart_summ_abs_w_adapter')
class BARTSummAbsWAdapterModel(BARTSummAbsModel):
    def __init__(self, args, encoder, decoder):
        super().__init__(args, encoder, decoder)
        # print("father: ", self.args)
        # if hasattr(self.args, "iadapter_config"):
        double_input = getattr(self.args, "lang_adapter_double_input", False)
        self.iadapter = intergratedAdapter(
            self.args,
            double_input=double_input
        )
        if args.freeze_iadapter:
            self.iadapter.requires_grad_(False)

    @staticmethod
    def add_args(parser):
        super(BARTSummAbsWAdapterModel, BARTSummAbsWAdapterModel).add_args(parser)
        parser.add_argument(
            "--iadapter-config", type=str
        )
        parser.add_argument(
            "--iadapter-dir", type=str
        )
        parser.add_argument(
            "--trained-iadapter-dir", type=str, default=None
        )
        parser.add_argument(
            "--freeze-iadapter", action="store_true"
        )
        parser.add_argument(
            "--train-iadapter", action="store_true"
        )
        parser.add_argument(
            "--lang-adapter-double-input", action="store_true", 
            help="feed the output of task adapter and subtracted component as the input of lang adapter"
        )

    def forward(
        self, src_tokens, src_lengths, prev_output_tokens=None,
        features_only=False, classification_head_name=None,
        **kwargs
    ):
        if classification_head_name is not None:
            features_only = True
        encoder_out_tuple, extra_dict = self.encoder(
            src_tokens, src_lengths
        )
        encoder_out_tensor = encoder_out_tuple.encoder_out
        
        # rewrite
        adapter_out = self.iadapter(
            src_tokens,
            encoder_out_tuple,
            encoder_out_tensor,
            extra_dict['subtract'].unsqueeze(0).expand(encoder_out_tensor.size(0), -1, -1)
        )
        doc_state = self.encoder.select_doc_state(
            src_tokens,
            adapter_out.encoder_out
        )
        encoder_out = adapter_out
        if getattr(self.args, "only_encoder_for_cls", False):
            x = encoder_out.encoder_out
            x = x.permute(1,0,2)
            if classification_head_name is not None:
                sentence_representation = x[
                    src_tokens.eq(self.encoder.encoder.dictionary.eos()), :
                ].view(x.size(0), -1, x.size(-1))[:, -1, :]
                x = self.classification_heads[classification_head_name](
                    sentence_representation
                )
            return x, encoder_out
        if not features_only:
            x, extra = self.decoder(
                prev_output_tokens,
                encoder_out=encoder_out,
                features_only=features_only,
                **kwargs,
            )
        else:
            x = None
            extra = dict()

        extra['encoder_out'] = encoder_out.encoder_out
        extra['encoder_doc_out'] = doc_state
        extra['fused_doc_state'] = doc_state
        if self.args.use_lang_classifier:
            if self.args.lang_cls_unsup_lang_no_grad:
                batch_size, hidden_dim = doc_state.size()
                sup_mask = torch.zeros([batch_size], dtype=torch.bool).to(doc_state.device)
                # TODO: write mask
                lang_token = src_tokens[:, -1]
                for token in self.summ_lang_ids:
                    sup_mask = sup_mask | (lang_token == token)
                sup_mask = sup_mask.unsqueeze(-1).repeat([1, hidden_dim])
                sup_hiddens = torch.where(
                    sup_mask,
                    doc_state,
                    torch.zeros(1).to(doc_state)
                )

                unsup_hiddens = torch.where(
                    ~sup_mask,
                    doc_state,
                    torch.zeros(1).to(doc_state)
                )
                unsup_hiddens.detach()
                doc_state = sup_hiddens + unsup_hiddens

            reversed_encoder_last_hidden = GradReverse.apply(doc_state, self.args.lambd)
            lang_cls_out = self.lang_classifer(
                reversed_encoder_last_hidden,
            )
            extra['lang_cls_out'] = lang_cls_out

        if self.args.use_nonstop_classifier:
            classifier_input = self.nonstop_fc(doc_state)
            classifier_input = torch.tanh(classifier_input)
            nonstop_cls_out = self.nonstop_classifer(
                classifier_input,
            )
            extra['nonstop_cls_out'] = nonstop_cls_out

        if self.args.use_stop_classifier:
            reversed_doc_state = GradReverse.apply(doc_state, self.args.lambd)
            classifier_input = self.stop_fc(reversed_doc_state)
            classifier_input = torch.tanh(classifier_input)
            stop_cls_out = self.stop_classifer(
                classifier_input,
            )
            extra['stop_cls_out'] = stop_cls_out

        if classification_head_name is not None:
            sentence_representation = x[
                src_tokens.eq(self.encoder.encoder.dictionary.eos()), :
            ].view(x.size(0), -1, x.size(-1))[:, -1, :]
            x = self.classification_heads[classification_head_name](
                sentence_representation
            )
        return x, extra

@register_model_architecture('bart_summ_abs_w_adapter', 'mbart_summ_abs_w_adapter_large')
def mbart_summ_ads_w_adapter_large_architecture(args):
    mbart_summ_large_architecture(args)
    args.trained_iadapter_dir = getattr(args, "trained_iadapter_dir", None)
    args.freeze_iadapter = getattr(args, "freeze_iadapter", False)
    args.train_iadapter = getattr(args, "train_iadapter", False)
    args.lang_adapter_double_input = getattr(args, "lang_adapter_double_input", False)
    args.decoder_version = getattr(args, "decoder_version", "default")
    args.decoder_iadapter_config = getattr(args, "decoder_iadapter_config", None)
