import ujson

from datasets import Dataset, DatasetDict


def load_data(filenames: list):
    data_wrapper = {}
    for filename in filenames:
        instances = {'text': [], 'label': []}
        with open(f'./data/{filename}.json', 'r', encoding='utf-8') as f:
            data = ujson.load(f)

        for item in data:
            text = item['text']
            label = item['label']
            instances['text'].append(text)
            instances['label'].append(label)
        data_wrapper[filename] = Dataset.from_dict(instances)

    dataset = DatasetDict(data_wrapper)
    return dataset


if __name__ == '__main__':
    '''
    To test, run this file in the root folder.
    `python ./utils/dataloader.py`
    '''
    dataset = load_data(filenames=['train', 'test'])
    print(dataset['test'][1])
