import os
import json
import sys

K=100 # Number of labeled data.
ori_dir = '../../DATA/LAMOL_DIVIDED/decaNLP/label_'+str(K)
mydir = '../../DATA/MYDATA_DIVIDED/decaNLP/label_'+str(K) # For semi-supervised learning.
# ori_dir = '../../DATA/LAMOL_DATA/'
# mydir = '../../DATA/MYDATA/decaNLP/' # For supervised learning upper bound.
# ori_dir = '../../DATA/LAMOL_DIVIDED/TC/label_'+str(K)
# mydir = '../../DATA/MYDATA_DIVIDED/TC/label_'+str(K)
# task_names = ['ag','dbpedia','yelp','yahoo','amazon']
task_names = ['woz.en', 'srl','sst','wikisql','squad']
# task_names = ['woz.en']
# task_names = ['wikisql']
# task_names = ['squad']

def sample_convert(sample, datatype='train'):
    paragraphs = sample['paragraphs']
    res_list = []
    for j in range(len(paragraphs)):
        sample = paragraphs[j]
        raw_context = sample['context']
        for qa in sample['qas']:
            sample_dict = {}
            question = qa['question']
            x = raw_context + ' ' + question

            raw_answers = qa['answers']
            answer = []
            for i, ans in enumerate(raw_answers):
                # print(ans)
                if datatype == 'train':
                    # answer.append(ans['text'])
                    answer += [ans['text']]
                else:
                    answer.append(ans['text'])
                # if datatype=='train': 
                #     break
                if i != len(raw_answers)-1 and datatype == 'train':
                    answer += ['<pad>']
            y = answer

            sample_dict['input'] = x
            sample_dict['output'] = y
            sample_dict['question'] = question
            sample_dict['text_wo_question'] = raw_context
            if 'woz' or 'wikisql' or 'squad' in task_names[0]:
                sample_dict['id'] = qa['id']
            res_list.append(sample_dict)
    return res_list

for task in task_names:
    dir_path = os.path.join(mydir, task)
    if not os.path.exists(task) and not os.path.isdir(dir_path):
        os.makedirs(dir_path)
    
    #* Convert labeled train-data
    label_file = os.path.join(ori_dir,task,'label_train.json')
    with open(label_file,'r', encoding='utf-8') as f:
        label_data = json.load(f)['data']
    
    label_outfile = os.path.join(dir_path,'label_train.json')
    with open(label_outfile,'w', encoding='utf-8') as fw:
        for sample in label_data:
            sample_dict = sample_convert(sample)[0]
            print(json.dumps(sample_dict, ensure_ascii=False), file=fw)
    
    # #* Convert unlabeled train-data
    unlabel_file = os.path.join(ori_dir,task,'unlabel_train.json')
    with open(unlabel_file, 'r', encoding='utf-8') as f:
        unlabel_data = json.load(f)['data']
    
    unlabel_outfile = os.path.join(dir_path, 'unlabel_train.json')
    with open(unlabel_outfile,'w', encoding='utf-8') as fw:
        for sample in unlabel_data:
            sample_dict = sample_convert(sample)[0]
            print(json.dumps(sample_dict, ensure_ascii=False), file=fw)
    # * Convert all train-data
    # train_file = os.path.join(ori_dir,task+'_to_squad-train-v2.0.json')
    # train_file = os.path.join(ori_dir,task+'-train-v1.1.json') # For SQuAD dataset
    # with open(train_file,'r', encoding='utf-8') as f:
    #     data = json.load(f)['data'] # list of dicts
    # train_outfile = os.path.join(dir_path,'train.json')
    # with open(train_outfile, 'w', encoding='utf-8') as fw:
    #     train_dict_list = []
    #     for sample in data:
    #         train_dict_list += sample_convert(sample)
    #     print(f'Number of train data is {len(train_dict_list)}') 
    #     for sample_dict in train_dict_list:
    #         print(json.dumps(sample_dict, ensure_ascii=False), file=fw)

    #* Convert test-data
    # test_file = os.path.join(ori_dir,task,'test.json')
    # test_file = os.path.join(ori_dir,task+'-dev-v1.1.json') # For SQuAD dataset
    # test_file = os.path.join(ori_dir,task+'_to_squad-test-v2.0.json')
    # with open(test_file,'r', encoding='utf-8') as f:
    #     test_data = json.load(f)['data']
    # print(f'Number of test data is {len(test_data)}') 

    # test_outfile = os.path.join(dir_path,'test.json')
    # with open(test_outfile,'w', encoding='utf-8') as fw:
    #     all_samples = []
    #     for sample in test_data:
    #         all_samples += sample_convert(sample, datatype='test')
    #     print(f'Number of test data is {len(all_samples)}') 
    #     for sample_dict in all_samples:
    #         print(json.dumps(sample_dict, ensure_ascii=False), file=fw)
    print('Finishing dealing with ', task, flush=True)

