# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
"""
BART: Denoising Sequence-to-Sequence Pre-training for
Natural Language Generation, Translation, and Comprehension
"""

import logging

import torch
import torch.nn as nn

from fairseq import utils
from fairseq.models import (
    register_model,
    register_model_architecture,
)
from fairseq.models.transformer import Embedding
from fairseq.models.bart import BARTModel, bart_large_architecture
from examples.summarization.models.summ_transformer import SummTransformerDecoder
from examples.summarization.modules.utils import GradReverse

logger = logging.getLogger(__name__)

@register_model('bart_summ')
class BARTSummModel(BARTModel):
    def __init__(self, args, encoder, decoder):
        super().__init__(args, encoder, decoder)
        self.args = args
        if self.args.use_lang_classifier:
            num_langs = len(args.summ_langs) + len(args.denoise_langs)
            self.use_lang_classifer = nn.Linear(
                args.encoder_embed_dim,
                num_langs,
                bias=False,
            )

        self.saliency_model = nn.Linear(
            args.encoder_embed_dim,
            1,
            bias=True
        )

    @staticmethod
    def add_args(parser):
        super(BARTSummModel, BARTSummModel).add_args(parser)
    
    @property
    def supported_targets(self):
        return {'self'}

    @classmethod
    def build_decoder(cls, args, tgt_dict, embed_tokens):
        return SummTransformerDecoder(
            args,
            tgt_dict,
            embed_tokens,
            no_encoder_attn=getattr(args, "no_cross_attention", False),
        )

    def forward(
        self, src_tokens, src_lengths, prev_output_tokens=None,
        features_only=False, classification_head_name=None, **kwargs
    ):
        if classification_head_name is not None:
            features_only = True

        encoder_out = self.encoder(
            src_tokens,
            src_lengths=src_lengths,
            **kwargs
        )

        encoder_last_hidden = encoder_out.encoder_out
        encoder_last_hidden = encoder_last_hidden.transpose(0, 1) # [T x B x C] -> [B x T x C]
        saliency_out = self.saliency_model(
            encoder_last_hidden
        ).squeeze(-1)
        saliency_logit = torch.sigmoid(saliency_out)

        if getattr(self.args, "only_encoder_for_cls", False):
            x = encoder_out.encoder_out
            x = x.permute(1,0,2)
            if classification_head_name is not None:
                sentence_representation = x[
                    src_tokens.eq(self.encoder.dictionary.eos()), :
                ].view(x.size(0), -1, x.size(-1))[:, -1, :]
                x = self.classification_heads[classification_head_name](
                    sentence_representation
                )
            return x, encoder_out

        saliency_mask = None
        for summ_lang in self.args.summ_langs:
            saliency_lang_mask = torch.eq(
                src_tokens[-1], self.encoder.dictionary.index(summ_lang)
            )
            if saliency_mask is None:
                saliency_mask = saliency_lang_mask
            else:
                saliency_mask = saliency_mask | saliency_lang_mask # [B]

        masked_saliency_logit = torch.where(
            saliency_mask,
            saliency_logit,
            torch.ones(saliency_logit.size()).type_as(saliency_logit).type_as(saliency_logit),
        )

        x, extra = self.decoder(
            prev_output_tokens,
            encoder_out=encoder_out,
            features_only=features_only,
            encdec_attn_weights_scale=masked_saliency_logit,
            **kwargs,
        )

        extra['encoder_out'] = encoder_out.encoder_out
        if self.args.use_lang_classifier:
            reversed_encoder_last_hidden = GradReverse.apply(encoder_last_hidden)
            lang_cls_out = self.lang_classifer(
                reversed_encoder_last_hidden,
            )
            extra['lang_cls_out'] = lang_cls_out

        if classification_head_name is not None:
            sentence_representation = x[
                src_tokens.eq(self.encoder.dictionary.eos()), :
            ].view(x.size(0), -1, x.size(-1))[:, -1, :]
            x = self.classification_heads[classification_head_name](
                sentence_representation
            )
        return x, extra

    def postprocess(self, args, task):
        """
        modify model after loading the existing checkpoint
        """
        new_embed_tokens = Embedding(
            len(task.source_dictionary), args.encoder_embed_dim, task.source_dictionary.pad()
        )
        old_dictionary_size = self.encoder.embed_tokens.weight.data.size(0)
        new_embed_tokens.weight.data[:old_dictionary_size] = self.encoder.embed_tokens.weight.data.clone().detach()

        self.tpu = getattr(args, 'tpu', False)
        self.cuda = torch.cuda.is_available() and not args.cpu and not self.tpu
        if self.cuda:
            self.device = torch.device('cuda')
        elif self.tpu:
            self.device = utils.get_tpu_device(args)
        else:
            self.device = torch.device('cpu')

        new_embed_tokens = new_embed_tokens.to(self.device)

        if args.fp16:
            new_embed_tokens = new_embed_tokens.half()
        elif args.bf16:
            new_embed_tokens = new_embed_tokens.to(dtype=torch.bfloat16)

        self.encoder.embed_tokens = new_embed_tokens 
        self.decoder.embed_tokens = self.encoder.embed_tokens

        if hasattr(args, "freeze_decoder") and args.freeze_decoder or \
            (hasattr(args, "freeze_encoder") and args.freeze_encoder) or \
            (hasattr(args, "freeze_embedding") and args.freeze_embedding):
            self.encoder.embed_tokens.requires_grad_(False)
        
        if hasattr(args, "freeze_position_embedding") and args.freeze_position_embedding:
            self.encoder.embed_positions.requires_grad_(False)

        if self.decoder.share_input_output_embed:
            self.decoder.output_projection.weight = self.decoder.embed_tokens.weight


@register_model_architecture('bart_summ', 'mbart_summ_large')
def bart_summ_large_architecture(args):
    args.freeze_encoder = getattr(args, "freeze_encoder", False)
    args.freeze_decoder = getattr(args, "freeze_decoder", False)
    args.freeze_embedding = getattr(args, "freeze_embedding", False)
    args.freezed_encoder_layers = getattr(args, "freezed_encoder_layers", [])
    args.freezed_decoder_layers = getattr(args, "freezed_decoder_layers", [])
    bart_large_architecture(args)
