# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.

import os
import sys
import torch
from fairseq.data import Dictionary, FileAudioDataset, Wav2BertDataset, Wav2BertDatasetWithoutMask, Wav2BertDatasetDifferentTokens, OnlyBertDataset, Wav2BertDatasetDifferentTokensDifferentVocab

from . import LegacyFairseqTask, register_task
from fairseq.data.masked_lm_dictionary import BertWordpieceDictionary
from bert import BertTokenizer



import logging
logger = logging.getLogger(__name__)


class LabelEncoder(object):
    def __init__(self, dictionary):
        self.dictionary = dictionary

    def __call__(self, label):
        return self.dictionary.encode_line(
            label, append_eos=False, add_if_not_exist=False
        )
    
    def convert_ids_to_tokens(self, ids):
        return self.dictionary.string(ids) # 这里有可能会报错，不是tensor之类的类型

@register_task("wav2bert_task")
class Wav2BertTask(LegacyFairseqTask):
    """"""

    @staticmethod
    def add_args(parser):
        """Add task-specific arguments to the parser."""
        parser.add_argument("data", help="path to data directory")
        parser.add_argument(
            "--sample-rate",
            default=16000,
            type=int,
            help="target sample rate. audio files will be up/down sampled to this rate",
        )
        parser.add_argument(
            "--normalize",
            action="store_true",
            help="if set, normalizes input to have 0 mean and unit variance",
        )
        parser.add_argument(
            "--max-sample-size",
            default=None,
            type=int,
            help="max sample size to crop to for batching. default = min sample length",
        )
        parser.add_argument(
            "--min-sample-size",
            default=None,
            type=int,
            help="min sample size to crop to for batching. default = same as --max-sample-size",
        )

        parser.add_argument(
            "--enable-padding",
            action="store_true",
            help="pad shorter samples instead of cropping",
        )

        parser.add_argument(
            "--labels",
            type=str,
            default=None,
            help="extension of the label file to load, if any",
        )
        parser.add_argument('--decoder-bert-model-name', type=str, default='bert-base-uncased')
        parser.add_argument('--different-tokens', action="store_true", help="BERT uses word piece token and wav2vec uses character tokens")
        parser.add_argument('--fusion-v2', action='store_true',
                            help='wav2bert BERT fusion mode')
        parser.add_argument('--fusion-v3', action='store_true',
                            help='wav2bert BERT fusion mode')
        parser.add_argument('--add-input', action='store_true',
                            help='add special token to input')
        parser.add_argument('--mask-radio-range', type=str,
                            help='mlm mask prob range change per epoch')
        parser.add_argument('--mask-step-range', type=str,
                            help='mlm mask prob range change per epoch')
        parser.add_argument('--no-mask', action="store_true",
                            help='dataset no mask')

        parser.add_argument('--only-bert', action="store_true",
                            help='only finetune the bert')
        parser.add_argument('--bert-mask-type', type=str, default='origin',
                            help='decide the bert mask num')

        parser.add_argument('--different-tokens-v2', action="store_true", help="BERT uses word piece token and wav2vec uses character tokens")
                        
    def __init__(self, args, source_dictionary=None, target_dictionary=None):
        super().__init__(args)
        self._target_dictionary = target_dictionary
        self._source_dictionary = source_dictionary
        self.is_ctc = args.criterion == "ctc"
        self.need_prev_output = True
        self.bert_model_name = args.decoder_bert_model_name
        self.different_tokens = getattr(args, 'different_tokens', None)
        self.different_tokens_v2 = getattr(args, 'different_tokens_v2', None)
        self.tokenizer_process = getattr(args, 'post_process', None)

        self.add_to_input = getattr(args, 'add_input', False)

        if getattr(args, "mask_step_range", None) is not None and \
            getattr(args, "mask_radio_range", None) is not None:
            self.mask_step_range = eval(args.mask_step_range)
            self.mask_radio_range = eval(args.mask_radio_range)
        else:
            self.mask_step_range = None
            self.mask_radio_range = None
        
        self.no_mask = getattr(args, 'no_mask', False)
        self.bert_mask_type = getattr(args, 'bert_mask_type', 'origin')

        if self.different_tokens_v2:
            character_dict_path = os.path.join(args.data, f"dict.ltr.txt")

            self.character_target_dictionary = Dictionary.load(character_dict_path)
            self.character_tokenizer = LabelEncoder(self.character_target_dictionary)

    @classmethod
    def setup_task(cls, args, **kwargs):
        """Setup the task (e.g., load dictionaries).

        Args:
            args (omegaconf.DictConfig): parsed command-line arguments
        """
        if args.labels:
            # dict_path = os.path.join(args.data, f"dict.{args.labels}.txt")
            if getattr(args, 'lexicon', None) is not None: # dict path 主要由参数lexicon决定
                dict_path = args.lexicon
            else:
                dict_path = os.path.join(args.decoder_bert_model_name, 'dict.en.txt')
            print(dict_path)
            target_dictionary = BertWordpieceDictionary.load(dict_path)
            print('| [{}] dictionary: {} types'.format(dict_path, len(target_dictionary)))
        else:
            target_dictionary = None

        return cls(args, target_dictionary=target_dictionary)

    def load_dataset(self, split, test_mode=False, **kwargs):
        """Load a given dataset split.

        Args:
            split (str): name of the split (e.g., train, valid, test)
        """
        manifest = os.path.join(self.args.data, "{}.tsv".format(split))
        self.datasets[split] = FileAudioDataset(
            manifest,
            sample_rate=self.args.sample_rate,
            max_sample_size=self.args.max_sample_size,
            min_sample_size=self.args.max_sample_size,
            min_length=self.args.min_sample_size,
            pad=self.args.labels is not None or self.args.enable_padding,
            normalize=self.args.normalize,
        )

        if self.args.labels:
            if self.different_tokens or self.different_tokens_v2 and not test_mode:
                label_path = os.path.join(self.args.data, f"{split}.wrd")
                print("Loading labels: ", label_path)
                wrd_labels = []
                with open(label_path, "r") as f:
                    for line in f:
                        wrd_labels.append(line)
                label_path = os.path.join(self.args.data, f"{split}.ltr")
                print("Loading labels: ", label_path)
                ltr_labels = []
                with open(label_path, "r") as f:
                    for line in f:
                        ltr_labels.append(line)
            else:
                label_path = os.path.join(self.args.data, f"{split}.{self.args.labels}")
                print("Loading labels: {} with process {}".format(label_path, self.tokenizer_process))
                labels = []
                with open(label_path, "r") as f:
                    for line in f:
                        labels.append(line)
            # import ipdb; ipdb.set_trace()
            if getattr(self.args, "only_bert", False):
                print("Building Dataset with OnlyBertDataset...")
                self.datasets[split] = OnlyBertDataset(
                    self.datasets[split],
                    labels,
                    pad=self.target_dictionary.pad(),
                    eos=self.target_dictionary.eos(),
                    batch_targets=True,
                    add_to_input=self.add_to_input,
                    bert_model_name = self.bert_model_name,
                    dictionary=self._target_dictionary,
                    tokenizer_process=self.tokenizer_process,
                    mask_type=self.bert_mask_type,
                )
            elif test_mode:
                print("In testing...")
                print("Building Dataset with Wav2BertDatasetWithoutMask...")
                self.datasets[split] = Wav2BertDatasetWithoutMask(
                ###### debug ########
                # self.datasets[split] = Wav2BertDataset(
                    self.datasets[split],
                    labels,
                    pad=self.target_dictionary.pad(),
                    eos=self.target_dictionary.eos(),
                    batch_targets=True,
                    add_to_input=not self.is_ctc,
                    bert_model_name = self.bert_model_name,
                    dictionary=self._target_dictionary,
                    tokenizer_process=self.tokenizer_process,
                )
            elif self.different_tokens:
                print("Building Dataset with Wav2BertDatasetDifferentTokens...")
                self.datasets[split] = Wav2BertDatasetDifferentTokens(
                    self.datasets[split],
                    wrd_labels, ltr_labels,
                    pad=self.target_dictionary.pad(),
                    eos=self.target_dictionary.eos(),
                    batch_targets=True,
                    add_to_input=self.add_to_input,
                    bert_model_name = self.bert_model_name,
                    dictionary=self._target_dictionary,
                    no_mask=self.no_mask,
                )
            elif self.different_tokens_v2:
                print("Building Dataset with Wav2BertDatasetDifferentTokensDifferentVocab")
                self.datasets[split] = Wav2BertDatasetDifferentTokensDifferentVocab(
                    self.datasets[split],
                    wrd_labels, ltr_labels,
                    pad=self.target_dictionary.pad(),
                    eos=self.target_dictionary.eos(),
                    batch_targets=True,
                    add_to_input=self.add_to_input,
                    bert_model_name = self.bert_model_name,
                    dictionary=self._target_dictionary,
                    no_mask=self.no_mask,
                    character_tokenizer=self.character_tokenizer,
                )
            else:
                print("Building Dataset with Wav2BertDataset...")
                self.datasets[split] = Wav2BertDataset(
                    self.datasets[split],
                    labels,
                    pad=self.target_dictionary.pad(),
                    eos=self.target_dictionary.eos(),
                    batch_targets=True,
                    add_to_input=self.add_to_input,
                    bert_model_name = self.bert_model_name,
                    dictionary=self._target_dictionary,
                    tokenizer_process=self.tokenizer_process,
                )

    @property
    def source_dictionary(self):
        return self._source_dictionary

    @property
    def target_dictionary(self):
        """Return the :class:`~fairseq.data.Dictionary` for the language
        model."""
        return self._target_dictionary

    def max_positions(self):
        """Maximum input length supported by the encoder."""
        return (sys.maxsize, sys.maxsize)

    def filter_indices_by_size(
        self,
        indices,
        dataset,
        max_positions=None,
        ignore_invalid_inputs=False,
    ):
        # we do not need to filter by size in this task as dataloaders take care of this
        return indices

    def train_step(
        self, sample, model, criterion, optimizer, update_num, ignore_grad=False
    ):
        if self.mask_radio_range is not None and update_num > self.mask_step_range[0]:
            # set the dataset mask prob
            s, e = self.mask_step_range
            sp, ep = self.mask_radio_range
            current_radio = min(1, (update_num - s) / (e - s)) * (ep - sp) + sp
            # logger.info(current_radio)
            self.datasets['train'].set_mask_low_radio(current_radio) # 这个radio只有在每个epoch dataloader更新的时候有用，修改的是随机mask下界

        return super(LegacyFairseqTask, self).train_step(sample, model, criterion, optimizer, update_num, ignore_grad)

        
    def inference_step(self, generator, models, sample, prefix_tokens=None, only_wav2vec=False, tgt_bert_encoder=None, tgt_bert_tokenizer=None, cif=None, ce=False, different_tokens_v2=False):
        if different_tokens_v2:
            with torch.no_grad():
                return generator.generate(models, sample, prefix_tokens=prefix_tokens, tgt_bert_encoder=tgt_bert_encoder,\
                                        tgt_bert_tokenizer=tgt_bert_tokenizer, cif=cif, different_tokens=self.args.different_tokens,\
                                        tgt_dictionary=self.character_target_dictionary)
        if ce:
            with torch.no_grad():
                tgt_bert_tokenizer = berttokenizer = BertTokenizer.from_pretrained(self.bert_model_name)
                return generator.generate(models, sample, prefix_tokens=prefix_tokens, tgt_bert_encoder=tgt_bert_encoder,\
                                            tgt_bert_tokenizer=tgt_bert_tokenizer, cif=cif, different_tokens=self.args.different_tokens,\
                                            tgt_dictionary=self.target_dictionary)

        if only_wav2vec:
            # print("Generate with W2lGreedyDecoder generator...")
            with torch.no_grad():
                return generator.generate(models, sample, need_prev_output=True, prefix_tokens=prefix_tokens, constraints=None)
        else:
            # print("Generate with SequenceGeneratorWithBert generator...")
            with torch.no_grad():
                if self.args.different_tokens:
                    tgt_bert_tokenizer = berttokenizer = BertTokenizer.from_pretrained(self.bert_model_name)
                    return generator.generate(models, sample, prefix_tokens=prefix_tokens, tgt_bert_encoder=tgt_bert_encoder,\
                                            tgt_bert_tokenizer=tgt_bert_tokenizer, cif=cif, different_tokens=self.args.different_tokens,\
                                            tgt_dictionary=self.target_dictionary)
                else:
                    tgt_bert_tokenizer = berttokenizer = BertTokenizer.from_pretrained(self.bert_model_name)
                    return generator.generate(models, sample, prefix_tokens=prefix_tokens, tgt_bert_encoder=tgt_bert_encoder,\
                                            tgt_bert_tokenizer=tgt_bert_tokenizer, cif=cif)


    def build_generator(self, args):
        if args.only_wav2vec_beam:
            print("Importing W2lDecoderBertRescore Generator...")
            from examples.speech_recognition.w2l_decoder import W2lDecoderBertRescore
            return W2lDecoderBertRescore(args, self.target_dictionary)
        elif args.only_wav2vec:
            print("Importing W2lGreedyDecoder Generator...")
            from examples.speech_recognition.w2l_decoder import W2lGreedyDecoder
            return W2lGreedyDecoder(args, self.target_dictionary)
        elif args.ce:
            print("Importing SequenceGeneratorWithFusionCE Generator...")
            from fairseq.sequence_generator_with_fusion_ce import SequenceGeneratorWithFusionCE
            return SequenceGeneratorWithFusionCE(
                self.target_dictionary,
                beam_size=getattr(args, 'beam', 5),
                max_len_a=getattr(args, 'max_len_a', 0),
                max_len_b=getattr(args, 'max_len_b', 200),
                min_len=getattr(args, 'min_len', 1),
                stop_early=(not getattr(args, 'no_early_stop', False)),
                normalize_scores=(not getattr(args, 'unnormalized', False)),
                len_penalty=getattr(args, 'lenpen', 1),
                unk_penalty=getattr(args, 'unkpen', 0),
                sampling=getattr(args, 'sampling', False),
                sampling_topk=getattr(args, 'sampling_topk', -1),
                temperature=getattr(args, 'temperature', 1.),
                diverse_beam_groups=getattr(args, 'diverse_beam_groups', -1),
                diverse_beam_strength=getattr(args, 'diverse_beam_strength', 0.5),
                match_source_len=getattr(args, 'match_source_len', False),
                no_repeat_ngram_size=getattr(args, 'no_repeat_ngram_size', 0),
                mask_pred_iter=getattr(args, 'mask_pred_iter', 10),
                decode_use_adapter=getattr(args, 'decode_use_adapter', False),
                args=args,
            )
        elif args.fusion_v3 or args.fusion_v2:
            print("Importing SequenceGeneratorWithFusion Generator...")
            from fairseq.sequence_generator_with_fusion import SequenceGeneratorWithFusion
            return SequenceGeneratorWithFusion(
                self.target_dictionary,
                beam_size=getattr(args, 'beam', 5),
                max_len_a=getattr(args, 'max_len_a', 0),
                max_len_b=getattr(args, 'max_len_b', 200),
                min_len=getattr(args, 'min_len', 1),
                stop_early=(not getattr(args, 'no_early_stop', False)),
                normalize_scores=(not getattr(args, 'unnormalized', False)),
                len_penalty=getattr(args, 'lenpen', 1),
                unk_penalty=getattr(args, 'unkpen', 0),
                sampling=getattr(args, 'sampling', False),
                sampling_topk=getattr(args, 'sampling_topk', -1),
                temperature=getattr(args, 'temperature', 1.),
                diverse_beam_groups=getattr(args, 'diverse_beam_groups', -1),
                diverse_beam_strength=getattr(args, 'diverse_beam_strength', 0.5),
                match_source_len=getattr(args, 'match_source_len', False),
                no_repeat_ngram_size=getattr(args, 'no_repeat_ngram_size', 0),
                mask_pred_iter=getattr(args, 'mask_pred_iter', 10),
                decode_use_adapter=getattr(args, 'decode_use_adapter', False),
                args=args,
            )
        else:
            print("Importing SequenceGeneratorWithBert Generator...")
            from fairseq.sequence_generator_with_bert import SequenceGeneratorWithBert
            return SequenceGeneratorWithBert(
                self.target_dictionary,
                beam_size=getattr(args, 'beam', 5),
                max_len_a=getattr(args, 'max_len_a', 0),
                max_len_b=getattr(args, 'max_len_b', 200),
                min_len=getattr(args, 'min_len', 1),
                stop_early=(not getattr(args, 'no_early_stop', False)),
                normalize_scores=(not getattr(args, 'unnormalized', False)),
                len_penalty=getattr(args, 'lenpen', 1),
                unk_penalty=getattr(args, 'unkpen', 0),
                sampling=getattr(args, 'sampling', False),
                sampling_topk=getattr(args, 'sampling_topk', -1),
                temperature=getattr(args, 'temperature', 1.),
                diverse_beam_groups=getattr(args, 'diverse_beam_groups', -1),
                diverse_beam_strength=getattr(args, 'diverse_beam_strength', 0.5),
                match_source_len=getattr(args, 'match_source_len', False),
                no_repeat_ngram_size=getattr(args, 'no_repeat_ngram_size', 0),
                mask_pred_iter=getattr(args, 'mask_pred_iter', 10),
                decode_use_adapter=getattr(args, 'decode_use_adapter', False),
                args=args,
            )