# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
"""
BART: Denoising Sequence-to-Sequence Pre-training for
Natural Language Generation, Translation, and Comprehension
"""

import logging

import torch
import torch.nn as nn

from fairseq import utils
from fairseq.models import (
    register_model,
    register_model_architecture,
)
from fairseq.models.transformer import Embedding
from fairseq.models.bart import BARTModel, mbart_large_architecture
from examples.summarization.modules.utils import GradReverse
from examples.summarization.modules.superEncoder import superEncoder, superFlEncoder
from examples.summarization.modules.embeddedAdapter import embeddedAdapter, embeddedFlAdapter
from fairseq.models.transformer import TransformerModel

from examples.summarization.modules.TransformerDecoderwAdapter import TransformerDecoderWAdapter
from fairseq.models.transformer import TransformerDecoder

logger = logging.getLogger(__name__)

@register_model('bart_summ_abs')
class BARTSummAbsModel(BARTModel):
    def __init__(self, args, encoder, decoder):
        super().__init__(args, encoder, decoder)
        self.args = args
        if getattr(args, "use_lang_classifier", False):
            num_langs = len(args.summ_langs) + len(args.unsupervised_langs)
            self.summ_lang_ids = [encoder.dictionary.index(lang) for lang in args.summ_langs]
            self.lang_classifer = nn.Linear(
                args.encoder_embed_dim,
                num_langs,
                bias=False,
            )
        
        if getattr(args, "use_nonstop_classifier", None):
            self.nonstop_fc = nn.Linear(
                args.encoder_embed_dim,
                args.encoder_embed_dim,
                bias=True
            )
            self.nonstop_classifer = nn.Linear(
                args.encoder_embed_dim,
                len(encoder.dictionary),
                bias=False,
            )
            self.nonstop_classifer.requires_grad_(False)

        if getattr(args, "use_stop_classifier", None):
            self.stop_fc = nn.Linear(
                args.encoder_embed_dim,
                args.encoder_embed_dim,
                bias=True
            )
            self.stop_classifer = nn.Linear(
                args.encoder_embed_dim,
                len(encoder.dictionary),
                bias=False,
            )
            self.stop_classifer.requires_grad_(False)
        
        if getattr(args, "freeze_encoder", False):
            if getattr(args, "tuned_encoder_layers", None):
                for i in args.tuned_encoder_layers:
                    self.encoder.layers[i].requires_grad_(True)
            if self.encoder.layer_norm and args.tune_encoder_layer_norm:
                self.encoder.layer_norm.requires_grad_(True)
            encoder_posttuned_fn = getattr(self.encoder, "posttuned_fn", None)
            if callable(encoder_posttuned_fn):
                encoder_posttuned_fn(args) # freeze / tune specific encoder modules

        if getattr(args, "freeze_decoder", False):
            if getattr(args, "tuned_decoder_layers", None):
                for i in args.tuned_decoder_layers:
                    self.decoder.layers[i].requires_grad_(True)
            if getattr(args, "tune_cross_attention", False):
                for layer in self.decoder.layers:
                    layer.encoder_attn.requires_grad_(True)
            decoder_posttuned_fn = getattr(self.decoder, "posttuned_fn", None)
            if callable(decoder_posttuned_fn):
                decoder_posttuned_fn(args) # freeze / tune specific encoder modules
            

    @staticmethod
    def add_args(parser):
        super(BARTSummAbsModel, BARTSummAbsModel).add_args(parser)
        parser.add_argument(
            '--adapter-layer-mapping', type=str, default=[], nargs="+", help="e.g. 10to0, 11to1"
        )
        parser.add_argument(
            '--encoder-layer-mapping', type=str, default=[], nargs="+", help="e.g. 10to12, 11to13"
        )
        parser.add_argument(
            '--adapter-copy-layernorm', action="store_true",
            help="If True, the layer norm of adapter will be initialized by that of the encoder"
        )
        parser.add_argument(
            '--wo-encoder-last-layernorm', action="store_true",
            help="If True, the last layernorm of encoder will be deleted"
        )
        parser.add_argument('--adapter-num-layer', type=int, default=0)
        parser.add_argument('--lambd', type=float, default=1)
        parser.add_argument('--freeze-adapter', action="store_true", help="If True, the adapter will be freezed")
        parser.add_argument(
            "--tune-encoder-layer-norm", action="store_true", 
            help="whether tune the final layer norm of the encoder"
        )
        parser.add_argument('--tuned-encoder-layers', nargs="+", type=int, default=[], 
            help='freezed encoder layers'
        )
        parser.add_argument('--tuned-decoder-layers', nargs="+", type=int, default=[], 
            help='freezed decoder layers'
        )

        parser.add_argument(
            '--fuse-encoder-and-adapter', default=None, 
            choices=["gated", "sum", "attn", "doc_gated"], 
            type=str
        )
        parser.add_argument('--doc-state', default="adapter", 
            choices=["encoder", "adapter", "fused", "proj"], type=str
        )
        parser.add_argument('--encoder-version', default="v2", 
            choices=["v1", "v2", "v1_fused", "v2_fused"], type=str
        )
        parser.add_argument(
            "--v2-adapter-pre-layernorm", action="store_true",
            help="If True, layer normalization will be applied to the output of adapter;" \
                "Otherwise, the fusion results"
        )

        parser.add_argument(
            "--use-nonstop-classifier", action="store_true"
        )

        parser.add_argument(
            "--use-stop-classifier", action="store_true"
        )

        parser.add_argument("--adapter-wo-lang-hidden", action="store_true")
        parser.add_argument("--bn-encoder-output", action="store_true")
        parser.add_argument(
            "--component-config", type=str, default=None, 
            help="a json file, '{tokenId (str): component_file name (str)}'"    
        )
        parser.add_argument("--reconstruct-bn-encoder-output", action="store_true")
        parser.add_argument("--ln-after-proj", action="store_true")
        parser.add_argument("--tune-cross-attention", action="store_true")
        parser.add_argument("--freeze-ln-after-proj", action="store_true")
        parser.add_argument("--proj-k", help="the number of top singular vector used for projection", type=int, default=6)
        parser.add_argument(
            "--fl-method", 
            help="the method of fusing lang-agnostic and lang-specific hidden",
            default="gated", type=str,
            choices=["sum", "gated", "gated_w_lang_agnostic", "concat"]
        )
        parser.add_argument(
            "--remove-lang-fn",
            help="the method of removing lang components",
            default="subtract",
            choices=['subtract', 'concat', 'concat_wtask', 'attn_subtract']
        )

        parser.add_argument(
            "--lang-cls-input",
            help="inputs of lang classifier",
            default="adapter_output",
            choices=['adapter_input', 'adapter_output']
        )
        parser.add_argument(
            "--proj-ln-wo-affine",
            action="store_true"
        )
        parser.add_argument(
            "--postfix-tuning",
            default=None,
            type=str
        )

        parser.add_argument('--decoder-version', default="default", 
            choices=["default", "w_adapter", "w_cross_adapter"], type=str
        )

        parser.add_argument(
            "--decoder-iadapter-config", type=str, default=None
        )

    @property
    def supported_targets(self):
        return {'self'}

    @classmethod
    def build_encoder(cls, args, src_dict, embed_tokens):
        if args.encoder_version == "v1":
            return superEncoder(args, src_dict, embed_tokens)
        elif args.encoder_version == "v2":
            return embeddedAdapter(args, src_dict, embed_tokens)
        elif args.encoder_version == "v1_fused":
            return superFlEncoder(args, src_dict, embed_tokens)
        else:
            return embeddedFlAdapter(args, src_dict, embed_tokens)

    @classmethod
    def build_decoder(cls, args, tgt_dict, embed_tokens):
        decoder_version = getattr(args, "decoder_version", "default")
        if decoder_version == "default":
            return TransformerDecoder(
                args,
                tgt_dict,
                embed_tokens,
                no_encoder_attn=getattr(args, "no_cross_attention", False),
            )
        elif decoder_version == "w_adapter":
            return TransformerDecoderWAdapter(
                args,
                tgt_dict,
                embed_tokens,
                no_encoder_attn=getattr(args, "no_cross_attention", False)
            )
        elif decoder_version == "w_cross_adapter":
            return TransformerDecoderWAdapter(
                args,
                tgt_dict,
                embed_tokens,
                no_encoder_attn=getattr(args, "no_cross_attention", False),
                adapter_pos="cross"
            )
        else:
            raise NotImplementedError

    def forward(
        self, src_tokens, src_lengths, prev_output_tokens=None,
        features_only=False, classification_head_name=None,
        **kwargs
    ):
        if classification_head_name is not None:
            features_only = True
        encoder_out, extra = self.encoder(
            src_tokens, src_lengths
        )

        doc_state = None
        if self.args.doc_state in ["encoder", "adapter"]:
            doc_state = extra["{}_doc_state".format(self.args.doc_state)]
        elif self.args.doc_state == "fused":
            doc_state = extra["doc_state"]

        if getattr(self.args, "only_encoder_for_cls", False):
            x = encoder_out.encoder_out
            x = x.permute(1,0,2)
            if classification_head_name is not None:
                sentence_representation = x[
                    src_tokens.eq(self.encoder.encoder.dictionary.eos()), :
                ].view(x.size(0), -1, x.size(-1))[:, -1, :]
                x = self.classification_heads[classification_head_name](
                    sentence_representation
                )
            return x, encoder_out

        x, extra = self.decoder(
            prev_output_tokens,
            encoder_out=encoder_out,
            features_only=features_only,
            **kwargs,
        )

        extra['encoder_out'] = encoder_out.encoder_out
        extra['encoder_doc_out'] = doc_state
        if self.args.use_lang_classifier:
            if self.args.lang_cls_unsup_lang_no_grad:
                batch_size, hidden_dim = doc_state.size()
                sup_mask = torch.zeros([batch_size], dtype=torch.bool).to(doc_state.device)
                # TODO: write mask
                lang_token = src_tokens[:, -1]
                for token in self.summ_lang_ids:
                    sup_mask = sup_mask | (lang_token == token)
                sup_mask = sup_mask.unsqueeze(-1).repeat([1, hidden_dim])
                sup_hiddens = torch.where(
                    sup_mask,
                    doc_state,
                    torch.zeros(1).to(doc_state)
                )

                unsup_hiddens = torch.where(
                    ~sup_mask,
                    doc_state,
                    torch.zeros(1).to(doc_state)
                )
                unsup_hiddens.detach()
                doc_state = sup_hiddens + unsup_hiddens

            reversed_encoder_last_hidden = GradReverse.apply(doc_state, self.args.lambd)
            lang_cls_out = self.lang_classifer(
                reversed_encoder_last_hidden,
            )
            extra['lang_cls_out'] = lang_cls_out

        if self.args.use_nonstop_classifier:
            classifier_input = self.nonstop_fc(doc_state)
            classifier_input = torch.tanh(classifier_input)
            nonstop_cls_out = self.nonstop_classifer(
                classifier_input,
            )
            extra['nonstop_cls_out'] = nonstop_cls_out

        if self.args.use_stop_classifier:
            reversed_doc_state = GradReverse.apply(doc_state, self.args.lambd)
            classifier_input = self.stop_fc(reversed_doc_state)
            classifier_input = torch.tanh(classifier_input)
            stop_cls_out = self.stop_classifer(
                classifier_input,
            )
            extra['stop_cls_out'] = stop_cls_out

        if classification_head_name is not None:
            sentence_representation = x[
                src_tokens.eq(self.encoder.encoder.dictionary.eos()), :
            ].view(x.size(0), -1, x.size(-1))[:, -1, :]
            x = self.classification_heads[classification_head_name](
                sentence_representation
            )
        return x, extra

    def postprocess(self, args, task):
        """
        modify model after loading the existing checkpoint
        """
        new_embed_tokens = Embedding(
            len(task.source_dictionary), args.encoder_embed_dim, task.source_dictionary.pad()
        )
        old_dictionary_size = self.encoder.embed_tokens.weight.data.size(0)
        new_embed_tokens.weight.data[:old_dictionary_size] = self.encoder.embed_tokens.weight.data.clone().detach()

        self.tpu = getattr(args, 'tpu', False)
        self.cuda = torch.cuda.is_available() and not args.cpu and not self.tpu
        if self.cuda:
            self.device = torch.device('cuda')
        elif self.tpu:
            self.device = utils.get_tpu_device(args)
        else:
            self.device = torch.device('cpu')

        new_embed_tokens = new_embed_tokens.to(self.device)

        if args.fp16:
            new_embed_tokens = new_embed_tokens.half()
        elif args.bf16:
            new_embed_tokens = new_embed_tokens.to(dtype=torch.bfloat16)

        self.encoder.embed_tokens = new_embed_tokens 
        self.decoder.embed_tokens = self.encoder.embed_tokens

        if hasattr(args, "freeze_decoder") and args.freeze_decoder or \
            (hasattr(args, "freeze_encoder") and args.freeze_encoder) or \
            (hasattr(args, "freeze_embedding") and args.freeze_embedding):
            self.encoder.embed_tokens.requires_grad_(False)
        
        if hasattr(args, "freeze_position_embedding") and args.freeze_position_embedding:
            self.encoder.embed_positions.requires_grad_(False)

        if self.decoder.share_input_output_embed:
            self.decoder.output_projection.weight = self.decoder.embed_tokens.weight

    def modify_state(self, model_states: dict):
        result_states = {}
        for k in model_states:
            result_states[k] = model_states[k]
        added_keys = []
        for mapping in self.args.adapter_layer_mapping:
            src, tgt = mapping.split("to")
            src_prefix = "encoder.layers.{}".format(src)
            tgt_prefix = "encoder.adapter.layers.{}".format(tgt)
            for key in model_states:
                if key.startswith(src_prefix):
                    tgt_name = tgt_prefix + key[len(src_prefix):]
                    src_param = model_states[key]
                    tgt_param = torch.zeros(src_param.size(), dtype=src_param.dtype).to(src_param.device)
                    tgt_param = tgt_param.copy_(src_param)
                    logger.info('src_name: {}, tgt_name: {}'.format(key, tgt_name))
                    added_keys.append(tgt_name)
                    result_states[tgt_name] = tgt_param
        
        for mapping in self.args.encoder_layer_mapping:
            src, tgt = mapping.split("to")
            src_prefix = "encoder.layers.{}".format(src)
            tgt_prefix = "encoder.layers.{}".format(tgt)
            for key in model_states:
                if key.startswith(src_prefix):
                    tgt_name = tgt_prefix + key[len(src_prefix):]
                    src_param = model_states[key]
                    tgt_param = torch.zeros(src_param.size(), dtype=src_param.dtype).to(src_param.device)
                    tgt_param = tgt_param.copy_(src_param)
                    logger.info('src_name: {}, tgt_name: {}'.format(key, tgt_name))
                    added_keys.append(tgt_name)
                    result_states[tgt_name] = tgt_param

        if self.args.adapter_copy_layernorm and self.args.encoder_version == "v1":
            src_prefix = "encoder.layer_norm."
            tgt_prefix = "encoder.adapter.layer_norm."
            for key in model_states:
                if key.startswith(src_prefix):
                    tgt_name = tgt_prefix + key[len(src_prefix):]
                    src_param = model_states[key]
                    tgt_param = torch.zeros(src_param.size(), dtype=src_param.dtype).to(src_param.device)
                    tgt_param = tgt_param.copy_(src_param)
                    logger.info('src_name: {}, tgt_name: {}'.format(key, tgt_name))
                    added_keys.append(tgt_name)
                    result_states[tgt_name] = tgt_param

        if self.args.use_nonstop_classifier:
            src_name = "encoder.embed_tokens.weight"
            tgt_name = "nonstop_classifer.weight"
            src_param = model_states[src_name]
            tgt_param = torch.zeros(src_param.size(), dtype=src_param.dtype).to(src_param.device)
            tgt_param = tgt_param.copy_(src_param)
            added_keys.append(tgt_name)
            result_states[tgt_name] = tgt_param

        if self.args.use_stop_classifier:
            src_name = "encoder.embed_tokens.weight"
            tgt_name = "stop_classifer.weight"
            src_param = model_states[src_name]
            tgt_param = torch.zeros(src_param.size(), dtype=src_param.dtype).to(src_param.device)
            tgt_param = tgt_param.copy_(src_param)
            added_keys.append(tgt_name)
            result_states[tgt_name] = tgt_param

        logger.info("added keys: {}\n".format(" | ".join(added_keys)))
        return result_states

    def upgrade_state_dict_named(self, state_dict, name):
        TransformerModel.upgrade_state_dict_named(self, state_dict, name)

        encoder_prefix = None
        if "encoder.embed_tokens.weight" in state_dict: 
            encoder_prefix = "encoder"
        elif "encoder.encoder.embed_tokens.weight" in state_dict: 
            encoder_prefix = "encoder.encoder"
        assert encoder_prefix is not None, "either encoder.embed_tokens.weight or encoder.encoder.embed_tokens.weight should be in state_dict"
        
        prefix = name + '.' if name != '' else ''
        current_head_names = [] if not hasattr(self, 'classification_heads') else \
            self.classification_heads.keys()

        # Handle new classification heads present in the state dict.
        keys_to_delete = []
        for k in state_dict.keys():
            if not k.startswith(prefix + 'classification_heads.'):
                continue

            head_name = k[len(prefix + 'classification_heads.'):].split('.')[0]
            num_classes = state_dict[prefix + 'classification_heads.' + head_name + '.out_proj.weight'].size(0)
            inner_dim = state_dict[prefix + 'classification_heads.' + head_name + '.dense.weight'].size(0)

            if getattr(self.args, 'load_checkpoint_heads', False):
                if head_name not in current_head_names:
                    self.register_classification_head(head_name, num_classes, inner_dim)
            else:
                if head_name not in current_head_names:
                    logger.warning(
                        'deleting classification head ({}) from checkpoint '
                        'not present in current model: {}'.format(head_name, k)
                    )
                    keys_to_delete.append(k)
                elif (
                    num_classes != self.classification_heads[head_name].out_proj.out_features
                    or inner_dim != self.classification_heads[head_name].dense.out_features
                ):
                    logger.warning(
                        'deleting classification head ({}) from checkpoint '
                        'with different dimensions than current model: {}'.format(head_name, k)
                    )
                    keys_to_delete.append(k)
        for k in keys_to_delete:
            del state_dict[k]

        def truncate_emb(key):
            if key in state_dict:
                state_dict[key] = state_dict[key][:-1, :]

        # When finetuning on translation task, remove last row of
        # embedding matrix that corresponds to mask_idx token.
        loaded_dict_size = state_dict['{}.embed_tokens.weight'.format(encoder_prefix)].size(0)
        if loaded_dict_size == len(self.encoder.dictionary) + 1 and '<mask>' not in self.encoder.dictionary:
            truncate_emb('{}.embed_tokens.weight'.format(encoder_prefix))
            truncate_emb('decoder.embed_tokens.weight')
            truncate_emb('{}.output_projection.weight'.format(encoder_prefix))
            truncate_emb('decoder.output_projection.weight')

        # When continued pretraining on new set of languages for mbart,
        # add extra lang embeddings at the end of embed_tokens.
        # Note: newly added languages are assumed to have been added at the end.
        if self.args.task == 'multilingual_denoising' and loaded_dict_size < len(self.encoder.dictionary):
            logger.info(
                "Adding extra language embeddings not found in pretrained model for "\
                "continued pretraining of MBART on new set of languages."
            )
            loaded_mask_token_embedding = state_dict['{}.embed_tokens.weight'.format(encoder_prefix)][-1, :]

            num_langids_to_add = len(self.encoder.dictionary) - loaded_dict_size
            embed_dim = state_dict['{}.embed_tokens.weight'.format(encoder_prefix)].size(1)

            new_lang_embed_to_add = torch.zeros(num_langids_to_add, embed_dim)
            nn.init.normal_(
                new_lang_embed_to_add,
                mean=0,
                std=embed_dim ** -0.5
            )
            new_lang_embed_to_add = new_lang_embed_to_add.to(
                dtype=state_dict['{}.embed_tokens.weight'.format(encoder_prefix)].dtype,
            )

            state_dict['{}.embed_tokens.weight'.format(encoder_prefix)] = torch.cat([
                state_dict['{}.embed_tokens.weight'.format(encoder_prefix)][:loaded_dict_size-1, :],
                new_lang_embed_to_add,
                loaded_mask_token_embedding.unsqueeze(0)]
            )
            state_dict['decoder.embed_tokens.weight'] = torch.cat([
                state_dict['decoder.embed_tokens.weight'][:loaded_dict_size-1, :],
                new_lang_embed_to_add,
                loaded_mask_token_embedding.unsqueeze(0)]
            )

        # Copy any newly-added classification heads into the state dict
        # with their current weights.
        if hasattr(self, 'classification_heads'):
            cur_state = self.classification_heads.state_dict()
            for k, v in cur_state.items():
                if prefix + 'classification_heads.' + k not in state_dict:
                    logger.info('Overwriting '+ prefix + 'classification_heads.' + k)
                    state_dict[prefix + 'classification_heads.' + k] = v

@register_model_architecture('bart_summ_abs', 'mbart_summ_abs_large')
def mbart_summ_large_architecture(args):
    args.freeze_encoder = getattr(args, "freeze_encoder", False)
    args.freeze_decoder = getattr(args, "freeze_decoder", False)
    args.freeze_embedding = getattr(args, "freeze_embedding", False)
    args.freezed_encoder_layers = getattr(args, "freezed_encoder_layers", [])
    args.freezed_decoder_layers = getattr(args, "freezed_decoder_layers", [])
    args.adapter_copy_layernorm = getattr(args, "adapter_copy_layernorm", False)
    args.freeze_adapter = getattr(args, "freeze_adapter", False)
    args.wo_encoder_last_layernorm = getattr(args, "wo_encoder_last_layernorm", False)
    args.tuned_encoder_layers = getattr(args, "tuned_encoder_layers", [])
    args.tuned_decoder_layers = getattr(args, "tuned_decoder_layers", [])
    args.tune_encoder_layer_norm = getattr(args, "tune_encoder_layer_norm", False)
    args.fuse_encoder_and_adapter = getattr(args, "fuse_encoder_and_adapter", None)
    args.doc_state = getattr(args, "doc_state", "adapter")
    args.v2_adapter_pre_layernorm = getattr(args, "v2_adapter_pre_layernorm", False)
    args.use_nonstop_classifier = getattr(args, "use_nonstop_classifier", False)
    args.use_stop_classifier = getattr(args, "use_stop_classifier", False)
    args.adapter_wo_lang_hidden = getattr(args, "adapter_wo_lang_hidden", False)
    args.bn_encoder_output = getattr(args, "bn_encoder_output", False)
    args.reconstruct_bn_encoder_output = getattr(args, "reconstruct_bn_encoder_output", False)
    args.component_config = getattr(args, "component_config", None)
    args.ln_after_proj = getattr(args, "ln_after_proj", None)
    args.tune_cross_attention = getattr(args, "tune_cross_attention", False)
    args.freeze_ln_after_proj = getattr(args, "freeze_ln_after_proj", False)
    args.proj_k = getattr(args, "proj_k", 6)
    args.fuse_lang_agnostic_method = getattr(args, "fl_method", "gated")
    args.remove_lang_fn = getattr(args, "remove_lang_fn", "subtract")
    args.lang_cls_input = getattr(args, "lang_cls_input", "adapter_output")
    args.proj_ln_wo_affine = getattr(args, "proj_ln_wo_affine", False)
    args.postfix_tuning = getattr(args, "postfix_tuning", False)
    mbart_large_architecture(args)
