# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import contextlib
import copy
import math
import random

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from fairseq import checkpoint_utils, tasks, utils
from fairseq.models import (
    BaseFairseqModel,
    FairseqEncoder,
    FairseqEncoderDecoderModel,
    FairseqIncrementalDecoder,
    register_model,
    register_model_architecture,
)
from fairseq.data import data_utils
from fairseq.modules import LayerNorm, PositionalEmbedding, TransformerDecoderLayer, MultiheadAttention
from bert import BertTokenizer
from bert import BertModelWithAdapter
from bert.modeling import BertEmbeddings, BertAttention, BertIntermediate, BertOutput, BertPreTrainedModel, BertOnlyMLMHead
from fairseq.data.data_utils import post_process

DEFAULT_MAX_SOURCE_POSITIONS = 1024
DEFAULT_MAX_TARGET_POSITIONS = 1024

import logging
logger = logging.getLogger(__name__)

def add_common_args(parser):
    parser.add_argument("--w2v-path", help="path to wav2bert 2.0 model")
    parser.add_argument(
        "--no-pretrained-weights",
        action="store_true",
        help="if true, does not load pretrained weights",
    )
    parser.add_argument(
        "--dropout-input",
        type=float,
        metavar="D",
        help="dropout to apply to the input (after feat extr)",
    )
    parser.add_argument(
        "--final-dropout",
        type=float,
        metavar="D",
        help="dropout after transformer and before final projection",
    )
    parser.add_argument(
        "--apply-mask", action="store_true", help="apply masking during fine-tuning"
    )
    parser.add_argument(
        "--dropout",
        type=float,
        metavar="D",
        help="dropout probability inside wav2bert 2.0 model",
    )
    parser.add_argument(
        "--attention-dropout",
        type=float,
        metavar="D",
        help="dropout probability for attention weights inside wav2bert 2.0 model",
    )
    parser.add_argument(
        "--activation-dropout",
        "--relu-dropout",
        type=float,
        metavar="D",
        help="dropout probability after activation in FFN inside wav2bert 2.0 model",
    )

    parser.add_argument(
        "--mask-length", type=int, help="repeat the mask indices multiple times"
    )

    parser.add_argument(
        "--mask-prob", type=float, help="probability of replacing a token with mask"
    )

    parser.add_argument(
        "--mask-selection",
        type=str,
        choices=["static", "uniform", "normal", "poisson"],
        help="how to choose masks",
    )

    parser.add_argument(
        "--mask-other",
        type=float,
        help="stdev of the mask length in case of 'normal' selection strategy",
    )

    parser.add_argument(
        "--no-mask-overlap",
        action="store_true",
        help="whether to allow masks to overlap",
    )

    parser.add_argument(
        "--mask-channel-length", type=int, help="repeat the mask indices multiple times"
    )

    parser.add_argument(
        "--mask-channel-prob",
        type=float,
        help="probability of replacing a token with mask",
    )

    parser.add_argument(
        "--mask-channel-selection",
        type=str,
        choices=["static", "uniform", "normal", "poisson"],
        help="how to choose masks",
    )

    parser.add_argument(
        "--mask-channel-other",
        type=float,
        help="stdev of the mask length in case of 'normal' selection strategy",
    )

    parser.add_argument(
        "--no-mask-channel-overlap",
        action="store_true",
        help="whether to allow masks to overlap",
    )

    parser.add_argument(
        "--freeze-finetune-updates",
        default=0,
        type=int,
        help="dont finetune wav2bert for this many updates",
    )

    parser.add_argument(
        "--feature-grad-mult",
        default=None,
        type=float,
        help="reset feature grad mult in wav2bert 2.0 to this",
    )

    parser.add_argument(
        "--layerdrop",
        default=0.0,
        type=float,
        help="probability of dropping a layer in wav2bert 2.0",
    )

    parser.add_argument(
        "--freeze-bert",
        action="store_true",
    )

    parser.add_argument(
        "--mix-ctc-deocde-prob-range",
        default=None,
        type=str,
    )

    parser.add_argument(
        "--mix-ctc-step-range",
        default=None,
        type=str,
    )

    parser.add_argument(
        "--batch-mix",
        action="store_true"
    )

    parser.add_argument(
        "--fuse-input",
        action="store_true"
    )

    parser.add_argument(
        "--fuse-input-add",
        action="store_true"
    )

    parser.add_argument(
        "--fuse-input-gate",
        action="store_true"
    )

    parser.add_argument(
        "--random-mix-input",
        action="store_true",
    )

    parser.add_argument(
        "--gold-rate-range", type=str, help="gold-rate-range"   
    )
    parser.add_argument(
        "--gold-rate-steps", type=str, help="gold-rate-steps"
    )

    parser.add_argument(
        "--simulate-wer-range", type=str, help="simulate wer range"
    )

    parser.add_argument(
        "--simulate-wer-step-range", type=str, help="simulate wer range"
    )

    parser.add_argument(
        "--period-index", type=int, default=119, help="simulate wer range"
    )

    parser.add_argument(
        "--gc-weight", action="store_true", help="gt and ctc loss weight changing"
    )

    parser.add_argument(
        "--load-pretrain-w2v", action="store_true", help="gt and ctc loss weight changing"
    )

    parser.add_argument(
        "--load-bert-fc", action="store_true", help="gt and ctc loss weight changing"
    )


    



@register_model("wav2bert_masked_predict_fusion_two_way_ctc_to_bert")
class Wav2BertCtcMlm(BaseFairseqModel):
    @staticmethod
    def add_args(parser):
        """Add model-specific arguments to the parser."""
        add_common_args(parser)
        parser.add_argument('--adapter-dimension', default=2048, type=int)

    def __init__(self, w2v_encoder, args):
        super().__init__()
        print("Construct wav2bert_masked_predict_fusion_ctc model...")
        self.w2v_encoder = w2v_encoder
        self.args = args

    def upgrade_state_dict_named(self, state_dict, name):
        super().upgrade_state_dict_named(state_dict, name)
        return state_dict

    @classmethod
    def build_model(cls, args, task):
        """Build a new model instance."""
        base_architecture(args)
        w2v_encoder = Wav2BertEncoder(args, task.target_dictionary)
        return cls(w2v_encoder, args)

    def get_normalized_fusion_ctc_probs(self, net_output, log_probs):
        """Get normalized probabilities (or log probs) from a net's output."""

        logits = net_output["fusion_ctc_out"]
        if log_probs:
            return utils.log_softmax(logits.float(), dim=-1)
        else:
            return utils.softmax(logits.float(), dim=-1)


    def get_normalized_fusion_ce_probs(self, net_output, log_probs):
        """Get normalized probabilities (or log probs) from a net's output."""

        logits = net_output["fusion_ce_out"]
        if log_probs:
            return utils.log_softmax(logits.float(), dim=-1)
        else:
            return utils.softmax(logits.float())
    
    def get_normalized_ctc_probs(self, net_output, log_probs):
        """Get normalized probabilities (or log probs) from a net's output."""

        logits = net_output["encoder_ctc"]
        if log_probs:
            return utils.log_softmax(logits.float(), dim=-1)
        else:
            return utils.softmax(logits.float(), dim=-1)

    def get_normalized_decoder_probs(self, net_output, log_probs):
        """Get normalized probabilities (or log probs) from a net's output."""

        logits = net_output["decoder_out"]
        if log_probs:
            return utils.log_softmax(logits.float(), dim=-1)
        else:
            return utils.softmax(logits.float(), dim=-1)

    def get_normalized_ctc_decode_fusion_probs(self, net_output, log_probs):
        logits = net_output['ctc_decode_fusion_out']
        assert logits != None

        if log_probs:
            return utils.log_softmax(logits.float(), dim=-1)
        else:
            return utils.softmax(logits.float(), dim=-1)
    
    def get_num_updates(self):
        return self.w2v_encoder.num_updates


    def forward(self, **kwargs):
        x = self.w2v_encoder(**kwargs)
        return x

    # def max_positions(self):
    #     return None

class Wav2BertEncoder(FairseqEncoder):
    def __init__(self, args, tgt_dict=None):
        self.apply_mask = args.apply_mask

        arg_overrides = {
            "dropout": args.dropout,
            "activation_dropout": args.activation_dropout,
            "dropout_input": args.dropout_input,
            "attention_dropout": args.attention_dropout,
            "mask_length": args.mask_length,
            "mask_prob": args.mask_prob,
            "mask_selection": args.mask_selection,
            "mask_other": args.mask_other,
            "no_mask_overlap": args.no_mask_overlap,
            "mask_channel_length": args.mask_channel_length,
            "mask_channel_prob": args.mask_channel_prob,
            "mask_channel_selection": args.mask_channel_selection,
            "mask_channel_other": args.mask_channel_other,
            "no_mask_channel_overlap": args.no_mask_channel_overlap,
            "encoder_layerdrop": args.layerdrop,
            "feature_grad_mult": args.feature_grad_mult,
        }

        if getattr(args, "w2v_args", None) is None:
            state = checkpoint_utils.load_checkpoint_to_cpu(
                args.w2v_path, arg_overrides
            )
            w2v_args = state["args"]
            if w2v_args is None:   #### for loading xxx's pretrain chinese model
                w2v_args = state["cfg"]["model"]
        else:
            state = None
            w2v_args = args.w2v_args

        assert (
            args.normalize == w2v_args.normalize
        ), "Fine-tuning works best when data normalization is the same"

        w2v_args.data = args.data
        task = tasks.setup_task(w2v_args) # 这里的task应该是根据预训练模型参数进行构建的，目的知识用来init对应的模型
        model = task.build_model(w2v_args)

        if state is not None and not args.no_pretrained_weights:
            # if getattr(args, 'load_pretrain_w2v', False):
            #     missing_keys, unexpected_keys = model.load_state_dict(state["model"], strict=False)
            #     logger.info("missing keys")
            #     logger.info(missing_keys)
            #     logger.info("unexpected keys")
            #     logger.info(unexpected_keys)
            # else:
            #     model.load_state_dict(state["model"], strict=True)
            model.load_state_dict(state["model"], strict=True)

        model.remove_pretraining_modules()

        super().__init__(task.source_dictionary)
        

        d = w2v_args.encoder_embed_dim
        # d = 30522

        self.w2v_model = model
        # add bert
        base_architecture(args)
        self.encoder_dropout = nn.Dropout(args.final_dropout)

        if not hasattr(args, 'max_source_positions'):
            args.max_source_positions = DEFAULT_MAX_SOURCE_POSITIONS
        if not hasattr(args, 'max_target_positions'):
            args.max_target_positions = DEFAULT_MAX_TARGET_POSITIONS
        # import ipdb; ipdb.set_trace()
        self.bertdecoder = BertAdapterDecoderFull.from_pretrained(args.decoder_bert_model_name, args, from_scratch=args.train_from_scratch)
        # end
        self.final_dropout_ctc = nn.Dropout(args.final_dropout)
        self.final_dropout_ce = nn.Dropout(args.final_dropout)
        self.freeze_finetune_updates = args.freeze_finetune_updates
        self.num_updates = 0

        

        self.bert_tokenizer = BertTokenizer.from_pretrained(getattr(args, "decoder_bert_model_name", None))

        self.different_token = getattr(args, 'different_tokens', False)
        self.different_tokens_v2 = getattr(args, 'different_tokens_v2', False)

        if self.different_tokens_v2:
            from fairseq.data import Dictionary
            from fairseq.tasks.wav2bert_task import LabelEncoder
            import os

            character_dict_path = os.path.join(args.data, f"dict.ltr.txt")
            self.character_target_dictionary = Dictionary.load(character_dict_path)
            self.character_tokenizer = LabelEncoder(self.character_target_dictionary)
        
        if self.different_tokens_v2:
            self.proj = Linear(d, len(self.character_target_dictionary))
            self.encoder_proj = Linear(d, len(self.character_target_dictionary))
            self.ce_proj = Linear(d, len(tgt_dict)) # ce 那边不能用小的vocab来计算
        elif tgt_dict is not None:
            self.proj = Linear(d, len(tgt_dict))
            self.encoder_proj = Linear(d, len(tgt_dict))
            self.ce_proj = Linear(d, len(tgt_dict)) 
            print("d:", d)
            # self.proj = Linear(d, 32)
            # print("len(tgt_dict):", len(tgt_dict))

        if getattr(args, "mix_ctc_deocde_prob_range", None) is not None and \
            getattr(args, "mix_ctc_step_range", None) is not None:
            self.mix_ctc_deocde_prob_range = eval(args.mix_ctc_deocde_prob_range)
            self.mix_ctc_step_range = eval(args.mix_ctc_step_range)
        else:
            self.mix_ctc_deocde_prob_range = None
            self.mix_ctc_step_range = None

        if getattr(args, "simulate_wer_range", None) is not None:
            self.simulate_wer_range = eval(args.simulate_wer_range)
        
        if getattr(args, "simulate_wer_step_range", None) is not None:
            self.simulate_wer_step_range = eval(args.simulate_wer_step_range)

        self.batch_mix = getattr(args, 'batch_mix', False)


        self.blank_idx = self.bert_tokenizer.cls()
        self.pad = self.bert_tokenizer.pad()
        self.add_input = getattr(args, 'add_input', False)
        self.no_mask = getattr(args, 'no_mask', False)
        self.period_idx = getattr(args, 'period_index', 119)

        self.gc_weight = getattr(args, 'gc_weight', False)

        self.load_pretrain_w2v = getattr(args, 'load_pretrain_w2v', False)

    def set_num_updates(self, num_updates):
        """Set the number of parameters updates."""
        super().set_num_updates(num_updates)
        self.num_updates = num_updates

    def decode_seperate(self, encoder_ctc, padding_mask, current_prob, tgt_tokens):
        all_mix_ctc = False # 当每个sample用到的都是gt的时候为false，当有一个用了ctc就为true，不计算mlm loss了
        with torch.no_grad():
            encoder_ctc_probs = encoder_ctc.transpose(0, 1).float().contiguous().cpu()
            non_padding_mask = ~padding_mask#[1, 811]B, SRC_DIM
            input_lengths = non_padding_mask.long().sum(-1)#811
            new_tgt_tokens = []
            for lp, inp_l, tgt_token in zip(encoder_ctc_probs, input_lengths, tgt_tokens):
                mix_ctc = torch.rand(1) < current_prob
                all_mix_ctc = all_mix_ctc or mix_ctc
                if not mix_ctc:
                    # p = (tgt_token != self.pad) # 这里不需要去掉cls和sep的，只去掉padding，后面重新组合后重新paddin # debug, 这里是否需要去掉padding呢，可能会导致batch 的max len发生变化
                    # t = tgt_token[p]
                    new_tgt_tokens.append(tgt_token)
                else:
                    lp = lp[:inp_l].unsqueeze(0)
                    toks = lp.argmax(dim=-1).unique_consecutive()
                    pred_units_arr = toks[toks != self.blank_idx].tolist()

                    # if different token, we also need to change the input token from character level to word level
                    if self.different_token:
                        targ_units = self.bert_tokenizer.convert_ids_to_tokens(pred_units_arr)
                        targ_units = ' '.join(i for i in targ_units if i!= self.pad)
                        targ_words = post_process(targ_units, "letter")

                        if self.add_input:
                            targ_words += '.'
                        
                        word_tgt_tokens = self.bert_tokenizer.encode_line(targ_words, post_proces='bert_bpe_piece')
                        if self.add_input:
                            word_tgt_tokens.insert(0, self.bert_tokenizer.cls())
                            word_tgt_tokens.append(self.bert_tokenizer.sep())
                        
                        new_tgt_tokens.append(torch.LongTensor(list(word_tgt_tokens)))
                    elif self.different_tokens_v2:
                        targ_units = self.character_tokenizer.convert_ids_to_tokens(pred_units_arr)
                        targ_units = ' '.join(i for i in targ_units if i!= self.pad)
                        targ_words = post_process(targ_units, "letter")

                        if self.add_input:
                            targ_words += '.'
                        
                        word_tgt_tokens = self.bert_tokencheizer.encode_line(targ_words, post_proces='bert_bpe_piece')
                        if self.add_input:
                            word_tgt_tokens.insert(0, self.bert_tokenizer.cls())
                            word_tgt_tokens.append(self.bert_tokenizer.sep())
                        
                        new_tgt_tokens.append(torch.LongTensor(list(word_tgt_tokens)))

                    else:
                        # word_tgt_tokens = toks.tolist() # 避免长度squeeze的情况, 之前忘记把blank idx给去掉了
                        word_tgt_tokens = pred_units_arr # fix了上面说的问题

                        if self.add_input:
                            word_tgt_tokens.insert(0, self.bert_tokenizer.cls())
                            word_tgt_tokens.append(self.period_idx) # 句号的标识符，当且是用multilingual的时候
                            word_tgt_tokens.append(self.bert_tokenizer.sep())
                        new_tgt_tokens.append(torch.LongTensor(list(word_tgt_tokens)))

            new_tgt_tokens = data_utils.collate_tokens(new_tgt_tokens, pad_idx=self.pad, left_pad=False)
            new_tgt_tokens = utils.move_to_cuda(new_tgt_tokens[:, :512], device=encoder_ctc.device) # 这里要限制最大长度, ctc decode 结果在前面可能效果不太好

            return new_tgt_tokens, all_mix_ctc

    
    def decode_ctc(self, encoder_ctc, padding_mask):

        with torch.no_grad():
            encoder_ctc_probs = encoder_ctc.transpose(0, 1).float().contiguous().cpu()
            non_padding_mask = ~padding_mask#[1, 811]B, SRC_DIM
            input_lengths = non_padding_mask.long().sum(-1)#811
            new_tgt_tokens = []
            new_tgt_tokens_lengths = []
            for lp, inp_l in zip(encoder_ctc_probs, input_lengths):
                lp = lp[:inp_l].unsqueeze(0)
                toks = lp.argmax(dim=-1).unique_consecutive()
                pred_units_arr = toks[toks != self.blank_idx].tolist()

                # if different token, we also need to change the input token from character level to word level
                if self.different_token:
                    targ_units = self.bert_tokenizer.convert_ids_to_tokens(pred_units_arr)
                    targ_units = ' '.join(i for i in targ_units if i!= self.pad)
                    targ_words = post_process(targ_units, "letter")

                    if self.add_input:
                        targ_words += '.'
                    
                    word_tgt_tokens = self.bert_tokenizer.encode_line(targ_words, post_proces='bert_bpe_piece')
                    if self.add_input:
                        word_tgt_tokens.insert(0, self.bert_tokenizer.cls())
                        word_tgt_tokens.append(self.bert_tokenizer.sep())
                    
                    new_tgt_tokens.append(torch.LongTensor(list(word_tgt_tokens)))
                elif self.different_tokens_v2:
                    targ_units = self.character_tokenizer.convert_ids_to_tokens(pred_units_arr)
                    targ_units = ' '.join(i for i in targ_units if i!= self.pad)
                    targ_words = post_process(targ_units, "letter")
                    # logger.info(targ_words)

                    if self.add_input:
                        targ_words += '.'
                    
                    word_tgt_tokens = self.bert_tokenizer.encode_line(targ_words, post_proces='bert_bpe_piece')
                    if self.add_input:
                        word_tgt_tokens.insert(0, self.bert_tokenizer.cls())
                        word_tgt_tokens.append(self.bert_tokenizer.sep())
                    # logger.info(word_tgt_tokens)
                    new_tgt_tokens.append(torch.LongTensor(list(word_tgt_tokens)))
                else:
                    # word_tgt_tokens = toks.tolist() # 避免长度squeeze的情况, 之前忘记把blank idx给去掉了
                    word_tgt_tokens = pred_units_arr # fix了上面说的问题
                    if self.add_input:
                        word_tgt_tokens.insert(0, self.bert_tokenizer.cls())
                        word_tgt_tokens.append(self.period_idx) # 句号的标识符，当且是用multilingual的时候
                        word_tgt_tokens.append(self.bert_tokenizer.sep())
                    new_tgt_tokens.append(torch.LongTensor(list(word_tgt_tokens)))

                new_tgt_tokens_lengths.append(min(512,len(word_tgt_tokens)))

            new_tgt_tokens = data_utils.collate_tokens(new_tgt_tokens, pad_idx=self.pad, left_pad=False)
            new_tgt_tokens = utils.move_to_cuda(new_tgt_tokens[:, :512], device=encoder_ctc.device) # 这里要限制最大长度, ctc decode 结果在前面可能效果不太好

            return new_tgt_tokens, new_tgt_tokens_lengths

    def simulate_wer(self, prev_output_tokens, wer):
        with torch.no_grad():
            new_tgt_tokens = []
            all_err_type = ['I', 'D', 'S'] # 分别对应三种错误类型
            # 因为现在的vocab 里面有其它语种的subword，因此想要在insert和subtitude的时候用当前batch内gt的subword
            # 同时为了避免要提前遍历，这个就随机选一个，然后进行替换或者insert
            for tgt_token in prev_output_tokens:
                p = (tgt_token != self.pad) # 这里不需要去掉cls和sep的，只去掉padding，后面重新组合后重新paddin
                tgt_token = tgt_token[p]

                # 然后根据wer来对句子进行随机改造,随机进行三种错误，当然这样计算出来不是真正的wer
                token_len = tgt_token.shape[0]
                err_num = int(token_len * wer)
                # 先随机得到操作错误的位置
                id_err = np.arange(token_len)
                np.random.shuffle(id_err)
                id_err = sorted(id_err[:err_num])

                new_tgt_token = []
                new_tgt_tokens_lengths = []
                for i, w in enumerate(tgt_token):
                    if i in id_err and i != 0 and i != token_len - 1:
                        err_type = random.choice(all_err_type)
                        if err_type == 'I':
                            # logger.info("insert {}".format(i))
                            choice_tokens = random.choice(prev_output_tokens)
                            choice_tokens_p = (choice_tokens != self.pad)
                            choice_tokens = choice_tokens[choice_tokens_p]
                            choice_token = random.choice(choice_tokens[1:-2]) # 去掉特殊字符的影响
                            new_tgt_token.append(w)
                            new_tgt_token.append(choice_token)
                        elif err_type == 'D':
                            # logger.info("delete {}".format(i))
                            continue
                        elif err_type == 'S':
                            # logger.info("sub {}".format(i))
                            choice_tokens = random.choice(prev_output_tokens)
                            choice_tokens_p = (choice_tokens != self.pad)
                            choice_tokens = choice_tokens[choice_tokens_p]
                            choice_token = random.choice(choice_tokens[1:-2]) # 去掉特殊字符的影响
                            new_tgt_token.append(choice_token)
                    else: 
                        new_tgt_token.append(w)
                
                # logger.info(tgt_token)
                # logger.info(torch.LongTensor(list(new_tgt_token)))
                new_tgt_tokens.append(torch.LongTensor(list(new_tgt_token)))
                new_tgt_tokens_lengths.append(len(new_tgt_token))

            new_tgt_tokens = data_utils.collate_tokens(new_tgt_tokens, pad_idx=self.pad, left_pad=False)
            new_tgt_tokens = utils.move_to_cuda(new_tgt_tokens[:, :512], device=prev_output_tokens.device) # 这里要限制最大长度, ctc decode 结果在前面可能效果不太好
            return new_tgt_tokens, new_tgt_tokens_lengths


    def forward(self, source, padding_mask, prev_output_tokens=None, tbc=True, **kwargs):
        w2v_args = {
            "source": source,#B*src_len[1, 259840]
            "padding_mask": padding_mask,#B*src_len[1, 259840]
            "mask": self.apply_mask and self.training,#True
        }

        ft = self.freeze_finetune_updates <= self.num_updates

        with torch.no_grad() if not ft else contextlib.ExitStack():
            x, padding_mask = self.w2v_model.extract_features(**w2v_args)#B,T,C [1, 811, 768]

        encoder_feature = x

        with torch.no_grad() if not ft else contextlib.ExitStack():
            if tbc:
                # B x T x C -> T x B x C
                encoder_ctc = x.transpose(0, 1) # [4, 235, 30522]->[235, 4, 30522] [811, 1, 768]
        encoder_ctc = self.encoder_dropout(encoder_ctc)

        if self.load_pretrain_w2v:
            if self.encoder_proj:
                with torch.no_grad() if not ft else contextlib.ExitStack():
                    encoder_ctc = self.encoder_proj(encoder_ctc)
        else:
            if self.encoder_proj:
                encoder_ctc = self.encoder_proj(encoder_ctc)

        # random replace with the prev_output_token, in this turn should we calculate the mlm loss? maybe not ?temporary
        # here we decode the ctc result to add to the bert, maybe very slow, need to decode every sample on cpu
        new_tgt_tokens_lengths = None
        if not self.training: # valid的时候直接替换,可是这样就没法算mlm loss了valid的时候
            new_prev_output_tokens, new_tgt_tokens_lengths = self.decode_ctc(encoder_ctc, padding_mask)
            prev_output_tokens = new_prev_output_tokens
            mix_ctc = True
        elif self.mix_ctc_step_range is not None and self.num_updates > self.mix_ctc_step_range[0]:
            s, e = self.mix_ctc_step_range
            sp, ep = self.mix_ctc_deocde_prob_range
            current_prob = min(1, (self.num_updates - s) / (e - s)) * (ep - sp) + sp

            if not self.batch_mix: # 直接整个batch要么全部用gt，要么全部用ctc
                mix_ctc = torch.rand(1) < current_prob
                if mix_ctc:
                    prev_output_tokens, new_tgt_tokens_lengths = self.decode_ctc(encoder_ctc, padding_mask)
                elif self.no_mask: # 如果是no mask还是用模拟wer，否则就直接用回mask的就行
                    mix_ctc = True # 为了避免后面计算mlm loss
                    if self.simulate_wer_step_range is None:
                        e = self.mix_ctc_step_range[0]
                        s = 0
                        wer_s, wer_e = self.simulate_wer_range # 应该是逐渐上升到ctc的水平会比较好
                        current_wer = min(1, (self.num_updates - s) / (e - s)) * (wer_e - wer_s) + wer_s
                    else:
                        s, e = self.simulate_wer_step_range
                        wer_s, wer_e = self.simulate_wer_range # 应该是逐渐上升到ctc的水平会比较好
                        if self.num_updates < s:
                            current_wer = 0
                        else:
                            current_wer = min(1, (self.num_updates - s) / (e - s)) * (wer_e - wer_s) + wer_s
                    prev_output_tokens, new_tgt_tokens_lengths = self.simulate_wer(prev_output_tokens, current_wer)
                    
            else: # batch 中每个sample的概率单独确定
                prev_output_tokens, mix_ctc = self.decode_seperate(encoder_ctc, padding_mask, current_prob, prev_output_tokens) 
        elif self.no_mask:  # 这里的逻辑是按照mix_ctc_step之前模拟wer，后面就直接用ctc结果了
            mix_ctc = True # 为了避免后面计算mlm loss
            if self.simulate_wer_step_range is None:
                e = self.mix_ctc_step_range[0]
                s = 0
                wer_s, wer_e = self.simulate_wer_range # 应该是逐渐上升到ctc的水平会比较好
                current_wer = min(1, (self.num_updates - s) / (e - s)) * (wer_e - wer_s) + wer_s
            else:
                s, e = self.simulate_wer_step_range
                wer_s, wer_e = self.simulate_wer_range # 应该是逐渐上升到ctc的水平会比较好
                if self.num_updates < s:
                    current_wer = 0
                else:
                    current_wer = min(1, (self.num_updates - s) / (e - s)) * (wer_e - wer_s) + wer_s
            # logger.info(current_wer)
            prev_output_tokens, new_tgt_tokens_lengths = self.simulate_wer(prev_output_tokens, current_wer)


        else:
            mix_ctc = False


        # B x T x C -> T x B x C
        decoder_input = x.permute(1,0,2).contiguous()#[811, 1, 768]
        # add bert here
        encoder_out = {
            'encoder_out': decoder_input, # T x B x C [811, 1, 768]
            'encoder_padding_mask': padding_mask, # B x T [1, 811]
        }

        if self.gc_weight:
            ctc_prev_output_tokens = self.decode_ctc(encoder_ctc, padding_mask)
            ctc_decoder_out, _, ctc_fusion_out, _ = self.bertdecoder(ctc_prev_output_tokens, encoder_out=encoder_out, padding_idx=0, num_updates=self.num_updates)
            if tbc:
                ctc_fusion_out = ctc_fusion_out.transpose(0, 1)
                ctc_fusion_out = self.final_dropout(ctc_fusion_out)
            if self.proj:
                ctc_fusion_out = self.proj(ctc_fusion_out)
            

        decoder_out, decoder_hidden, fusion_out_ctc, align_result_ctc, fusion_out_ce, align_result_ce = self.bertdecoder(prev_output_tokens, encoder_out=encoder_out, padding_idx=0, num_updates=self.num_updates) # prev_output_tokens B*L # x.shape [4, 235, 30522] B,T',C #validation step2 error
        # fusion v2:x输入独立的fusion模块，以及bertdecoder的最后一个hidden state也输入，计算出fusion output，然后通过一层linear算CTC loss
        if tbc:
            # B x T x C -> T x B x C
            fusion_out_ctc = fusion_out_ctc.transpose(0, 1) # [4, 235, 30522]->[235, 4, 30522] [811, 1, 768]
            fusion_out_ce = fusion_out_ce.transpose(0, 1)
        fusion_out_ctc = self.final_dropout_ctc(fusion_out_ctc)
        fusion_out_ce = self.final_dropout_ce(fusion_out_ce)

        ctc_result = self.proj(fusion_out_ctc) # source_channel_dim->vocab_dim 768->30522
        ce_result = self.ce_proj(fusion_out_ce)

        # print("x.shape:", x.shape) # [235, 4, 32] T',B,C
        return {
            "fusion_ctc_out": ctc_result,  # T x B x C [235, 4, 32] [811, 1, 30522]
            "fusion_ce_out": ce_result, 
            "encoder_padding_mask": padding_mask,  # B x T [4, 235] [1, 811]
            "decoder_out": decoder_out,#[1, 278, 30522] B,tgt_len,vocab_size
            "padding_mask": padding_mask,
            "encoder_ctc": encoder_ctc,
            "align_target": encoder_feature,
            "align_fusion": align_result_ctc, # here not the align
            "mix_ctc": mix_ctc,
            "ctc_decode_fusion_out": ctc_fusion_out if self.gc_weight else None,
            "ctc_decode_lengths": new_tgt_tokens_lengths, 
        }

    def reorder_encoder_out(self, encoder_out, new_order):
        if encoder_out["encoder_out"] is not None:
            encoder_out["encoder_out"] = encoder_out["encoder_out"].index_select(
                1, new_order
            )
        if encoder_out["encoder_padding_mask"] is not None:
            encoder_out["encoder_padding_mask"] = encoder_out[
                "encoder_padding_mask"
            ].index_select(0, new_order)
        return encoder_out

    def max_positions(self):
        """Maximum input length supported by the encoder."""
        return None

    def upgrade_state_dict_named(self, state_dict, name):
        return state_dict


class BertAdapterDecoderFull(BertPreTrainedModel):
    def __init__(self, config, args):
        super(BertAdapterDecoderFull, self).__init__(config)
        self.bert = BertDecoderAssemble(config, args)
        self.cls = BertOnlyMLMHead(config, self.bert.embeddings.word_embeddings.weight)
        self.apply(self.init_bert_weights)
        self.fusion_adapter_ctc = FusionAdapter2CTC(args)
        self.fusion_adapter_ce = FusionAdapter2CE(args)
        self.freeze_bert = getattr(args, 'freeze_bert', False)
        self.onnx_trace = False

    def forward(self, prev_output_tokens, src_tokens=None, encoder_out=None, padding_idx=0, num_updates=None, **kwargs):
        with torch.no_grad() if self.freeze_bert else contextlib.ExitStack():
            sequence_output, targets_padding = self.bert(prev_output_tokens, encoder_out, padding_idx, num_updates)
        prediction_scores = self.cls(sequence_output) # dim from 768 to 30522
        fusion_out_ctc, align_result_ctc = self.fusion_adapter_ctc(encoder_out['encoder_out'], sequence_output, targets_padding)
        fusion_out_ce, align_result_ce = self.fusion_adapter_ce(encoder_out['encoder_out'], sequence_output, encoder_out['encoder_padding_mask'])
        return prediction_scores, sequence_output, fusion_out_ctc, align_result_ctc, fusion_out_ce, align_result_ce

    def prepare_for_onnx_export_(self):
        self.onnx_trace = True

    def get_normalized_probs(self, net_output, log_probs, sample):
        """Get normalized probabilities (or log probs) from a net's output."""
        logits = net_output[0]
        if log_probs:
            return utils.log_softmax(logits, dim=-1, onnx_trace=self.onnx_trace)
        else:
            return utils.softmax(logits, dim=-1, onnx_trace=self.onnx_trace)

class BertDecoderAssemble(BertPreTrainedModel):
    def __init__(self, config, args):
        super(BertDecoderAssemble, self).__init__(config)
        self.embeddings = BertEmbeddings(config)
        self.encoder = BertDecoder(config, args)
        self.apply(self.init_bert_weights)
        self.hidden_size = config.hidden_size

    def forward(self, prev_output_tokens, encoder_out=None, padding_idx=0, num_updates=None):

        targets_padding = prev_output_tokens.eq(padding_idx) # prev_output_tokens's padding mask, padding地方为True,未padding为mask, shape=B,tgt_len
        position_ids = torch.arange(prev_output_tokens.size(1), dtype=torch.long, device=prev_output_tokens.device)
        position_ids = position_ids.unsqueeze(0).expand_as(prev_output_tokens)
        positions = self.embeddings.position_embeddings(position_ids).transpose(0, 1) #这里出错,原因是pisition_ids有>512,而bert embedding最长接受的是512,所以出错,需要跳过label len>512的数据
        token_type_ids = torch.zeros_like(prev_output_tokens)

        extended_attention_mask = targets_padding.unsqueeze(1).unsqueeze(2).float()
        extended_attention_mask = extended_attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility
        extended_attention_mask *= -10000.0
        embedding_output = self.embeddings(prev_output_tokens, token_type_ids)
        encoded_layers = self.encoder(embedding_output,
                                      extended_attention_mask,
                                      output_all_encoded_layers=False,
                                      encoder_out=encoder_out['encoder_out'] if encoder_out is not None else None,
                                      encoder_padding_mask=encoder_out['encoder_padding_mask'] if encoder_out is not None else None,
                                      position_embedding=positions,
                                      targets_padding=targets_padding,
                                      num_updates=num_updates,
                                      )
        return encoded_layers[-1], targets_padding


class BertDecoder(nn.Module):
    def __init__(self, config, args):
        super(BertDecoder, self).__init__()
        self.num_layers = config.num_hidden_layers
        self.layer = nn.ModuleList([copy.deepcopy(BertDecoderLayer(config, args, i)) for i in range(config.num_hidden_layers)])

        self.random_mix = getattr(args, 'random_mix_input', False)

        if getattr(args, "gold_rate_range", None) is not None and \
            getattr(args, "gold_rate_steps", None) is not None:
            self.gold_rate_range = eval(args.gold_rate_range)
            self.gold_rate_steps = eval(args.gold_rate_steps)
        else:
            self.gold_rate_range = None
            self.gold_rate_steps = None

        self.fuse_input = getattr(args, 'fuse_input', False)
        self.fuse_input_add = getattr(args, 'fuse_input_add', False)
        self.fuse_input_gate = getattr(args, 'fuse_input_gate', False)

        if self.fuse_input:
            self.align_attention = AlignAttention(args)
        if self.fuse_input_gate:
            self.embed_dim = args.decoder_embed_dim
            self.align_attention = AlignAttention(args)
            self.fuse_gate = Linear(getattr(args, 'encoder_embed_dim', None) + self.embed_dim, self.embed_dim, bias=True)


    def forward(self, hidden_states, attention_mask, output_all_encoded_layers=True,
                encoder_out=None, encoder_padding_mask=None, position_embedding=None, targets_padding=None, num_updates=None):
        all_decoder_layers = []
        for i in range(self.num_layers):
            is_last_layer = True if i == (self.num_layers-1) else False
            layer_module = self.layer[i]
            hidden_states = layer_module(hidden_states,
                            encoder_out=encoder_out,
                            encoder_padding_mask=encoder_padding_mask,
                            self_attn_mask=attention_mask,
                            position_embedding=position_embedding,
                            targets_padding=targets_padding,
                            layer_num=i,
                            is_last_layer=is_last_layer)
            if output_all_encoded_layers:
                all_decoder_layers.append(hidden_states)

            #logger.info("fuse input {} fuse add {} fuse gate {} random mix {}".format(self.fuse_input, self.fuse_input_add, self.fuse_input_gate, self.random_mix))
            if self.fuse_input and i == 0:
                # align the first layerout and the wav2vec feature and the result as the next input, using random mix 
                align_result = self.align_attention(key_value=encoder_out, query=hidden_states.transpose(0,1), key_padding_mask=encoder_padding_mask)
                align_result = align_result.transpose(0,1)  # b,t,c
                if not self.random_mix: # 不random mix有三种，要么fusion，要么直接替换
                    if self.fuse_input_add:
                        # logger.info("fuse add not mix")
                        hidden_states = align_result + hidden_states
                    elif self.fuse_input_gate:
                        # logger.info("fuse gate not mix")
                        gate_weight = self.fuse_gate(torch.cat([hidden_states, align_result], dim=-1))
                        gate_weight = F.sigmoid(gate_weight)
                        hidden_states = hidden_states * gate_weight + align_result * (1-gate_weight)
                    else:
                        # logger.info("not fuse not mix")
                        hidden_states = align_result
                elif self.random_mix:  # random mix 也有三种，要么普通random mix，要么是random mix fusion的结果和原本的hidden state
                    # logger.info("random")
                    if self.fuse_input_add:
                        # logger.info("fuse add random")
                        align_result = align_result + hidden_states
                    elif self.fuse_input_gate:
                        # logger.info("fuse gate random")
                        gate_weight = self.fuse_gate(torch.cat([hidden_states, align_result], dim=-1))
                        gate_weight = F.sigmoid(gate_weight)
                        align_result = hidden_states * gate_weight + align_result * (1-gate_weight)

                    if self.training:
                        b, l, c = align_result.shape
                        gold_rate = self.get_gold_rate(num_updates)
                        #* (~targets_padding) # 这里target_padding 让结果反过来了，为什么要乘target padding呢，应该是padding部分不需要替换， 暂时不确定是否要mix这个padding的部分
                        pred_mask = (torch.rand((b,l), device=hidden_states.device) > gold_rate)  #* (~targets_padding)
                        hidden_states = torch.where(pred_mask[:, :, None].repeat(1, 1, hidden_states.size(-1)),
                                                    align_result,
                                                    hidden_states)
                    else:
                        hidden_states = align_result
                

        if not output_all_encoded_layers:
            all_decoder_layers.append(hidden_states) # output_all_encoded_layers传入为False, 只返回最后一层的输出
        return all_decoder_layers
    
    def get_gold_rate(self, num_updates):
        assert self.random_mix

        s, e = self.gold_rate_range
        s1, s2 = self.gold_rate_steps
        # logger.info(s)
        # logger.info(e)
        # logger.info(s1)
        # logger.info(s2)
        gold_rate = max((1 - max((num_updates - s1), 0) / (s2-s1)) * (s-e), 0) + e

        return gold_rate
    

class BertDecoderLayer(nn.Module):
    def __init__(self, config, args, layer_num):
        super(BertDecoderLayer, self).__init__()
        self.embed_dim = args.decoder_embed_dim
        self.dropout = args.dropout
        self.fusion_v2 = args.fusion_v2
        self.activation_fn = utils.get_activation_fn(
            activation=getattr(args, 'activation_fn', 'relu')
        )
        self.activation_dropout = getattr(args, 'activation_dropout', 0)
        if self.activation_dropout == 0:
            # for backwards compatibility with models that use args.relu_dropout
            self.activation_dropout = getattr(args, 'relu_dropout', 0)
        self.normalize_before = args.decoder_normalize_before

        self.attention = BertAttention(config)
        self.intermediate = BertIntermediate(config)
        self.output = BertOutput(config)
        self.top_layer_adapter = getattr(args,'top_layer_adapter', -1)

        export = getattr(args, 'char_inputs', False)
 

    def forward(
        self,
        x,
        encoder_out=None,
        encoder_padding_mask=None,
        self_attn_mask=None,
        position_embedding=None,
        targets_padding=None,
        layer_num=-1,
        is_last_layer=False,
    ):
        x = self.attention(x, self_attn_mask)

        intermediate_output = self.intermediate(x)
        x = self.output(intermediate_output, x)
        
        return x

    def make_generation_fast_(self, need_attn=False, **kwargs):
        self.need_attn = need_attn

    def prepare_for_onnx_export_(self):
        self.onnx_trace = True

    def without_self_mask(self, tensor):
        dim = tensor.size(0)
        eye_matrix = torch.eye(dim)
        eye_matrix[eye_matrix == 1.0] = float('-inf')
        return eye_matrix.cuda()

    def maybe_layer_norm(self, layer_norm, x, before=False, after=False):
        assert before ^ after
        if after ^ self.normalize_before:
            return layer_norm(x)
        else:
            return x

class AlignAttention(nn.Module):
    def __init__(self, args):
        super(AlignAttention, self).__init__()
        self.decoder_embed_dim = args.decoder_embed_dim
        self.align_attn = MultiheadAttention(
            self.decoder_embed_dim, args.decoder_attention_heads,
            kdim=getattr(args, 'encoder_embed_dim', None),
            vdim=getattr(args, 'encoder_embed_dim', None),
            dropout=args.attention_dropout, encoder_decoder_attention=True
        )

        self.align_attn_layer_norm = LayerNorm(self.decoder_embed_dim, export=False)
        self.need_attn = False

    def forward(self, key_value, query, key_padding_mask):
        x, attn = self.align_attn(
            query=query, # 需要输入shape为T,B,C  encoder_out
            key=key_value, # 需要输入shape为T,B,C  bert layer hidden_states
            value=key_value,
            key_padding_mask=key_padding_mask,
            static_kv=True,
            need_weights=(not self.training and self.need_attn),
        )
        x = self.align_attn_layer_norm(x)
        
        return x

class FusionAdapter2CTC(nn.Module):
    def __init__(self, args):
        super(FusionAdapter2CTC, self).__init__()
        self.embed_dim = args.decoder_embed_dim
        self.dropout = args.dropout
        self.normalize_before = args.decoder_normalize_before
        self.activation_fn = utils.get_activation_fn(
            activation=getattr(args, 'activation_fn', 'relu')
        )
        self.activation_dropout = getattr(args, 'activation_dropout', 0)
        if self.activation_dropout == 0:
            # for backwards compatibility with models that use args.relu_dropout
            self.activation_dropout = getattr(args, 'relu_dropout', 0)
        export = getattr(args, 'char_inputs', False)
        self.encoder_attn = MultiheadAttention(
            getattr(args, 'encoder_embed_dim', None), args.decoder_attention_heads,
            kdim=self.embed_dim,
            vdim=self.embed_dim,
            dropout=args.attention_dropout, encoder_decoder_attention=True
        )
        self.encoder_attn_layer_norm = LayerNorm(getattr(args, 'encoder_embed_dim', None), export=export)

        self.encoder_attn_fc1 = Linear(getattr(args, 'encoder_embed_dim', None), args.decoder_ffn_embed_dim)
        self.encoder_attn_fc2 = Linear(args.decoder_ffn_embed_dim, getattr(args, 'encoder_embed_dim', None))
        self.encoder_attn_final_layer_norm = LayerNorm(getattr(args, 'encoder_embed_dim', None), export=export)
        self.need_attn = False

        # self.gate = Linear(getattr(args, 'encoder_embed_dim', None) + self.embed_dim, self.embed_dim, bias=True)


    def forward(self, encoder_out, bert_hidden, targets_padding):        
        key_value = bert_hidden.transpose(0, 1) # T,B,Dim
        query = encoder_out
        key_padding_mask = targets_padding
        x = query
        
        residual = x
        x = self.maybe_layer_norm(self.encoder_attn_layer_norm, x, before=True)
        # 把bert插入到encoder中计算的attention
        x, attn = self.encoder_attn(
            query=x, # 需要输入shape为T,B,C  encoder_out
            key=key_value, # 需要输入shape为T,B,C  bert layer hidden_states
            value=key_value,
            key_padding_mask=key_padding_mask,
            static_kv=True,
            need_weights=(not self.training and self.need_attn),
        )
        align_result = x.transpose(0,1)
        # x = F.dropout(x, p=self.dropout, training=self.training)
        x = residual + x

        # gate_weight = self.gate(torch.cat([residual, x], dim=-1))
        # gate_weight = F.sigmoid(gate_weight)
        # x = residual * gate_weight + x * (1-gate_weight)

        
        x = self.maybe_layer_norm(self.encoder_attn_layer_norm, x, after=True)
        residual = x
        x = self.maybe_layer_norm(self.encoder_attn_final_layer_norm, x, before=True)
        x = self.activation_fn(self.encoder_attn_fc1(x))
        x = F.dropout(x, p=self.activation_dropout, training=self.training)
        x = self.encoder_attn_fc2(x)
        x = F.dropout(x, p=self.dropout, training=self.training)
        x = residual + x
        layer_output = self.maybe_layer_norm(self.encoder_attn_final_layer_norm, x, after=True)
        layer_output = layer_output.transpose(0,1)
        return layer_output, align_result
        
    def maybe_layer_norm(self, layer_norm, x, before=False, after=False):
        assert before ^ after
        if after ^ self.normalize_before:
            return layer_norm(x)
        else:
            return x


def Embedding(num_embeddings, embedding_dim, padding_idx):
    m = nn.Embedding(num_embeddings, embedding_dim, padding_idx=padding_idx)
    nn.init.normal_(m.weight, mean=0, std=embedding_dim ** -0.5)
    nn.init.constant_(m.weight[padding_idx], 0)
    return m


def Linear(in_features, out_features, bias=True):
    m = nn.Linear(in_features, out_features, bias)
    nn.init.xavier_uniform_(m.weight)
    if bias:
        nn.init.constant_(m.bias, 0.0)
    return m


class FusionAdapter2CE(nn.Module):
    def __init__(self, args):
        super(FusionAdapter2CE, self).__init__()
        self.embed_dim = args.decoder_embed_dim
        self.dropout = args.dropout
        self.normalize_before = args.decoder_normalize_before
        self.activation_fn = utils.get_activation_fn(
            activation=getattr(args, 'activation_fn', 'relu')
        )
        self.activation_dropout = getattr(args, 'activation_dropout', 0)
        if self.activation_dropout == 0:
            # for backwards compatibility with models that use args.relu_dropout
            self.activation_dropout = getattr(args, 'relu_dropout', 0)
        export = getattr(args, 'char_inputs', False)
        self.encoder_attn = MultiheadAttention(
            self.embed_dim, args.decoder_attention_heads,
            kdim=getattr(args, 'encoder_embed_dim', None),
            vdim=getattr(args, 'encoder_embed_dim', None),
            dropout=args.attention_dropout, encoder_decoder_attention=True
        )
        self.encoder_attn_layer_norm = LayerNorm(self.embed_dim, export=export)

        self.encoder_attn_fc1 = Linear(self.embed_dim, args.decoder_ffn_embed_dim)
        self.encoder_attn_fc2 = Linear(args.decoder_ffn_embed_dim, self.embed_dim)
        self.encoder_attn_final_layer_norm = LayerNorm(self.embed_dim, export=export)
        self.need_attn = False


    def forward(self, encoder_out, bert_hidden, encoder_padding_mask):        
        key_value = encoder_out
        query = bert_hidden.transpose(0, 1)
       
        key_padding_mask = encoder_padding_mask
        x = query
        
        residual = x
        x = self.maybe_layer_norm(self.encoder_attn_layer_norm, x, before=True)
        # 把bert插入到encoder中计算的attention
        x, attn = self.encoder_attn(
            query=x, # 需要输入shape为T,B,C  encoder_out
            key=key_value, # 需要输入shape为T,B,C  bert layer hidden_states
            value=key_value,
            key_padding_mask=key_padding_mask,
            static_kv=True,
            need_weights=(not self.training and self.need_attn),
        )
        align_result = x.transpose(0, 1)
        
        # x = F.dropout(x, p=self.dropout, training=self.training)
        x = residual + x
        # gate_weight = self.gate(torch.cat([residual, x], dim=-1))
        # gate_weight = F.sigmoid(gate_weight)
        # x = residual * gate_weight + x * (1-gate_weight)


        x = self.maybe_layer_norm(self.encoder_attn_layer_norm, x, after=True)

        residual = x
        x = self.maybe_layer_norm(self.encoder_attn_final_layer_norm, x, before=True)
        x = self.activation_fn(self.encoder_attn_fc1(x))
        x = F.dropout(x, p=self.activation_dropout, training=self.training)
        x = self.encoder_attn_fc2(x)
        x = F.dropout(x, p=self.dropout, training=self.training)
        x = residual + x
        layer_output = self.maybe_layer_norm(self.encoder_attn_final_layer_norm, x, after=True)
        layer_output = layer_output.transpose(0,1)
        # bert adapter above 经过adapter后，layer_output.shape为query的shape, if query=x, layer_output.shape=x, if query=encoder_out, layer_output=encoder_out.shape
        return layer_output, align_result
        
    def maybe_layer_norm(self, layer_norm, x, before=False, after=False):
        assert before ^ after
        if after ^ self.normalize_before:
            return layer_norm(x)
        else:
            return x

class AlignAttention(nn.Module):
    def __init__(self, args):
        super(AlignAttention, self).__init__()
        self.decoder_embed_dim = args.decoder_embed_dim
        self.align_attn = MultiheadAttention(
            self.decoder_embed_dim, args.decoder_attention_heads,
            kdim=getattr(args, 'encoder_embed_dim', None),
            vdim=getattr(args, 'encoder_embed_dim', None),
            dropout=args.attention_dropout, encoder_decoder_attention=True
        )

        self.align_attn_layer_norm = LayerNorm(self.decoder_embed_dim, export=False)
        self.need_attn = False

    def forward(self, key_value, query, key_padding_mask):
        x, attn = self.align_attn(
            query=query, # 需要输入shape为T,B,C  encoder_out
            key=key_value, # 需要输入shape为T,B,C  bert layer hidden_states
            value=key_value,
            key_padding_mask=key_padding_mask,
            static_kv=True,
            need_weights=(not self.training and self.need_attn),
        )
        x = self.align_attn_layer_norm(x)
        
        return x

class BertInitProj(nn.Module):
    """
    This layer use the bert last Linear layer to init (also the wording embedding layer)
    """
    def __init__(self, decoder_bert_model_name, in_features, vocab_size):
        super().__init__()
        bertModel = BertDecoderFull.from_pretrained(decoder_bert_model_name, None)
        bw = bertModel.bert.embeddings.word_embeddings.weight
        self.final_proj = nn.Linear(in_features, vocab_size, bias=False)
        self.final_proj.weight = bw
        
        self.bias = nn.Parameter(torch.zeros(bw.size(0)))
    
    def forward(self, x: torch.Tensor, padding_mask):
        x = self.final_proj(x) + self.bias
        return x, padding_mask



@register_model_architecture("wav2bert_masked_predict_fusion_two_way_ctc_to_bert", "wav2bert_masked_predict_fusion_two_way_ctc_to_bert")
def base_architecture(args):
    args.no_pretrained_weights = getattr(args, "no_pretrained_weights", False)
    args.dropout_input = getattr(args, "dropout_input", 0)
    args.final_dropout = getattr(args, "final_dropout", 0)
    args.apply_mask = getattr(args, "apply_mask", False)
    args.dropout = getattr(args, "dropout", 0)
    args.attention_dropout = getattr(args, "attention_dropout", 0)
    args.activation_dropout = getattr(args, "activation_dropout", 0)

    args.mask_length = getattr(args, "mask_length", 10)
    args.mask_prob = getattr(args, "mask_prob", 0.5)
    args.mask_selection = getattr(args, "mask_selection", "static")
    args.mask_other = getattr(args, "mask_other", 0)
    args.no_mask_overlap = getattr(args, "no_mask_overlap", False)
    args.mask_channel_length = getattr(args, "mask_channel_length", 10)
    args.mask_channel_prob = getattr(args, "mask_channel_prob", 0.5)
    args.mask_channel_selection = getattr(args, "mask_channel_selection", "static")
    args.mask_channel_other = getattr(args, "mask_channel_other", 0)
    args.no_mask_channel_overlap = getattr(args, "no_mask_channel_overlap", False)

    args.freeze_finetune_updates = getattr(args, "freeze_finetune_updates", 0)
    args.feature_grad_mult = getattr(args, "feature_grad_mult", 0)
    args.layerdrop = getattr(args, "layerdrop", 0.0)

    # args from transformer_nat_ymask_bert_two_adapter
    args.decoder_embed_path = getattr(args, 'decoder_embed_path', None)
    args.fusion_v2 = getattr(args, 'fusion_v2', None)
    args.fusion_v3 = getattr(args, 'fusion_v3', None)
    args.decoder_normalize_before = getattr(args, 'decoder_normalize_before', False)
    args.decoder_learned_pos = getattr(args, 'decoder_learned_pos', False)
    args.attention_dropout = getattr(args, 'attention_dropout', 0.)
    args.activation_dropout = getattr(args, 'activation_dropout', 0.)
    args.activation_fn = getattr(args, 'activation_fn', 'relu')
    args.dropout = getattr(args, 'dropout', 0.1)
    args.adaptive_softmax_cutoff = getattr(args, 'adaptive_softmax_cutoff', None)
    args.adaptive_softmax_dropout = getattr(args, 'adaptive_softmax_dropout', 0)
    args.share_decoder_input_output_embed = getattr(args, 'share_decoder_input_output_embed', False)
    args.share_all_embeddings = getattr(args, 'share_all_embeddings', False)
    args.no_token_positional_embeddings = getattr(args, 'no_token_positional_embeddings', False)
    args.adaptive_input = getattr(args, 'adaptive_input', False)

    
    args.decoder_embed_dim = getattr(args, 'decoder_embed_dim', 768)
    args.decoder_ffn_embed_dim = getattr(args, 'decoder_ffn_embed_dim', 2048)
    args.decoder_attention_heads = getattr(args, 'decoder_attention_heads', 8)
    args.decoder_layers = getattr(args, 'decoder_layers', 6)


    # args from transformer_nat_ymask_bert_two_adapter_deep_small
    args.decoder_embed_dim = getattr(args, 'decoder_embed_dim', 768)
    args.decoder_ffn_embed_dim = getattr(args, 'decoder_ffn_embed_dim', 512)
    args.decoder_attention_heads = getattr(args, 'decoder_attention_heads', 4)
    args.decoder_layers = getattr(args, 'decoder_layers', 5)
    args.encoder_embed_dim = getattr(args, 'encoder_embed_dim', 768)


    args.decoder_output_dim = getattr(args, 'decoder_output_dim', args.decoder_embed_dim)
    args.decoder_input_dim = getattr(args, 'decoder_input_dim', args.decoder_embed_dim)
    args.finetune_whole_encoder = getattr(args, 'finetune_whole_encoder', False)
    args.train_from_scratch = getattr(args, 'train_from_scratch', False)