# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.

import math
import numpy as np

import torch
import torch.nn.functional as F

from fairseq import search, utils, checkpoint_utils
from fairseq.data import data_utils

from bert import BertTokenizer
from bert import BertForPreTraining
from fairseq.logging.meters import safe_round
from fairseq.data.data_utils import post_process

def assign_single_value_byte(x, i, y):
    x.view(-1)[i.view(-1).nonzero()] = y


def assign_multi_value_byte(x, i, y):
    x.view(-1)[i.view(-1).nonzero()] = y.view(-1)[i.view(-1).nonzero()]


def assign_single_value_long(x, i, y):
    b, l = x.size()
    i = i + torch.arange(0, b*l, l, device=i.device).unsqueeze(1)
    x.view(-1)[i.view(-1)] = y


def assign_multi_value_long(x, i, y):
    b, l = x.size()
    i = i + torch.arange(0, b*l, l, device=i.device).unsqueeze(1)
    x.view(-1)[i.view(-1)] = y.view(-1)[i.view(-1)]


class SequenceGeneratorWithFusion(object):
    def __init__(
        self,
        tgt_dict,
        beam_size=1,
        max_len_a=0,
        max_len_b=200,
        min_len=1,
        stop_early=True,
        normalize_scores=True,
        len_penalty=1.,
        unk_penalty=0.,
        retain_dropout=False,
        sampling=False,
        sampling_topk=-1,
        temperature=1.,
        diverse_beam_groups=-1,
        diverse_beam_strength=0.5,
        match_source_len=False,
        no_repeat_ngram_size=0,
        mask_pred_iter=10,
        decode_use_adapter=False,
        args=None,
    ):
        """Generates translations of a given source sentence.

        Args:
            tgt_dict (~fairseq.data.Dictionary): target dictionary
            beam_size (int, optional): beam width (default: 1)
            max_len_a/b (int, optional): generate sequences of maximum length
                ax + b, where x is the source length
            min_len (int, optional): the minimum length of the generated output
                (not including end-of-sentence)
            stop_early (bool, optional): stop generation immediately after we
                finalize beam_size hypotheses, even though longer hypotheses
                might have better normalized scores (default: True)
            normalize_scores (bool, optional): normalize scores by the length
                of the output (default: True)
            len_penalty (float, optional): length penalty, where <1.0 favors
                shorter, >1.0 favors longer sentences (default: 1.0)
            unk_penalty (float, optional): unknown word penalty, where <0
                produces more unks, >0 produces fewer (default: 0.0)
            retain_dropout (bool, optional): use dropout when generating
                (default: False)
            sampling (bool, optional): sample outputs instead of beam search
                (default: False)
            sampling_topk (int, optional): only sample among the top-k choices
                at each step (default: -1)
            temperature (float, optional): temperature, where values
                >1.0 produce more uniform samples and values <1.0 produce
                sharper samples (default: 1.0)
            diverse_beam_groups/strength (float, optional): parameters for
                Diverse Beam Search sampling
            match_source_len (bool, optional): outputs should match the source
                length (default: False)
        """
        self.pad = tgt_dict.pad()
        self.unk = tgt_dict.unk()
        self.eos = tgt_dict.eos()
        self.mask = tgt_dict.mask()
        self.vocab_size = len(tgt_dict)
        self.beam_size = beam_size
        # the max beam size is the dictionary size - 1, since we never select pad
        self.beam_size = min(beam_size, self.vocab_size - 1)
        self.max_len_a = max_len_a
        self.max_len_b = max_len_b
        self.min_len = min_len
        self.stop_early = stop_early
        self.normalize_scores = normalize_scores
        self.len_penalty = len_penalty
        self.unk_penalty = unk_penalty
        self.retain_dropout = retain_dropout
        self.temperature = temperature
        self.match_source_len = match_source_len
        self.no_repeat_ngram_size = no_repeat_ngram_size
        self.mask_pred_iter = mask_pred_iter

        self.tgt_dict = tgt_dict
        self.decode_use_adapter = decode_use_adapter

        assert sampling_topk < 0 or sampling, '--sampling-topk requires --sampling'
        assert temperature > 0, '--temperature must be greater than 0'

        if sampling:
            self.search = search.Sampling(tgt_dict, sampling_topk)
        elif diverse_beam_groups > 0:
            self.search = search.DiverseBeamSearch(tgt_dict, diverse_beam_groups, diverse_beam_strength)
        elif match_source_len:
            self.search = search.LengthConstrainedBeamSearch(
                tgt_dict, min_len_a=1, min_len_b=0, max_len_a=1, max_len_b=0,
            )
        else:
            self.search = search.BeamSearch(tgt_dict)
         
        
        self.two_way_ctc = getattr(args, "two_way_ctc", False)
        self.two_way_ce = getattr(args, "two_way_ce", False)

    def wer_computed_by_validate(self, decoder_out, target, prob=False):
        # import ipdb; ipdb.set_trace()
        import editdistance
        with torch.no_grad():
            if prob:
                decoder_out_t = decoder_out.float().contiguous().cpu() # B,tgt_len,C
            else:
                decoder_out_t = decoder_out.contiguous().cpu() # B,tgt_len
            # whole seq's WER with the origin_target seq
            # predicted masked seq's WER with the origin masked value of target seq

            c_len = 0
            w_len = 0
            dout_c_err = 0
            dout_w_errs = 0
            dout_wv_errs = 0
            for dout, t in zip(
                decoder_out_t,
                target,
            ):
                p = (t != self.tgt_dict.pad()) & (
                    t != self.tgt_dict.eos()
                )
                targ = t[p]
                targ_units = self.tgt_dict.string(targ)
                targ_units_arr = targ.tolist()

                p = (dout != self.tgt_dict.pad()) & (
                    dout != self.tgt_dict.eos()
                )
                dout = dout[p]
                if prob:
                    # dout_toks = dout.argmax(dim=-1).unique_consecutive() # B,tgt_len,C 
                    dout_toks = dout.argmax(dim=-1)
                else:
                    dout_toks = dout # B,tgt_len unique_consecutive预计是CTC用比较好
                dout_pred_units_arr = dout_toks[dout_toks != self.tgt_dict.bos()].tolist()

                c_len += len(targ_units_arr)
                c_dist = editdistance.eval(dout_pred_units_arr, targ_units_arr)
                dout_c_err += c_dist
                print("single CER: ", c_dist/len(targ_units_arr))
                targ_words = post_process(targ_units, 'letter').split()

                dout_pred_units = self.tgt_dict.string(dout_pred_units_arr)
                dout_pred_words_raw = post_process(dout_pred_units, 'letter').split()

                dout_dist = editdistance.eval(dout_pred_words_raw, targ_words)
                dout_w_errs += dout_dist
                dout_wv_errs += dout_dist
                w_len += len(targ_words)
                print("single WER: ", dout_dist/len(targ_words))
                print("predict WORDS: ", dout_pred_words_raw)
                print("target WORDS: ", targ_words)
            uer = safe_round(
                    dout_c_err * 100.0 / c_len, 3
                )
            wer = safe_round(
                    dout_w_errs * 100.0 / w_len, 3
                )
            raw_wer = safe_round(
                    dout_wv_errs * 100.0 / w_len, 3
                )
            print("Avg UER: {}; Avg WER: {}; Avg: RAW_WER: {}.".format(uer, wer, raw_wer))
            # 通过debug发现预测出来的结果于gt相差甚远
            # import ipdb; ipdb.set_trace()
        return
    
    def getFirstInputByCTC(self, ctc_outputs, padding_mask, beam_size, copy_batches, device):
        predicts = []
        predicts_probs = []
        ctc_outputs = F.softmax(ctc_outputs, dim=-1)
        ctc_outputs = ctc_outputs.transpose(0, 1).float().contiguous()
        non_padding_mask = ~padding_mask
        input_lengths = non_padding_mask.long().sum(-1)
        for lp, inp_l in zip(
            ctc_outputs,
            input_lengths,
        ):
            lp = lp[:inp_l].unsqueeze(0)
            max_probs, idx = lp.max(dim=-1)
            toks, inverse_indices = idx.unique_consecutive(return_inverse=True) # toks: size([pred_len]), inverse_indices存储着原始index地方对应的toks对应index, [input_len]
            max_probs_unique_consecutive = []
            for index in range(toks.size(-1)): # 循环toks的index,即循环预测的token的下标，toks[index]才是预测出的符号
                if toks[index] == self.tgt_dict.bos(): # 跳过blank 
                    continue
                consecutive_index = (inverse_indices == index) # inverse_indices中合并成toks该index的 连续index为True，其他为False，即取了CTC decoding合并成对应tok的部分的下标集
                max_probs_unique_consecutive.append(max(max_probs[consecutive_index])) # max_probs[consecutive_index]合成该tok的对应的连续toks的概率，取最大的作为该tok的概率
            pred_units_arr = toks[toks != self.tgt_dict.bos()].tolist() # 该sample的CTC decoding output
            predicts.append(torch.LongTensor(list(pred_units_arr)))
            predicts_probs.append(torch.FloatTensor(list(max_probs_unique_consecutive))) # 即取出ctc output的每个位置概率，每个位置概率取的是原先连续符号中最大的概率
        predicts_len =  torch.LongTensor([len(t) for t in predicts])
        predicts = data_utils.collate_tokens(predicts, pad_idx=self.pad, left_pad=False) # B,max_pred_len
        predicts_probs = data_utils.collate_tokens(predicts_probs, pad_idx=1.0, left_pad=False)
        predicts = utils.move_to_cuda(
            predicts, device=device
        )
        predicts_len = utils.move_to_cuda(
            predicts_len, device=device
        )
        predicts_probs = utils.move_to_cuda(
            predicts_probs, device=device
        )
        
        return copy_batches(predicts, beam_size), copy_batches(predicts_probs, beam_size), predicts_len #将结果copy成beam 
        
    @torch.no_grad()
    def generate(
        self,
        models,
        sample,
        prefix_tokens=None,
        bos_token=None,
        tgt_bert_encoder=None, 
        tgt_bert_tokenizer=None,
        cif=None,
        different_tokens=False,
        tgt_dictionary=None,
        **kwargs
    ):
        """Generate a batch of translations.

        Args:
            models (List[~fairseq.models.FairseqModel]): ensemble of models
            sample (dict): batch
            prefix_tokens (torch.LongTensor, optional): force decoder to begin
                with these tokens
        """
        model = EnsembleModel(models)
        if not self.retain_dropout:
            model.eval()

        # model.forward normally channels prev_output_tokens into the decoder
        # separately, but SequenceGenerator directly calls model.encoder
        encoder_input = {
            k: v for k, v in sample['net_input'].items()
            if k != 'prev_output_tokens' and k != 'bert_input'
        }#不需要prev_output_tokens,dict_keys(['src_tokens', 'src_lengths'])

        src_tokens = encoder_input['source']
        src_lengths = (src_tokens.ne(self.eos) & src_tokens.ne(self.pad)).long().sum(dim=1)
        input_size = src_tokens.size()#[32, 4]
        # batch dimension goes first followed by source lengths
        bsz = input_size[0]
        src_len = input_size[1]
        beam_size = self.beam_size#4

        finalized = [[] for i in range(bsz)]
        
        # non-autoregressive decoding

        def get_hypo_nat(decoded_id):
            return {
                'tokens': decoded_id,
                'score': 0.0,
                'attention': None,  # src_len x tgt_len
                'alignment': None,
                'positional_scores': torch.Tensor([0.0]),
            }

        def copy_batches(tensor, num_copies):
            if tensor is None:
                return None
            x_size = tensor.size()
            tensor = tensor.contiguous().view(x_size[0], 1, -1)
            tensor = tensor.repeat(1, num_copies, 1)
            if len(x_size)==2:
                return tensor.view(-1, x_size[1])
            elif len(x_size)==3:
                return tensor.view(-1, x_size[1], x_size[2])
            else:
                raise NotImplementedError

        def select_worst(token_probs, num_mask):
            bsz, seq_len = token_probs.size()
            masks = [token_probs[batch, :].topk(max(1, num_mask[batch]), largest=False, sorted=False)[1] for batch in range(bsz)]
            masks = [torch.cat([mask, mask.new(seq_len - mask.size(0)).fill_(mask[0])], dim=0) for mask in masks]
            return torch.stack(masks, dim=0)

        w2v_args = {
            "source": src_tokens,#B*src_len[1, 259840]
            "padding_mask": encoder_input['padding_mask'],#B*src_len[1, 259840]
            "mask": False,
        }
        with torch.no_grad():
            x, padding_mask = model.models[0].w2v_encoder.w2v_model.extract_features(**w2v_args)#B,T,C [1, 811, 768]
            ctc_output = x.transpose(0, 1)
            ctc_output = model.models[0].w2v_encoder.encoder_dropout(ctc_output)
            ctc_output = model.models[0].w2v_encoder.encoder_proj(ctc_output)

        if cif:
            print('cif...')
            cif_output, _, encoder_padding_mask = model.models[0].w2v_encoder.CIF(x, padding_mask)
            decoder_input = model.models[0].w2v_encoder.CIF_proj(cif_output)
            # B x T x C -> T x B x C
            decoder_input = decoder_input.permute(1,0,2).contiguous()#[811, 1, 768]
            # add bert here
            encoder_outs = [{
                'encoder_out': decoder_input, # T x B x C [811, 1, 768]
                'encoder_padding_mask': encoder_padding_mask, # B x T [1, 811]
            }]
        else:
            print('w/o cif...')
            decoder_input = x.permute(1,0,2).contiguous()
            encoder_padding_mask = padding_mask
            encoder_outs = [{
                'encoder_out': decoder_input,
                'encoder_padding_mask': encoder_padding_mask,
            }]

        ###### debug ######## 3 when 3 on 用CTC output的作为mask predict的初次输入，且mask掉低概率的token
        print("Mask predict First input: ctc output..")
        tgt_tokens, token_probs, _ = self.getFirstInputByCTC(ctc_output, padding_mask, beam_size, copy_batches, decoder_input.device)
        
        tgt_tokens = copy_batches(sample['target'], beam_size)
        
        # 第一轮fusion的结果不进行mask，不进入下面的循环中，因此在转换token的时候也不需要调整概率的问题
        if different_tokens:
            new_tgt_tokens = []
            for tgt_token in tgt_tokens:
                targ_units = tgt_dictionary.string(tgt_token)
                targ_words = post_process(targ_units, "letter")
                targ_words += '.'
                
                word_tgt_tokens = tgt_bert_tokenizer.encode_line(targ_words, post_proces='bert_bpe_piece')
                word_tgt_tokens.insert(0, tgt_bert_tokenizer.cls())
                word_tgt_tokens.append(tgt_bert_tokenizer.sep())
                
                new_tgt_tokens.append(torch.LongTensor(list(word_tgt_tokens)))
            new_tgt_tokens = data_utils.collate_tokens(new_tgt_tokens, pad_idx=self.pad, left_pad=False)
            new_tgt_tokens = utils.move_to_cuda(new_tgt_tokens, device=decoder_input.device)
            tgt_tokens = new_tgt_tokens
        
        pad_mask = tgt_tokens.eq(self.pad)#tgt_tokens对应为pad即非mask而是pad的地方为True为1
        seq_lens = tgt_tokens.size(1) - pad_mask.sum(dim=1)#tgt_tokens每个elem对应的103的序列长度 shape [128]

        encoder_outs_value = copy_batches(encoder_outs[0]['encoder_out'].transpose(0,1), beam_size)#[32, 4,768]->copy beam_size份：[128, 4, 768]
        encoder_outs_value = encoder_outs_value.transpose(0,1)#[4, 128, 768] [src_len,B*beam_size,dim]
        encoder_padding = copy_batches(encoder_outs[0]['encoder_padding_mask'], beam_size)#[32, 4]->[128, 4]

        encoder_outs = [{'encoder_out': encoder_outs_value, 'encoder_padding_mask': encoder_padding}]

        # 这里相当于已经用bertdezcoder decode了一次
        tgt_tokens, token_probs, _ = self.getFirstInputByCTC(model.forward_decoder(
            tgt_tokens, encoder_outs, #tgt_tokens充当训练时的prev_output_token其实就是bert decoder model的输入,这个tgt_tokens目前是全部被masked了
            src_tokens=None, 
            temperature=self.temperature,
            two_way_ctc=self.two_way_ctc,
            two_way_ce=self.two_way_ce,
        ), encoder_padding_mask, beam_size, copy_batches, decoder_input.device)
        pad_mask = tgt_tokens.eq(self.pad)#tgt_tokens对应为pad即非mask而是pad的地方为True为1
        seq_lens = tgt_tokens.size(1) - pad_mask.sum(dim=1)#tgt_tokens每个elem对应的103的序列长度 shape [128]

        for i in range(1, self.mask_pred_iter+1):#self.mask_pred_iter=10 wer 98-100% iter=100 wer 88%最后还是100%
            # for tgt_token in tgt_tokens:
            if different_tokens:
                # 这里不仅要转换token，还要转换概率那些
                new_tgt_tokens = []
                new_tgt_probs = []
                # print(len(tgt_tokens))
                for tgt_token,token_prob in zip(tgt_tokens, token_probs):
                    targ_units = tgt_dictionary.string(tgt_token) ## too slow
                    mask_num = torch.sum(tgt_token.eq(self.pad))
                    mask_num = int(mask_num.cpu())
                    mask_index = len(tgt_token) - mask_num
                       
                    # 这里虽然看不到一些特殊字符，但是还是存在的
                    targ_words = post_process(targ_units, "letter")
                    
                    targ_words += '.'
                    
                    # 在encode的时候basic tokenizer的clean text会去除这部分的字符，导致粉刺的结果里那个字符已经不存在了
                    word_tgt_units = tgt_bert_tokenizer.encode_line(targ_words, post_proces='bert_bpe_piece', not_id=True)
                    word_tgt_tokens = tgt_bert_tokenizer.encode_line(targ_words, post_proces='bert_bpe_piece')
                    word_tgt_tokens.insert(0, tgt_bert_tokenizer.cls())
                    word_tgt_tokens.append(tgt_bert_tokenizer.sep())
                    
                    ### 很难通过将字符和subword一一对应回来，bengali中有些字符会被清除掉，现在假设的一个方案是先用word对应上，然后subword里在区分
                    # 首先是验证个数是不是相同
                    # 第一个是char 转成句子后split出的word个数
                    tag_word_split = [w for w in ''.join(targ_units.split()).split("|") if w != ""]
                    split_index = [i for i, ltr in enumerate(targ_units.split()) if ltr == '|']
                    # 这里要注意两种情况，一种是连续两个分隔符，一种是最后少了一个分隔符
                    # 最后少了的就直接append一个就行
                    if split_index[-1] != len(targ_units.split()) - 1:
                        split_index.append(len(targ_units.split()) - 1)
                    
                    #print("tag_word_split", len(tag_word_split))
                    #print("split_index", len(split_index))
                    # 接着是encode出来的subword
                    encode_subword = [w for w in word_tgt_units if ("##" not in w) and ("." not in w)]
                    #print("encode_subword", len(encode_subword))
                    #print(targ_units.split())
                    #print(len(targ_units.split()))
                    #print(targ_units)
                    #print(tag_word_split)
                    #print(encode_subword)
                    #print(word_tgt_units)
                    
                    # 要把subword对应回原来的letter，subword的概率则是letter的概率求和
                    pre_word_idx = 0
                    split_count = 0
                    subword_probs = [1] # cls的概率
                    #for subword_index in range(len(word_tgt_units)):
                    subword_index = -1
                    while subword_index + 1 < len(word_tgt_units) - 1: # 去掉句号的影响
                        subword_index += 1
                        # 如果不是#subword，就先把概率求出来
                        current_word_idx = split_index[split_count]
                        split_count += 1
                        # 避免连续两个
                        if current_word_idx == pre_word_idx: # 因为这里pre是+1，所以直接判断是否相等
                            current_word_idx = split_index[split_count]
                            split_count += 1
                        
                        word_prob = token_prob[pre_word_idx: current_word_idx]
                        word = targ_units.split()[pre_word_idx: current_word_idx]
                        #print("current word", word, pre_word_idx, current_word_idx)
                        
                        pre_word_idx = current_word_idx + 1
                        
                        current_subwords = [word_tgt_units[subword_index]]
                        # 后面再开始判断后面是否有#subword
                        next_subword_index = subword_index + 1
                        while next_subword_index < len(word_tgt_units) and "#" in word_tgt_units[next_subword_index]:
                            subword_index += 1
                            current_subwords.append(word_tgt_units[next_subword_index][2:])# 去除掉##标志
                            next_subword_index += 1
                        
                        #print("current subwords", current_subwords)
                        # 接下来再逐一把对应的character概率求平均
                        character_index = 0 # 这里的下标对应的是word_prob，而不是原来大的概率数组
                        for subword in current_subwords:
                            subword_prob = 0
                            begin_indexs = []
                            for c in subword:
                                prob_get = False
                                #print("current char in subword", c)
                                temp_character_index = character_index
                                begin_indexs.append(temp_character_index)
                                while character_index < len(word):
                                    #print("current char word", word[character_index])
                                    if c == word[character_index]:
                                        subword_prob += float(word_prob[character_index].cpu())
                                        prob_get = True
                                        character_index += 1
                                        break
                                    character_index += 1 # 有可能一些在encode的时候去除掉了，这种暂时不考虑把
                                # assert prob_get == True
                                # 这里的符号很奇怪，如果找不到就暂时用均值替代
                                if not prob_get:
                                    #print("not found", c, [w for w in word], subword, temp_character_index, character_index, begin_indexs)
                                    #print("current word", word, pre_word_idx, current_word_idx)
                                    #print("current subwords", current_subwords)
                                    character_index = temp_character_index
                                    subword_prob += float(torch.mean(word_prob).cpu())
                                #else:
                                    #print("found", c, [w for w in word], subword, temp_character_index, character_index, begin_indexs)
                                    
                            
                            subword_prob /= len(subword)
                            
                            # 这里在转化的过程中可能会产生unk，可能是ctc 结果没有很好预测出
                            if subword == '[UNK]':
                                subword_probs.append(0)
                            else:
                                subword_probs.append(subword_prob)
                    subword_probs.append(1) # 句号
                    subword_probs.append(1) # sep
                    
                    ## debug 一下概率的问题
                    #print("subword level")
                    #print(word_tgt_units)
                    #print(subword_probs)
                    #print("char level")
                    #print(targ_units)
                    #print(token_prob)
                    
                    #print(len(word_tgt_tokens))
                    #print(len(subword_probs))
                    new_tgt_probs.append(torch.FloatTensor(subword_probs))
                    new_tgt_tokens.append(torch.LongTensor(list(word_tgt_tokens)))

                
                new_tgt_tokens = data_utils.collate_tokens(new_tgt_tokens, pad_idx=self.pad, left_pad=False)
                new_tgt_tokens = utils.move_to_cuda(new_tgt_tokens, device=decoder_input.device)

                new_tgt_probs = data_utils.collate_tokens(new_tgt_probs, pad_idx=1, left_pad=False)
                new_tgt_probs = utils.move_to_cuda(new_tgt_probs, device=decoder_input.device)


                tgt_tokens = new_tgt_tokens
                token_probs = new_tgt_probs
                #print("debug")
                #print(tgt_tokens.shape)
                #print(token_probs.shape)

                pad_mask = tgt_tokens.eq(self.pad)
                seq_lens = tgt_tokens.size(1) - pad_mask.sum(dim=1)

            num_mask = (seq_lens.float()*(1.0-i/(self.mask_pred_iter))).long()#将mask的长度变短，按照(1.0-i/self.mask_pred_iter)概率逐步减少mask,最后num_mask会全部变为0，即不mask
            print("num mask", num_mask)
            assign_single_value_byte(token_probs, pad_mask, 1.0)#把pad的decoder_out_probs置为1.0
            mask_ind = select_worst(token_probs, num_mask)#选择token_probs最差的num_mask长度继续mask这些mask_index
            assign_single_value_long(tgt_tokens, mask_ind, self.mask)#把mask_ind相应的地方置为mask
            assign_single_value_byte(tgt_tokens, pad_mask, self.pad)#把pad的地方置为pad

            new_tgt_tokens, new_token_probs, _ = self.getFirstInputByCTC(model.forward_decoder(
                tgt_tokens, encoder_outs, 
                src_tokens=None, 
                temperature=self.temperature,#self.temperature=1.0
                two_way_ctc=self.two_way_ctc,
                two_way_ce=self.two_way_ce,
            ), encoder_padding_mask, beam_size, copy_batches, decoder_input.device)

            tgt_tokens = new_tgt_tokens
            token_probs = new_token_probs
            pad_mask = tgt_tokens.eq(self.pad)#tgt_tokens对应为pad即非mask而是pad的地方为True为1
            seq_lens = tgt_tokens.size(1) - pad_mask.sum(dim=1)#tgt_tokens每个elem对应的103的序列长度 shape [128]
            ####### debug #######
            # self.wer_computed_by_validate(tgt_tokens, sample['target'])
            #####################   
            
            
        
        #最后一次的token_probs和tgt_tokens为最后的结果
        max_len = tgt_tokens.size(1)
        lprobs = token_probs.log().sum(-1)#将token_probs的每个batch的每个beam算出其预测的句子的prob之和
        hypotheses = tgt_tokens.view(bsz, beam_size, max_len)
        lprobs = lprobs.view(bsz, beam_size)
        ################## 其他作为input时有len penalty ##################
        # # add len penalty
        # tgt_lengths = (1 - length_mask).sum(-1)#predict的target长度
        # length_penalty = ((5.0 + tgt_lengths.float()) ** self.len_penalty
        #                   / (6.0 ** self.len_penalty))
        # length_penalty = length_penalty.view((bsz, beam_size))
        # avg_log_prob = lprobs / length_penalty
        # best_lengths = avg_log_prob.max(-1)[1]#avg_log_prob的prob最大的beam对应的index
        ################## CTC output作为input时,best_lengths为下面 ##################
        best_lengths = lprobs.max(-1)[1]
        ##################

        hypotheses = torch.stack([hypotheses[b, l, :] for b, l in enumerate(best_lengths)], dim=0)#最优predict_length的beam的hypo B,tgt_max_len

        for i in range(bsz):
            finalized[i].append(get_hypo_nat(hypotheses[i]))
        return finalized


class EnsembleModel(torch.nn.Module):
    """A wrapper around an ensemble of models."""

    def __init__(self, models):
        super().__init__()
        self.models = torch.nn.ModuleList(models)
        self.incremental_states = None

    def has_encoder(self):
        return True
        # return hasattr(self.models[0], 'encoder')

    def max_decoder_positions(self):
        return min(m.max_decoder_positions() for m in self.models)

    @torch.no_grad()
    def forward_encoder(self, encoder_input):
        if not self.has_encoder():
            return None
        return [model.encoder(**encoder_input) for model in self.models]

    @torch.no_grad()
    def forward_decoder(self, tokens, encoder_outs, src_tokens=None, temperature=1., two_way_ctc=False, two_way_ce=False):
        if len(self.models) == 1:
            if two_way_ctc:
                return self._decode_two_way_ctc(
                    tokens,
                    self.models[0],
                    encoder_outs[0] if self.has_encoder() else None,
                    log_probs=True,
                    src_tokens=src_tokens,
                    temperature=temperature,
                )
            else:
                return self._decode_one(
                    tokens,
                    self.models[0],
                    encoder_outs[0] if self.has_encoder() else None,
                    log_probs=True,
                    src_tokens=src_tokens,
                    temperature=temperature,
                )

        log_probs = []
        avg_attn = None
        for model, encoder_out in zip(self.models, encoder_outs):
            probs, attn = self._decode_one(
                tokens,
                model,
                encoder_out,
                log_probs=True,
                src_tokens=src_tokens,
                temperature=temperature,
            )
            log_probs.append(probs)
            if attn is not None:
                if avg_attn is None:
                    avg_attn = attn
                else:
                    avg_attn.add_(attn)
        avg_probs = torch.logsumexp(torch.stack(log_probs, dim=0), dim=0) - math.log(len(self.models))
        if avg_attn is not None:
            avg_attn.div_(len(self.models))
        return avg_probs, avg_attn
    
    def _decode_two_way_ctc(
        self, tokens, model, encoder_out, log_probs, src_tokens=None, temperature=1.,
    ):
        decoder_out = list(model.w2v_encoder.bertdecoder(tokens, src_tokens=src_tokens, encoder_out=encoder_out))#[tensor[128, 6, 30522]B,tgt_len,vocab_size]
        if temperature != 1.:
            decoder_out[0].div_(temperature)

        fusion_ctc_out = decoder_out[2]
        ctc_output = fusion_ctc_out.transpose(0, 1)
        ctc_output = model.w2v_encoder.final_dropout_ctc(ctc_output)
        ctc_output = model.w2v_encoder.ctc_proj(ctc_output)
        return ctc_output
    
    def _decode_one(
        self, tokens, model, encoder_out, log_probs, src_tokens=None, temperature=1.,
    ):
        # decoder_out, _, fusion_out
        decoder_out = list(model.w2v_encoder.bertdecoder(tokens, src_tokens=src_tokens, encoder_out=encoder_out))#[tensor[128, 6, 30522]B,tgt_len,vocab_size]
        if temperature != 1.:
            decoder_out[0].div_(temperature)

        fusion_out = decoder_out[2]
        ctc_output = fusion_out.transpose(0, 1)
        ctc_output = model.w2v_encoder.final_dropout(ctc_output)
        ctc_output = model.w2v_encoder.proj(ctc_output)
        return ctc_output

    def reorder_encoder_out(self, encoder_outs, new_order):
        if not self.has_encoder():
            return
        return [
            model.encoder.reorder_encoder_out(encoder_out, new_order)
            for model, encoder_out in zip(self.models, encoder_outs)
        ]

    def reorder_incremental_state(self, new_order):
        if self.incremental_states is None:
            return
        for model in self.models:
            model.w2v_encoder.bertdecoder.reorder_incremental_state(self.incremental_states[model], new_order)
