import json
import os
from torch.utils.data import Dataset

# dataset for generative model
class language_text2text_multitask_dataset(Dataset):
    '''
    preprocessed text2text data:
    {'input': xxx, 'target': xxx, 'task': xxx}
    example:
    {
        "input": "qqp quertion1: How is the life of a math student? Could you describe your own experiences? quertion2: Which level of prepration is enough for the exam jlpt5?",
        "target": "not duplicates",
        "task": "qqp"
    }
    '''
    def __init__(self, ann_file, file_root, debug=False):
        self.ann = []
        for ann_f in ann_file:
            with open(os.path.join(file_root, ann_f)) as f:
                # print(os.path.join(file_root, ann_f))
                if debug:
                    self.ann += json.load(f)[:100]
                else:
                    self.ann += json.load(f)

        print("[INFO] ann_file:", ann_file, "length:", len(self.ann))

    def __len__(self):
        return len(self.ann)

    def __getitem__(self, index):
        ann = self.ann[index]
        text1 = ann['input']
        text2 = ann['target']
        # normalize task name like sts-b
        task = ann['task'].replace('-', '')
        # return task name
        return text1, text2, task