import argparse
import json
import os
import random

from qaeval_utils import parse_rel, upred2bow, reconstruct_sent_from_svo

def get_first_half(string):
	string = string.split('\t')
	assert len(string) == 4
	string = '\t'.join(string[:2])
	return string

parser = argparse.ArgumentParser()
parser.add_argument('--input_fn', type=str,
					default='nc_final_samples_15_30_1000_triple_doc_disjoint_30_0_40000_2_lexic_wordnet_0.0_%s.json')
parser.add_argument('--train_num_posis', type=int, default=1600)  # default=1600)
parser.add_argument('--dev_num_posis', type=int, default=400)  # default=400)
parser.add_argument('--output_fn', type=str,
					default='../multilingual-lexical-inference/datasets/data_qaeval_15_30_1000_30_0_0/hypoonly_%sarg_lhsize/%s.txt')
parser.add_argument('--arguments', type=str, default='type')
parser.add_argument('--type_translation_path', type=str,
					default='../../cfet/u2figer/type_translation_layer1.jsonl')
parser.add_argument('--lang', type=str, default='en')

args = parser.parse_args()
assert args.arguments in ['name', 'type']

output_directory = '/'.join(args.output_fn.split('/')[:-1]) % args.arguments

os.makedirs(output_directory, exist_ok=True)

train_ofp = open(args.output_fn % (args.arguments, 'train'), 'w', encoding='utf8')
dev_ofp = open(args.output_fn % (args.arguments, 'dev'), 'w', encoding='utf8')
test_ofp = open(args.output_fn % (args.arguments, 'test'), 'w', encoding='utf8')

train_actual_num_posis = 0
dev_actual_num_posis = 0
test_actual_num_posis = 0

if args.lang == 'zh':
	with open(args.type_translation_path, 'r', encoding='utf8') as ifp:
		type_translations = json.load(ifp)
	TRUE_STR = '正确'
elif args.lang == 'en':
	type_translations = None
	TRUE_STR = 'true'
else:
	raise AssertionError

in_dev_lines_bucket = set()
in_test_lines_bucket = set()

in_dev_out_lines = []
in_test_out_lines = []

with open(args.input_fn % 'dev', 'r', encoding='utf8') as dev_ifp:
	for line in dev_ifp:
		if len(line) < 2:
			continue
		item = json.loads(line)
		upred, subj, obj, tsubj, tobj = parse_rel(item)
		if args.lang == 'zh':
			tsubj = type_translations[tsubj]
			tobj = type_translations[tobj]

		if args.arguments == 'name':
			reconstructed_sent = reconstruct_sent_from_svo(upred, subj, obj, 500, args.lang)
		elif args.arguments == 'type':
			reconstructed_sent = reconstruct_sent_from_svo(upred, tsubj, tobj, 500, args.lang)
		else:
			raise AssertionError
		if item['label'] is True:
			label_str = 'True'
		elif item['label'] is False:
			label_str = 'False'
		else:
			raise AssertionError
		out_line = f"{reconstructed_sent},,\t{TRUE_STR},,\t{label_str}\t{args.lang.upper()}\n"
		in_dev_lines_bucket.add(f"{reconstructed_sent},,\t{TRUE_STR},,")
		in_dev_out_lines.append(out_line)

with open(args.input_fn % 'test', 'r', encoding='utf8') as test_ifp:
	for line in test_ifp:
		if len(line) < 2:
			continue
		item = json.loads(line)
		upred, subj, obj, tsubj, tobj = parse_rel(item)
		if args.lang == 'zh':
			tsubj = type_translations[tsubj]
			tobj = type_translations[tobj]

		if args.arguments == 'name':
			reconstructed_sent = reconstruct_sent_from_svo(upred, subj, obj, 500, args.lang)
		elif args.arguments == 'type':
			reconstructed_sent = reconstruct_sent_from_svo(upred, tsubj, tobj, 500, args.lang)
		else:
			raise AssertionError
		if item['label'] is True:
			label_str = 'True'
		elif item['label'] is False:
			label_str = 'False'
		else:
			raise AssertionError
		out_line = f"{reconstructed_sent},,\t{TRUE_STR},,\t{label_str}\t{args.lang.upper()}\n"
		in_test_lines_bucket.add(f"{reconstructed_sent},,\t{TRUE_STR},,")
		in_test_out_lines.append(out_line)

intersaction_set = in_dev_lines_bucket.intersection(in_test_lines_bucket)
in_dev_lines_bucket = in_dev_lines_bucket - intersaction_set
in_test_lines_bucket = in_test_lines_bucket - intersaction_set

dev_seen_exc_cnt = 0
test_seen_exc_cnt = 0

for line in intersaction_set:
	assert line not in in_dev_lines_bucket and line not in in_test_lines_bucket
	if random.random() < 0.5:
		in_dev_lines_bucket.add(line)
	else:
		in_test_lines_bucket.add(line)

in_train_lines_bucket = set()
dev_in_train_exc_cnt = 0

for indev_out_line in in_dev_out_lines:
	if get_first_half(indev_out_line) not in in_dev_lines_bucket:
		dev_seen_exc_cnt += 1
		continue
	label_str = indev_out_line.split('\t')[2]
	if train_actual_num_posis <= args.train_num_posis:
		in_train_lines_bucket.add(get_first_half(indev_out_line))
		train_ofp.write(indev_out_line)
		if label_str == 'True':
			train_actual_num_posis += 1
		else:
			assert label_str == 'False'
	elif dev_actual_num_posis <= args.dev_num_posis:
		if get_first_half(indev_out_line) in in_train_lines_bucket:
			dev_in_train_exc_cnt += 1
			continue
		dev_ofp.write(indev_out_line)
		if label_str == 'True':
			dev_actual_num_posis += 1
		else:
			assert label_str == 'False'
	else:
		pass

for intest_out_line in in_test_out_lines:
	if get_first_half(intest_out_line) not in in_test_lines_bucket:
		test_seen_exc_cnt += 1
		continue
	label_str = intest_out_line.split('\t')[2]
	test_ofp.write(intest_out_line)
	if label_str == 'True':
		test_actual_num_posis += 1
	else:
		assert label_str == 'False'

train_ofp.close()
dev_ofp.close()
test_ofp.close()

print(f"train_actual_num_posis: {train_actual_num_posis}")
print(f"dev_actual_num_posis: {dev_actual_num_posis}")
print(f"test_actual_num_posis: {test_actual_num_posis}")

print(f"dev_in_train_exc_cnt: {dev_in_train_exc_cnt}")
print(f"Dev seen exc cnt: {dev_seen_exc_cnt}")
print(f"Test seen exc cnt: {test_seen_exc_cnt}")

print(f"Finished!")
