# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import contextlib
import copy
import math

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from fairseq import checkpoint_utils, tasks, utils
from fairseq.models import (
    BaseFairseqModel,
    FairseqEncoder,
    FairseqEncoderDecoderModel,
    FairseqIncrementalDecoder,
    register_model,
    register_model_architecture,
)
from fairseq.modules import LayerNorm, PositionalEmbedding, TransformerDecoderLayer, MultiheadAttention
from bert import BertTokenizer
from bert import BertModelWithAdapter
from bert.modeling import BertEmbeddings, BertAttention, BertIntermediate, BertOutput, BertPreTrainedModel, BertOnlyMLMHead

DEFAULT_MAX_SOURCE_POSITIONS = 1024
DEFAULT_MAX_TARGET_POSITIONS = 1024


import logging
logger = logging.getLogger(__name__)

def add_common_args(parser):
    parser.add_argument("--w2v-path", help="path to wav2bert 2.0 model")
    parser.add_argument(
        "--no-pretrained-weights",
        action="store_true",
        help="if true, does not load pretrained weights",
    )
    parser.add_argument(
        "--dropout-input",
        type=float,
        metavar="D",
        help="dropout to apply to the input (after feat extr)",
    )
    parser.add_argument(
        "--final-dropout",
        type=float,
        metavar="D",
        help="dropout after transformer and before final projection",
    )
    parser.add_argument(
        "--apply-mask", action="store_true", help="apply masking during fine-tuning"
    )
    parser.add_argument(
        "--dropout",
        type=float,
        metavar="D",
        help="dropout probability inside wav2bert 2.0 model",
    )
    parser.add_argument(
        "--attention-dropout",
        type=float,
        metavar="D",
        help="dropout probability for attention weights inside wav2bert 2.0 model",
    )
    parser.add_argument(
        "--activation-dropout",
        "--relu-dropout",
        type=float,
        metavar="D",
        help="dropout probability after activation in FFN inside wav2bert 2.0 model",
    )

    parser.add_argument(
        "--mask-length", type=int, help="repeat the mask indices multiple times"
    )

    parser.add_argument(
        "--mask-prob", type=float, help="probability of replacing a token with mask"
    )

    parser.add_argument(
        "--mask-selection",
        type=str,
        choices=["static", "uniform", "normal", "poisson"],
        help="how to choose masks",
    )

    parser.add_argument(
        "--mask-other",
        type=float,
        help="stdev of the mask length in case of 'normal' selection strategy",
    )

    parser.add_argument(
        "--no-mask-overlap",
        action="store_true",
        help="whether to allow masks to overlap",
    )

    parser.add_argument(
        "--mask-channel-length", type=int, help="repeat the mask indices multiple times"
    )

    parser.add_argument(
        "--mask-channel-prob",
        type=float,
        help="probability of replacing a token with mask",
    )

    parser.add_argument(
        "--mask-channel-selection",
        type=str,
        choices=["static", "uniform", "normal", "poisson"],
        help="how to choose masks",
    )

    parser.add_argument(
        "--mask-channel-other",
        type=float,
        help="stdev of the mask length in case of 'normal' selection strategy",
    )

    parser.add_argument(
        "--no-mask-channel-overlap",
        action="store_true",
        help="whether to allow masks to overlap",
    )

    parser.add_argument(
        "--freeze-finetune-updates",
        default=0,
        type=int,
        help="dont finetune wav2bert for this many updates",
    )

    parser.add_argument(
        "--feature-grad-mult",
        default=None,
        type=float,
        help="reset feature grad mult in wav2bert 2.0 to this",
    )

    parser.add_argument(
        "--layerdrop",
        default=0.0,
        type=float,
        help="probability of dropping a layer in wav2bert 2.0",
    )

    parser.add_argument(
        "--freeze-bert",
        action="store_true",
    )


@register_model("wav2bert_masked_predict_fusion_ctc_gate2")
class Wav2BertCtcMlm(BaseFairseqModel):
    @staticmethod
    def add_args(parser):
        """Add model-specific arguments to the parser."""
        add_common_args(parser)
        parser.add_argument('--adapter-dimension', default=2048, type=int)

    def __init__(self, w2v_encoder, args):
        super().__init__()
        print("Construct wav2bert_masked_predict_fusion_ctc model...")
        self.w2v_encoder = w2v_encoder
        self.args = args

    def upgrade_state_dict_named(self, state_dict, name):
        super().upgrade_state_dict_named(state_dict, name)
        return state_dict

    @classmethod
    def build_model(cls, args, task):
        """Build a new model instance."""
        base_architecture(args)
        w2v_encoder = Wav2BertEncoder(args, task.target_dictionary)
        return cls(w2v_encoder, args)

    def get_normalized_probs(self, net_output, log_probs):
        """Get normalized probabilities (or log probs) from a net's output."""

        logits = net_output["encoder_out"]
        if log_probs:
            return utils.log_softmax(logits.float(), dim=-1)
        else:
            return utils.softmax(logits.float(), dim=-1)
    
    def get_normalized_ctc_probs(self, net_output, log_probs):
        """Get normalized probabilities (or log probs) from a net's output."""

        logits = net_output["encoder_ctc"]
        if log_probs:
            return utils.log_softmax(logits.float(), dim=-1)
        else:
            return utils.softmax(logits.float(), dim=-1)

    def get_normalized_decoder_probs(self, net_output, log_probs):
        """Get normalized probabilities (or log probs) from a net's output."""

        logits = net_output["decoder_out"]
        if log_probs:
            return utils.log_softmax(logits.float(), dim=-1)
        else:
            return utils.softmax(logits.float(), dim=-1)

    def forward(self, **kwargs):
        x = self.w2v_encoder(**kwargs)
        return x

    # def max_positions(self):
    #     return None

class Wav2BertEncoder(FairseqEncoder):
    def __init__(self, args, tgt_dict=None):
        self.apply_mask = args.apply_mask

        arg_overrides = {
            "dropout": args.dropout,
            "activation_dropout": args.activation_dropout,
            "dropout_input": args.dropout_input,
            "attention_dropout": args.attention_dropout,
            "mask_length": args.mask_length,
            "mask_prob": args.mask_prob,
            "mask_selection": args.mask_selection,
            "mask_other": args.mask_other,
            "no_mask_overlap": args.no_mask_overlap,
            "mask_channel_length": args.mask_channel_length,
            "mask_channel_prob": args.mask_channel_prob,
            "mask_channel_selection": args.mask_channel_selection,
            "mask_channel_other": args.mask_channel_other,
            "no_mask_channel_overlap": args.no_mask_channel_overlap,
            "encoder_layerdrop": args.layerdrop,
            "feature_grad_mult": args.feature_grad_mult,
        }

        if getattr(args, "w2v_args", None) is None:
            state = checkpoint_utils.load_checkpoint_to_cpu(
                args.w2v_path, arg_overrides
            )
            w2v_args = state["args"]
        else:
            state = None
            w2v_args = args.w2v_args

        assert (
            args.normalize == w2v_args.normalize
        ), "Fine-tuning works best when data normalization is the same"

        w2v_args.data = args.data
        task = tasks.setup_task(w2v_args)
        model = task.build_model(w2v_args)

        if state is not None and not args.no_pretrained_weights:
            model.load_state_dict(state["model"], strict=True)

        model.remove_pretraining_modules()

        super().__init__(task.source_dictionary)

        d = w2v_args.encoder_embed_dim
        # d = 30522

        self.w2v_model = model
        # add bert
        base_architecture(args)
        self.encoder_dropout = nn.Dropout(args.final_dropout)

        if not hasattr(args, 'max_source_positions'):
            args.max_source_positions = DEFAULT_MAX_SOURCE_POSITIONS
        if not hasattr(args, 'max_target_positions'):
            args.max_target_positions = DEFAULT_MAX_TARGET_POSITIONS
        # import ipdb; ipdb.set_trace()
        self.bertdecoder = BertAdapterDecoderFull.from_pretrained(args.decoder_bert_model_name, args, from_scratch=args.train_from_scratch)
        # end
        self.final_dropout = nn.Dropout(args.final_dropout)
        self.freeze_finetune_updates = args.freeze_finetune_updates
        self.num_updates = 0

        if tgt_dict is not None:
            self.proj = Linear(d, len(tgt_dict))
            self.encoder_proj = Linear(d, len(tgt_dict))
            print("d:", d)
            # self.proj = Linear(d, 32)
            # print("len(tgt_dict):", len(tgt_dict))
        elif getattr(args, "decoder_embed_dim", d) != d:
            self.proj = Linear(d, args.decoder_embed_dim)
            self.encoder_proj = Linear(d, args.decoder_embed_dim)
        else:
            self.proj = None

    def set_num_updates(self, num_updates):
        """Set the number of parameters updates."""
        super().set_num_updates(num_updates)
        self.num_updates = num_updates

    def forward(self, source, padding_mask, prev_output_tokens=None, tbc=True, **kwargs):
        w2v_args = {
            "source": source,#B*src_len[1, 259840]
            "padding_mask": padding_mask,#B*src_len[1, 259840]
            "mask": self.apply_mask and self.training,#True
        }

        ft = self.freeze_finetune_updates <= self.num_updates

        with torch.no_grad() if not ft else contextlib.ExitStack():
            x, padding_mask = self.w2v_model.extract_features(**w2v_args)#B,T,C [1, 811, 768]

        encoder_feature = x

        with torch.no_grad() if not ft else contextlib.ExitStack():
            if tbc:
                # B x T x C -> T x B x C
                encoder_ctc = x.transpose(0, 1) # [4, 235, 30522]->[235, 4, 30522] [811, 1, 768]
        encoder_ctc = self.encoder_dropout(encoder_ctc)

        if self.encoder_proj:
            encoder_ctc = self.encoder_proj(encoder_ctc)

        # B x T x C -> T x B x C
        decoder_input = x.permute(1,0,2).contiguous()#[811, 1, 768]
        # add bert here
        encoder_out = {
            'encoder_out': decoder_input, # T x B x C [811, 1, 768]
            'encoder_padding_mask': padding_mask, # B x T [1, 811]
        }
        # print("x.shape:", x.shape) # x.shape [811, 4, 768] T,B,C padding_mask.shape=[4, 811] B,T, prev_output_tokens [4, 235] B, T'
        # decoder_out: B,tgt_len,Out_dim[1, 278, 30522]
        decoder_out, decoder_hidden, fusion_out, align_result = self.bertdecoder(prev_output_tokens, encoder_out=encoder_out, padding_idx=0) # prev_output_tokens B*L # x.shape [4, 235, 30522] B,T',C #validation step2 error
        # fusion v2:x输入独立的fusion模块，以及bertdecoder的最后一个hidden state也输入，计算出fusion output，然后通过一层linear算CTC loss
        x = fusion_out #x.shape [2, 767, 768]->fusion.shape [2, 767, 768] 若将bert attention到wav2vec则shape不变，否则变成tgt_len
        # with torch.no_grad() if not ft else contextlib.ExitStack():
        if tbc:
            # B x T x C -> T x B x C
            x = x.transpose(0, 1) # [4, 235, 30522]->[235, 4, 30522] [811, 1, 768]
        x = self.final_dropout(x)

        if self.proj:
            x = self.proj(x) # source_channel_dim->vocab_dim 768->30522
        # print("x.shape:", x.shape) # [235, 4, 32] T',B,C
        return {
            "encoder_out": x,  # T x B x C [235, 4, 32] [811, 1, 30522]
            "encoder_padding_mask": padding_mask,  # B x T [4, 235] [1, 811]
            "decoder_out": decoder_out,#[1, 278, 30522] B,tgt_len,vocab_size
            "padding_mask": padding_mask,
            "encoder_ctc": encoder_ctc,
            "align_target": encoder_feature,
            "align_fusion": align_result, # here not the align
        }

    def reorder_encoder_out(self, encoder_out, new_order):
        if encoder_out["encoder_out"] is not None:
            encoder_out["encoder_out"] = encoder_out["encoder_out"].index_select(
                1, new_order
            )
        if encoder_out["encoder_padding_mask"] is not None:
            encoder_out["encoder_padding_mask"] = encoder_out[
                "encoder_padding_mask"
            ].index_select(0, new_order)
        return encoder_out

    def max_positions(self):
        """Maximum input length supported by the encoder."""
        return None

    def upgrade_state_dict_named(self, state_dict, name):
        return state_dict


class BertAdapterDecoderFull(BertPreTrainedModel):
    def __init__(self, config, args):
        super(BertAdapterDecoderFull, self).__init__(config)
        self.bert = BertDecoderAssemble(config, args)
        self.cls = BertOnlyMLMHead(config, self.bert.embeddings.word_embeddings.weight)
        self.apply(self.init_bert_weights)
        self.fusion_adapter = FusionAdapter(args)
        self.freeze_bert = getattr(args, 'freeze_bert', False)
        self.onnx_trace = False

    def forward(self, prev_output_tokens, src_tokens=None, encoder_out=None, padding_idx=0, **kwargs):
        with torch.no_grad() if self.freeze_bert else contextlib.ExitStack():
            sequence_output, targets_padding = self.bert(prev_output_tokens, encoder_out, padding_idx)
        prediction_scores = self.cls(sequence_output) # dim from 768 to 30522
        fusion_out, align_result = self.fusion_adapter(encoder_out['encoder_out'], sequence_output, targets_padding)
        return prediction_scores, sequence_output, fusion_out, align_result

    def prepare_for_onnx_export_(self):
        self.onnx_trace = True

    def get_normalized_probs(self, net_output, log_probs, sample):
        """Get normalized probabilities (or log probs) from a net's output."""
        logits = net_output[0]
        if log_probs:
            return utils.log_softmax(logits, dim=-1, onnx_trace=self.onnx_trace)
        else:
            return utils.softmax(logits, dim=-1, onnx_trace=self.onnx_trace)

class BertDecoderAssemble(BertPreTrainedModel):
    def __init__(self, config, args):
        super(BertDecoderAssemble, self).__init__(config)
        self.embeddings = BertEmbeddings(config)
        self.encoder = BertDecoder(config, args)
        self.apply(self.init_bert_weights)
        self.hidden_size = config.hidden_size

    def forward(self, prev_output_tokens, encoder_out=None, padding_idx=0):

        targets_padding = prev_output_tokens.eq(padding_idx) # prev_output_tokens's padding mask, padding地方为True,未padding为mask, shape=B,tgt_len
        position_ids = torch.arange(prev_output_tokens.size(1), dtype=torch.long, device=prev_output_tokens.device)
        position_ids = position_ids.unsqueeze(0).expand_as(prev_output_tokens)
        positions = self.embeddings.position_embeddings(position_ids).transpose(0, 1) #这里出错,原因是pisition_ids有>512,而bert embedding最长接受的是512,所以出错,需要跳过label len>512的数据
        token_type_ids = torch.zeros_like(prev_output_tokens)

        extended_attention_mask = targets_padding.unsqueeze(1).unsqueeze(2).float()
        extended_attention_mask = extended_attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility
        extended_attention_mask *= -10000.0
        embedding_output = self.embeddings(prev_output_tokens, token_type_ids)
        encoded_layers = self.encoder(embedding_output,
                                      extended_attention_mask,
                                      output_all_encoded_layers=False,
                                      encoder_out=encoder_out['encoder_out'] if encoder_out is not None else None,
                                      encoder_padding_mask=encoder_out['encoder_padding_mask'] if encoder_out is not None else None,
                                      position_embedding=positions,
                                      targets_padding=targets_padding,
                                      )
        return encoded_layers[-1], targets_padding


class BertDecoder(nn.Module):
    def __init__(self, config, args):
        super(BertDecoder, self).__init__()
        self.num_layers = config.num_hidden_layers
        self.layer = nn.ModuleList([copy.deepcopy(BertDecoderLayer(config, args, i)) for i in range(config.num_hidden_layers)])

    def forward(self, hidden_states, attention_mask, output_all_encoded_layers=True,
                encoder_out=None, encoder_padding_mask=None, position_embedding=None, targets_padding=None):
        all_decoder_layers = []
        for i in range(self.num_layers):
            is_last_layer = True if i == (self.num_layers-1) else False
            layer_module = self.layer[i]
            hidden_states = layer_module(hidden_states,
                            encoder_out=encoder_out,
                            encoder_padding_mask=encoder_padding_mask,
                            self_attn_mask=attention_mask,
                            position_embedding=position_embedding,
                            targets_padding=targets_padding,
                            layer_num=i,
                            is_last_layer=is_last_layer)
            if output_all_encoded_layers:
                all_decoder_layers.append(hidden_states)
        if not output_all_encoded_layers:
            all_decoder_layers.append(hidden_states) # output_all_encoded_layers传入为False, 只返回最后一层的输出
        return all_decoder_layers

class BertDecoderLayer(nn.Module):
    def __init__(self, config, args, layer_num):
        super(BertDecoderLayer, self).__init__()
        self.embed_dim = args.decoder_embed_dim
        self.dropout = args.dropout
        self.fusion_v2 = args.fusion_v2
        self.activation_fn = utils.get_activation_fn(
            activation=getattr(args, 'activation_fn', 'relu')
        )
        self.activation_dropout = getattr(args, 'activation_dropout', 0)
        if self.activation_dropout == 0:
            # for backwards compatibility with models that use args.relu_dropout
            self.activation_dropout = getattr(args, 'relu_dropout', 0)
        self.normalize_before = args.decoder_normalize_before

        self.attention = BertAttention(config)
        self.intermediate = BertIntermediate(config)
        self.output = BertOutput(config)
        self.top_layer_adapter = getattr(args,'top_layer_adapter', -1)

        export = getattr(args, 'char_inputs', False)
 

    def forward(
        self,
        x,
        encoder_out=None,
        encoder_padding_mask=None,
        self_attn_mask=None,
        position_embedding=None,
        targets_padding=None,
        layer_num=-1,
        is_last_layer=False,
    ):
        x = self.attention(x, self_attn_mask)

        intermediate_output = self.intermediate(x)
        x = self.output(intermediate_output, x)
        
        return x

    def make_generation_fast_(self, need_attn=False, **kwargs):
        self.need_attn = need_attn

    def prepare_for_onnx_export_(self):
        self.onnx_trace = True

    def without_self_mask(self, tensor):
        dim = tensor.size(0)
        eye_matrix = torch.eye(dim)
        eye_matrix[eye_matrix == 1.0] = float('-inf')
        return eye_matrix.cuda()

    def maybe_layer_norm(self, layer_norm, x, before=False, after=False):
        assert before ^ after
        if after ^ self.normalize_before:
            return layer_norm(x)
        else:
            return x

class FusionAdapter(nn.Module):
    def __init__(self, args):
        super(FusionAdapter, self).__init__()
        self.embed_dim = args.decoder_embed_dim
        self.dropout = args.dropout
        self.normalize_before = args.decoder_normalize_before
        self.activation_fn = utils.get_activation_fn(
            activation=getattr(args, 'activation_fn', 'relu')
        )
        self.activation_dropout = getattr(args, 'activation_dropout', 0)
        if self.activation_dropout == 0:
            # for backwards compatibility with models that use args.relu_dropout
            self.activation_dropout = getattr(args, 'relu_dropout', 0)
        export = getattr(args, 'char_inputs', False)
        self.encoder_attn = MultiheadAttention(
            getattr(args, 'encoder_embed_dim', None), args.decoder_attention_heads,
            kdim=self.embed_dim,
            vdim=self.embed_dim,
            dropout=args.attention_dropout, encoder_decoder_attention=True
        )
        self.encoder_attn_layer_norm = LayerNorm(getattr(args, 'encoder_embed_dim', None), export=export)

        self.encoder_attn_fc1 = Linear(getattr(args, 'encoder_embed_dim', None), args.decoder_ffn_embed_dim)
        self.encoder_attn_fc2 = Linear(args.decoder_ffn_embed_dim, getattr(args, 'encoder_embed_dim', None))
        self.encoder_attn_final_layer_norm = LayerNorm(getattr(args, 'encoder_embed_dim', None), export=export)
        self.need_attn = False

        self.gate = Linear(getattr(args, 'encoder_embed_dim', None) + self.embed_dim, self.embed_dim, bias=True)

    def forward(self, encoder_out, bert_hidden, targets_padding):        
        key_value = bert_hidden.transpose(0, 1) # T,B,Dim
        query = encoder_out
        key_padding_mask = targets_padding
        x = query
        
        residual = x
        x = self.maybe_layer_norm(self.encoder_attn_layer_norm, x, before=True)
        # 把bert插入到encoder中计算的attention
        x, attn = self.encoder_attn(
            query=x, # 需要输入shape为T,B,C  encoder_out
            key=key_value, # 需要输入shape为T,B,C  bert layer hidden_states
            value=key_value,
            key_padding_mask=key_padding_mask,
            static_kv=True,
            need_weights=(not self.training and self.need_attn),
        )
        align_result = x.transpose(0,1)

        # x = F.dropout(x, p=self.dropout, training=self.training)
        # x = residual + x
        gate_weight = self.gate(torch.cat([residual, x], dim=-1))
        gate_weight = F.sigmoid(gate_weight)
        x = residual * gate_weight + x * (1-gate_weight)

        x = self.maybe_layer_norm(self.encoder_attn_layer_norm, x, after=True)
        residual = x
        x = self.maybe_layer_norm(self.encoder_attn_final_layer_norm, x, before=True)
        x = self.activation_fn(self.encoder_attn_fc1(x))
        x = F.dropout(x, p=self.activation_dropout, training=self.training)
        x = self.encoder_attn_fc2(x)
        x = F.dropout(x, p=self.dropout, training=self.training)
        x = residual + x
        layer_output = self.maybe_layer_norm(self.encoder_attn_final_layer_norm, x, after=True)
        layer_output = layer_output.transpose(0,1)
        return layer_output, align_result
        
    def maybe_layer_norm(self, layer_norm, x, before=False, after=False):
        assert before ^ after
        if after ^ self.normalize_before:
            return layer_norm(x)
        else:
            return x


def Embedding(num_embeddings, embedding_dim, padding_idx):
    m = nn.Embedding(num_embeddings, embedding_dim, padding_idx=padding_idx)
    nn.init.normal_(m.weight, mean=0, std=embedding_dim ** -0.5)
    nn.init.constant_(m.weight[padding_idx], 0)
    return m


def Linear(in_features, out_features, bias=True):
    m = nn.Linear(in_features, out_features, bias)
    nn.init.xavier_uniform_(m.weight)
    if bias:
        nn.init.constant_(m.bias, 0.0)
    return m


@register_model_architecture("wav2bert_masked_predict_fusion_ctc_gate2", "wav2bert_masked_predict_fusion_ctc_gate2")
def base_architecture(args):
    args.no_pretrained_weights = getattr(args, "no_pretrained_weights", False)
    args.dropout_input = getattr(args, "dropout_input", 0)
    args.final_dropout = getattr(args, "final_dropout", 0)
    args.apply_mask = getattr(args, "apply_mask", False)
    args.dropout = getattr(args, "dropout", 0)
    args.attention_dropout = getattr(args, "attention_dropout", 0)
    args.activation_dropout = getattr(args, "activation_dropout", 0)

    args.mask_length = getattr(args, "mask_length", 10)
    args.mask_prob = getattr(args, "mask_prob", 0.5)
    args.mask_selection = getattr(args, "mask_selection", "static")
    args.mask_other = getattr(args, "mask_other", 0)
    args.no_mask_overlap = getattr(args, "no_mask_overlap", False)
    args.mask_channel_length = getattr(args, "mask_channel_length", 10)
    args.mask_channel_prob = getattr(args, "mask_channel_prob", 0.5)
    args.mask_channel_selection = getattr(args, "mask_channel_selection", "static")
    args.mask_channel_other = getattr(args, "mask_channel_other", 0)
    args.no_mask_channel_overlap = getattr(args, "no_mask_channel_overlap", False)

    args.freeze_finetune_updates = getattr(args, "freeze_finetune_updates", 0)
    args.feature_grad_mult = getattr(args, "feature_grad_mult", 0)
    args.layerdrop = getattr(args, "layerdrop", 0.0)

    # args from transformer_nat_ymask_bert_two_adapter
    args.decoder_embed_path = getattr(args, 'decoder_embed_path', None)
    args.fusion_v2 = getattr(args, 'fusion_v2', None)
    args.fusion_v3 = getattr(args, 'fusion_v3', None)
    args.decoder_normalize_before = getattr(args, 'decoder_normalize_before', False)
    args.decoder_learned_pos = getattr(args, 'decoder_learned_pos', False)
    args.attention_dropout = getattr(args, 'attention_dropout', 0.)
    args.activation_dropout = getattr(args, 'activation_dropout', 0.)
    args.activation_fn = getattr(args, 'activation_fn', 'relu')
    args.dropout = getattr(args, 'dropout', 0.1)
    args.adaptive_softmax_cutoff = getattr(args, 'adaptive_softmax_cutoff', None)
    args.adaptive_softmax_dropout = getattr(args, 'adaptive_softmax_dropout', 0)
    args.share_decoder_input_output_embed = getattr(args, 'share_decoder_input_output_embed', False)
    args.share_all_embeddings = getattr(args, 'share_all_embeddings', False)
    args.no_token_positional_embeddings = getattr(args, 'no_token_positional_embeddings', False)
    args.adaptive_input = getattr(args, 'adaptive_input', False)

    
    args.decoder_embed_dim = getattr(args, 'decoder_embed_dim', 768)
    args.decoder_ffn_embed_dim = getattr(args, 'decoder_ffn_embed_dim', 2048)
    args.decoder_attention_heads = getattr(args, 'decoder_attention_heads', 8)
    args.decoder_layers = getattr(args, 'decoder_layers', 6)


    # args from transformer_nat_ymask_bert_two_adapter_deep_small
    args.decoder_embed_dim = getattr(args, 'decoder_embed_dim', 768)
    args.decoder_ffn_embed_dim = getattr(args, 'decoder_ffn_embed_dim', 512)
    args.decoder_attention_heads = getattr(args, 'decoder_attention_heads', 4)
    args.decoder_layers = getattr(args, 'decoder_layers', 5)
    args.encoder_embed_dim = getattr(args, 'encoder_embed_dim', 768)


    args.decoder_output_dim = getattr(args, 'decoder_output_dim', args.decoder_embed_dim)
    args.decoder_input_dim = getattr(args, 'decoder_input_dim', args.decoder_embed_dim)
    args.finetune_whole_encoder = getattr(args, 'finetune_whole_encoder', False)
    args.train_from_scratch = getattr(args, 'train_from_scratch', False)