# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.

import math
from argparse import Namespace

import torch
import torch.nn.functional as F
from fairseq import metrics, utils
from fairseq.criterions import LegacyFairseqCriterion, register_criterion
from fairseq.data.data_utils import post_process
from fairseq.logging.meters import safe_round


@register_criterion("ctc_mlm_ce")
class CtcMlmCriterion(LegacyFairseqCriterion):
    def __init__(self, args, task):
        super().__init__(args, task)
        self.blank_idx = task.target_dictionary.bos()
        self.pad_idx = task.target_dictionary.pad()
        self.eos_idx = task.target_dictionary.eos()
        self.post_process = args.post_process if args.post_process else "letter"

        if args.wer_args is not None:
            from examples.speech_recognition.w2l_decoder import W2lKenLMDecoder

            wer_compute_kenlm, wer_lexicon, lm_w, ws_w = eval(args.wer_args)

            dec_args = Namespace()
            dec_args.nbest = 1
            dec_args.criterion = "ctc"
            dec_args.kenlm_model = wer_compute_kenlm
            dec_args.lexicon = wer_lexicon
            dec_args.beam = 50
            dec_args.beam_size_token = min(50, len(task.target_dictionary))
            dec_args.beam_threshold = min(50, len(task.target_dictionary))
            dec_args.lm_weight = lm_w
            dec_args.word_score = ws_w
            dec_args.unk_weight = -math.inf
            dec_args.sil_weight = 0

            self.w2l_decoder = W2lKenLMDecoder(dec_args, task.target_dictionary)
        else:
            self.w2l_decoder = None

        self.zero_infinity = args.zero_infinity
        self.sentence_avg = args.sentence_avg
        self.eps = getattr(args, 'label_smoothing', 0.1)
        self.ce_weight = args.ctc_weight
        self.encoder_ctc_weight = args.encoder_ctc_weight
        self.mlm_weight = args.mlm_weight
        print("ce_weight:", self.ce_weight)
        print("encoder_ctc_weight:", self.encoder_ctc_weight)
        print("mlm_weight:", self.mlm_weight)
        self.quantity_weight = args.quantity_weight
        print("quantity_weight:", self.quantity_weight)
        self.quantity_loss = getattr(args, 'quantity_loss', None)
        print("quantity_loss:", self.quantity_loss)
        self.chinese_cer = getattr(args, 'chinese_cer', False)

    @staticmethod
    def add_args(parser):
        """Add criterion-specific arguments to the parser."""
        parser.add_argument(
            "--zero-infinity", action="store_true", help="zero inf loss"
        )
        try:
            parser.add_argument(
                "--post-process",
                "--remove-bpe",
                default="letter",
                help="remove BPE tokens before scoring (can be set to sentencepiece, letter, and more)",
            )
        except:
            pass  # this option might have been added from eval args
        parser.add_argument(
            "--wer-args",
            type=str,
            default=None,
            help="options for wer computation on valid set using 4 gram lm. this should be a tuple of 4 elements: path to 4-gram lm, \
            path to lexicon, lm score, word score",
        )
        parser.add_argument('--label-smoothing', default=0., type=float, metavar='D',
                            help='epsilon for label smoothing, 0 means no label smoothing')
        parser.add_argument('--ctc-weight', default=0.5, type=float,
                            help='epsilon for label smoothing, 0 means no label smoothing')
        parser.add_argument('--encoder-ctc-weight', default=0.5, type=float,
                            help='epsilon for label smoothing, 0 means no label smoothing')
        parser.add_argument('--mlm-weight', default=0.5, type=float,
                            help='epsilon for label smoothing, 0 means no label smoothing')
        parser.add_argument('--quantity-weight', default=0.5, type=float,
                            help='epsilon for label smoothing, 0 means no label smoothing')
        parser.add_argument('--quantity-loss', default=None, type=str,
                            help='type of quantity loss (l1 or l2)')
        parser.add_argument("--chinese-cer", action="store_true", help="compute cer for chinese when wer")

    def forward(self, model, sample, reduce=True):
        # import ipdb; ipdb.set_trace()
        # print("sample[net_input]:", sample["net_input"])
        if self.quantity_loss is not None:
            sample["net_input"]["target_lengths"] = sample["target_lengths"]
        net_output = model(**sample["net_input"])
        #### BERT Conditional MLM ####
        decoder_out, mlm_loss, nll_loss = self.compute_mlm_loss(model, net_output, sample, reduce=reduce)
        #### wav2vec fusion module CE loss ####
        lprobs = model.get_normalized_probs( #[811, 1, 30522]
            net_output, log_probs=True
        ).contiguous()  # (T, B, C) from the encoder
        ce_loss = self.compute_ce_loss(lprobs, sample["origin_target"], reduce=reduce)
        ####### wav2vec encoder module CTC ##########
        if "src_lengths" in sample["net_input"]:
            input_lengths = sample["net_input"]["src_lengths"]
        else:
            non_padding_mask = ~net_output["padding_mask"]#[1, 811]B, SRC_DIM
            input_lengths = non_padding_mask.long().sum(-1)#811
        pad_mask = (sample["origin_target"] != self.pad_idx) & (#[1, 278] B,TGT_LEN
            sample["origin_target"] != self.eos_idx
        )
        targets_flat = sample["origin_target"].masked_select(pad_mask)#[278]
        target_lengths = sample["target_lengths"]
        ctc_lprobs = model.get_normalized_ctc_probs( #[811, 1, 30522]
            net_output, log_probs=True
        ).contiguous()  # (T, B, C) from the encoder
        with torch.backends.cudnn.flags(enabled=False):
            ctc_loss_encoder_ctc = F.ctc_loss(
                ctc_lprobs,
                targets_flat,
                input_lengths,
                target_lengths,
                blank=self.blank_idx,
                reduction="sum",
                zero_infinity=self.zero_infinity,
            )
        
        ntokens = (
            sample["ntokens"] if "ntokens" in sample else target_lengths.sum().item()
        )
        ####### quantity loss ##########
        if self.quantity_loss is not None:
            scale_weights_sum = torch.sum(net_output['cif_weights'], -1)
            if self.quantity_loss == 'l2':
                quantity_loss = F.mse_loss(scale_weights_sum, target_lengths)
            elif self.quantity_loss == 'l1':
                quantity_loss = F.l1_loss(scale_weights_sum, target_lengths)
            total_loss = self.ce_weight * ce_loss + self.mlm_weight * mlm_loss + self.encoder_ctc_weight * ctc_loss_encoder_ctc + self.quantity_weight * quantity_loss
        else:
            total_loss = self.ce_weight * ce_loss + self.mlm_weight * mlm_loss + self.encoder_ctc_weight * ctc_loss_encoder_ctc
        #############################################

        sample_size = sample["origin_target"].size(0) if self.sentence_avg else ntokens
        logging_output = {
            'total_loss': utils.item(total_loss.data),
            "ce_loss": utils.item(ce_loss.data),  # * sample['ntokens'],
            "ctc_loss_encoder_ctc": utils.item(ctc_loss_encoder_ctc.data),
            'mlm_loss': utils.item(mlm_loss.data),
            'nll_loss': utils.item(nll_loss.data),
            'quantity_loss': utils.item(quantity_loss.data),
            "ntokens": ntokens,
            "nsentences": sample["id"].numel(),
            "sample_size": sample_size,
        }

        # if not model.training:
        if True:
            import editdistance
            with torch.no_grad():
                # lprobs_t和decoder_out_t都为target length了，ctc_lprobs_t为src length
                lprobs_t = lprobs.transpose(0, 1).float().contiguous().cpu() #(T, B, C)->(B, T, C)
                ctc_lprobs_t = ctc_lprobs.transpose(0, 1).float().contiguous().cpu() 
                decoder_out_t = decoder_out.float().contiguous().cpu() # B,tgt_len,C
                encoder_non_padding_mask = ~net_output["encoder_padding_mask"]#[1, 811]B, SRC_DIM
                encoder_input_lengths = encoder_non_padding_mask.long().sum(-1)
                # whole seq's WER with the origin_target seq
                # predicted masked seq's WER with the origin masked value of target seq
                c_err = 0
                ctc_c_err = 0
                c_len = 0
                c_len_bert = 0
                w_errs = 0
                w_len = 0
                w_len_bert = 0
                wv_errs = 0
                ctc_w_errs = 0
                dout_c_err = 0
                dout_w_errs = 0
                dout_wv_errs = 0
                different_tokens = False
                if 'origin_bert_target' in sample:
                    different_tokens = True
                for lp, ctc_lp, dout, t, t_bert, inp_l, ctc_inp_l in zip(
                    lprobs_t,
                    ctc_lprobs_t,
                    decoder_out_t,
                    sample["target_label"]
                    if "target_label" in sample
                    else sample["origin_target"],
                    sample['origin_bert_target']
                    if 'origin_bert_target' in sample
                    else sample["origin_target"],
                    encoder_input_lengths,
                    input_lengths,
                ):
                    lp = lp[:inp_l].unsqueeze(0)#[1, 715, 30522]
                    ctc_lp = ctc_lp[:ctc_inp_l].unsqueeze(0)
                    decoded = None
                    if self.w2l_decoder is not None:
                        decoded = self.w2l_decoder.decode(lp)
                        if len(decoded) < 1:
                            decoded = None
                        else:
                            decoded = decoded[0]
                            if len(decoded) < 1:
                                decoded = None
                            else:
                                decoded = decoded[0]

                    p = (t != self.task.target_dictionary.pad()) & (
                        t != self.task.target_dictionary.eos()
                    )
                    targ = t[p]
                    targ_units = self.task.target_dictionary.string(targ)
                    targ_units_arr = targ.tolist()

                    p = (t_bert != self.task.target_dictionary.pad()) & (
                        t_bert != self.task.target_dictionary.eos()
                    )
                    targ_bert = t_bert[p]
                    targ_units_bert = self.task.target_dictionary.string(targ_bert)
                    targ_units_arr_bert = targ_bert.tolist()

                    dout = dout[p].unsqueeze(0) #可以成功dim0对应为False的维度将mask掉,dim1保留
                    toks = lp.argmax(dim=-1)
                    ctc_toks = ctc_lp.argmax(dim=-1).unique_consecutive()
                    pred_units_arr = toks[toks != self.blank_idx].tolist()
                    ctc_pred_units_arr = ctc_toks[ctc_toks != self.blank_idx].tolist()
                    dout_toks = dout.argmax(dim=-1)
                    dout_pred_units_arr = dout_toks[dout_toks != self.blank_idx].tolist()
                    #若是wordpiece，则这里的cer算的并不准确，只能是uer,就是word piece的error rate，因为是针对于predict id和target id来算的
                    #对于中文的character形式,UER即为CER
                    c_err += editdistance.eval(pred_units_arr, targ_units_arr)
                    ctc_c_err += editdistance.eval(ctc_pred_units_arr, targ_units_arr)
                    c_len += len(targ_units_arr)
                    dout_c_err += editdistance.eval(dout_pred_units_arr, targ_units_arr_bert)
                    c_len_bert += len(targ_units_arr_bert)

                    if different_tokens:
                        targ_words = post_process(targ_units, "letter").split()
                        targ_words_bert = post_process(targ_units_bert, 'bert_bpe_piece').split()

                        pred_units = self.task.target_dictionary.string(pred_units_arr)
                        pred_words_raw = post_process(pred_units, "letter").split()
                        dout_pred_units = self.task.target_dictionary.string(dout_pred_units_arr)
                        dout_pred_words_raw = post_process(dout_pred_units, 'bert_bpe_piece').split()
                    else:
                        pred_units = self.task.target_dictionary.string(pred_units_arr)
                        ctc_pred_units = self.task.target_dictionary.string(ctc_pred_units_arr)
                        dout_pred_units = self.task.target_dictionary.string(dout_pred_units_arr)
                        if self.chinese_cer:
                            targ_words = post_process(targ_units, self.post_process)
                            targ_words_bert = post_process(targ_units_bert, self.post_process)
                            pred_words_raw = post_process(pred_units, self.post_process)
                            ctc_pred_words_raw = post_process(ctc_pred_units, self.post_process)
                            dout_pred_words_raw = post_process(dout_pred_units, self.post_process) #不加split之前就可以算CER
                        else:
                            targ_words = post_process(targ_units, self.post_process).split()
                            targ_words_bert = post_process(targ_units_bert, self.post_process).split()
                            pred_words_raw = post_process(pred_units, self.post_process).split()
                            ctc_pred_words_raw = post_process(ctc_pred_units, self.post_process).split()
                            dout_pred_words_raw = post_process(dout_pred_units, self.post_process).split() #加split就可以算WER

                    if decoded is not None and "words" in decoded:
                        pred_words = decoded["words"]
                        w_errs += editdistance.eval(pred_words, targ_words)
                        wv_errs += editdistance.eval(pred_words_raw, targ_words)
                    else:
                        # 对于中文character形式而言，WER为Sentence Error Rate，即整个句子和gt整个句子相比错了多少
                        dist = editdistance.eval(pred_words_raw, targ_words)
                        w_errs += dist
                        wv_errs += dist
                        ctc_dist = editdistance.eval(ctc_pred_words_raw, targ_words)
                        ctc_w_errs += ctc_dist
                        dout_dist = editdistance.eval(dout_pred_words_raw, targ_words_bert)
                        dout_w_errs += dout_dist
                        dout_wv_errs += dout_dist


                    w_len += len(targ_words)
                    w_len_bert += len(targ_words_bert)

                logging_output["wv_errors"] = wv_errs
                logging_output["w_errors"] = w_errs
                logging_output["ctc_w_errors"] = ctc_w_errs
                logging_output["dout_wv_errors"] = dout_wv_errs
                logging_output["dout_w_errors"] = dout_w_errs
                logging_output["w_total"] = w_len
                logging_output["w_total_bert"] = w_len_bert
                logging_output["c_errors"] = c_err
                logging_output["ctc_c_errors"] = ctc_c_err
                logging_output["c_total"] = c_len
                logging_output["c_total_bert"] = c_len_bert
                logging_output['dout_c_errors'] = dout_c_err


        return total_loss, sample_size, logging_output

    def compute_ce_loss(self, lprobs, targets, reduce=True):
        lprobs = lprobs.view(-1, lprobs.size(-1))
        target = targets.view(-1)
        loss = F.nll_loss(
            lprobs,
            target,
            ignore_index=self.padding_idx,
            reduction="sum" if reduce else "none",
        )
        return loss

    def compute_mlm_loss(self, model, net_output, sample, reduce=True):
        decoder_out = model.get_normalized_decoder_probs(net_output, log_probs=True) #log_softmax(net_output[0]=decoder_output, dim=-1) [B,tgt_len,output_dim]=[64, 21, 30522]-[1, 278, 30522]
        lprobs = decoder_out.view(-1, decoder_out.size(-1)) #[B*tgt_len,output_dim]=[1344, 30522] [278, 30522]
        sample['padding_idx'] = self.padding_idx
        target = sample['target'].view(-1, 1) # fairseq_model.py(40)get_targets() return sample['target']得到的是y_t^m=mask掉的id的真实值和没被mask的就为0 -> [B*tgt_len,1]=[1344, 1]-[278,1]
        non_pad_mask = target.ne(self.padding_idx) # target没有置为pad0的部分，即mask的原始值部分，也是target的non_pad_mask=[B*tgt_len,1]=[1344, 1]

        # compute length prediction loss
        # length_lprobs = net_output[1]['predicted_lengths']#[B,pred_len_dim]=[64, 1024]
        # length_target = sample['net_input']['prev_output_tokens'].ne(self.padding_idx).sum(-1).unsqueeze(-1) #mask的输入，除了是pad的部分，其他都不为pad，这个是把每个sample的tgt len给求出来，并加多一维，[B,1]=[64,1]存储bs中每个sample对应tgt_len
        # length_loss = -length_lprobs.gather(dim=-1, index=length_target)#[B,1]=[64,1]

        nll_loss = -lprobs.gather(dim=-1, index=target)[non_pad_mask]#[1344, 1][non_pad_mask]=[B*tgt_len,1][non_pad_mask]=[758],把non pad的部分留下了，pad的部分去掉了
        smooth_loss = -lprobs.sum(dim=-1, keepdim=True)[non_pad_mask]#[1344, 1][non_pad_mask]=[B*tgt_len,1][non_pad_mask]=[758],把lprobs每一维的lprobs给求和
        if reduce:
            nll_loss = nll_loss.sum() #这两个是decoder_output和target之间的loss
            smooth_loss = smooth_loss.sum() ##这两个是decoder_output和target之间的loss
            # length_loss = length_loss.sum() #这个是encoder predict的length和tgt len之间的loss
        eps_i = self.eps / lprobs.size(-1) # label smoothing ratio
        # loss = (1. - self.eps) * nll_loss + eps_i * smooth_loss + 0.1 * length_loss #总loss=三个loss加起来
        loss = (1. - self.eps) * nll_loss + eps_i * smooth_loss
        # return loss, nll_loss, length_loss
        return decoder_out, loss, nll_loss


    @staticmethod
    def reduce_metrics(logging_outputs) -> None:
        # import ipdb; ipdb.set_trace()
        """Aggregate logging outputs from data parallel training."""

        loss_sum = utils.item(sum(log.get("total_loss", 0) for log in logging_outputs))
        ntokens = utils.item(sum(log.get("ntokens", 0) for log in logging_outputs))
        nsentences = utils.item(
            sum(log.get("nsentences", 0) for log in logging_outputs)
        )
        sample_size = utils.item(
            sum(log.get("sample_size", 0) for log in logging_outputs)
        )

        metrics.log_scalar(
            "total_loss", loss_sum / sample_size / math.log(2), sample_size, round=3
        )
        metrics.log_scalar("ntokens", ntokens)
        metrics.log_scalar("nsentences", nsentences)
        if sample_size != ntokens:
            metrics.log_scalar(
                "total_nll_loss", loss_sum / ntokens / math.log(2), ntokens, round=3
            )

        ce_loss = utils.item(sum(log.get("ce_loss", 0) for log in logging_outputs))
        metrics.log_scalar(
            "ce_loss", ce_loss / sample_size / math.log(2), sample_size, round=3
        )
        if sample_size != ntokens:
            metrics.log_scalar(
                "ce_nll_loss", ce_loss / ntokens / math.log(2), ntokens, round=3
            )
        ctc_loss_encoder_ctc = utils.item(sum(log.get("ctc_loss_encoder_ctc", 0) for log in logging_outputs))
        metrics.log_scalar(
            "ctc_loss_encoder_ctc", ctc_loss_encoder_ctc / sample_size / math.log(2), sample_size, round=3
        )
        quantity_loss = utils.item(sum(log.get("quantity_loss", 0) for log in logging_outputs))
        metrics.log_scalar(
            "quantity_loss", quantity_loss / sample_size / math.log(2), sample_size, round=3
        )
        # 为了便于和wav2vec的实验tensorboards放一起比较loss
        metrics.log_scalar(
            "loss", ce_loss / sample_size / math.log(2), sample_size, round=3
        )
        if sample_size != ntokens:
            metrics.log_scalar(
                "nll_loss", ce_loss / ntokens / math.log(2), ntokens, round=3
            )

        mlm_loss = utils.item(sum(log.get("mlm_loss", 0) for log in logging_outputs))
        nll_loss = utils.item(sum(log.get("nll_loss", 0) for log in logging_outputs))
        metrics.log_scalar(
            "mlm_loss", mlm_loss / sample_size / math.log(2), sample_size, round=3
        )
        if sample_size != ntokens:
            metrics.log_scalar(
                "mlm_nll_loss", nll_loss / ntokens / math.log(2), ntokens, round=3
            )

        c_errors = sum(log.get("c_errors", 0) for log in logging_outputs)
        metrics.log_scalar("_c_errors", c_errors)
        ctc_c_errors = sum(log.get("ctc_c_errors", 0) for log in logging_outputs)
        metrics.log_scalar("_ctc_c_errors", ctc_c_errors)
        c_total = sum(log.get("c_total", 0) for log in logging_outputs)
        metrics.log_scalar("_c_total", c_total)
        c_total_bert = sum(log.get("c_total_bert", 0) for log in logging_outputs)
        metrics.log_scalar("_c_total_bert", c_total_bert)
        w_errors = sum(log.get("w_errors", 0) for log in logging_outputs)
        metrics.log_scalar("_w_errors", w_errors)
        wv_errors = sum(log.get("wv_errors", 0) for log in logging_outputs)
        metrics.log_scalar("_wv_errors", wv_errors)
        ctc_w_errors = sum(log.get("ctc_w_errors", 0) for log in logging_outputs)
        metrics.log_scalar("_ctc_w_errors", ctc_w_errors)
        w_total = sum(log.get("w_total", 0) for log in logging_outputs)
        metrics.log_scalar("_w_total", w_total)
        w_total_bert = sum(log.get("w_total_bert", 0) for log in logging_outputs)
        metrics.log_scalar("_w_total_bert", w_total_bert)

        dout_c_errors = sum(log.get("dout_c_errors", 0) for log in logging_outputs)
        metrics.log_scalar("_dout_c_errors", dout_c_errors)
        dout_w_errors = sum(log.get("dout_w_errors", 0) for log in logging_outputs)
        metrics.log_scalar("_dout_w_errors", dout_w_errors)
        dout_wv_errors = sum(log.get("dout_wv_errors", 0) for log in logging_outputs)
        metrics.log_scalar("_dout_wv_errors", dout_wv_errors)

        if c_total > 0:
            metrics.log_derived(
                "uer",
                lambda meters: safe_round(
                    meters["_c_errors"].sum * 100.0 / meters["_c_total"].sum, 3
                )
                if meters["_c_total"].sum > 0
                else float("nan"),
            )
            metrics.log_derived(
                "uer_encoder_ctc",
                lambda meters: safe_round(
                    meters["_ctc_c_errors"].sum * 100.0 / meters["_c_total"].sum, 3
                )
                if meters["_c_total"].sum > 0
                else float("nan"),
            )
            metrics.log_derived(
                "dout_uer",
                lambda meters: safe_round(
                    meters["_dout_c_errors"].sum * 100.0 / meters["_c_total_bert"].sum, 3
                )
                if meters["_c_total_bert"].sum > 0
                else float("nan"),
            )
        if w_total > 0:
            metrics.log_derived(
                "wer",
                lambda meters: safe_round(
                    meters["_w_errors"].sum * 100.0 / meters["_w_total"].sum, 3
                )
                if meters["_w_total"].sum > 0
                else float("nan"),
            )
            metrics.log_derived(
                "wer_encoder_ctc",
                lambda meters: safe_round(
                    meters["_ctc_w_errors"].sum * 100.0 / meters["_w_total"].sum, 3
                )
                if meters["_w_total"].sum > 0
                else float("nan"),
            )
            metrics.log_derived(
                "raw_wer",
                lambda meters: safe_round(
                    meters["_wv_errors"].sum * 100.0 / meters["_w_total"].sum, 3
                )
                if meters["_w_total"].sum > 0
                else float("nan"),
            )
            metrics.log_derived(
                "dout_wer",
                lambda meters: safe_round(
                    meters["_dout_w_errors"].sum * 100.0 / meters["_w_total_bert"].sum, 3
                )
                if meters["_w_total_bert"].sum > 0
                else float("nan"),
            )
            metrics.log_derived(
                "dout_raw_wer",
                lambda meters: safe_round(
                    meters["_dout_wv_errors"].sum * 100.0 / meters["_w_total_bert"].sum, 3
                )
                if meters["_w_total_bert"].sum > 0
                else float("nan"),
            )

    @staticmethod
    def logging_outputs_can_be_summed() -> bool:
        """
        Whether the logging outputs returned by `forward` can be summed
        across workers prior to calling `reduce_metrics`. Setting this
        to True will improves distributed training speed.
        """
        return True
