import sys
import os

import json
import argparse
from src.data import load_data

def getParser():
    parser = argparse.ArgumentParser(description="parser for arguments", formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    parser.add_argument("--dataset", type=str, help="dataset name", required=True)
    parser.add_argument("--data_path", type=str, help="base dataset directory", default="./data")
    parser.add_argument("--outdir", type=str, help="output directory", required=True)
    return parser

def prepare(params):
    data = load_data(params)
    delim1 = ' '
    triples = {'train':data.train_trips, 'test':data.test_trips, 'valid':data.valid_trips}
    output = {'train':[], 'test':[], 'valid':[]}
    count = 0
    for sp in ['train', 'valid', 'test']:
        with open(os.path.join(params.outdir, '%s_sentences.json'%sp), 'w') as fout:
            for head, rel, tail in triples[sp]:
                head = data.id2ent[head]
                rel = data.id2rel[rel]
                tail = data.id2ent[tail]
                testcase = {}
                testcase['mention_span'] = tail
                testcase['left_context_token'] = head.split(delim1)+rel.split(delim1)
                testcase['right_context_token'] = []
                testcase['y_str'] = ['politician']
                testcase['annot_id'] = 'reverb%d' % count
                count += 1
                output[sp].append(testcase)
                print(json.dumps(testcase), file=fout)
    return output

def getDataFiles(params):
    # set data files
    params.data_files = {
        'ent2id_path'       : params.data_path + '/' + params.dataset + '/ent2id.txt',
        'rel2id_path'       : params.data_path + '/' + params.dataset + '/rel2id.txt',
        'train_trip_path'   : params.data_path + '/' + params.dataset + '/train_trip.txt',
        'test_trip_path'    : params.data_path + '/' + params.dataset + '/test_trip.txt',
        'valid_trip_path'   : params.data_path + '/' + params.dataset + '/valid_trip.txt',
        'gold_npclust_path' : params.data_path + '/' + params.dataset + '/gold_npclust.txt',
        'cesi_npclust_path' : params.data_path + '/' + params.dataset + '/cesi_npclust.txt',
        'glove_path'        : params.data_path + '/' + 'glove/glove.6B.300d.txt'
    }
    params.use_cuda = False
    params.model_variant = 'CaRe'
    params.batch_size = 128
    params.reverse = False
    return params


def main():
    parser = getParser()
    try:
        params = parser.parse_args()
    except:
        # parser.print_help()
        sys.exit(1)
    params = getDataFiles(params)
    prepare(params)

if __name__ == "__main__":
    main()
