import json, os, argparse
import sys
# from settings import parse_args
import random
import shutil
random.seed(0)

K=100
# task_names = ['ag','dbpedia','yelp','yahoo','amazon']
# task_names = ['srl','sst']
# task_names = ['squad']
# task_names = ['woz.en']
task_names = ['wikisql','woz.en','srl','sst']
ori_dir = 'lamol_data'
out_dir = '../../DATA/LAMOL_DIVIDED'

def data_split(data,datatype='train'):
    newdata = []
    for i in range(len(data)):
        paragraphs = data[i]['paragraphs']
        for j in range(len(paragraphs)):
            sample = paragraphs[j]
            raw_context = sample['context']
            # qa = sample['qas'][0]
            for qa in sample['qas']:
            # for _ in range(1):
                sample_dict = {}
                qa_list = []
                question = qa['question']
                raw_answers = qa['answers']
                id = qa['id']

                answer = ''
                for i, ans in enumerate(raw_answers):
                    # print(ans)
                    # if datatype == 'train':
                        # answer.append(ans['text'])
                    answer += ans['text']
                    # else:
                    #     answer.append(ans['text'])
                    # if datatype=='train': 
                    #     break
                    if i != len(raw_answers)-1 and datatype == 'train':
                        answer += '<pad>'

                sample_dict['context'] = raw_context
                qa_list.append({'question':question, 'answers':[{'text':answer}],'id':id})
                sample_dict['qas'] = qa_list
                newdata.append({'paragraphs':[sample_dict]}) 
    return newdata

for task in task_names:
    dir_path = os.path.join(out_dir,'decaNLP','label_'+str(K),task)
    # dir_path = os.path.join(out_dir,'TC','label_'+str(K),task)
    if not os.path.exists(task) and not os.path.isdir(dir_path):
        os.makedirs(dir_path)
    
    #* Convert train-data
    train_file = os.path.join(ori_dir,task+'_to_squad-train-v2.0.json')
    # train_file = os.path.join(ori_dir,task+'-train-v1.1.json') # For SQuAD dataset
    with open(train_file,'r', encoding='utf-8') as f:
        data = json.load(f)['data'] # list of dicts
    data = data_split(data,datatype='train') #! For squad and woz.en datasets
    print(f'Number of all splited data is {len(data)}') 

    train_label_outfile = os.path.join(dir_path,'label_train.json') 
    train_unlabel_outfile = os.path.join(dir_path,'unlabel_train.json') 
    label_idx_list = random.sample(range(len(data)), K)
    label_data = [data[i] for i in range(len(data)) if i in label_idx_list]
    print(f'Number of labeled data is {len(label_data)}')

    unlabel_data = [data[i] for i in range(len(data)) if i not in label_idx_list]
    print(f'Number of unlabeled data is {len(unlabel_data)}')

    with open(train_label_outfile,'w', encoding='utf-8') as fw:
        label_res = {'data':label_data}
        print(json.dumps(label_res, ensure_ascii=False), file=fw)

    with open(train_unlabel_outfile,'w', encoding='utf-8') as fw:
        unlabel_res = {'data':unlabel_data}
        print(json.dumps(unlabel_res, ensure_ascii=False), file=fw)

    #* Convert test-data
    # test_file = os.path.join(ori_dir,task+'-dev-v1.1.json') # For SQuAD dataset
    # test_file = os.path.join(ori_dir,task+'_to_squad-test-v2.0.json')
    # test_outfile = os.path.join(dir_path,'test.json')
    # with open(test_file,'r', encoding='utf-8') as f:
    #     test_data = json.load(f)['data'] # list of dicts
    # test_data = data_split(test_data,datatype='test') 
    # with open(test_outfile,'w', encoding='utf-8') as fw:
    #     test_res = {'data':test_data}
    #     print(json.dumps(test_res, ensure_ascii=False), file=fw)
    # print(f'Number of test data is {len(test_data)}')
    # shutil.copyfile(test_file, test_outfile)
    # * Only for WOZ data of goal oriented dialogue systems.
    # test_ans_file = os.path.join(ori_dir,task+'_answers.json')
    # test_ans_outfile = os.path.join(dir_path, 'answers.json')
    # shutil.copyfile(test_ans_file, test_ans_outfile)
    
    print('Finishing dealing with ',task,flush=True)

