# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import torch
from bert import BertTokenizer
from . import BaseWrapperDataset, data_utils
import numpy as np

class Wav2BertDatasetDifferentTokensDifferentVocab(BaseWrapperDataset):
    def __init__(
        self,
        dataset,
        wrd_labels,
        ltr_labels,
        pad,
        eos,
        batch_targets,
        add_to_input=False,
        bert_model_name='bert_english',
        dictionary=None,
        no_mask = False,
        character_tokenizer=None,
    ):
        super().__init__(dataset)
        self.wrd_labels = wrd_labels
        self.ltr_labels = ltr_labels
        self.batch_targets = batch_targets
        self.pad = pad
        self.character_pad = character_tokenizer.dictionary.pad()
        self.eos = eos
        self.add_to_input = add_to_input
        self.berttokenizer = BertTokenizer.from_pretrained(bert_model_name)
        self.dictionary=dictionary
        self.mask_low_radio = 0
        self.no_mask = no_mask
        self.character_tokenizer = character_tokenizer

    def set_mask_low_radio(self, low_radio):
        self.mask_low_radio = low_radio

    def get_label(self, index):
        tgt_character_list = self.character_tokenizer(self.ltr_labels[index]).tolist()

        # 这里为了适应bert的训练数据我们加上最后的句号以及cls，sep的符号
        if self.add_to_input:
            tgt_wordpiece_list = self.berttokenizer.encode_line(self.wrd_labels[index] + '.', post_proces='bert_bpe_piece')
            tgt_wordpiece_list.insert(0, self.dictionary.cls())
            tgt_wordpiece_list.append(self.dictionary.eos())
        else:
           tgt_wordpiece_list = self.berttokenizer.encode_line(self.wrd_labels[index], post_proces='bert_bpe_piece') 
        return tgt_character_list, tgt_wordpiece_list

    def mask_word(self, w, p=None):
        voc = self.dictionary

        p = np.random.random() if p is None else p
        if p >= 0.2:
            return voc.mask_index # [MASK] 103
        elif p >= 0.1:
            return np.random.randint(voc.nspecial, len(voc)) # 0->len 随机选一个替换
        else:
            return w # 不被mask

    def __getitem__(self, index):
        item = self.dataset[index]
        tgt_character_list, tgt_wordpiece_list = self.get_label(index)
        origin_target = tgt_character_list.copy()
        origin_bert_target = tgt_wordpiece_list.copy()
        target = tgt_wordpiece_list.copy()
        output = tgt_wordpiece_list.copy()
        len_to_consider = len(tgt_wordpiece_list)
        
        # 如果不进行mask就直接返回item，此时bert的输入和mlm loss的target都是gt
        if self.no_mask:
            item['target']=torch.LongTensor(target) # BERT decoder的输入
            item['output']=torch.LongTensor(output) # MLM loss 的target
            item['origin_target']=torch.LongTensor(origin_target) # CTC LOSS的target
            item['origin_bert_target']=torch.LongTensor(origin_bert_target) # BERT计算WER时的ground truth
            return item

        low_bound = min(1, int(self.mask_low_radio*len(tgt_wordpiece_list)))
        num_mask = np.random.randint(low_bound, len_to_consider + 1) #随机target seq被mask的数量
        # num_mask = int(0.15 * len_to_consider)

        id_mask = np.arange(len_to_consider)
        np.random.shuffle(id_mask)

        id_mask = sorted(id_mask[:num_mask]) # 随机打乱并选出去的num_mask个需要mask的id in target seq

        for i, w in enumerate(tgt_wordpiece_list):
            # 这里还要注意不能把cls和sep给mask掉了, mlm loss target设置成pad，这样就不用计算这个位置了
            if self.add_to_input:
                if i in id_mask and i != 0 and i != len_to_consider - 1:#[0, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 15, 16, 17, 18, 19, 20]
                    target[i] = self.mask_word(w) # 0.8被mask(id 103),0.1随机替换,0.1不变
                else:
                    output[i] = self.dictionary.pad()
            else:
                if i in id_mask:#[0, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 15, 16, 17, 18, 19, 20]
                    target[i] = self.mask_word(w) # 0.8被mask(id 103),0.1随机替换,0.1不变
                else:
                    output[i] = self.dictionary.pad()
            ###### debug ########  debug当全都mask时，wer30%的model在validation时也predict不出来有效的结果
            # target[i] = self.dictionary.mask_index
            ###### debug ########
        item['target']=torch.LongTensor(target) # BERT decoder的输入
        item['output']=torch.LongTensor(output) # MLM loss 的target
        item['origin_target']=torch.LongTensor(origin_target) # CTC LOSS的target
        item['origin_bert_target']=torch.LongTensor(origin_bert_target) # BERT计算WER时的ground truth
        return item

    def size(self, index):
        sz = self.dataset.size(index)
        own_sz = len(self.get_label(index))
        return (sz, own_sz)

    def collater(self, samples):
        collated = self.dataset.collater(samples)
        if len(collated) == 0:
            return collated
        indices = set(collated["id"].tolist())
        target = [s["target"] for s in samples if s["id"] in indices]
        output = [s["output"] for s in samples if s["id"] in indices]
        origin_target = [s["origin_target"] for s in samples if s["id"] in indices]
        origin_bert_target = [s["origin_bert_target"] for s in samples if s["id"] in indices]

        if self.batch_targets:
            collated["target_lengths"] = torch.LongTensor([len(t) for t in origin_target]) # means origin_target's length
            target = data_utils.collate_tokens(target, pad_idx=self.pad, left_pad=False)
            collated["ntokens"] = collated["target_lengths"].sum().item()
            output = data_utils.collate_tokens(output, pad_idx=self.pad, left_pad=False)
            origin_target = data_utils.collate_tokens(origin_target, pad_idx=self.character_pad, left_pad=False)
            origin_bert_target = data_utils.collate_tokens(origin_bert_target, pad_idx=self.pad, left_pad=False)
        else:
            collated["ntokens"] = sum([len(t) for t in target])

        collated["nsentences"] = len(samples)
        collated["target"] = output # BERT的output的ground truth
        collated["origin_target"] = origin_target # wav2vec CTC的ground truth
        collated["origin_bert_target"] = origin_bert_target # BERT计算WER时的ground truth
        collated["net_input"]["prev_output_tokens"] = target # BERT的mask输入
        return collated


class Wav2BertDatasetDifferentTokens(BaseWrapperDataset):
    def __init__(
        self,
        dataset,
        wrd_labels,
        ltr_labels,
        pad,
        eos,
        batch_targets,
        add_to_input=False,
        bert_model_name='bert_english',
        dictionary=None,
        no_mask = False,
    ):
        super().__init__(dataset)
        self.wrd_labels = wrd_labels
        self.ltr_labels = ltr_labels
        self.batch_targets = batch_targets
        self.pad = pad
        self.eos = eos
        self.add_to_input = add_to_input
        self.berttokenizer = BertTokenizer.from_pretrained(bert_model_name)
        self.dictionary=dictionary
        self.mask_low_radio = 0
        self.no_mask = no_mask

    def set_mask_low_radio(self, low_radio):
        self.mask_low_radio = low_radio

    def get_label(self, index):
        tgt_character_list = self.berttokenizer.encode_line(self.ltr_labels[index], post_proces='letter')

        # 这里为了适应bert的训练数据我们加上最后的句号以及cls，sep的符号
        if self.add_to_input:
            tgt_wordpiece_list = self.berttokenizer.encode_line(self.wrd_labels[index] + '.', post_proces='bert_bpe_piece')
            tgt_wordpiece_list.insert(0, self.dictionary.cls())
            tgt_wordpiece_list.append(self.dictionary.eos())
        else:
           tgt_wordpiece_list = self.berttokenizer.encode_line(self.wrd_labels[index], post_proces='bert_bpe_piece') 
        return tgt_character_list, tgt_wordpiece_list

    def mask_word(self, w, p=None):
        voc = self.dictionary

        p = np.random.random() if p is None else p
        if p >= 0.2:
            return voc.mask_index # [MASK] 103
        elif p >= 0.1:
            return np.random.randint(voc.nspecial, len(voc)) # 0->len 随机选一个替换
        else:
            return w # 不被mask

    def __getitem__(self, index):
        item = self.dataset[index]
        tgt_character_list, tgt_wordpiece_list = self.get_label(index)
        origin_target = tgt_character_list.copy()
        origin_bert_target = tgt_wordpiece_list.copy()
        target = tgt_wordpiece_list.copy()
        output = tgt_wordpiece_list.copy()
        len_to_consider = len(tgt_wordpiece_list)
        
        # 如果不进行mask就直接返回item，此时bert的输入和mlm loss的target都是gt
        if self.no_mask:
            item['target']=torch.LongTensor(target) # BERT decoder的输入
            item['output']=torch.LongTensor(output) # MLM loss 的target
            item['origin_target']=torch.LongTensor(origin_target) # CTC LOSS的target
            item['origin_bert_target']=torch.LongTensor(origin_bert_target) # BERT计算WER时的ground truth
            return item

        low_bound = min(1, int(self.mask_low_radio*len(tgt_wordpiece_list)))
        num_mask = np.random.randint(low_bound, len_to_consider + 1) #随机target seq被mask的数量
        # num_mask = int(0.15 * len_to_consider)

        id_mask = np.arange(len_to_consider)
        np.random.shuffle(id_mask)

        id_mask = sorted(id_mask[:num_mask]) # 随机打乱并选出去的num_mask个需要mask的id in target seq

        for i, w in enumerate(tgt_wordpiece_list):
            # 这里还要注意不能把cls和sep给mask掉了, mlm loss target设置成pad，这样就不用计算这个位置了
            if self.add_to_input:
                if i in id_mask and i != 0 and i != len_to_consider - 1:#[0, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 15, 16, 17, 18, 19, 20]
                    target[i] = self.mask_word(w) # 0.8被mask(id 103),0.1随机替换,0.1不变
                else:
                    output[i] = self.dictionary.pad()
            else:
                if i in id_mask:#[0, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 15, 16, 17, 18, 19, 20]
                    target[i] = self.mask_word(w) # 0.8被mask(id 103),0.1随机替换,0.1不变
                else:
                    output[i] = self.dictionary.pad()
            ###### debug ########  debug当全都mask时，wer30%的model在validation时也predict不出来有效的结果
            # target[i] = self.dictionary.mask_index
            ###### debug ########
        item['target']=torch.LongTensor(target) # BERT decoder的输入
        item['output']=torch.LongTensor(output) # MLM loss 的target
        item['origin_target']=torch.LongTensor(origin_target) # CTC LOSS的target
        item['origin_bert_target']=torch.LongTensor(origin_bert_target) # BERT计算WER时的ground truth
        return item

    def size(self, index):
        sz = self.dataset.size(index)
        own_sz = len(self.get_label(index))
        return (sz, own_sz)

    def collater(self, samples):
        collated = self.dataset.collater(samples)
        if len(collated) == 0:
            return collated
        indices = set(collated["id"].tolist())
        target = [s["target"] for s in samples if s["id"] in indices]
        output = [s["output"] for s in samples if s["id"] in indices]
        origin_target = [s["origin_target"] for s in samples if s["id"] in indices]
        origin_bert_target = [s["origin_bert_target"] for s in samples if s["id"] in indices]

        if self.batch_targets:
            collated["target_lengths"] = torch.LongTensor([len(t) for t in origin_target]) # means origin_target's length
            target = data_utils.collate_tokens(target, pad_idx=self.pad, left_pad=False)
            collated["ntokens"] = collated["target_lengths"].sum().item()
            output = data_utils.collate_tokens(output, pad_idx=self.pad, left_pad=False)
            origin_target = data_utils.collate_tokens(origin_target, pad_idx=self.pad, left_pad=False)
            origin_bert_target = data_utils.collate_tokens(origin_bert_target, pad_idx=self.pad, left_pad=False)
        else:
            collated["ntokens"] = sum([len(t) for t in target])

        collated["nsentences"] = len(samples)
        collated["target"] = output # BERT的output的ground truth
        collated["origin_target"] = origin_target # wav2vec CTC的ground truth
        collated["origin_bert_target"] = origin_bert_target # BERT计算WER时的ground truth
        collated["net_input"]["prev_output_tokens"] = target # BERT的mask输入
        return collated


class Wav2BertDataset(BaseWrapperDataset):
    def __init__(
        self,
        dataset,
        labels,
        pad,
        eos,
        batch_targets,
        add_to_input=False,
        bert_model_name='bert_english',
        dictionary=None,
        tokenizer_process='bert_bpe_piece',
    ):
        super().__init__(dataset)
        self.labels = labels
        self.batch_targets = batch_targets
        self.pad = pad
        self.eos = eos
        self.add_to_input = add_to_input
        self.berttokenizer = BertTokenizer.from_pretrained(bert_model_name)
        self.dictionary=dictionary
        self.tokenizer_process = tokenizer_process

    def get_label(self, index):
        # 这里为了适应bert的训练数据我们加上最后的句号以及cls，sep的符号
        if self.add_to_input:
            tgt_wordpiece_list = self.berttokenizer.encode_line(self.labels[index] + '.', post_proces=self.tokenizer_process)
            tgt_wordpiece_list.insert(0, self.dictionary.cls())
            tgt_wordpiece_list.append(self.dictionary.eos())
            return tgt_wordpiece_list
        else:
            return (
                self.berttokenizer.encode_line(self.labels[index], post_proces=self.tokenizer_process)
            )
        
        

    def mask_word(self, w, p=None):
        voc = self.dictionary

        p = np.random.random() if p is None else p
        if p >= 0.2:
            return voc.mask_index # [MASK] 103
        elif p >= 0.1:
            return np.random.randint(voc.nspecial, len(voc)) # 0->len 随机选一个替换
        else:
            return w # 不被mask

    def __getitem__(self, index):
        item = self.dataset[index]
        tgt_list = self.get_label(index)
        origin_target = tgt_list.copy()
        origin_bert_target = tgt_list.copy()
        target = tgt_list.copy()
        output = tgt_list.copy()
        len_to_consider = len(tgt_list)
        num_mask = np.random.randint(1, len_to_consider + 1) #随机target seq被mask的数量
        id_mask = np.arange(len_to_consider)
        np.random.shuffle(id_mask)

        id_mask = sorted(id_mask[:num_mask]) # 随机打乱并选出去的num_mask个需要mask的id in target seq

        for i, w in enumerate(tgt_list):
            # 当添加了符号之后，这里还要注意不能把cls和sep给mask掉了, mlm loss target设置成pad，这样就不用计算这个位置了
            if (i in id_mask) and ((not self.add_to_input) or (i != 0 and i != len_to_consider - 1)):#[0, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 15, 16, 17, 18, 19, 20]
                target[i] = self.mask_word(w) # 0.8被mask(id 103),0.1随机替换,0.1不变
            else:
                output[i] = self.dictionary.pad()
            ###### debug ########  debug当全都mask时，wer30%的model在validation时也predict不出来有效的结果
            # target[i] = self.dictionary.mask_index
            ###### debug ########
        item['target']=torch.LongTensor(target) # BERT decoder的输入
        item['output']=torch.LongTensor(output) # MLM loss 的target
        if self.add_to_input:
            item['origin_target']=torch.LongTensor(origin_target[1:-2]) # CTC LOSS的target, 将cls和sep以及句号去掉
        else:
            item['origin_target']=torch.LongTensor(origin_target)
        item['origin_bert_target']=torch.LongTensor(origin_bert_target)
        return item

    def size(self, index):
        sz = self.dataset.size(index)
        own_sz = len(self.get_label(index))
        return (sz, own_sz)

    def collater(self, samples):
        collated = self.dataset.collater(samples)
        if len(collated) == 0:
            return collated
        indices = set(collated["id"].tolist())
        target = [s["target"] for s in samples if s["id"] in indices]
        output = [s["output"] for s in samples if s["id"] in indices]
        origin_target = [s["origin_target"] for s in samples if s["id"] in indices]
        origin_bert_target = [s["origin_bert_target"] for s in samples if s["id"] in indices]

        if self.batch_targets:
            collated["target_lengths"] = torch.LongTensor([len(t) for t in origin_target]) # 这里的target_length主要用到的还是ctc计算，因此还是用origin的来计算
            target = data_utils.collate_tokens(target, pad_idx=self.pad, left_pad=False)
            collated["ntokens"] = collated["target_lengths"].sum().item()
            output = data_utils.collate_tokens(output, pad_idx=self.pad, left_pad=False)
            origin_target = data_utils.collate_tokens(origin_target, pad_idx=self.pad, left_pad=False)
            origin_bert_target = data_utils.collate_tokens(origin_bert_target, pad_idx=self.pad, left_pad=False)
        else:
            collated["ntokens"] = sum([len(t) for t in target])

        collated["nsentences"] = len(samples)
        collated["target"] = output
        collated["origin_target"] = origin_target
        collated["net_input"]["prev_output_tokens"] = target
        collated["origin_bert_target"] = origin_bert_target # BERT计算WER时的ground truth
        return collated

class Wav2BertDatasetV2WithoutMask(BaseWrapperDataset):
    def __init__(
        self,
        dataset,
        labels,
        pad,
        eos,
        batch_targets,
        add_to_input=False,
        bert_model_name='bert_english',
        dictionary=None,
        tokenizer_process='bert_bpe_piece',
        character_tokenizer=None,
    ):
        super().__init__(dataset)
        self.labels = labels
        self.batch_targets = batch_targets
        self.pad = pad
        self.eos = eos
        self.add_to_input = add_to_input
        self.berttokenizer = BertTokenizer.from_pretrained(bert_model_name)
        self.dictionary=dictionary
        self.tokenizer_process = tokenizer_process
        self.character_pad = character_tokenizer.dictionary.pad()
        self.character_tokenizer = character_tokenizer
    
    def get_label(self, index):
        return (
            self.character_tokenizer(self.labels[index]).tolist()
        )

    def __getitem__(self, index):
        # import ipdb; ipdb.set_trace()
        item = self.dataset[index]
        tgt_list = self.get_label(index)
        item['target']=torch.LongTensor(tgt_list)
        return item

    def size(self, index):
        sz = self.dataset.size(index)
        own_sz = len(self.get_label(index))
        return (sz, own_sz)

    def collater(self, samples):
        # import ipdb; ipdb.set_trace()
        collated = self.dataset.collater(samples)
        if len(collated) == 0:
            return collated
        indices = set(collated["id"].tolist())
        target = [s["target"] for s in samples if s["id"] in indices]

        if self.batch_targets:
            collated["target_lengths"] = torch.LongTensor([len(t) for t in target])
            target = data_utils.collate_tokens(target, pad_idx=self.pad, left_pad=False)
            collated["ntokens"] = collated["target_lengths"].sum().item()
        else:
            collated["ntokens"] = sum([len(t) for t in target])

        collated["nsentences"] = len(samples)
        collated["target"] = target
        prev_output_tokens = target.new(target.shape).fill_(self.dictionary.mask_index)
        # print("target: ", target)
        # print("prev_output_tokens: ", prev_output_tokens)
        # collated["net_input"]["prev_output_tokens"] = prev_output_tokens
        collated["net_input"]["prev_output_tokens"] = prev_output_tokens
        
        return collated

class Wav2BertDatasetWithoutMask(BaseWrapperDataset):
    def __init__(
        self,
        dataset,
        labels,
        pad,
        eos,
        batch_targets,
        add_to_input=False,
        bert_model_name='bert_english',
        dictionary=None,
        tokenizer_process='bert_bpe_piece',
    ):
        super().__init__(dataset)
        self.labels = labels
        self.batch_targets = batch_targets
        self.pad = pad
        self.eos = eos
        self.add_to_input = add_to_input
        self.berttokenizer = BertTokenizer.from_pretrained(bert_model_name)
        self.dictionary=dictionary
        self.tokenizer_process = tokenizer_process
    
    def get_label(self, index):
        return (
            self.berttokenizer.encode_line(self.labels[index], post_proces=self.tokenizer_process)
        )

    def __getitem__(self, index):
        # import ipdb; ipdb.set_trace()
        item = self.dataset[index]
        tgt_list = self.get_label(index)
        item['target']=torch.LongTensor(tgt_list)
        return item

    def size(self, index):
        sz = self.dataset.size(index)
        own_sz = len(self.get_label(index))
        return (sz, own_sz)

    def collater(self, samples):
        # import ipdb; ipdb.set_trace()
        collated = self.dataset.collater(samples)
        if len(collated) == 0:
            return collated
        indices = set(collated["id"].tolist())
        target = [s["target"] for s in samples if s["id"] in indices]

        if self.batch_targets:
            collated["target_lengths"] = torch.LongTensor([len(t) for t in target])
            target = data_utils.collate_tokens(target, pad_idx=self.pad, left_pad=False)
            collated["ntokens"] = collated["target_lengths"].sum().item()
        else:
            collated["ntokens"] = sum([len(t) for t in target])

        collated["nsentences"] = len(samples)
        collated["target"] = target
        prev_output_tokens = target.new(target.shape).fill_(self.dictionary.mask_index)
        # print("target: ", target)
        # print("prev_output_tokens: ", prev_output_tokens)
        # collated["net_input"]["prev_output_tokens"] = prev_output_tokens
        collated["net_input"]["prev_output_tokens"] = prev_output_tokens
        
        return collated


class OnlyBertDataset(BaseWrapperDataset):
    def __init__(
        self,
        dataset,
        labels,
        pad,
        eos,
        batch_targets,
        add_to_input=False,
        bert_model_name='bert_english',
        dictionary=None,
        tokenizer_process='bert_bpe_piece',
        mask_type='origin'
    ):
        super().__init__(dataset)
        self.labels = labels
        self.batch_targets = batch_targets
        self.pad = pad
        self.eos = eos
        self.add_to_input = add_to_input
        self.berttokenizer = BertTokenizer.from_pretrained(bert_model_name)
        self.dictionary=dictionary
        self.tokenizer_process = tokenizer_process
        self.mask_type = mask_type

    def get_label(self, index):
        # 这里为了适应bert的训练数据我们加上最后的句号以及cls，sep的符号
        if self.add_to_input:
            tgt_wordpiece_list = self.berttokenizer.encode_line(self.labels[index] + '.', post_proces=self.tokenizer_process)
            tgt_wordpiece_list.insert(0, self.dictionary.cls())
            tgt_wordpiece_list.append(self.dictionary.eos())
            if len(tgt_wordpiece_list)  > 512:
                print("err")
                print(tgt_wordpiece_list)
                tgt_wordpiece_list = tgt_wordpiece_list[:512]
            return tgt_wordpiece_list
        else:
            return (
                self.berttokenizer.encode_line(self.labels[index], post_proces=self.tokenizer_process)
            )
        
        

    def mask_word(self, w, p=None):
        voc = self.dictionary

        p = np.random.random() if p is None else p
        if p >= 0.2:
            return voc.mask_index # [MASK] 103
        elif p >= 0.1:
            return np.random.randint(voc.nspecial, len(voc)) # 0->len 随机选一个替换
        else:
            return w # 不被mask

    def __getitem__(self, index):
        item = self.dataset[index]
        tgt_list = self.get_label(index)
        origin_target = tgt_list.copy()
        target = tgt_list.copy()
        output = tgt_list.copy()
        len_to_consider = len(tgt_list)
        if self.mask_type == 'origin':
            num_mask = int(0.15 * len_to_consider) # 原本bert的论文中只mask 15%的词语
        elif self.mask_type == 'new_paper':
            num_mask = min(4, int(0.15 * len_to_consider))
        else:
            raise ValueError("{} not support".format(self.mask_type))
        id_mask = np.arange(len_to_consider)
        np.random.shuffle(id_mask)

        id_mask = sorted(id_mask[:num_mask]) # 随机打乱并选出去的num_mask个需要mask的id in target seq

        for i, w in enumerate(tgt_list):
            # 当添加了符号之后，这里还要注意不能把cls和sep给mask掉了, mlm loss target设置成pad，这样就不用计算这个位置了
            if (i in id_mask) and ((not self.add_to_input) or (i != 0 and i != len_to_consider - 1)):#[0, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 15, 16, 17, 18, 19, 20]
                target[i] = self.mask_word(w) # 0.8被mask(id 103),0.1随机替换,0.1不变
            else:
                output[i] = self.dictionary.pad()
        item['target']=torch.LongTensor(target) # BERT decoder的输入
        item['output']=torch.LongTensor(output) # MLM loss 的target
        if self.add_to_input:
            item['origin_target']=torch.LongTensor(origin_target) # CTC LOSS的target, 将cls和sep以及句号去掉
        else:
            item['origin_target']=torch.LongTensor(origin_target)
        return item

    def size(self, index):
        sz = self.dataset.size(index)
        own_sz = len(self.get_label(index))
        return (sz, own_sz)

    def collater(self, samples):
        collated = self.dataset.collater(samples)
        if len(collated) == 0:
            return collated
        indices = set(collated["id"].tolist())
        target = [s["target"] for s in samples if s["id"] in indices]
        output = [s["output"] for s in samples if s["id"] in indices]
        origin_target = [s["origin_target"] for s in samples if s["id"] in indices]

        if self.batch_targets:
            collated["target_lengths"] = torch.LongTensor([len(t) for t in origin_target]) # 这里的target_length主要用到的还是ctc计算，因此还是用origin的来计算
            target = data_utils.collate_tokens(target, pad_idx=self.pad, left_pad=False)
            collated["ntokens"] = collated["target_lengths"].sum().item()
            output = data_utils.collate_tokens(output, pad_idx=self.pad, left_pad=False)
            origin_target = data_utils.collate_tokens(origin_target, pad_idx=self.pad, left_pad=False)
        else:
            collated["ntokens"] = sum([len(t) for t in target])

        collated["nsentences"] = len(samples)
        collated["target"] = output
        collated["origin_target"] = origin_target
        collated["net_input"]["prev_output_tokens"] = target
        return collated


