# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.

import math
from argparse import Namespace

import torch
import torch.nn.functional as F
from fairseq import metrics, utils
from fairseq.criterions import LegacyFairseqCriterion, register_criterion
from fairseq.data.data_utils import post_process
from fairseq.logging.meters import safe_round

import logging
logger = logging.getLogger(__name__)

@register_criterion("ctc_att")
class CtcCriterionAtt(LegacyFairseqCriterion):
    def __init__(self, args, task):
        super().__init__(args, task)
        self.blank_idx = task.target_dictionary.bos()
        self.pad_idx = task.target_dictionary.pad()
        self.eos_idx = task.target_dictionary.eos()
        self.post_process = args.post_process if args.post_process else "letter"
        
        from tensorboardX import SummaryWriter
        self.tf_writer = SummaryWriter(args.tensorboard_logdir)

        if args.wer_args is not None:
            from examples.speech_recognition.w2l_decoder import W2lKenLMDecoder

            wer_compute_kenlm, wer_lexicon, lm_w, ws_w = eval(args.wer_args)

            dec_args = Namespace()
            dec_args.nbest = 1
            dec_args.criterion = "ctc"
            dec_args.kenlm_model = wer_compute_kenlm
            dec_args.lexicon = wer_lexicon
            dec_args.beam = 50
            dec_args.beam_size_token = min(50, len(task.target_dictionary))
            dec_args.beam_threshold = min(50, len(task.target_dictionary))
            dec_args.lm_weight = lm_w
            dec_args.word_score = ws_w
            dec_args.unk_weight = -math.inf
            dec_args.sil_weight = 0

            self.w2l_decoder = W2lKenLMDecoder(dec_args, task.target_dictionary)
        else:
            self.w2l_decoder = None

        self.zero_infinity = args.zero_infinity
        self.sentence_avg = args.sentence_avg
        self.chinese_cer = getattr(args, 'chinese_cer', False)

        self.ctc_weight = args.ctc_weight

    @staticmethod
    def add_args(parser):
        """Add criterion-specific arguments to the parser."""
        parser.add_argument(
            "--zero-infinity", action="store_true", help="zero inf loss"
        )
        parser.add_argument(
            "--ctc-weight", type=float, help="the weight of ctc loss "
        )
        try:
            parser.add_argument(
                "--post-process",
                "--remove-bpe",
                default="letter",
                help="remove BPE tokens before scoring (can be set to sentencepiece, letter, and more)",
            )
        except:
            pass  # this option might have been added from eval args
        parser.add_argument(
            "--wer-args",
            type=str,
            default=None,
            help="options for wer computation on valid set using 4 gram lm. this should be a tuple of 4 elements: path to 4-gram lm, \
            path to lexicon, lm score, word score",
        )
        parser.add_argument("--chinese-cer", action="store_true", help="compute cer for chinese when wer")

    def forward(self, model, sample, reduce=True):
        # print("sample[net_input]:", sample["net_input"])
        net_output = model(source=sample["net_input"]["source"], padding_mask=sample["net_input"]["padding_mask"], target=sample["target"])
        lprobs = model.get_normalized_probs(
            net_output, log_probs=True
        ).contiguous()  # (T, B, C) from the encoder

        if "src_lengths" in sample["net_input"]:
            # print("A")
            input_lengths = sample["net_input"]["src_lengths"]
        else:
            # print("B")
            non_padding_mask = ~net_output["padding_mask"]
            input_lengths = non_padding_mask.long().sum(-1)
        # print("input_lengths:", input_lengths)
        pad_mask = (sample["target"] != self.pad_idx) & (
            sample["target"] != self.eos_idx
        )
        # logger.info(pad_mask)
        targets_flat = sample["target"].masked_select(pad_mask)
        # logger.info(sample["target"])
        # logger.info(targets_flat)
        target_lengths = sample["target_lengths"]
        # import ipdb; ipdb.set_trace()
        with torch.backends.cudnn.flags(enabled=False):
            loss_ctc = F.ctc_loss( # error
                lprobs,
                targets_flat,
                input_lengths,
                target_lengths,
                blank=self.blank_idx,
                reduction="sum",
                zero_infinity=self.zero_infinity,
            )
            att_out = net_output['att_out']
            # logger.info(att_out.shape)
            # logger.info(sample["target"].shape)
            # logger.info(sample['target'].view(-1).shape)
            b,t,c = att_out.shape
            if model.training:
                att_loss = F.cross_entropy(att_out.view(b*t,c), sample["target"].view(-1), reduction='none',ignore_index=self.pad_idx)
                # logger.info("debug")
                # logger.info(att_loss.shape)
                # logger.info(att_loss)
                # logger.info(torch.sum(sample["target"]!=self.pad_idx,dim=-1))
                # logger.info(torch.sum(att_loss.view(b,t),dim=-1))
                att_loss = torch.sum(att_loss.view(b,t),dim=-1)#/torch.sum(sample["target"]!=self.pad_idx,dim=-1) # Sum each uttr and devide by length
                # logger.info(att_loss) 
                att_loss = torch.mean(att_loss) # Mean by batch
                # logger.info(att_loss)
            else: # when validation, the decdoer step is bigger than max label
                ans_len = int(torch.max(torch.sum(sample["target"]!=0,dim=-1)))
                att_loss = F.cross_entropy(att_out[:, :ans_len, :].contiguous().view(-1,c), sample["target"].view(-1), reduction='none',ignore_index=self.pad_idx)
                # logger.info(att_loss.shape)
                # logger.info(att_loss.view(b, -1).shape)
                # logger.info(torch.sum(sample["target"]!=self.pad_idx,dim=-1))
                att_loss = torch.sum(att_loss.view(b,-1),dim=-1)#/torch.sum(sample["target"]!=self.pad_idx,dim=-1) # Sum each uttr and devide by length
                att_loss = torch.mean(att_loss) # Mean by batch
        # logger.info(att_loss)
        # loss = self.ctc_weight * loss_ctc + (1-self.ctc_weight) * att_loss * 50
        loss = self.ctc_weight * loss_ctc + (1-self.ctc_weight) * att_loss

        ntokens = (
            sample["ntokens"] if "ntokens" in sample else target_lengths.sum().item()
        )

        sample_size = sample["target"].size(0) if self.sentence_avg else ntokens
        logging_output = {
            "loss": utils.item(loss.data),  # * sample['ntokens'],
            "ctc_loss": utils.item(loss_ctc.data),
            "att_loss": utils.item(att_loss.data),
            "ntokens": ntokens,
            "nsentences": sample["id"].numel(),
            "sample_size": sample_size,
        }
        count = 0 
        import random
        import numpy as np
        if model.w2v_encoder.num_updates % 500 == 0:
            att_maps = net_output["att_map_out"]
            attmaps = []
            for att,hyp in zip(att_maps,np.argmax(att_out.cpu().detach(),axis=-1)):
                att_len = len(hyp)
                att = att.detach().cpu()
                attmaps.append(torch.stack([att,att,att],dim=0)[:,:att_len,:]) # +1 for att. @ <eos>


            self.tf_writer.add_image("att_weight",attmaps[0], model.w2v_encoder.num_updates)

        if not model.training: #or random.random() < 0.1:
            import editdistance

            with torch.no_grad():
                lprobs_t = lprobs.transpose(0, 1).float().contiguous().cpu()
                att_out = att_out.float().contiguous().cpu()

                att_wer = 0
                c_err = 0
                c_len = 0
                w_errs = 0
                w_errs_att = 0
                w_len = 0
                wv_errs = 0
                for lp, ap, t, inp_l in zip(
                    lprobs_t,
                    att_out,
                    sample["target_label"]
                    if "target_label" in sample
                    else sample["target"],
                    input_lengths,
                ):
                    lp = lp[:inp_l].unsqueeze(0)

                    att_pred_units_arr = ap.argmax(dim=-1)
                    # find the eos or not
                    try:
                        eos_index = att_pred_units_arr.index(self.eos_idx)
                        # logger.info("find eos index in prediction")
                        att_pred_units_arr = att_pred_units_arr[:eos_index]
                    except:
                        # logger.info("not found eos index")
                        att_pred_units_arr = att_pred_units_arr
                    
                    decoded = None
                    if self.w2l_decoder is not None:
                        decoded = self.w2l_decoder.decode(lp)
                        if len(decoded) < 1:
                            decoded = None
                        else:
                            decoded = decoded[0]
                            if len(decoded) < 1:
                                decoded = None
                            else:
                                decoded = decoded[0]

                    p = (t != self.task.target_dictionary.pad()) & (
                        t != self.task.target_dictionary.eos()
                    )#去除target中的特殊字符
                    targ = t[p]
                    targ_units = self.task.target_dictionary.string(targ)#用vocab把所有idx翻译过来tokens，但对于token还未进行处理成对应seq
                    targ_units_arr = targ.tolist()#每个unit的id

                    toks = lp.argmax(dim=-1).unique_consecutive()#ctc decoding成pred idx
                    pred_units_arr = toks[toks != self.blank_idx].tolist()#去掉self.blank_idx，留下pred tokens idx

                    c_err += editdistance.eval(pred_units_arr, targ_units_arr)#targ的unit id和pred unit id计算error rate，即unit er
                    c_len += len(targ_units_arr)

                    pred_units = self.task.target_dictionary.string(pred_units_arr)
                    att_pred_units = self.task.target_dictionary.string(att_pred_units_arr)
                    if self.chinese_cer:
                        targ_words = post_process(targ_units, self.post_process)#通过post_process和对应的类型self.post_process，处理tokens list成tokens seq
                        pred_words_raw = post_process(pred_units, self.post_process)
                    else:
                        targ_words = post_process(targ_units, self.post_process).split()#通过post_process和对应的类型self.post_process，处理tokens list成tokens seq
                        pred_words_raw = post_process(pred_units, self.post_process).split()
                        att_pred_words_raw = post_process(att_pred_units, self.post_process).split()

                    if decoded is not None and "words" in decoded:
                        pred_words = decoded["words"]
                        w_errs += editdistance.eval(pred_words, targ_words)
                        wv_errs += editdistance.eval(pred_words_raw, targ_words)
                    else:
                        # logger.info("wer")
                        # logger.info(targ_units)
                        # logger.info(pred_units)
                        # count += 1
                        # if count < 3:
                        #     logger.info("wer2{}\n".format(model.training))
                        #     logger.info(att_pred_words_raw)
                        #     logger.info("\n")
                        #     logger.info(pred_words_raw)
                        #     logger.info("\n")
                        #     logger.info(targ_words)
                        #     logger.info("\n")
                        dist = editdistance.eval(pred_words_raw, targ_words)
                        dist_att = editdistance.eval(att_pred_words_raw, targ_words)
                        w_errs += dist
                        wv_errs += dist
                        w_errs_att += dist_att

                    w_len += len(targ_words)

                logging_output["w_att_errors"] = w_errs_att
                logging_output["wv_errors"] = wv_errs
                logging_output["w_errors"] = w_errs
                logging_output["w_total"] = w_len
                logging_output["c_errors"] = c_err
                logging_output["c_total"] = c_len

        return loss, sample_size, logging_output

    @staticmethod
    def reduce_metrics(logging_outputs) -> None:
        """Aggregate logging outputs from data parallel training."""

        loss_sum = utils.item(sum(log.get("loss", 0) for log in logging_outputs))
        loss_ctc_sum = utils.item(sum(log.get("ctc_loss", 0) for log in logging_outputs))
        loss_att_sum = utils.item(sum(log.get("att_loss", 0) for log in logging_outputs))
        ntokens = utils.item(sum(log.get("ntokens", 0) for log in logging_outputs))
        nsentences = utils.item(
            sum(log.get("nsentences", 0) for log in logging_outputs)
        )
        sample_size = utils.item(
            sum(log.get("sample_size", 0) for log in logging_outputs)
        )

        metrics.log_scalar(
            "ctc_loss", loss_ctc_sum / sample_size / math.log(2), sample_size, round=3
        )
        metrics.log_scalar(
            "att_loss", loss_att_sum / sample_size / math.log(2), sample_size, round=3
        )
        metrics.log_scalar(
            "loss", loss_sum / sample_size / math.log(2), sample_size, round=3
        )
        metrics.log_scalar("ntokens", ntokens)
        metrics.log_scalar("nsentences", nsentences)
        if sample_size != ntokens:
            metrics.log_scalar(
                "nll_loss", loss_sum / ntokens / math.log(2), ntokens, round=3
            )

        w_errors_att = sum(log.get("w_att_errors", 0) for log in logging_outputs)
        metrics.log_scalar("_w_att_errors", w_errors_att)
        c_errors = sum(log.get("c_errors", 0) for log in logging_outputs)
        metrics.log_scalar("_c_errors", c_errors)
        c_total = sum(log.get("c_total", 0) for log in logging_outputs)
        metrics.log_scalar("_c_total", c_total)
        w_errors = sum(log.get("w_errors", 0) for log in logging_outputs)
        metrics.log_scalar("_w_errors", w_errors)
        wv_errors = sum(log.get("wv_errors", 0) for log in logging_outputs)
        metrics.log_scalar("_wv_errors", wv_errors)
        w_total = sum(log.get("w_total", 0) for log in logging_outputs)
        metrics.log_scalar("_w_total", w_total)

        if c_total > 0:
            metrics.log_derived(
                "uer",
                lambda meters: safe_round(
                    meters["_c_errors"].sum * 100.0 / meters["_c_total"].sum, 3
                )
                if meters["_c_total"].sum > 0
                else float("nan"),
            )
        if w_total > 0:
            metrics.log_derived(
                "wer",
                lambda meters: safe_round(
                    meters["_w_errors"].sum * 100.0 / meters["_w_total"].sum, 3
                )
                if meters["_w_total"].sum > 0
                else float("nan"),
            )
            metrics.log_derived(
                "raw_wer",
                lambda meters: safe_round(
                    meters["_wv_errors"].sum * 100.0 / meters["_w_total"].sum, 3
                )
                if meters["_w_total"].sum > 0
                else float("nan"),
            )
            metrics.log_derived(
                "att_wer",
                lambda meters: safe_round(
                    meters["_w_att_errors"].sum * 100.0 / meters["_w_total"].sum, 3
                )
                if meters["_w_total"].sum > 0
                else float("nan"),
            )

    @staticmethod
    def logging_outputs_can_be_summed() -> bool:
        """
        Whether the logging outputs returned by `forward` can be summed
        across workers prior to calling `reduce_metrics`. Setting this
        to True will improves distributed training speed.
        """
        return True
