# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
"""
BART: Denoising Sequence-to-Sequence Pre-training for
Natural Language Generation, Translation, and Comprehension
"""

import logging

import torch.nn as nn

from fairseq import utils
from fairseq.models import (
    register_model,
    register_model_architecture,
)

from fairseq.data import data_utils

from fairseq.models.fairseq_model import BaseFairseqModel
from fairseq.models.bart.model import BARTModel, mbart_large_architecture
from fairseq.models.transformer import base_architecture
from fairseq.modules.transformer_sentence_encoder import init_bert_params

logger = logging.getLogger(__name__)

DEFAULT_MAX_SOURCE_POSITIONS = 1024
DEFAULT_MAX_TARGET_POSITIONS = 1024


@register_model('multi_bart')
class MultiBARTModel(BARTModel):

    def __init__(self, args, encoder, decoders):

        BaseFairseqModel.__init__(self)

        self.args = args
        self.supports_align_args = True

        self._is_generation_fast = False        
        self.encoder = encoder
        self.decoders = decoders
        self.decoder = decoders[0]

        # We follow BERT's random weight initialization
        self.apply(init_bert_params)

    @staticmethod
    def add_args(parser):
        super(MultiBARTModel, MultiBARTModel).add_args(parser)

    @classmethod
    def build_model(cls, args, task):
        """Build a new model instance."""
        # make sure all arguments are present in older models
        base_architecture(args)
        if args.encoder_layers_to_keep:
            args.encoder_layers = len(args.encoder_layers_to_keep.split(","))
        if args.decoder_layers_to_keep:
            args.decoder_layers = len(args.decoder_layers_to_keep.split(","))

        if getattr(args, "max_source_positions", None) is None:
            args.max_source_positions = DEFAULT_MAX_SOURCE_POSITIONS
        if getattr(args, "max_target_positions", None) is None:
            args.max_target_positions = DEFAULT_MAX_TARGET_POSITIONS

        src_dict, tgt_dict = task.source_dictionary, task.target_dictionary

        if args.share_all_embeddings:
            if src_dict != tgt_dict:
                raise ValueError("--share-all-embeddings requires a joined dictionary")
            if args.encoder_embed_dim != args.decoder_embed_dim:
                raise ValueError(
                    "--share-all-embeddings requires --encoder-embed-dim to match --decoder-embed-dim"
                )
            if args.decoder_embed_path and (
                args.decoder_embed_path != args.encoder_embed_path
            ):
                raise ValueError(
                    "--share-all-embeddings not compatible with --decoder-embed-path"
                )
            encoder_embed_tokens = cls.build_embedding(
                args, src_dict, args.encoder_embed_dim, args.encoder_embed_path
            )
            decoder_embed_tokens = encoder_embed_tokens
            args.share_decoder_input_output_embed = True
        else:
            encoder_embed_tokens = cls.build_embedding(
                args, src_dict, args.encoder_embed_dim, args.encoder_embed_path
            )
            decoder_embed_tokens = cls.build_embedding(
                args, tgt_dict, args.decoder_embed_dim, args.decoder_embed_path
            )

        encoder = cls.build_encoder(args, src_dict, encoder_embed_tokens)
        decoders = nn.ModuleList([cls.build_decoder(args, tgt_dict, decoder_embed_tokens) for _ in range(2)])

        return cls(args, encoder, decoders)

    def forward(
        self, src_tokens, src_lengths, prev_output_tokens=None,
        features_only=False, classification_head_name=None, **kwargs
    ):
        if classification_head_name is not None:
            features_only = True

        encoder_out = self.encoder(
            src_tokens,
            src_lengths=src_lengths,
            **kwargs,
        )

        x1, extra1 = self.decoders[0](
            prev_output_tokens,
            encoder_out=encoder_out,
            features_only=features_only,
            **kwargs,
        )

        prev_output_tokens = self._left_pad_src_token(src_tokens, prev_output_tokens)
        x2, extra2 = self.decoders[1](
            prev_output_tokens,
            encoder_out=encoder_out,
            features_only=features_only,
            **kwargs,
        )

        return (x1, extra1),(x2, extra2)

    
    def load_state_dict(self, state_dict, strict=True, args=None):
        """Copies parameters and buffers from *state_dict* into this module and
        its descendants.

        Overrides the method in :class:`nn.Module`. Compared with that method
        this additionally "upgrades" *state_dicts* from old checkpoints.
        """

        model_dict = self.state_dict().copy()

        load_dict = {}
        for k, v in state_dict.items():
            if "decoder" in k and "version" not in k:
                component = k[len("decoder."):]
                load_dict["decoders.0.{}".format(component)] = v
                load_dict["decoders.1.{}".format(component)] = v
            else:
                if k != "decoder.version":
                    load_dict[k] = v
                
        load_dict["decoders.0.version"] = load_dict["decoders.1.version"] = state_dict["decoder.version"]
        logger.info(load_dict.keys())
        model_dict.update(load_dict)
        
        return super().load_state_dict(model_dict, strict)

    def _left_pad_src_token(self, src_token, prev_output_tokens):
        # move eos to beginning and shift
        dst = src_token.clone()
        dst[:,0] = prev_output_tokens[:,0]
        dst[:,1:] = src_token[:,:-1]
        return dst

        
        


@register_model_architecture('multi_bart', 'multi_mbart_large')
def multimbart_large_architecture(args):
    mbart_large_architecture(args)
