# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
"""
BART: Denoising Sequence-to-Sequence Pre-training for
Natural Language Generation, Translation, and Comprehension
"""

import logging
from fairseq.models import (
    register_model, register_model_architecture
)
from examples.summarization.models.summ_transformer import SummTransformerDecoder
from examples.summarization.models.bartSummAbs import BARTSummAbsModel, mbart_summ_large_architecture

logger = logging.getLogger(__name__)

@register_model('bart_summ_abs_bn')
class BARTSummBNModel(BARTSummAbsModel):
    def __init__(self, args, encoder, decoder):
        super().__init__(args, encoder, decoder)
        self.args = args

    @staticmethod
    def add_args(parser):
        super(BARTSummBNModel, BARTSummBNModel).add_args(parser)
        parser.add_argument(
            '--decoder-layer-w-bn-cross-attn', nargs="+", type=int, default=[], 
            help='decoder layers with cross-attention where query and key are batch normalized before calculating weights'
        )
        parser.add_argument(
            "--tune-query-weight", action="store_true", 
            help="For decoder layers with BN cross-attention, whether tune W_query"
        )
        parser.add_argument(
            "--tune-query-weight-wo-bn", action="store_true", 
            help="For decoder layers wo BN cross-attention, whether tune W_query"
        )
        parser.add_argument(
            "--tune-key-weight-wo-bn", action="store_true", 
            help="For decoder layers wo BN cross-attention, whether tune W_key"
        )
    
    @property
    def supported_targets(self):
        return {'self'}

    @classmethod
    def build_decoder(cls, args, tgt_dict, embed_tokens):
        return SummTransformerDecoder(
            args,
            tgt_dict,
            embed_tokens,
            no_encoder_attn=getattr(args, "no_cross_attention", False),
        )

    def postprocess(self, args, task):
        super().postprocess(args, task)

        if self.args.tune_query_weight:
            for i in self.args.decoder_layer_w_bn_cross_attn:
                self.decoder.layers[i].encoder_attn.q_proj.requires_grad_(True)
        
        if self.args.tune_query_weight_wo_bn:
            for i in range(args.decoder_layers):
                self.decoder.layers[i].encoder_attn.q_proj.requires_grad_(True)
        
        if self.args.tune_key_weight_wo_bn:
            for i in range(args.decoder_layers):
                self.decoder.layers[i].encoder_attn.k_proj.requires_grad_(True)

@register_model_architecture('bart_summ_abs_bn', 'mbart_summ_abs_bn_large')
def mbart_summ_abs_bn_large_architecture(args):
    mbart_summ_large_architecture(args)
    args.tune_query_weight = getattr(args, "tune_query_weight", False)
    args.tune_query_weight_wo_bn = getattr(args, "tune_query_weight_wo_bn", False)
    args.tune_key_weight_wo_bn = getattr(args, "tune_key_weight_wo_bn", False)
