# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
"""
BART: Denoising Sequence-to-Sequence Pre-training for
Natural Language Generation, Translation, and Comprehension
"""

import logging
import copy

import torch
import torch.nn.functional as F
import torch.nn as nn

from fairseq import utils
from fairseq.models import (
    register_model,
    register_model_architecture,
)
from fairseq.models.transformer import TransformerModel, TransformerDecoder
from fairseq.modules.transformer_sentence_encoder import init_bert_params

from .hub_interface import BARTHubInterface


def insert_lang_code(res, left_pad, lang_code):
    tag_words = res.new(res.size(0), 1).fill_(lang_code)
    if res.size(1) == 1:
        if left_pad:
            return torch.cat([res, tag_words], dim=1)
        else:
            return torch.cat([tag_words, res], dim=1)
    if left_pad:
        res = torch.cat([res[:,:-1], tag_words, res[:,-1].view(res.size(0), 1)], dim=1)
    else:
        res = torch.cat([res[:,0].view(res.size(0), 1), tag_words, res[:,1:]], dim=1)
    return res
logger = logging.getLogger(__name__)


@register_model('bart')
class BARTModel(TransformerModel):

    @classmethod
    def hub_models(cls):
        return {
            'bart.large': 'http://dl.fbaipublicfiles.com/fairseq/models/bart.large.tar.gz',
            'bart.large.mnli': 'http://dl.fbaipublicfiles.com/fairseq/models/bart.large.mnli.tar.gz',
            'bart.large.cnn': 'http://dl.fbaipublicfiles.com/fairseq/models/bart.large.cnn.tar.gz',
        }

    def __init__(self, args, encoder, decoder):
        super().__init__(args, encoder, decoder)

        # We follow BERT's random weight initialization
        self.apply(init_bert_params)

        self.classification_heads = nn.ModuleDict()
        self.register_classification_head('formal',num_classes =2, inner_dim=1024)

    def add_disc_arch(self, args,target_dict):
        self.lm_src = TransformerDecoder(args, target_dict, self.decoder.embed_tokens, no_encoder_attn = True).cuda()
        self.lm_tgt = TransformerDecoder(args, target_dict, self.decoder.embed_tokens, no_encoder_attn = True).cuda()
        for p in self.lm_src.parameters():
            p.requires_grad = True
        for p in self.lm_tgt.parameters():
            p.requires_grad = True

    @staticmethod
    def add_args(parser):
        super(BARTModel, BARTModel).add_args(parser)
        parser.add_argument(
            '--pooler-dropout', type=float, metavar='D',
            help='dropout probability in the masked_lm pooler layers'
        )
        parser.add_argument(
            '--pooler-activation-fn',
            choices=utils.get_available_activation_fns(),
            help='activation function to use for pooler layer'
        )

    @property
    def supported_targets(self):
        return {'self'}

    def fix_discriminator(self):
        classification_head_name='formal'
        for p in self.classification_heads[classification_head_name].parameters():
            p.requires_grad = False
        for p in self.lm_src.parameters():
            p.requires_grad = False
        for p in self.lm_tgt.parameters():
            p.requires_grad = False

    def forward(
        self, src_tokens, src_lengths, prev_output_tokens,
        tgt_lang = 1, features_only=False, classification_head_name=None, **kwargs
    ):
        #print(src_tokens.size())
        if (features_only):
            y = None
            extra=None
            lm1 = self.lm_tgt(prev_output_tokens)[0]
            lm2 = self.lm_src(prev_output_tokens)[0]
            lm1_score = torch.prod(torch.gather(lm1, 2, src_tokens.unsqueeze(2)),1)
            lm2_score = torch.prod(torch.gather(lm2, 2, src_tokens.unsqueeze(2)),1)
            scores = torch.cat((lm2_score,lm1_score),1)
            return(y,extra, scores)
            x = self.decoder.embed_tokens(src_tokens)
        else:
            add_word = 50262 if (tgt_lang==1) else 50263
            encoder_out = self.encoder(
                insert_lang_code(src_tokens,True,add_word),
                src_lengths=src_lengths,
               **kwargs,
            )
            x, extra = self.decoder(
                prev_output_tokens,
                encoder_out=encoder_out,
                features_only=True,
                **kwargs,
            )
            y = self.decoder.output_layer(x)
            src_tokens = torch.argmax(y,2)

            lm1 = self.lm_tgt(prev_output_tokens)[0]
            lm2 = self.lm_src(prev_output_tokens)[0]
            lm1_score = torch.prod(torch.gather(lm1, 2, src_tokens.unsqueeze(2)),1)
            lm2_score = torch.prod(torch.gather(lm2, 2, src_tokens.unsqueeze(2)),1)
            scores = torch.cat((lm2_score,lm1_score),1)
            return(y,extra, scores)

        if classification_head_name is not None:
            x = self.classification_heads[classification_head_name](
                x
            )
        return y, extra, x

    @classmethod
    def from_pretrained(
        cls,
        model_name_or_path,
        checkpoint_file='model.pt',
        data_name_or_path='.',
        bpe='gpt2',
        **kwargs,
    ):
        from fairseq import hub_utils
        x = hub_utils.from_pretrained(
            model_name_or_path,
            checkpoint_file,
            data_name_or_path,
            archive_map=cls.hub_models(),
            bpe=bpe,
            load_checkpoint_heads=True,
            **kwargs,
        )
        return BARTHubInterface(x['args'], x['task'], x['models'][0])

    def register_classification_head(self, name, num_classes=None, inner_dim=None, **kwargs):
        """Register a classification head."""
        logger.info("Registering classification head: {0}".format(name))
        if name in self.classification_heads:
            prev_num_classes = self.classification_heads[name].out_proj.out_features
            prev_inner_dim = self.classification_heads[name].dense.out_features
            if num_classes != prev_num_classes or inner_dim != prev_inner_dim:
                logger.warning(
                    're-registering head "{}" with num_classes {} (prev: {}) '
                    'and inner_dim {} (prev: {})'.format(
                        name, num_classes, prev_num_classes, inner_dim, prev_inner_dim
                    )
                )
        self.classification_heads[name] = BARTClassificationHead(
            self.args.encoder_embed_dim,
            inner_dim or self.args.encoder_embed_dim,
            num_classes,
            self.args.pooler_activation_fn,
            self.args.pooler_dropout,
        )

    def upgrade_state_dict_named(self, state_dict, name):
        super().upgrade_state_dict_named(state_dict, name)

        prefix = name + '.' if name != '' else ''
        current_head_names = [] if not hasattr(self, 'classification_heads') else \
            self.classification_heads.keys()

        # Handle new classification heads present in the state dict.
        keys_to_delete = []
        for k in state_dict.keys():
            if not k.startswith(prefix + 'classification_heads.'):
                continue

            head_name = k[len(prefix + 'classification_heads.'):].split('.')[0]
            num_classes = 2#state_dict[prefix + 'classification_heads.' + head_name + '.disc_fc.weight'].size(0)
            inner_dim = 1024#state_dict[prefix + 'classification_heads.' + head_name + '.dense.weight'].size(0)

            if getattr(self.args, 'load_checkpoint_heads', False):
                if head_name not in current_head_names:
                    self.register_classification_head(head_name, num_classes, inner_dim)
            else:
                if head_name not in current_head_names:
                    logger.warning(
                        'deleting classification head ({}) from checkpoint '
                        'not present in current model: {}'.format(head_name, k)
                    )
                    keys_to_delete.append(k)
                elif (
                    num_classes != 2#self.classification_heads[head_name].out_proj.out_features
                    or inner_dim != 1024#self.classification_heads[head_name].dense.out_features
                ):
                    logger.warning(
                        'deleting classification head ({}) from checkpoint '
                        'with different dimensions than current model: {}'.format(head_name, k)
                    )
                    keys_to_delete.append(k)
        for k in keys_to_delete:
            del state_dict[k]

        # When finetuning on translation task, remove last row of
        # embedding matrix that corresponds to mask_idx token.
        loaded_dict_size = state_dict['encoder.embed_tokens.weight'].size(0)
        if loaded_dict_size == len(self.encoder.dictionary) + 1 and '<mask>' not in self.encoder.dictionary:
            state_dict['encoder.embed_tokens.weight'] = state_dict['encoder.embed_tokens.weight'][:loaded_dict_size-1, :]
            state_dict['decoder.embed_tokens.weight'] = state_dict['decoder.embed_tokens.weight'][:loaded_dict_size-1, :]

        # Copy any newly-added classification heads into the state dict
        # with their current weights.
        if hasattr(self, 'classification_heads'):
            cur_state = self.classification_heads.state_dict()
            for k, v in cur_state.items():
                if prefix + 'classification_heads.' + k not in state_dict:
                    #logger.info('Overwriting', prefix + 'classification_heads.' + k)
                    state_dict[prefix + 'classification_heads.' + k] = v


class BARTClassificationHead(nn.Module):
    """Head for sentence-level classification tasks."""

    def __init__(
        self,
        input_dim,
        inner_dim,
        num_classes,
        activation_fn,
        pooler_dropout,
    ):
        super().__init__()
        self.conv3 = nn.Conv2d(1, 3, (3, 1024))
        self.conv4 = nn.Conv2d(1, 3, (4, 1024))
        self.conv5 = nn.Conv2d(1, 3, (5, 1024))

        self.disc_fc = nn.Sequential(
            nn.Dropout(0.1),
            nn.Linear(3*3, 2)
        )

    def forward(self, features, **kwargs):
        inputs = features.unsqueeze(1)  # mbsize x 1 x seq_len x emb_dim

        # Conv
        x3 = F.relu(self.conv3(inputs)).squeeze(dim=3)
        x4 = F.relu(self.conv4(inputs)).squeeze(dim=3)
        x5 = F.relu(self.conv5(inputs)).squeeze(dim=3)

        # Max-over-time-pool
        x3 = F.max_pool1d(x3, x3.size(2)).squeeze(dim=2)
        x4 = F.max_pool1d(x4, x4.size(2)).squeeze(dim=2)
        x5 = F.max_pool1d(x5, x5.size(2)).squeeze(dim=2)

        x = torch.cat([x3, x4, x5], dim=1)
        y = self.disc_fc(x)
        return y



@register_model_architecture('bart', 'bart_large')
def bart_large_architecture(args):
    args.encoder_embed_path = getattr(args, 'encoder_embed_path', None)
    args.encoder_embed_dim = getattr(args, 'encoder_embed_dim', 1024)
    args.encoder_ffn_embed_dim = getattr(args, 'encoder_ffn_embed_dim', 4*1024)
    args.encoder_layers = getattr(args, 'encoder_layers', 12)
    args.encoder_attention_heads = getattr(args, 'encoder_attention_heads', 16)
    args.encoder_normalize_before = getattr(args, 'encoder_normalize_before', False)
    args.encoder_learned_pos = getattr(args, 'encoder_learned_pos', True)
    args.decoder_embed_path = getattr(args, 'decoder_embed_path', None)
    args.decoder_embed_dim = getattr(args, 'decoder_embed_dim', args.encoder_embed_dim)
    args.decoder_ffn_embed_dim = getattr(args, 'decoder_ffn_embed_dim', args.encoder_ffn_embed_dim)
    args.decoder_layers = getattr(args, 'decoder_layers', 12)
    args.decoder_attention_heads = getattr(args, 'decoder_attention_heads', 16)
    args.decoder_normalize_before = getattr(args, 'decoder_normalize_before', False)
    args.decoder_learned_pos = getattr(args, 'decoder_learned_pos', True)
    args.attention_dropout = getattr(args, 'attention_dropout', 0.)
    args.relu_dropout = getattr(args, 'relu_dropout', 0.)
    args.dropout = getattr(args, 'dropout', 0.1)
    args.max_target_positions = getattr(args, 'max_target_positions', 1024)
    args.max_source_positions = getattr(args, 'max_source_positions', 1024)
    args.adaptive_softmax_cutoff = getattr(args, 'adaptive_softmax_cutoff', None)
    args.adaptive_softmax_dropout = getattr(args, 'adaptive_softmax_dropout', 0)
    args.share_decoder_input_output_embed = getattr(args, 'share_decoder_input_output_embed', True)
    args.share_all_embeddings = getattr(args, 'share_all_embeddings', True)

    args.decoder_output_dim = getattr(args, 'decoder_output_dim', args.decoder_embed_dim)
    args.decoder_input_dim = getattr(args, 'decoder_input_dim', args.decoder_embed_dim)

    args.no_scale_embedding = getattr(args, 'no_scale_embedding', True)
    args.layernorm_embedding = getattr(args, 'layernorm_embedding', True)

    args.activation_fn = getattr(args, 'activation_fn', 'gelu')
    args.pooler_activation_fn = getattr(args, 'pooler_activation_fn', 'tanh')
    args.pooler_dropout = getattr(args, 'pooler_dropout', 0.0)


@register_model_architecture('bart', 'mbart_large')
def mbart_large_architecture(args):
    args.no_scale_embedding = getattr(args, 'no_scale_embedding', False)
    bart_large_architecture(args)
