import json
import os
from torch.utils.data import Dataset

# dataset for discriminative model
class language_dataset(Dataset):
    '''
    preprocessed data:
    {'text1': xxx, 'text2': xxx (or None), 'label': 0/1(/2), 'task': xxx}
    example:
    {
        "text1": "How is the life of a math student? Could you describe your own experiences?",
        "text2": "Which level of prepration is enough for the exam jlpt5?",
        "label": 0,
        "tasks": "qqp"
    }
    '''
    def __init__(self, ann_file, file_root=''):
        # save all data
        self.ann = []
        for ann_f in ann_file:
            with open(os.path.join(file_root, ann_f)) as f:
                self.ann += json.load(f)

        print("[INFO] ann_file:", ann_file, "file_root:", file_root, "length:", len(self.ann))

    def __len__(self):
        return len(self.ann)

    def __getitem__(self, index):
        ann = self.ann[index]
        text1 = ann['text1']
        text2 = ann['text2'] if ann['text2'] is not None else ''

        label = ann['label']
        return text1, text2, label

if __name__ == '__main__':
    datasets = language_dataset(['glue_data/CoLA/train.json'], '.')
    text1, text2, label = datasets[0]
