import json
import os
from torch.utils.data import Dataset

# dataset for generative model
class language_text2text_dataset(Dataset):
    '''
    preprocessed text2text data:
    {'input': xxx, 'target': xxx, 'task': xxx}
    example:
    {
        "input": "qqp quertion1: How is the life of a math student? Could you describe your own experiences? quertion2: Which level of prepration is enough for the exam jlpt5?",
        "target": "not duplicates",
        "task": "qqp"
    }
    '''
    def __init__(self, ann_file, file_root, shot=-1, debug=False):
        self.ann = []
        for ann_f in ann_file:
            with open(os.path.join(file_root, ann_f)) as f:
                if debug:
                    self.ann += json.load(f)[:100]
                else:
                    self.ann += json.load(f)

        ## No use, we implement few-shot by changing dataset file
        # if shot > 0:
        #     labels = list(set([item['target'] for item in self.ann]))
        #     ann_map = {key: [] for key in labels}
        #     for item in self.ann:
        #         ann_map[item['target']].append(item)
        #     self.ann = []
        #     for key in ann_map:
        #         self.ann += random.sample(ann_map[key], shot)

        print("[INFO] ann_file:", ann_file, "length:", len(self.ann))

    def __len__(self):
        return len(self.ann)

    def __getitem__(self, index):
        ann = self.ann[index]
        text1 = ann['input']
        text2 = ann['target'] 
        
        return text1, text2