# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.

import os
import sys
from bert import BertTokenizer
from fairseq.data import AddTargetDataset, Dictionary, FileAudioDataset, SpmDictionary
from fairseq.data.masked_lm_dictionary import BertWordpieceDictionary
import torch
from . import LegacyFairseqTask, register_task

import sentencepiece as spm

import logging
logger = logging.getLogger(__name__)

class LabelEncoder(object):
    def __init__(self, dictionary):
        self.dictionary = dictionary

    def __call__(self, label):
        return self.dictionary.encode_line(
            label, append_eos=False, add_if_not_exist=False
        )

class BertLabelEncoder(object):
    def __init__(self, berttokenizer, post_proces='bert_bpe_piece', check=False):
        self.berttokenizer = berttokenizer
        self.post_proces = post_proces
        self.check = check

    def __call__(self, label):
        if self.check:
            return self.berttokenizer.encode_line(
                label, post_proces=self.post_proces, check=self.check,
            )
        else:
            return torch.LongTensor(self.berttokenizer.encode_line(
                label, post_proces=self.post_proces, check=self.check,
            ))

    def close_check(self):
        self.check = False



class SpmLabelEncoder(object):
    def __init__(self, model_path):
        self.spm_model = spm.SentencePieceProcessor(model_file=model_path)

    def __call__(self, label):
        ids = self.spm_model.encode(label)
        # logger.info("Debug")
        # logger.info(self.spm_model.encode(label, out_type=str))
        return torch.LongTensor(ids)

@register_task("audio_pretraining")
class AudioPretrainingTask(LegacyFairseqTask):
    """"""

    @staticmethod
    def add_args(parser):
        """Add task-specific arguments to the parser."""
        parser.add_argument("data", help="path to data directory")
        parser.add_argument(
            "--sample-rate",
            default=16000,
            type=int,
            help="target sample rate. audio files will be up/down sampled to this rate",
        )
        parser.add_argument(
            "--normalize",
            action="store_true",
            help="if set, normalizes input to have 0 mean and unit variance",
        )
        parser.add_argument(
            "--max-sample-size",
            default=None,
            type=int,
            help="max sample size to crop to for batching. default = min sample length",
        )
        parser.add_argument(
            "--min-sample-size",
            default=None,
            type=int,
            help="min sample size to crop to for batching. default = same as --max-sample-size",
        )

        parser.add_argument(
            "--enable-padding",
            action="store_true",
            help="pad shorter samples instead of cropping",
        )

        parser.add_argument(
            "--labels",
            type=str,
            default=None,
            help="extension of the label file to load, if any",
        )
        parser.add_argument('--need-prev-output', action="store_true", help="add prev output when load data")
        parser.add_argument('--decoder-bert-model-name', type=str, default=None)
        parser.add_argument('--spm-model-path', type=str, default='/wav2bert/train_clean_1000.model')

    def __init__(self, args, source_dictionary=None, target_dictionary=None, post_process=None):
        super().__init__(args)
        self._target_dictionary = target_dictionary
        self._source_dictionary = source_dictionary
        self.is_ctc = args.criterion == "ctc"
        self.need_prev_output = False
        self.post_process = post_process
        self.bert_model_name = getattr(args, 'decoder_bert_model_name', None)
        self.spm_model_path = getattr(args, 'spm_model_path', None)

    @classmethod
    def setup_task(cls, args, **kwargs):
        """Setup the task (e.g., load dictionaries).

        Args:
            args (omegaconf.DictConfig): parsed command-line arguments
        """
        for k, v in kwargs.items():
            print ('Optional argument %s (kwargs) in setup_task: %s' % (k, v))
        if args.labels:
            if getattr(args, 'lexicon', None) is not None: # dict path 主要由参数lexicon决定
                dict_path = args.lexicon
            else:
                dict_path = os.path.join(args.data, f"dict.{args.labels}.txt")
            print("Loading dict from lexicon:", dict_path)
            if getattr(args, 'line_tokenizer', None) is not None: # 而line tokenizer决定用传统的dictionary还是bertdictionary
                if getattr(args, 'line_tokenizer', None) == 'spm':
                    target_dictionary = SpmDictionary.load(dict_path, model_path=getattr(args, 'spm_model_path', None))
                else:
                    target_dictionary = BertWordpieceDictionary.load(dict_path)
                    print('| [{}] dictionary: {} types'.format(dict_path, len(target_dictionary)))
            else:
                target_dictionary = Dictionary.load(dict_path)
                print("set up Dictionary...")
        else:
            target_dictionary = None

        return cls(args, target_dictionary=target_dictionary, post_process=getattr(args, 'line_tokenizer', None))

    def load_dataset(self, split, **kwargs):
        """Load a given dataset split.

        Args:
            split (str): name of the split (e.g., train, valid, test)
        """
        manifest = os.path.join(self.args.data, "{}.tsv".format(split))
        self.datasets[split] = FileAudioDataset(
            manifest,
            sample_rate=self.args.sample_rate,
            max_sample_size=self.args.max_sample_size,
            min_sample_size=self.args.max_sample_size,
            min_length=self.args.min_sample_size,
            pad=self.args.labels is not None or self.args.enable_padding,
            normalize=self.args.normalize,
        )

        check_tokens = False
        if self.args.labels:
            label_path = os.path.join(self.args.data, f"{split}.{self.args.labels}")
            labels = []
            with open(label_path, "r") as f:
                for line in f:
                    labels.append(line)
            print("self.post_process:", self.post_process)
            if self.post_process is not None:
                if self.post_process == 'word':
                    load_vocab_file = self.bert_model_name + '/word_vocab.txt'
                    print("load_vocab_file:", load_vocab_file)
                    self.berttokenizer = BertTokenizer.from_pretrained(load_vocab_file)
                    print("initial load BERT tokenizer with {}...".format(self.post_process))
                    process_label = BertLabelEncoder(self.berttokenizer, post_proces=self.post_process, check=check_tokens)
                elif self.post_process == 'spm':
                    process_label = SpmLabelEncoder(model_path=self.spm_model_path)
                else:
                    load_vocab_file = self.bert_model_name + '/vocab.txt'
                    print("load_vocab_file:", load_vocab_file)
                    self.berttokenizer = BertTokenizer.from_pretrained(load_vocab_file)
                    print("initial load BERT tokenizer with {}...".format(self.post_process))
                    process_label = BertLabelEncoder(self.berttokenizer, post_proces=self.post_process, check=check_tokens)
            else:
                print("initial load LabelEncoder...")
                process_label = LabelEncoder(self.target_dictionary)

            self.datasets[split] = AddTargetDataset(
                self.datasets[split],
                labels,
                pad=self.target_dictionary.pad(),
                eos=self.target_dictionary.eos(),
                batch_targets=True,
                process_label=process_label,
                add_to_input=not self.is_ctc,
                need_prev_output=self.need_prev_output,
                check_tokens=check_tokens,
            )

    @property
    def source_dictionary(self):
        return self._source_dictionary

    @property
    def target_dictionary(self):
        """Return the :class:`~fairseq.data.Dictionary` for the language
        model."""
        return self._target_dictionary

    def max_positions(self):
        """Maximum input length supported by the encoder."""
        return (sys.maxsize, sys.maxsize)

    def filter_indices_by_size(
        self,
        indices,
        dataset,
        max_positions=None,
        ignore_invalid_inputs=False,
    ):
        # we do not need to filter by size in this task as dataloaders take care of this
        return indices
