import json
import os
from torch.utils.data import Dataset

# dataset for discriminative model
class language_multitask_dataset(Dataset):
    '''
    preprocessed text2text data:
    {'input': xxx, 'target': xxx, 'task': xxx}
    example:
    {
        "input": "qqp quertion1: How is the life of a math student? Could you describe your own experiences? quertion2: Which level of prepration is enough for the exam jlpt5?",
        "target": "not duplicates",
        "task": "qqp"
    }
    '''
    def __init__(self, ann_file, file_root):
        self.ann = []
        for ann_f in ann_file:
            with open(os.path.join(file_root, ann_f), encoding='utf-8') as f:
                self.ann += json.load(f)

        print("[INFO] ann_file:", ann_file)
        print("[INFO] length:", len(self.ann))

    def __len__(self):
        return len(self.ann)

    def __getitem__(self, index):
        data = self.ann[index]
        text1 = data['text1']
        text2 = data['text2'] if data['text2'] != None else ''
        label = data['label']
        try:
            task = data['tasks']
        except:
            task = data['task']
        # return task information
        return text1, text2, label, task


if __name__ == '__main__':
    datasets = language_multitask_dataset(['glue_data/CoLA/train.json',
                                           'glue_data/MNLI/train.json',
                                           'glue_data/MRPC/train.json',
                                           'glue_data/QNLI/train.json',
                                           'glue_data/QQP/train.json',
                                           'glue_data/RTE/train.json',
                                           'glue_data/SST-2/train.json',
                                           'glue_data/STS-B/train.json',
                                           'glue_data/WNLI/train.json', ])
    print(datasets[-10])

