#!/usr/bin/env python3 -u
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

"""
Run inference for pre-processed data with a trained model.
"""

import logging
import math
import os
import sys
sys.path.append("/wav2bert/fairseq")

import editdistance
import numpy as np
import torch
from fairseq import checkpoint_utils, options, progress_bar, tasks, utils
from fairseq.data.data_utils import post_process
from fairseq.dataclass.utils import convert_namespace_to_omegaconf
from fairseq.logging.meters import StopwatchMeter, TimeMeter


logging.basicConfig()
logging.root.setLevel(logging.INFO)
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)


def add_asr_eval_argument(parser):
    parser.add_argument("--lexicon", help="lexicon for w2l decoder")
    parser.add_argument('--test', default=True, action='store_true')
    parser.add_argument('--left-pad-source', default='True', type=str, metavar='BOOL',
                            help='pad the source on the left')
    return parser


def check_args(args):
    # assert args.path is not None, "--path required for generation!"
    # assert args.results_path is not None, "--results_path required for generation!"
    assert (
        args.replace_unk is None or args.raw_text
    ), "--replace-unk requires a raw text dataset (--raw-text)"


def get_dataset_itr(args, task, models):
    return task.get_batch_iterator(
        dataset=task.dataset(args.gen_subset),
        max_tokens=args.max_tokens,
        max_sentences=args.batch_size,
        max_positions=(sys.maxsize, sys.maxsize),
        ignore_invalid_inputs=args.skip_invalid_size_inputs_valid_test,
        required_batch_size_multiple=args.required_batch_size_multiple,
        num_shards=args.num_shards,
        shard_id=args.shard_id,
        num_workers=args.num_workers,
        data_buffer_size=args.data_buffer_size,
    ).next_epoch_itr(shuffle=False)


def process_predictions(
    args, hypos, sp, tgt_dict, bert_dict, target_tokens, res_files, speaker, id, filename=None
):
    for hypo in hypos[: min(len(hypos), args.nbest)]:
        hyp_pieces = tgt_dict.string(hypo["tokens"].int().cpu())
        if "words" in hypo:
            hyp_words = " ".join(hypo["words"])
        else:
            hyp_words = post_process(hyp_pieces, args.post_process)

        if res_files is not None:
            if filename is not None:
                print(
                "{} {}".format(hyp_pieces, filename),
                file=res_files["hypo.units"],
                )
                print(
                    "{} {}".format(hyp_words, filename),
                    file=res_files["hypo.words"],
                )
            else:
                print(
                    "{} ({}-{})".format(hyp_pieces, speaker, id),
                    file=res_files["hypo.units"],
                )
                print(
                    "{} ({}-{})".format(hyp_words, speaker, id),
                    file=res_files["hypo.words"],
                )

        tgt_pieces = tgt_dict.string(target_tokens)
        print(args.post_process)
        tgt_words = post_process(tgt_pieces, args.post_process)

        if res_files is not None:
            if filename is not None:
                    print(
                    "{} {}".format(tgt_pieces, filename),
                    file=res_files["ref.units"],
                    )
                    print(
                        "{} {}".format(tgt_words, filename), file=res_files["ref.words"]
                    )
            else:
                print(
                    "{} ({}-{})".format(tgt_pieces, speaker, id),
                    file=res_files["ref.units"],
                )
                print(
                    "{} ({}-{})".format(tgt_words, speaker, id), file=res_files["ref.words"]
                )
            # only score top hypothesis
            if not args.quiet:
                logger.debug("HYPO:" + hyp_words)
                logger.debug("TARGET:" + tgt_words)
                logger.debug("___________________")
        if not args.chinese_cer:
            hyp_words = hyp_words.split()
            tgt_words = tgt_words.split()
        print(hyp_words)
        print(tgt_words)
        # 这里算wer的时候，因为已经通过post process处理了，因此只要处理的参数是对的就没有问题，算cer的话，为了保持和valid一致，需要把分隔符也算进去，只要把word拼起来在分割就行了
        # return editdistance.eval(hyp_words, tgt_words), len(tgt_words), editdistance.eval([c for c in hyp_pieces if c != ' '], [c for c in tgt_pieces if c != ' ']), len([c for c in tgt_pieces if c != ' '])
        return editdistance.eval(hyp_words, tgt_words), len(tgt_words)

def prepare_result_files(args):
    def get_res_file(file_prefix):
        if args.num_shards > 1:
            file_prefix = f"{args.shard_id}_{file_prefix}"
        path = os.path.join(
            args.results_path,
            "{}-{}-{}.txt".format(
                file_prefix, os.path.basename(args.path), args.gen_subset
            ),
        )
        return open(path, "w", buffering=1)

    if not args.results_path:
        return None

    return {
        "hypo.words": get_res_file("hypo.word"),
        "hypo.units": get_res_file("hypo.units"),
        "ref.words": get_res_file("ref.word"),
        "ref.units": get_res_file("ref.units"),
    }


def load_models_and_criterions(
    filenames, data_path, arg_overrides=None, task=None, model_state=None
):
    models = []
    criterions = []

    if arg_overrides is None:
        arg_overrides = {}

    arg_overrides["wer_args"] = None
    arg_overrides["data"] = data_path

    if filenames is None:
        assert model_state is not None
        filenames = [0]
    else:
        filenames = filenames.split(":")

    for filename in filenames:
        if model_state is None:
            if not os.path.exists(filename):
                raise IOError("Model file not found: {}".format(filename))
            state = checkpoint_utils.load_checkpoint_to_cpu(filename, arg_overrides)
        else:
            state = model_state

        if "cfg" in state:
            cfg = state["cfg"]
        else:
            cfg = convert_namespace_to_omegaconf(state["args"])

        if task is None:
            if hasattr(cfg.task, 'data'):
                cfg.task.data = data_path
            task = tasks.setup_task(cfg.task)
        model = task.build_model(cfg.model)
        # import ipdb; ipdb.set_trace()
        model.load_state_dict(state["model"], strict=True)
        # 验证出来参数确实变了
        models.append(model)

        criterion = task.build_criterion(cfg.criterion)
        if "criterion" in state:
            criterion.load_state_dict(state["criterion"], strict=True)
        criterions.append(criterion)
    return models, criterions, task


def optimize_models(args, use_cuda, models):
    """Optimize ensemble for generation"""
    for model in models:
        model.make_generation_fast_(
            beamable_mm_beam_size=None if args.no_beamable_mm else args.beam,
            need_attn=args.print_alignment,
        )
        if args.fp16:
            model.half()
        if use_cuda:
            model.cuda()


class ExistingEmissionsDecoder(object):
    def __init__(self, decoder, emissions):
        self.decoder = decoder
        self.emissions = emissions

    def generate(self, models, sample, **unused):
        ids = sample["id"].cpu().numpy()
        try:
            emissions = np.stack(self.emissions[ids])
        except:
            print([x.shape for x in self.emissions[ids]])
            raise Exception("invalid sizes")
        emissions = torch.from_numpy(emissions)
        return self.decoder.decode(emissions)


def main(args, task=None, model_state=None):
    check_args(args)
    print("args.chinese_cer:", args.chinese_cer)
    if args.max_tokens is None and args.batch_size is None:
        args.max_tokens = 4000000
    logger.info(args)

    use_cuda = torch.cuda.is_available() and not args.cpu


    logger.info("| decoding with criterion {}".format(args.criterion))

    logger.info("| loading model(s) from {}".format(args.path))
    # import ipdb; ipdb.set_trace()
    models, criterions, _ = load_models_and_criterions(
        args.path,
        data_path=args.data,
        arg_overrides=eval(args.model_overrides),  # noqa
        task=task,
        model_state=model_state,
    )
    optimize_models(args, use_cuda, models)

    if task is None:
        # Load dataset splits
        task = tasks.setup_task(args)

        task.args.labels = args.labels
        task.args.post_process = args.post_process
        task.tokenizer_process = args.post_process
        task.load_dataset(args.gen_subset, test_mode=args.test)
        logger.info(
            "| {} {} {} examples".format(
                args.data, args.gen_subset, len(task.dataset(args.gen_subset))
            )
        )

    # Set dictionary
    try:
        tgt_dict = task.character_target_dictionary
    except:
        tgt_dict = task.target_dictionary
    
    bert_dict = task.target_dictionary

    # Load dataset (possibly sharded)
    itr = get_dataset_itr(args, task, models)

    # Initialize generator
    gen_timer = StopwatchMeter()

    # please do not touch this unless you test both generate.py and infer.py with audio_pretraining task
    from examples.speech_recognition.wav2bert_decoder_different_token_v2 import Wav2BertDecoder

    generator = Wav2BertDecoder(tgt_dict, args)
    num_sentences = 0
    has_target = True

    if args.results_path is not None and not os.path.exists(args.results_path):
        os.makedirs(args.results_path)

    # get the min max source length support by the system or the model
    max_source_pos = (
        utils.resolve_max_positions(
            task.max_positions(), *[model.max_positions() for model in models]
        ),
    )

    if max_source_pos is not None:
        max_source_pos = max_source_pos[0]
        if max_source_pos is not None:
            max_source_pos = max_source_pos[0] - 1

  
    # get the result file name
    res_files = prepare_result_files(args)
    errs_t = 0
    lengths_t = 0
    with progress_bar.build_progress_bar(args, itr) as t:
        wps_meter = TimeMeter()
        for sample in t:
            sample = utils.move_to_cuda(sample) if use_cuda else sample
            if "net_input" not in sample:
                continue

            prefix_tokens = None
            if args.prefix_size > 0:
                prefix_tokens = sample["target"][:, : args.prefix_size]

            gen_timer.start()
            hypos = task.inference_step(generator, models, sample, prefix_tokens, different_tokens_v2=True)
            num_generated_tokens = sum(len(h[0]["tokens"]) for h in hypos)
            gen_timer.stop(num_generated_tokens)

            for i, sample_id in enumerate(sample["id"].tolist()):
                speaker = None
                # id = task.dataset(args.gen_subset).ids[int(sample_id)]
                id = sample_id
                toks = (
                    sample["target"][i, :]
                    if "target_label" not in sample
                    else sample["target_label"][i, :]
                )
                filename = sample['filename'][i]
                target_tokens = utils.strip_pad(toks, tgt_dict.pad()).int().cpu()
                # Process top predictions
                errs, length = process_predictions(
                    args,
                    hypos[i],
                    None,
                    tgt_dict,
                    bert_dict,
                    target_tokens,
                    res_files,
                    speaker,
                    id,
                    filename,
                )
                errs_t += errs
                lengths_t += length

            wps_meter.update(num_generated_tokens)
            t.log({"wps": round(wps_meter.avg)})
            num_sentences += (
                sample["nsentences"] if "nsentences" in sample else sample["id"].numel()
            )       

    wer = None

    if lengths_t > 0:
        wer = errs_t * 100.0 / lengths_t
        logger.info(f"WER: {wer} {errs_t} {lengths_t}")

    logger.info(
        "| Processed {} sentences ({} tokens) in {:.1f}s ({:.2f}"
        "sentences/s, {:.2f} tokens/s)".format(
            num_sentences,
            gen_timer.n,
            gen_timer.sum,
            num_sentences / gen_timer.sum,
            1.0 / gen_timer.avg,
        )
    )
    logger.info("| Generate {}".format(args.gen_subset))
    logger.info("| CKPT From {}...".format(args.path))
    return task, wer


def make_parser():
    parser = options.get_generation_parser()
    parser = add_asr_eval_argument(parser)
    return parser


def cli_main():
    parser = make_parser()
    args = options.parse_args_and_arch(parser)
    main(args)


if __name__ == "__main__":
    cli_main()
