import json
import os
import numpy as np
logiqa_data_dir = '../../logiqa-data/'
answer_list = ['A', 'B', 'C', 'D', 'E']
answer_label_list = [0, 1, 2, 3, 4]


def warp_json(datatype='train', sufix='_logi'):
    if datatype == 'train':
        lr_file_name = 'train' + sufix +'.json'
        wrap_lr_file_name = 'wrap_train' + sufix + '.json'
    elif datatype == 'dev':
        lr_file_name = 'val' + sufix +'.json'
        wrap_lr_file_name = 'wrap_val' + sufix +'.json'
    else:
        lr_file_name = 'test' + sufix +'.json'
        wrap_lr_file_name = 'wrap_test' + sufix +'.json'

    with open(os.path.join(logiqa_data_dir, lr_file_name), "r") as f:
        lines = json.load(f)

    dump_data = {'data': lines}
    with open(os.path.join(logiqa_data_dir, wrap_lr_file_name), 'w') as w_f:
        json.dump(dump_data, w_f)


answer_strs = ['a', 'b', 'c', 'd']
def process_LogiQA_data(datatype='train'):
    if datatype == 'train':
        file_name = 'Train.txt'
        write_file_name = 'train_logi.json'
    elif datatype == 'dev':
        file_name = 'Eval.txt'
        write_file_name = 'val_logi.json'
    else:
        file_name = 'Test.txt'
        write_file_name = 'test_logi.json'

    all_data = list()

    with open(os.path.join(logiqa_data_dir, file_name), 'r', encoding="utf-8") as r_f:
        lines = r_f.readlines()
        print(len(lines), len(lines)/8)

        i = 0
        while i < len(lines):
            if i % 8 == 0:
                cur_instance = dict()
                cur_instance['id_string'] = datatype + "_" + str(i//8)
                cur_instance['label'] = answer_strs.index(lines[i+1].strip())
                cur_instance['context'] = lines[i+2].strip()
                cur_instance['question'] = lines[i+3].strip()
                cur_instance['answers'] = list()
                for j in range(4):
                    cur_instance['answers'].append(lines[i+4+j].strip()[2:])

                all_data.append(cur_instance)
                i += 8
            else:
                print('*'*100)

    print(len(all_data))
    print(all_data[-1])

    with open(os.path.join(logiqa_data_dir, write_file_name), 'w') as w_f:
        json.dump(all_data, w_f)


def enrich_logiqa_json(datatype='train'):
    if datatype == 'train':
        wrap_lr_file_name = 'wrap_train_logi.json'
        enrich_lr_file_name = 'enrich_train_logi.json'
        extend_context_file = "train_extended_context_cp_v6.npy"
        negative_context_file = "train_negative_context_cp_v19.npy"
        negative_extend_context_file = "train_extended_context_cp_v196.npy"
    elif datatype == 'dev':
        wrap_lr_file_name = 'wrap_val_logi.json'
        enrich_lr_file_name = 'enrich_val_logi.json'
        extend_context_file = "val_extended_context_cp_v6.npy"
        negative_context_file = "val_negative_context_cp_v19.npy"
        negative_extend_context_file = "val_extended_context_cp_v196.npy"
    else:
        wrap_lr_file_name = 'wrap_test_logi.json'
        enrich_lr_file_name = 'enrich_test_logi.json'
        extend_context_file = "test_extended_context_cp_v6.npy"
        negative_context_file = "test_negative_context_cp_v19.npy"
        negative_extend_context_file = "test_extended_context_cp_v196.npy"

    extend_contexts = np.load(os.path.join(logiqa_data_dir, extend_context_file), allow_pickle=True)
    negative_contexts = np.load(os.path.join(logiqa_data_dir, negative_context_file), allow_pickle=True)
    negative_extend_contexts = np.load(os.path.join(logiqa_data_dir, negative_extend_context_file), allow_pickle=True)

    with open(os.path.join(logiqa_data_dir, wrap_lr_file_name), "r") as f:
        load_data = json.load(f)

    print(len(load_data['data']), len(extend_contexts), len(negative_contexts), len(negative_extend_contexts))

    for i in range(len(load_data['data'])):
        label = load_data['data'][i]['label']
        context = load_data['data'][i]['context']
        answers = load_data['data'][i]['answers']

        # if label >= 2: # reverse
        if label < 2:
            contras_contexts = [negative_contexts[i][0], context]
            contras_label = 1
            contras_extend_contexts = [negative_extend_contexts[i][0], extend_contexts[i][label]]
        else:
            contras_contexts = [context, negative_contexts[i][0]]
            contras_label = 0
            contras_extend_contexts = [extend_contexts[i][label], negative_extend_contexts[i][0]]

        load_data['data'][i]['extend_contexts'] = [extend_contexts[i][0], extend_contexts[i][1], extend_contexts[i][2], extend_contexts[i][3]]
        load_data['data'][i]['contras_contexts'] = contras_contexts
        load_data['data'][i]['contras_label'] = contras_label
        load_data['data'][i]['contras_endings'] = [answers[label], answers[label]]
        load_data['data'][i]['contras_extend_context'] = contras_extend_contexts

    with open(os.path.join(logiqa_data_dir, enrich_lr_file_name), 'w') as w_f:
        json.dump(load_data, w_f)


if __name__ == '__main__':
    process_LogiQA_data(datatype='train')
    process_LogiQA_data(datatype='dev')
    process_LogiQA_data(datatype='test')

    warp_json(datatype='train', sufix='_logi')
    warp_json(datatype='dev', sufix='_logi')
    warp_json(datatype='test', sufix='_logi')

    enrich_logiqa_json(datatype='train')
    enrich_logiqa_json(datatype='dev')
    enrich_logiqa_json(datatype='test')
