# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
"""
BART: Denoising Sequence-to-Sequence Pre-training for
Natural Language Generation, Translation, and Comprehension
"""

import logging

import torch
import torch.nn as nn

from fairseq import utils
from fairseq.models import (
    register_model,
    register_model_architecture,
)
from fairseq.models.transformer import TransformerModel, transformer_vaswani_wmt_en_de_big
from fairseq.modules.transformer_sentence_encoder import init_bert_params


logger = logging.getLogger(__name__)


@register_model('adaptivesum')
class AdaptiveSumModel(TransformerModel):

    def __init__(self, args, encoder, decoder):
        super().__init__(args, encoder, decoder)

        # We follow BERT's random weight initialization
        self.apply(init_bert_params)


    @staticmethod
    def add_args(parser):
        """Add model-specific arguments to the parser."""
        super(AdaptiveSumModel, AdaptiveSumModel).add_args(parser)
        # fmt: off
        parser.add_argument('--drop-embed-position', action='store_true', help='drop embed position')
        parser.add_argument('--load-embed-position', action='store_true', help='load embed position')
        parser.add_argument('--load-enc-layer-from',type=str, help='load enc layer from the pretrain model, split by ,')
        parser.add_argument('--load-enc-layer-to',type=str, help='load enc layer to the new model, split by ,')
        parser.add_argument('--load-dec-layer-from',type=str, help='load dec layer from the pretrain model, split by ,')
        parser.add_argument('--load-dec-layer-to',type=str, help='load dec layer to the new model, split by ,')
        # fmt: on

    def forward(
        self, src_tokens, src_lengths, prev_output_tokens,
        features_only=False, **kwargs
    ):
        encoder_out = self.encoder(
            src_tokens,
            src_lengths=src_lengths,
            **kwargs,
        )
        decoder_out = self.decoder(
            prev_output_tokens,
            encoder_out=encoder_out,
            features_only=features_only,
            **kwargs,
        )

        return decoder_out


    def load_state_dict(self, state_dict, strict=True, args=None):
        """Copies parameters and buffers from *state_dict* into this module and
        its descendants.

        Overrides the method in :class:`nn.Module`. Compared with that method
        this additionally "upgrades" *state_dicts* from old checkpoints.
        """

        self.upgrade_state_dict(state_dict)

        model_dict = self.state_dict().copy()

        def expand_embed_positions(old_embed):
            padding_idx_embed = old_embed[:2, :] # [2, position_embed_size]
            new_embed = old_embed[2:, :].repeat(4, 1)   # [300 * 4, position_embed_size]
            new_embed = new_embed[:1024, :]
            new_embed = torch.cat((padding_idx_embed, new_embed), dim=0) # [2 + 1024, position_embed_size]
            return new_embed


        if getattr(self.args, "drop_embed_position", False):
            logger.info("Drop encoder and decoder embed position")
            load_dict = {k: v for k, v in state_dict.items() if "embed_positions" not in k}
        elif getattr(self.args, "load_embed_position", False):
            logger.info("Load encoder and decoder embed position from 300 to 1024")
            load_dict = {k: v for k, v in state_dict.items() if "embed_positions" not in k}
            load_dict["encoder.embed_positions.weight"] = expand_embed_positions(state_dict["encoder.embed_positions.weight"])
            load_dict["decoder.embed_positions.weight"] = expand_embed_positions(state_dict["decoder.embed_positions.weight"])
        else:
            load_dict = state_dict
                
        # TODO
        def overwriteLayer(state_dict, layer_prefix, from_idx, to_idx):
            new_dict = {k: v for k, v in state_dict.items() if not k.startswith(layer_prefix)}  # delete all layers related to layer_prefix
            layer_name = ["{}.{}".format(layer_prefix, idx) for idx in from_idx]
            new_layer_name = ["{}.{}".format(layer_prefix, idx) for idx in to_idx]
            mapping = dict(zip(layer_name, new_layer_name))
            for k, v in state_dict.items():
                for lname in layer_name:
                    if k.startswith(lname):
                        new_lname = k.replace(lname, mapping[lname])
                        new_dict[new_lname] = v
            return new_dict

        if getattr(self.args, "load_enc_layer_from", None) and getattr(self.args, "load_enc_layer_to", None):
            from_layer, to_layer = self.args.load_enc_layer_from, self.args.load_enc_layer_to
            logger.info("Load enc layer %s from pretrain model to %s in new model" % (from_layer, to_layer))
            from_layer, to_layer = from_layer.split(','), to_layer.split(',')
            assert len(from_layer)==len(to_layer), "The length of load_enc_layer_from must be the same with load_enc_layer_to"
            load_dict = overwriteLayer(load_dict, "encoder.layers", from_layer, to_layer)
        
        if getattr(self.args, "load_dec_layer_from", None) and getattr(self.args, "load_dec_layer_to", None):
            from_layer, to_layer = self.args.load_dec_layer_from, self.args.load_dec_layer_to
            logger.info("Load dec layer %s from pretrain model to %s in new model" % (from_layer, to_layer))
            from_layer, to_layer = from_layer.split(','), to_layer.split(',')
            assert len(from_layer)==len(to_layer), "The length of load_dec_layer_from must be the same with load_dec_layer_to"
            load_dict = overwriteLayer(load_dict, "decoder.layers", from_layer, to_layer)
            
        logger.info(load_dict.keys())
        model_dict.update(load_dict)
        
        return super().load_state_dict(model_dict, strict)

@register_model_architecture("adaptivesum", "adaptivesum_base")
def adaptivesum_base(args):
    args.encoder_layers = getattr(args, "encoder_layers", 12)
    transformer_vaswani_wmt_en_de_big(args)

@register_model_architecture("adaptivesum", "adaptivesum_enc_load_begin")
def adaptivesum_enc_load_begin(args):
    args.encoder_layers = getattr(args, "encoder_layers", 12)
    args.load_enc_layer_from = getattr(args, "load_enc_layer_from", '0,1,2,3,4,5')
    args.load_enc_layer_to = getattr(args, "load_enc_layer_to", '0,1,2,3,4,5')
    transformer_vaswani_wmt_en_de_big(args)

@register_model_architecture("adaptivesum", "adaptivesum_enc_load_end")
def adaptivesum_enc_load_end(args):
    args.encoder_layers = getattr(args, "encoder_layers", 12)
    args.load_enc_layer_from = getattr(args, "load_enc_layer_from", '0,1,2,3,4,5')
    args.load_enc_layer_to = getattr(args, "load_enc_layer_to", '6,7,8,9,10,11')
    transformer_vaswani_wmt_en_de_big(args)

@register_model_architecture("adaptivesum", "adaptivesum_dec_load_begin")
def adaptivesum_dec_load_begin(args):
    args.encoder_layers = getattr(args, "decoder_layers", 12)
    args.load_enc_layer_from = getattr(args, "load_de_layer_from", '0,1,2,3,4,5')
    args.load_enc_layer_to = getattr(args, "load_de_layer_to", '0,1,2,3,4,5')
    transformer_vaswani_wmt_en_de_big(args)