# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.

import math
import numpy as np

import torch
import torch.nn.functional as F

from fairseq import search, utils, checkpoint_utils
from fairseq.data import data_utils

from bert import BertTokenizer
from bert import BertForPreTraining
from fairseq.logging.meters import safe_round
from fairseq.data.data_utils import post_process

def assign_single_value_byte(x, i, y):
    x.view(-1)[i.view(-1).nonzero()] = y


def assign_multi_value_byte(x, i, y):
    x.view(-1)[i.view(-1).nonzero()] = y.view(-1)[i.view(-1).nonzero()]


def assign_single_value_long(x, i, y):
    b, l = x.size()
    i = i + torch.arange(0, b*l, l, device=i.device).unsqueeze(1)
    x.view(-1)[i.view(-1)] = y


def assign_multi_value_long(x, i, y):
    b, l = x.size()
    i = i + torch.arange(0, b*l, l, device=i.device).unsqueeze(1)
    x.view(-1)[i.view(-1)] = y.view(-1)[i.view(-1)]


class SequenceGeneratorWithBert(object):
    def __init__(
        self,
        tgt_dict,
        beam_size=1,
        max_len_a=0,
        max_len_b=200,
        min_len=1,
        stop_early=True,
        normalize_scores=True,
        len_penalty=1.,
        unk_penalty=0.,
        retain_dropout=False,
        sampling=False,
        sampling_topk=-1,
        temperature=1.,
        diverse_beam_groups=-1,
        diverse_beam_strength=0.5,
        match_source_len=False,
        no_repeat_ngram_size=0,
        mask_pred_iter=10,
        decode_use_adapter=False,
        args=None,
    ):
        """Generates translations of a given source sentence.

        Args:
            tgt_dict (~fairseq.data.Dictionary): target dictionary
            beam_size (int, optional): beam width (default: 1)
            max_len_a/b (int, optional): generate sequences of maximum length
                ax + b, where x is the source length
            min_len (int, optional): the minimum length of the generated output
                (not including end-of-sentence)
            stop_early (bool, optional): stop generation immediately after we
                finalize beam_size hypotheses, even though longer hypotheses
                might have better normalized scores (default: True)
            normalize_scores (bool, optional): normalize scores by the length
                of the output (default: True)
            len_penalty (float, optional): length penalty, where <1.0 favors
                shorter, >1.0 favors longer sentences (default: 1.0)
            unk_penalty (float, optional): unknown word penalty, where <0
                produces more unks, >0 produces fewer (default: 0.0)
            retain_dropout (bool, optional): use dropout when generating
                (default: False)
            sampling (bool, optional): sample outputs instead of beam search
                (default: False)
            sampling_topk (int, optional): only sample among the top-k choices
                at each step (default: -1)
            temperature (float, optional): temperature, where values
                >1.0 produce more uniform samples and values <1.0 produce
                sharper samples (default: 1.0)
            diverse_beam_groups/strength (float, optional): parameters for
                Diverse Beam Search sampling
            match_source_len (bool, optional): outputs should match the source
                length (default: False)
        """
        self.pad = tgt_dict.pad()
        self.unk = tgt_dict.unk()
        self.eos = tgt_dict.eos()
        self.mask = tgt_dict.mask()
        self.vocab_size = len(tgt_dict)
        self.beam_size = beam_size
        # the max beam size is the dictionary size - 1, since we never select pad
        self.beam_size = min(beam_size, self.vocab_size - 1)
        self.max_len_a = max_len_a
        self.max_len_b = max_len_b
        self.min_len = min_len
        self.stop_early = stop_early
        self.normalize_scores = normalize_scores
        self.len_penalty = len_penalty
        self.unk_penalty = unk_penalty
        self.retain_dropout = retain_dropout
        self.temperature = temperature
        self.match_source_len = match_source_len
        self.no_repeat_ngram_size = no_repeat_ngram_size
        self.mask_pred_iter = mask_pred_iter

        self.tgt_dict = tgt_dict
        self.decode_use_adapter = decode_use_adapter

        assert sampling_topk < 0 or sampling, '--sampling-topk requires --sampling'
        assert temperature > 0, '--temperature must be greater than 0'

        if sampling:
            self.search = search.Sampling(tgt_dict, sampling_topk)
        elif diverse_beam_groups > 0:
            self.search = search.DiverseBeamSearch(tgt_dict, diverse_beam_groups, diverse_beam_strength)
        elif match_source_len:
            self.search = search.LengthConstrainedBeamSearch(
                tgt_dict, min_len_a=1, min_len_b=0, max_len_a=1, max_len_b=0,
            )
        else:
            self.search = search.BeamSearch(tgt_dict)

    def wer_computed_by_validate(self, decoder_out, target, prob=False):
        # import ipdb; ipdb.set_trace()
        import editdistance
        with torch.no_grad():
            if prob:
                decoder_out_t = decoder_out.float().contiguous().cpu() # B,tgt_len,C
            else:
                decoder_out_t = decoder_out.contiguous().cpu() # B,tgt_len
            # whole seq's WER with the origin_target seq
            # predicted masked seq's WER with the origin masked value of target seq

            c_len = 0
            w_len = 0
            dout_c_err = 0
            dout_w_errs = 0
            dout_wv_errs = 0
            for dout, t in zip(
                decoder_out_t,
                target,
            ):
                p = (t != self.tgt_dict.pad()) & (
                    t != self.tgt_dict.eos()
                )
                targ = t[p]
                targ_units = self.tgt_dict.string(targ)
                targ_units_arr = targ.tolist()

                dout = dout[p].unsqueeze(0) #可以成功dim0对应为False的维度将mask掉,dim1保留
                if prob:
                    # dout_toks = dout.argmax(dim=-1).unique_consecutive() # B,tgt_len,C 
                    dout_toks = dout.argmax(dim=-1)
                else:
                    dout_toks = dout # B,tgt_len unique_consecutive预计是CTC用比较好
                dout_pred_units_arr = dout_toks[dout_toks != self.tgt_dict.bos()].tolist()

                c_len += len(targ_units_arr)
                c_dist = editdistance.eval(dout_pred_units_arr, targ_units_arr)
                dout_c_err += c_dist
                print("single CER: ", c_dist/len(targ_units_arr))
                targ_words = post_process(targ_units, 'letter').split()

                dout_pred_units = self.tgt_dict.string(dout_pred_units_arr)
                dout_pred_words_raw = post_process(dout_pred_units, 'letter').split()

                dout_dist = editdistance.eval(dout_pred_words_raw, targ_words)
                dout_w_errs += dout_dist
                dout_wv_errs += dout_dist
                w_len += len(targ_words)
                print("single WER: ", dout_dist/len(targ_words))
                print("predict WORDS: ", dout_pred_words_raw)
                print("target WORDS: ", targ_words)
            uer = safe_round(
                    dout_c_err * 100.0 / c_len, 3
                )
            wer = safe_round(
                    dout_w_errs * 100.0 / w_len, 3
                )
            raw_wer = safe_round(
                    dout_wv_errs * 100.0 / w_len, 3
                )
            print("Avg UER: {}; Avg WER: {}; Avg: RAW_WER: {}.".format(uer, wer, raw_wer))
            # 通过debug发现预测出来的结果于gt相差甚远
            # import ipdb; ipdb.set_trace()
        return
    
    def getFirstInputByCTC(self, ctc_outputs, padding_mask, beam_size, copy_batches):
        predicts = []
        ctc_outputs = F.softmax(ctc_outputs, dim=-1)
        ctc_outputs = ctc_outputs.transpose(0, 1).float().contiguous()
        non_padding_mask = ~padding_mask
        input_lengths = non_padding_mask.long().sum(-1)
        for lp, inp_l in zip(
            ctc_outputs,
            input_lengths,
        ):
            lp = lp[:inp_l].unsqueeze(0)
            toks = lp.argmax(dim=-1).unique_consecutive()
            pred_units_arr = toks[toks != self.tgt_dict.bos()].tolist()
            predicts.append(torch.LongTensor(list(pred_units_arr)))
        predicts = data_utils.collate_tokens(predicts, pad_idx=self.pad, left_pad=False) # B,max_pred_len
        return copy_batches(predicts, beam_size)

    @torch.no_grad()
    def generate(
        self,
        models,
        sample,
        prefix_tokens=None,
        bos_token=None,
        tgt_bert_encoder=None, 
        tgt_bert_tokenizer=None,
        **kwargs
    ):
        """Generate a batch of translations.

        Args:
            models (List[~fairseq.models.FairseqModel]): ensemble of models
            sample (dict): batch
            prefix_tokens (torch.LongTensor, optional): force decoder to begin
                with these tokens
        """
        model = EnsembleModel(models)
        if not self.retain_dropout:
            model.eval()

        ###### debug ######## 如果这个debug结果还是很差 那代表是模型的问题，或valid有什么问题导致太好，而非decoding代码导致差的问题
        # import ipdb; ipdb.set_trace()
        # net_output = model.models[0](**sample["net_input"])
        # decoder_out = model.models[0].get_normalized_decoder_probs(net_output, log_probs=True)
        # self.wer_computed_by_validate(decoder_out, sample['origin_target'], prob=True)
        #
        # valid的效果好 会不会是因为bert的输入有很多是未mask的，只按照了概率mask，因此这样直接预测结果就较好
        # 但是model直接输入全部mask的，第一次结果就没有很好，然后后面就依旧不会越变越好，debug情况wer是一致90+，虽然uer有略微下降
        # sample["net_input"]['prev_output_tokens'].ne(103).long().sum(-1) 看未被mask的token数目
        # max_len 444,bs=7, 未被mask length: [258, 276, 407, 310, 299, 226, 395] 结果较好 30%wer
        #[123,  99,  82, 176, 133,  82, 130] 结果较差90+wer
        # [21,  5,  8, 81, 66,  0, 36] 结果很差 100.0wer
        ################### debug结果:这个结果挺好，说明是要么那种输入和算法会导致效果差/要么是代码哪里出问题了？


        # model.forward normally channels prev_output_tokens into the decoder
        # separately, but SequenceGenerator directly calls model.encoder
        encoder_input = {
            k: v for k, v in sample['net_input'].items()
            if k != 'prev_output_tokens' and k != 'bert_input'
        }#不需要prev_output_tokens,dict_keys(['src_tokens', 'src_lengths'])

        src_tokens = encoder_input['source']
        src_lengths = (src_tokens.ne(self.eos) & src_tokens.ne(self.pad)).long().sum(dim=1)
        input_size = src_tokens.size()#[32, 4]
        # batch dimension goes first followed by source lengths
        bsz = input_size[0]
        src_len = input_size[1]
        beam_size = self.beam_size#4

        finalized = [[] for i in range(bsz)]
        
        # non-autoregressive decoding

        def get_hypo_nat(decoded_id):
            return {
                'tokens': decoded_id,
                'score': 0.0,
                'attention': None,  # src_len x tgt_len
                'alignment': None,
                'positional_scores': torch.Tensor([0.0]),
            }

        def copy_batches(tensor, num_copies):
            if tensor is None:
                return None
            x_size = tensor.size()
            tensor = tensor.contiguous().view(x_size[0], 1, -1)
            tensor = tensor.repeat(1, num_copies, 1)
            if len(x_size)==2:
                return tensor.view(-1, x_size[1])
            elif len(x_size)==3:
                return tensor.view(-1, x_size[1], x_size[2])
            else:
                raise NotImplementedError

        def select_worst(token_probs, num_mask):
            bsz, seq_len = token_probs.size()
            masks = [token_probs[batch, :].topk(max(1, num_mask[batch]), largest=False, sorted=False)[1] for batch in range(bsz)]
            masks = [torch.cat([mask, mask.new(seq_len - mask.size(0)).fill_(mask[0])], dim=0) for mask in masks]
            return torch.stack(masks, dim=0)

        w2v_args = {
            "source": src_tokens,#B*src_len[1, 259840]
            "padding_mask": encoder_input['padding_mask'],#B*src_len[1, 259840]
            "mask": False,
        }
        with torch.no_grad():
            x, padding_mask = model.models[0].w2v_encoder.w2v_model.extract_features(**w2v_args)#B,T,C [1, 811, 768]
            ctc_output = x.transpose(0, 1)
            ctc_output = model.models[0].w2v_encoder.final_dropout(ctc_output)
            ctc_output = model.models[0].w2v_encoder.proj(ctc_output)

        predicted_lengths = sample["target_lengths"]
        bert_encoder_out = x.permute(1,0,2).contiguous()
        encoder_outs = [{
            'encoder_out': bert_encoder_out,
            'encoder_padding_mask': padding_mask,
        }]

        # predicted_lengths  = encoder_outs[0]['predicted_lengths']#[32, 1024],B,tgt_len_size(tgt最长1024长度)，因此dim-1哪个维度prob高认为predict length是哪个
        # beam = predicted_lengths.topk(beam_size, dim=1)[1]#预测的前topk(beam_size=4)个长度,返回predict_length对每个Batch预测的前topk的长度(prob高的前topk的index) [B,beam_size]=[32, 4]
        beam = predicted_lengths.unsqueeze(1).expand(bsz,beam_size)#为了得到[B,beam_size]
        # debug
        beam[beam<2] = 2#预测长度小于2的置为2

        max_len = beam.max().item()
        # 初始化最长到max_len的每个长度的对应mask
        # length_mask=size(max_len,max_len)
        # [[0, 1, 1, 1, 1, 1], 长度为1的mask
        # [0, 0, 1, 1, 1, 1],
        # [0, 0, 0, 1, 1, 1],
        # [0, 0, 0, 0, 1, 1],
        # [0, 0, 0, 0, 0, 1],
        # [0, 0, 0, 0, 0, 0]] 长度为max_len的mask
        length_mask = torch.triu(src_tokens.new(max_len, max_len).fill_(1).long(), 1)#size为max_len,max_len的[6, 6]）的一个上三角矩阵
        #[32, 4, 6] B,beam_size,max_len
        # 把每个batch的predicted_lengths的前beam_size个（即beam）的每个length mask存在length_mask里
        length_mask = torch.stack([length_mask[beam[batch] - 1] for batch in range(bsz)], dim=0)

        ###### debug ######## 3 when 3 on 用CTC output的作为mask predict的初次输入，且mask掉低概率的token
        print("ctc input..")
        tgt_tokens = self.getFirstInputByCTC(ctc_output, padding_mask, beam_size, copy_batches)
        tgt_tokens = utils.move_to_cuda(
            tgt_tokens, device=bert_encoder_out.device
        )
        # print("tgt_tokens.shape:", tgt_tokens.shape)
        max_len = tgt_tokens.size(1)
        # print("max_len:", max_len)
        ###### debug ######## 2 when 2 on 下面是gt作为输入
        # print('gt...')
        # tgt_tokens = copy_batches(sample['target'], beam_size)
        #################### 这些需要注释当debug需要tgt_tokens都为gt时, 下面是第一次输入为全[MASK]
        # print("all mask...")
        # tgt_tokens = sample['target'].new(bsz, beam_size, max_len).fill_(self.mask)
        # tgt_tokens = (1 - length_mask) * tgt_tokens + length_mask * self.pad #lengthmask的0的部分为bert mask 103,1的部分为self.pad=0,即length_mask对应0就是beam对应predict length的在tgt_tokens里对应103，放到bert里是需要输出结果的
        # tgt_tokens = tgt_tokens.view(bsz * beam_size, max_len)#[32, 4, 6] B,beam_size,max_len-> [128, 6]B*beam_size,max_len
        #################### 上面是第一次输入为全[MASK]
        pad_mask = tgt_tokens.eq(self.pad)#tgt_tokens对应为pad即非mask而是pad的地方为True为1
        seq_lens = tgt_tokens.size(1) - pad_mask.sum(dim=1)#tgt_tokens每个elem对应的103的序列长度 shape [128]
        # encoder_outs[0]['encoder_out']原为[src_len,B,dim]->[B,src_len,dim]
        encoder_outs_value = copy_batches(encoder_outs[0]['encoder_out'].transpose(0,1), beam_size)#[32, 4,768]->copy beam_size份：[128, 4, 768]
        encoder_outs_value = encoder_outs_value.transpose(0,1)#[4, 128, 768] [src_len,B*beam_size,dim]
        encoder_padding = copy_batches(encoder_outs[0]['encoder_padding_mask'], beam_size)#[32, 4]->[128, 4]
        #predict_lengths 还是B,max_len[32, 1024]
        encoder_outs = [{'encoder_out': encoder_outs_value, 'encoder_padding_mask': encoder_padding}]
        
        # import ipdb; ipdb.set_trace()
        # tgt_tokens B*beam_size,tgt_len [28, 444]; 
        # encoder_outs[0]['encoder_out'].shape [1621, 28, 768] T,B*beam_size,C
        # encoder_outs[0]['encoder_padding_mask'].shape [28, 1621] B*beam_size,T
        tgt_tokens, token_probs, decoder_softmax_probs = model.forward_decoder(
            tgt_tokens, encoder_outs, #tgt_tokens充当训练时的prev_output_token其实就是bert decoder model的输入,这个tgt_tokens目前是全部被masked了
            src_tokens=None, 
            temperature=self.temperature,
        )#[128, 6] idx, max_probs
        ###### debug ########
        # print("the 0th of total {}".format(self.mask_pred_iter+1))
        # self.wer_computed_by_validate(tgt_tokens, sample['target'])
        ######################
        assign_single_value_byte(tgt_tokens, pad_mask, self.pad) #将该pad的地方值置为self.pad
        assign_single_value_byte(token_probs, pad_mask, 1.0)
        for i in range(1, self.mask_pred_iter+1):#self.mask_pred_iter=10 wer 98-100% iter=100 wer 88%最后还是100%
            num_mask = (seq_lens.float()*(1.0-i/self.mask_pred_iter)).long()#将mask的长度变短，按照(1.0-i/self.mask_pred_iter)概率逐步减少mask,最后num_mask会全部变为0，即不mask
            assign_single_value_byte(token_probs, pad_mask, 1.0)#把pad的decoder_out_probs置为1.0
            mask_ind = select_worst(token_probs, num_mask)#选择token_probs最差的num_mask长度继续mask这些mask_index
            assign_single_value_long(tgt_tokens, mask_ind, self.mask)#把mask_ind相应的地方置为mask
            assign_single_value_byte(tgt_tokens, pad_mask, self.pad)#把pad的地方置为pad

            new_tgt_tokens, new_token_probs, all_token_probs = model.forward_decoder(
                tgt_tokens, encoder_outs, 
                src_tokens=None, 
                temperature=self.temperature,#self.temperature=1.0
            )
            assign_multi_value_long(token_probs, mask_ind, new_token_probs) #只将mask_ind的地方被new_token_probs更新
            assign_single_value_byte(token_probs, pad_mask, 1.0)#pad的地方置为pad
            
            assign_multi_value_long(tgt_tokens, mask_ind, new_tgt_tokens) #只将mask_ind的地方被new_tgt_tokens更新
            assign_single_value_byte(tgt_tokens, pad_mask, self.pad)#pad的地方置为pad
            ###### debug ########
            # print("the {}th of total {}".format(i, self.mask_pred_iter+1))
            # self.wer_computed_by_validate(tgt_tokens, sample['target'])
            ######################
        
        #最后一次的token_probs和tgt_tokens为最后的结果
        lprobs = token_probs.log().sum(-1)#将token_probs的每个batch的每个beam算出其预测的句子的prob之和
        hypotheses = tgt_tokens.view(bsz, beam_size, max_len)
        lprobs = lprobs.view(bsz, beam_size)
        ################## 其他作为input时有len penalty ##################
        # add len penalty
        tgt_lengths = (1 - length_mask).sum(-1)#predict的target长度
        length_penalty = ((5.0 + tgt_lengths.float()) ** self.len_penalty
                          / (6.0 ** self.len_penalty))
        length_penalty = length_penalty.view((bsz, beam_size))
        avg_log_prob = lprobs / length_penalty
        best_lengths = avg_log_prob.max(-1)[1]#avg_log_prob的prob最大的beam对应的index
        ################## CTC output作为input时,best_lengths为下面 ##################
        # best_lengths = lprobs.max(-1)[1]
        ##################

        hypotheses = torch.stack([hypotheses[b, l, :] for b, l in enumerate(best_lengths)], dim=0)#最优predict_length的beam的hypo B,tgt_max_len

        for i in range(bsz):
            finalized[i].append(get_hypo_nat(hypotheses[i]))
        # import ipdb; ipdb.set_trace()
        return finalized


class EnsembleModel(torch.nn.Module):
    """A wrapper around an ensemble of models."""

    def __init__(self, models):
        super().__init__()
        self.models = torch.nn.ModuleList(models)
        self.incremental_states = None

    def has_encoder(self):
        return True
        # return hasattr(self.models[0], 'encoder')

    def max_decoder_positions(self):
        return min(m.max_decoder_positions() for m in self.models)

    @torch.no_grad()
    def forward_encoder(self, encoder_input):
        if not self.has_encoder():
            return None
        return [model.encoder(**encoder_input) for model in self.models]

    @torch.no_grad()
    def forward_decoder(self, tokens, encoder_outs, src_tokens=None, temperature=1.):
        if len(self.models) == 1:
            return self._decode_one(
                tokens,
                self.models[0],
                encoder_outs[0] if self.has_encoder() else None,
                log_probs=True,
                src_tokens=src_tokens,
                temperature=temperature,
            )

        log_probs = []
        avg_attn = None
        for model, encoder_out in zip(self.models, encoder_outs):
            probs, attn = self._decode_one(
                tokens,
                model,
                encoder_out,
                log_probs=True,
                src_tokens=src_tokens,
                temperature=temperature,
            )
            log_probs.append(probs)
            if attn is not None:
                if avg_attn is None:
                    avg_attn = attn
                else:
                    avg_attn.add_(attn)
        avg_probs = torch.logsumexp(torch.stack(log_probs, dim=0), dim=0) - math.log(len(self.models))
        if avg_attn is not None:
            avg_attn.div_(len(self.models))
        return avg_probs, avg_attn

    def _decode_one(
        self, tokens, model, encoder_out, log_probs, src_tokens=None, temperature=1.,
    ):
        # import ipdb; ipdb.set_trace()
        # decoder_out[0] [28, 444, 30522] B*beam_size,tgt_len,vocab_size
        decoder_out = list(model.w2v_encoder.bertdecoder(tokens, src_tokens=src_tokens, encoder_out=encoder_out))#[tensor[128, 6, 30522]B,tgt_len,vocab_size]
        if temperature != 1.:
            decoder_out[0].div_(temperature)

        probs = F.softmax(decoder_out[0], dim=-1)
        max_probs, idx = probs.max(dim=-1)#[128, 6]B,tgt_len [28, 444]每个batch的每个len elem 对应的index
        return idx, max_probs, probs

    def reorder_encoder_out(self, encoder_outs, new_order):
        if not self.has_encoder():
            return
        return [
            model.encoder.reorder_encoder_out(encoder_out, new_order)
            for model, encoder_out in zip(self.models, encoder_outs)
        ]

    def reorder_incremental_state(self, new_order):
        if self.incremental_states is None:
            return
        for model in self.models:
            model.w2v_encoder.bertdecoder.reorder_incremental_state(self.incremental_states[model], new_order)
