# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import copy
import logging
import math
from typing import Any, Dict, List, Optional, Tuple

import torch
import torch.nn as nn
import torch.nn.functional as F

from fairseq import utils
from fairseq.file_io import PathManager
from fairseq.iterative_refinement_generator import DecoderOut
from fairseq.models import (
    register_model,
    register_model_architecture, BaseFairseqModel,
)
from fairseq.models.my_transformer import MyTransformerModel
from fairseq.models.nat import CMLMNATransformerModel
from fairseq.models.transformer import TransformerModel, Embedding, Linear, TransformerDecoder
from torch import Tensor

DEFAULT_MAX_SOURCE_POSITIONS = 1024
DEFAULT_MAX_TARGET_POSITIONS = 1024

logger = logging.getLogger("fairseq.models.my_at_nat_transformer")
@register_model("my_at_nat_transformer")
class MyATNATTransformerModel(BaseFairseqModel):

    def __init__(self, base_args, at_args, nat_args, at_transformer, nat_transformer, mode):
        super().__init__()
        self.base_args = base_args
        self.at_args = at_args
        self.nat_args = nat_args
        self.at_transformer = at_transformer
        self.nat_transformer = nat_transformer
        self.mode = mode

        logger.info("Type of AT model: {}".format(type(self.at_transformer)))
        logger.info("Type of NAT model: {}".format(type(self.nat_transformer)))
        logger.info("Self.mode: {}".format(self.mode))


    @staticmethod
    def add_args(parser):
        """Add model-specific arguments to the parser."""
        # fmt: off
        parser.add_argument('--activation-fn',
                            choices=utils.get_available_activation_fns(),
                            help='activation function to use')
        parser.add_argument('--dropout', type=float, metavar='D',
                            help='dropout probability')
        parser.add_argument('--attention-dropout', type=float, metavar='D',
                            help='dropout probability for attention weights')
        parser.add_argument('--activation-dropout', '--relu-dropout', type=float, metavar='D',
                            help='dropout probability after activation in FFN.')
        parser.add_argument('--encoder-embed-path', type=str, metavar='STR',
                            help='path to pre-trained encoder embedding')
        parser.add_argument('--encoder-embed-dim', type=int, metavar='N',
                            help='encoder embedding dimension')
        parser.add_argument('--encoder-ffn-embed-dim', type=int, metavar='N',
                            help='encoder embedding dimension for FFN')
        parser.add_argument('--encoder-layers', type=int, metavar='N',
                            help='num encoder layers')
        parser.add_argument('--encoder-attention-heads', type=int, metavar='N',
                            help='num encoder attention heads')
        parser.add_argument('--encoder-normalize-before', action='store_true',
                            help='apply layernorm before each encoder block')
        parser.add_argument('--encoder-learned-pos', action='store_true',
                            help='use learned positional embeddings in the encoder')
        parser.add_argument('--decoder-embed-path', type=str, metavar='STR',
                            help='path to pre-trained decoder embedding')
        parser.add_argument('--decoder-embed-dim', type=int, metavar='N',
                            help='decoder embedding dimension')
        parser.add_argument('--share-decoder-input-output-embed', action='store_true',
                            help='share decoder input and output embeddings')
        parser.add_argument('--share-all-embeddings', action='store_true',
                            help='share encoder, decoder and output embeddings'
                                 ' (requires shared dictionary and embed dim)')
        parser.add_argument('--no-token-positional-embeddings', default=False, action='store_true',
                            help='if set, disables positional embeddings (outside self attention)')
        parser.add_argument('--adaptive-softmax-cutoff', metavar='EXPR',
                            help='comma separated list of adaptive softmax cutoff points. '
                                 'Must be used with adaptive_loss criterion'),
        parser.add_argument('--adaptive-softmax-dropout', type=float, metavar='D',
                            help='sets adaptive softmax dropout for the tail projections')
        parser.add_argument('--layernorm-embedding', action='store_true',
                            help='add layernorm to embedding')
        parser.add_argument('--no-scale-embedding', action='store_true',
                            help='if True, dont scale embeddings')
        parser.add_argument('--encoder-layerdrop', type=float, metavar='D', default=0,
                            help='LayerDrop probability for encoder')
        parser.add_argument('--encoder-layers-to-keep', default=None,
                            help='which layers to *keep* when pruning as a comma-separated list')
        # args for Training with Quantization Noise for Extreme Model Compression ({Fan*, Stock*} et al., 2020)
        parser.add_argument('--quant-noise-pq', type=float, metavar='D', default=0,
                            help='iterative PQ quantization noise at training time')
        parser.add_argument('--quant-noise-pq-block-size', type=int, metavar='D', default=8,
                            help='block size of quantization noise at training time')
        parser.add_argument('--quant-noise-scalar', type=float, metavar='D', default=0,
                            help='scalar quantization noise and scalar quantization at training time')
        # args for "Cross+Self-Attention for Transformer Models" (Peitz et al., 2019)
        parser.add_argument('--no-cross-attention', default=False, action='store_true',
                            help='do not perform cross-attention')
        parser.add_argument('--cross-self-attention', default=False, action='store_true',
                            help='perform cross+self-attention')
        #####################
        ### args for AT parts
        ## args for AT decoder
        parser.add_argument('--at-decoder-embed-dim', type=int, metavar='N')
        parser.add_argument('--at-decoder-ffn-embed-dim', type=int, metavar='N',
                            help='decoder embedding dimension for FFN')
        parser.add_argument('--at-decoder-layers', type=int, metavar='N',
                            help='num decoder layers')
        parser.add_argument('--at-decoder-attention-heads', type=int, metavar='N',
                            help='num decoder attention heads')
        parser.add_argument('--at-decoder-learned-pos', action='store_true',
                            help='use learned positional embeddings in the decoder')
        parser.add_argument('--at-decoder-normalize-before', action='store_true',
                            help='apply layernorm before each decoder block')
        parser.add_argument('--at-decoder-output-dim', type=int, metavar='N',
                            help='decoder output dimension (extra linear layer '
                                 'if different from decoder embed dim')
        # args for "Reducing Transformer Depth on Demand with Structured Dropout" (Fan et al., 2019)
        parser.add_argument('--at-decoder-layerdrop', type=float, metavar='D', default=0,
                            help='LayerDrop probability for decoder')

        parser.add_argument('--at-decoder-layers-to-keep', default=None,
                            help='which layers to *keep* when pruning as a comma-separated list')

        # fmt: on
        ### args for NAT part
        # length prediction
        parser.add_argument("--src-embedding-copy", action="store_true",
                            help="copy encoder word embeddings as the initial input of the decoder")
        parser.add_argument("--pred-length-offset", action="store_true",
                            help="predicting the length difference between the target and source sentences")
        parser.add_argument("--sg-length-pred", action="store_true",
                            help="stop the gradients back-propagated from the length predictor")
        parser.add_argument("--length-loss-factor", type=float,
                            help="weights on the length prediction loss")
        parser.add_argument(
            "--apply-bert-init",
            action="store_true",
            help="use custom param initialization for BERT",
        )
        ## args for NAT decoder
        parser.add_argument('--nat-decoder-embed-dim', type=int, metavar='N')
        parser.add_argument('--nat-decoder-ffn-embed-dim', type=int, metavar='N',
                            help='decoder embedding dimension for FFN')
        parser.add_argument('--nat-decoder-layers', type=int, metavar='N',
                            help='num decoder layers')
        parser.add_argument('--nat-decoder-attention-heads', type=int, metavar='N',
                            help='num decoder attention heads')
        parser.add_argument('--nat-decoder-learned-pos', action='store_true',
                            help='use learned positional embeddings in the decoder')
        parser.add_argument('--nat-decoder-normalize-before', action='store_true',
                            help='apply layernorm before each decoder block')
        parser.add_argument('--nat-decoder-output-dim', type=int, metavar='N',
                            help='decoder output dimension (extra linear layer '
                                 'if different from decoder embed dim')
        # args for "Reducing Transformer Depth on Demand with Structured Dropout" (Fan et al., 2019)
        parser.add_argument('--nat-decoder-layerdrop', type=float, metavar='D', default=0,
                            help='LayerDrop probability for decoder')
        parser.add_argument('--nat-decoder-layers-to-keep', default=None,
                            help='which layers to *keep* when pruning as a comma-separated list')

    @classmethod
    def build_model(cls, args, task):
        """Build a new model instance."""

        # make sure all arguments are present in older models
        my_at_nat_transformer_arch(args)

        if args.encoder_layers_to_keep:
            args.encoder_layers = len(args.encoder_layers_to_keep.split(","))
        if args.at_decoder_layers_to_keep:
            args.at_decoder_layers = len(args.at_decoder_layers_to_keep.split(","))
        if args.nat_decoder_layers_to_keep:
            args.nat_decoder_layers = len(args.nat_decoder_layers_to_keep.split(","))

        if getattr(args, "max_source_positions", None) is None:
            args.max_source_positions = DEFAULT_MAX_SOURCE_POSITIONS
        if getattr(args, "max_target_positions", None) is None:
            args.max_target_positions = DEFAULT_MAX_TARGET_POSITIONS

        src_dict, tgt_dict = task.source_dictionary, task.target_dictionary
        at_args = args_for_at_decoder(args)
        nat_args = args_for_nat_decoder(args)

        at_nat_share_encoder = args.at_nat_share_encoder
        at_nat_share_emb = args.at_nat_share_emb
        src_tgt_share_emb = args.src_tgt_share_emb
        if at_nat_share_encoder:
            if args.share_all_embeddings:
                embed_tokens = cls.build_embedding(
                    args, src_dict, args.encoder_embed_dim, args.encoder_embed_path
                )
                at_decoder_embed_tokens = embed_tokens
                nat_decoder_embed_tokens = embed_tokens
                encoder = MyTransformerModel.build_encoder(at_args, task, embed_tokens)
                at_encoder = encoder
                nat_encoder = encoder
            elif at_nat_share_emb and not src_tgt_share_emb:
                encoder_embed_tokens = cls.build_embedding(
                    args, src_dict, args.encoder_embed_dim, args.encoder_embed_path
                )
                decoder_embed_tokens = cls.build_embedding(
                    args, tgt_dict, args.decoder_embed_dim, args.decoder_embed_path
                )
                at_decoder_embed_tokens = decoder_embed_tokens
                nat_decoder_embed_tokens = decoder_embed_tokens
                encoder = MyTransformerModel.build_encoder(at_args, task, encoder_embed_tokens)
                at_encoder = encoder
                nat_encoder = encoder
            else:
                raise NotImplementedError
        else:
            at_encoder_embed_tokens = cls.build_embedding(
                args, src_dict, args.encoder_embed_dim, args.encoder_embed_path
            )
            nat_encoder_embed_tokens = cls.build_embedding(
                args, src_dict, args.encoder_embed_dim, args.encoder_embed_path
            )
            at_encoder = MyTransformerModel.build_encoder(at_args, task, at_encoder_embed_tokens)
            nat_encoder = MyTransformerModel.build_encoder(nat_args, task, nat_encoder_embed_tokens)
            if src_tgt_share_emb:
                at_decoder_embed_tokens = at_encoder_embed_tokens
                nat_decoder_embed_tokens = nat_encoder_embed_tokens
            else:
                at_decoder_embed_tokens = cls.build_embedding(
                    args, tgt_dict, args.decoder_embed_dim, args.decoder_embed_path
                )
                nat_decoder_embed_tokens = cls.build_embedding(
                    args, tgt_dict, args.decoder_embed_dim, args.decoder_embed_path
                )

        at_decoder = MyTransformerModel.build_decoder(at_args, tgt_dict, at_decoder_embed_tokens)
        nat_decoder = CMLMNATransformerModel.build_decoder(nat_args, tgt_dict, nat_decoder_embed_tokens)
        at_transformer = MyTransformerModel(at_args, at_encoder, at_decoder)
        nat_transformer = CMLMNATransformerModel(nat_args, nat_encoder, nat_decoder)

        default_model_mode = getattr(args, "default_model_mode", None)
        model_mode = default_model_mode if default_model_mode is not None else getattr(args, "mode", "at")

        return cls(args, at_args, nat_args, at_transformer, nat_transformer, model_mode)

    @classmethod
    def build_embedding(cls, args, dictionary, embed_dim, path=None):
        num_embeddings = len(dictionary)
        padding_idx = dictionary.pad()

        emb = Embedding(num_embeddings, embed_dim, padding_idx)
        # if provided, load from preloaded dictionaries
        if path:
            embed_dict = utils.parse_embedding(path)
            utils.load_embedding(embed_dict, dictionary, emb)
        return emb

    @property
    def args(self):
        return self.base_args, self.at_args, self.nat_args

    @property
    def encoder(self):
        if self.mode == 'at':
            return self.at_transformer.encoder
        elif self.mode == 'nat':
            return self.nat_transformer.encoder
        else:
            raise ValueError("Mode Error!!!")

    @property
    def decoder(self):
        if self.mode == 'at':
            return self.at_transformer.decoder
        elif self.mode == 'nat':
            return self.nat_transformer.decoder
        else:
            raise ValueError("Mode Error!!!")

    @property
    def tgt_dict(self, mode=None):
        return self.decoder.dictionary

    @property
    def bos(self, mode=None):
        return self.decoder.dictionary.bos()

    @property
    def eos(self, mode=None):
        return self.decoder.dictionary.eos()

    @property
    def pad(self, mode=None):
        return self.decoder.dictionary.pad()

    @property
    def unk(self, mode=None):
        return self.decoder.dictionary.unk()

    def forward(
        self,
        src_tokens,
        src_lengths,
        prev_output_tokens,
        return_all_hiddens: bool = True,
        features_only: bool = False,
        alignment_layer: Optional[int] = None,
        alignment_heads: Optional[int] = None,
        tgt_tokens=None,
        **kwargs
    ):
        if self.mode == 'at' or self.mode == 'joint-at':
            at_transformer_out = self.at_transformer(
                src_tokens,
                src_lengths,
                prev_output_tokens,
                return_all_hiddens,
                features_only,
                alignment_layer,
                alignment_heads
            )
            return at_transformer_out
        elif self.mode == 'nat' or self.mode == 'joint-nat':
            nat_transformer_out = self.nat_transformer(
                src_tokens,
                src_lengths,
                prev_output_tokens,
                tgt_tokens,
                **kwargs
            )
            return nat_transformer_out
        else:
            raise ValueError("Mode Error!!!")

    def get_normalized_probs_scriptable(
        self,
        net_output: Tuple[Tensor, Optional[Dict[str, List[Optional[Tensor]]]]],
        log_probs: bool,
        sample: Optional[Dict[str, Tensor]] = None,
    ):
        if hasattr(self, "decoder"):
            return self.decoder.get_normalized_probs(net_output, log_probs, sample)
        elif torch.is_tensor(net_output):
            logits = net_output.float()
            if log_probs:
                return F.log_softmax(logits, dim=-1)
            else:
                return F.softmax(logits, dim=-1)
        raise NotImplementedError

    def max_decoder_positions(self):
        return self.decoder.max_positions()

    def forward_encoder(self, encoder_inputs):
        return self.encoder(*encoder_inputs)

    def forward_decoder(self, *args, **kwargs):
        if self.mode == 'at' or self.mode == 'joint-at':
            prev_output_tokens = args[0]
            return self.at_transformer.forward_decoder(prev_output_tokens, **kwargs)
        elif self.mode == 'nat' or self.mode == 'joint-nat':
            decoder_out = args[0]
            encoder_out = args[1]
            return self.nat_transformer.forward_decoder(decoder_out, encoder_out, **kwargs)
        else:
            raise ValueError("Mode Error!!!")

    def extract_features(self, src_tokens, src_lengths, prev_output_tokens, **kwargs):
        """
        Similar to *forward* but only return features.

        Returns:
            tuple:
                - the decoder's features of shape `(batch, tgt_len, embed_dim)`
                - a dictionary with any model-specific outputs
        """
        encoder_out = self.encoder(src_tokens, src_lengths=src_lengths, **kwargs)
        features = self.decoder.extract_features(
            prev_output_tokens, encoder_out=encoder_out, **kwargs
        )
        return features

    def output_layer(self, features, **kwargs):
        """Project features to the default output size (typically vocabulary size)."""
        return self.decoder.output_layer(features, **kwargs)

    def max_positions(self):
        """Maximum length supported by the model."""
        return (self.encoder.max_positions(), self.decoder.max_positions())

    @property
    def allow_length_beam(self):
        return False

    def initialize_output_tokens(self, encoder_out, src_tokens):
        # used in NAT decoding
        # length prediction
        length_tgt = self.decoder.forward_length_prediction(
            self.decoder.forward_length(normalize=True, encoder_out=encoder_out),
            encoder_out=encoder_out
        )

        max_length = length_tgt.clamp_(min=2).max()
        idx_length = utils.new_arange(src_tokens, max_length)

        initial_output_tokens = src_tokens.new_zeros(
            src_tokens.size(0), max_length
        ).fill_(self.pad)
        initial_output_tokens.masked_fill_(
            idx_length[None, :] < length_tgt[:, None], self.unk
        )
        initial_output_tokens[:, 0] = self.bos
        initial_output_tokens.scatter_(1, length_tgt[:, None] - 1, self.eos)

        initial_output_scores = initial_output_tokens.new_zeros(
            *initial_output_tokens.size()
        ).type_as(encoder_out.encoder_out)

        return DecoderOut(
            output_tokens=initial_output_tokens,
            output_scores=initial_output_scores,
            attn=None,
            step=0,
            max_step=0,
            history=None
        )

    def regenerate_length_beam(self, decoder_out, beam_size):
        output_tokens = decoder_out.output_tokens
        length_tgt = output_tokens.ne(self.pad).sum(1)
        length_tgt = length_tgt[:, None] + utils.new_arange(length_tgt, 1, beam_size) - beam_size // 2
        length_tgt = length_tgt.view(-1).clamp_(min=2)
        max_length = length_tgt.max()
        idx_length = utils.new_arange(length_tgt, max_length)

        initial_output_tokens = output_tokens.new_zeros(
            length_tgt.size(0), max_length
        ).fill_(self.pad)
        initial_output_tokens.masked_fill_(
            idx_length[None, :] < length_tgt[:, None], self.unk
        )
        initial_output_tokens[:, 0] = self.bos
        initial_output_tokens.scatter_(1, length_tgt[:, None] - 1, self.eos)

        initial_output_scores = initial_output_tokens.new_zeros(
            *initial_output_tokens.size()
        ).type_as(decoder_out.output_scores)

        return decoder_out._replace(
            output_tokens=initial_output_tokens,
            output_scores=initial_output_scores
        )

    def set_mode(self, mode):
        assert mode in ('at', 'nat')
        self.mode = mode


class SelectiveDecoder(TransformerDecoder):
    def __init__(self):
        pass


@register_model_architecture("my_at_nat_transformer", "my_at_nat_transformer")
def my_at_nat_transformer_arch(args):
    args.encoder_embed_path = getattr(args, "encoder_embed_path", None)
    args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 512)
    args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 2048)
    args.encoder_layers = getattr(args, "encoder_layers", 6)
    args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 8)
    args.encoder_normalize_before = getattr(args, "encoder_normalize_before", False)
    args.encoder_learned_pos = getattr(args, "encoder_learned_pos", False)
    args.decoder_embed_path = getattr(args, "decoder_embed_path", None)
    args.decoder_embed_dim = getattr(args, "decoder_embed_dim", args.encoder_embed_dim)
    args.attention_dropout = getattr(args, "attention_dropout", 0.0)
    args.activation_dropout = getattr(args, "activation_dropout", 0.0)
    args.activation_fn = getattr(args, "activation_fn", "relu")
    args.dropout = getattr(args, "dropout", 0.1)
    args.adaptive_softmax_cutoff = getattr(args, "adaptive_softmax_cutoff", None)
    args.adaptive_softmax_dropout = getattr(args, "adaptive_softmax_dropout", 0)
    args.share_decoder_input_output_embed = getattr(
        args, "share_decoder_input_output_embed", False
    )
    args.share_all_embeddings = getattr(args, "share_all_embeddings", False)
    args.no_token_positional_embeddings = getattr(
        args, "no_token_positional_embeddings", False
    )
    args.adaptive_input = getattr(args, "adaptive_input", False)
    args.no_cross_attention = getattr(args, "no-_cross_attention", False)
    args.cross_self_attention = getattr(args, "cross_self_attention", False)
    args.no_scale_embedding = getattr(args, "no_scale_embedding", False)
    args.layernorm_embedding = getattr(args, "layernorm_embedding", False)
    args.tie_adaptive_weights = getattr(args, "tie_adaptive_weights", False)
    ############################
    ### arguments for AT decoder
    args.at_decoder_ffn_embed_dim = getattr(
        args, "at_decoder_ffn_embed_dim", args.encoder_ffn_embed_dim
    )
    args.at_decoder_layers = getattr(args, "at_decoder_layers", 6)
    args.at_decoder_attention_heads = getattr(args, "at_decoder_attention_heads", 8)
    args.at_decoder_normalize_before = getattr(args, "at_decoder_normalize_before", False)
    args.at_decoder_learned_pos = getattr(args, "at_decoder_learned_pos", False)


    args.at_decoder_output_dim = getattr(
        args, "at_decoder_output_dim", args.decoder_embed_dim
    )
    args.at_decoder_input_dim = getattr(args, "at_decoder_input_dim", args.decoder_embed_dim)
    #############################
    ### arguments for NAT decoder


def args_for_at_decoder(args):
    ## args for AT decoder
    at_args = copy.deepcopy(args)
    at_args.decoder_embed_dim = getattr(args, "at_decoder_embed_dim", args.encoder_embed_dim)
    at_args.decoder_ffn_embed_dim = getattr(args, "at_decoder_ffn_embed_dim", args.encoder_ffn_embed_dim)
    at_args.decoder_layers = getattr(args, "at_decoder_layers", 6)
    at_args.decoder_attention_heads = getattr(args, "at_decoder_attention_heads", 8)
    at_args.decoder_learned_pos = getattr(args, "at_decoder_learned_pos", False)
    at_args.decoder_normalize_before = getattr(args, "at_decoder_normalize_before", False)
    at_args.decoder_output_dim = getattr(args, "at_decoder_output_dim", args.decoder_embed_dim)
    at_args.decoder_layerdrop = getattr(args, "at_decoder_layerdrop", 0)
    at_args.decoder_layers_to_keep = getattr(args, "at_decoder_layers_to_keep", None)

    return at_args


def args_for_nat_decoder(args):
    ## args for NAT decoder
    nat_args = copy.deepcopy(args)
    nat_args.decoder_embed_dim = getattr(args, "nat_decoder_embed_dim", args.encoder_embed_dim)
    nat_args.decoder_ffn_embed_dim = getattr(args, "nat_decoder_ffn_embed_dim", args.encoder_ffn_embed_dim)
    nat_args.decoder_layers = getattr(args, "nat_decoder_layers", 6)
    nat_args.decoder_attention_heads = getattr(args, "nat_decoder_attention_heads", 8)
    nat_args.decoder_learned_pos = getattr(args, "nat_decoder_learned_pos", False)
    nat_args.decoder_normalize_before = getattr(args, "nat_decoder_normalize_before", False)
    nat_args.decoder_output_dim = getattr(args, "nat_decoder_output_dim", args.decoder_embed_dim)
    nat_args.decoder_layerdrop = getattr(args, "nat_decoder_layerdrop", 0)
    nat_args.decoder_layers_to_keep = getattr(args, "nat_decoder_layers_to_keep", None)

    return nat_args
