import csv
import json
from nltk.stem import WordNetLemmatizer
import ddparser
import sys
import time
import os
import random
import copy
sys.path.append('/Users/teddy/Files/Potential Corpus/WebHoses_Chinese_News_Articles/')
from extract import FineGrainedInfo, CoarseGrainedInfo
from dudepparse import post_processing, Token_Normalizer

PREM_PLACEHOLDER_REL = ('[UNK][UNK]占位符主语·条件', '[UNK][UNK]占位符谓语·条件', '[UNK][UNK]占位符宾语·条件')
HYPO_PLACEHOLDER_REL = ('[UNK][UNK]占位符主语·假设', '[UNK][UNK]占位符谓语·假设', '[UNK][UNK]占位符宾语·假设')


def merge_and_unique(rel_sets):  # [["重庆日报报业集团", "授权", "华龙网"], "SVO", [1, 2, 3]]
	rels = []
	for cand_rels in rel_sets:
		if cand_rels is not None:
			rels += cand_rels
	rels_dct = {}
	for rel in rels:
		rel_ser = f"{rel[0]}::{rel[1]}::{rel[2]}"
		if rel_ser not in rels_dct:
			rels_dct[rel_ser] = rel
	rels_uniq = [copy.deepcopy(rels_dct[_key]) for _key in rels_dct]
	return rels_uniq


def convert_tsv_to_txt(input_fn, output_fn, mapping_fn, trans_fn, num_splits):
	input_fp = open(input_fn, 'r')
	output_fps = [open(output_fn % x, 'w', encoding='utf8') for x in range(num_splits)]
	tsv_file = csv.reader(input_fp, delimiter="\t")
	txt_dict = {}
	txt_mapping = {}
	wordnet_lemmatizer = WordNetLemmatizer()

	for instance in tsv_file:
		assert len(instance) == 3
		cst_p = instance[0]  # comma splitted text for premise
		premise_svo = instance[0].split(',')
		premise_svo = [x.strip().lower() for x in premise_svo]
		premise_svo[1] = wordnet_lemmatizer.lemmatize(premise_svo[1], 'v')
		premise_svo = ' '.join(premise_svo)+'.'
		cst_h = instance[1]  # comma splitted text for hypothesis
		hypo_svo = instance[1].split(',')
		hypo_svo = [x.strip().lower() for x in hypo_svo]
		hypo_svo[1] = wordnet_lemmatizer.lemmatize(hypo_svo[1], 'v')
		hypo_svo = ' '.join(hypo_svo)+'.'
		txt_dict[cst_p] = premise_svo
		txt_dict[cst_h] = hypo_svo
	global_id = 0
	slice_size = len(txt_dict)//num_splits
	for key in txt_dict:
		slice_id = min(global_id//slice_size, num_splits-1)
		output_fps[slice_id].write(txt_dict[key]+'\n')
		txt_mapping[key] = global_id
		global_id += 1

	with open(mapping_fn, 'w', encoding='utf8') as fp:
		json.dump(txt_mapping, fp, ensure_ascii=False)
	for i in range(num_splits):
		empty_fp = open(trans_fn%i, 'w', encoding='utf8')
		empty_fp.close()
	input_fp.close()
	for fp in output_fps:
		fp.close()


def merge_txt_to_dict(trans_fn, mapping_fn, num_splits):
	lines = []
	mapped = {}
	for i in range(num_splits):
		with open(trans_fn%i, 'r', encoding='utf8') as fp:
			for line in fp:
				lines.append(line.strip())

	with open(mapping_fn, 'r', encoding='utf8') as fp:
		mapping = json.load(fp)

	for key in mapping:
		mapped[key] = lines[mapping[key]]
	return mapped


def build_triples_from_rels(rels):
	if len(rels) == 0:
		subj_rn = random.randrange(0, 1000000000)
		pred_rn = random.randrange(0, 1000000000)
		obj_rn = random.randrange(0, 1000000000)
		return f'PLACEHOLDER_{subj_rn},PLACEHOLDER_{pred_rn},PLACEHOLDER_{obj_rn}', f'(PLACEHOLDER_{pred_rn}.1,PLACEHOLDER_{pred_rn}.2) PLACEHOLDER_{subj_rn}::%s PLACEHOLDER_{obj_rn}::%s'
	otrel = rels[0]["r"]
	assert otrel[0] == '(' and otrel[-1] == ')'
	trel = otrel[1:-1].split('::')
	assert len(trel) == 6
	predicate = trel[0]
	predicate = predicate.split(',')
	assert len(predicate) == 2
	predicate = predicate[1]
	assert predicate[-3:] in ['.1)', '.2)', '.3)']
	predicate = predicate[:-3]
	predicate = predicate.split('.')
	predicate = ' '.join(predicate)
	subject = trel[1]
	subject = subject.split('_')
	subject = ' '.join(subject)
	obj = trel[2]
	obj = obj.split('_')
	obj = ' '.join(obj)

	return f'{subject},{predicate},{obj}', f'{trel[0]} {trel[1]}::%s {trel[2]}::%s'


def merge_backtranslation(backtrans_rels_fn, mapping_fn):
	lines = {}
	mapped = {}
	with open(mapping_fn, 'r', encoding='utf8') as fp:
		mapping = json.load(fp)
	input_fp = open(backtrans_rels_fn, 'r', encoding='utf8')
	for line in input_fp:
		try:
			item = json.loads(line)
			line_id = int(item["lineId"])
			assert line_id not in lines
			lines[line_id] = item
		except Exception as e:
			print(e, line)

	for key in mapping:
		target_lineid = mapping[key]
		if target_lineid not in lines:
			rels = []
		else:
			rels = lines[target_lineid]["rels"]
		built_rel, structured_rel = build_triples_from_rels(rels)
		mapped[key] = (built_rel, structured_rel)
	return mapped


def convert_text_to_entries(trans_fn, entries_fn, num_splits):
	entries_fp = open(entries_fn, 'w', encoding='utf8')
	for i in range(num_splits):
		with open(trans_fn%i, 'r', encoding='utf8') as fp:
			for line in fp:
				entry = {'splitted_text': [line.strip()]}
				out_line = json.dumps(entry, ensure_ascii=False)
				entries_fp.write(out_line+'\n')
	entries_fp.close()


def merge_entries_to_dicts(entries_fn, mapping_fn):
	entries = []
	mapped_entries = {}
	with open(entries_fn, 'r', encoding='utf8') as entries_fp:
		for line in entries_fp:
			item = json.loads(line)
			entries.append(item)
	with open(mapping_fn, 'r', encoding='utf8') as fp:
		mapping = json.load(fp)

	for key in mapping:
		mapped_entries[key] = entries[mapping[key]]

	return mapped_entries


# if there are no binary relations, return None; if there is one binary relation, return it;
# if there are multiple binary relations, return the one with the longest total span
def find_binary_from_rels(rels, ddp_sent_res, mode='best'):
	binaries = []
	for rel in rels:
		if rel[0][0] is not None and rel[0][2] is not None:
			binaries.append(rel)
	if len(binaries) == 0:
		return None, []
	elif len(binaries) == 1:
		if mode == 'best':
			return binaries[0][0], []  # return the names (indexed 0) in the desired relation
		elif mode == 'all':
			return [binaries[0][0]], []
		else:
			raise AssertionError
	elif len(binaries) > 1:
		generalized_head_tids = []  # tokens bearing deprel of HED, COO, VV or IC, tid starting from 0
		for tid in range(len(ddp_sent_res['word'])):
			if ddp_sent_res['deprel'][tid] in ['HED', 'COO', 'VV', 'IC']:
				generalized_head_tids.append(tid)
		better_binaries = []
		for rel in binaries:
			if rel[2][1] in generalized_head_tids:
				better_binaries.append(rel)
		if mode == 'best':
			if len(better_binaries) > 0:
				ret_bin = random.choice(better_binaries)
				return ret_bin[0], []  # return the names (indexed 0) in the desired relation
			else:
				return None, [rel[0] for rel in binaries]  # return the names (indexed 0) in the desired relation
		elif mode == 'all':
			if len(better_binaries) > 0:
				return [rel[0] for rel in better_binaries], []
			else:
				return None, [rel[0] for rel in binaries]
		else:
			raise AssertionError


def random_chunk(string):
	string = string.strip('。').strip('？').strip('！').strip('；')
	assert len(string) > 1
	if len(string) == 2:
		string = string[0]+'是'+string[1]
	typical_chunk_size = len(string) // 3
	bar_1 = typical_chunk_size
	bar_1 = random.choices([bar_1 - 2, bar_1 - 1, bar_1, bar_1 + 1, bar_1 + 2], weights=[0.1, 0.2, 0.4, 0.2, 0.1], k=1)
	bar_1 = max(1, bar_1[0])
	bar_1 = min(len(string)-2, bar_1)
	bar_2 = typical_chunk_size * 2
	bar_2 = random.choices([bar_2 - 2, bar_2 - 1, bar_2, bar_2 + 1, bar_2 + 2], weights=[0.1, 0.2, 0.4, 0.2, 0.1], k=1)
	assert bar_1+1 <= len(string)-1
	bar_2 = max(bar_2[0], bar_1+1)
	bar_2 = min(len(string) - 1, bar_2)
	res = (string[:bar_1], string[bar_1:bar_2], string[bar_2:])
	assert len(res[0]) > 0 and len(res[1]) > 0 and len(res[2]) > 0
	return res


# for each chunk of a predicate, mask its first instance in the sentence
def mask_predicate(sent, pred):
	masked_idxs = []
	pred_list = pred.split('·')
	for p in pred_list:
		if p == 'X':
			continue
		p_span_start = sent.find(p)
		for idx in range(p_span_start, p_span_start+len(p)):
			masked_idxs.append(idx)
	masked_sent = ''
	for idx, tok in enumerate(sent):
		if idx not in masked_idxs:
			masked_sent += tok
		else:
			masked_sent += '[MASK]'
	return masked_sent


def construct_binary_csv_line_from_sent_and_rels(trans_premise, trans_hypo, rels_bucket, ddp_res, label, add_crossed):
	match_flag = True  # whether or not a usable relation in found in the relation extraction pipeline
	[ddp_premise, ddp_hypo] = ddp_res

	fine_prem_rels = rels_bucket['premise']['fine']
	fine_prem_rels_amend = rels_bucket['premise']['fine_amend']
	coarse_prem_rels = rels_bucket['premise']['coarse']
	coarse_prem_rels_amend = rels_bucket['premise']['coarse_amend']
	crossed_prem_rels = rels_bucket['premise']['crossed']
	crossed_prem_rels_amend = rels_bucket['premise']['crossed_amend']
	if add_crossed:
		coarse_prem_rels += crossed_prem_rels
		coarse_prem_rels_amend += crossed_prem_rels_amend

	premise_best_rel, premise_part_rels_amend_fine = find_binary_from_rels(fine_prem_rels_amend, ddp_premise)
	if premise_best_rel is None:
		premise_best_rel, premise_part_rels_fine = find_binary_from_rels(fine_prem_rels, ddp_premise)

	if premise_best_rel is None:
		premise_best_rel, premise_part_rels_amend_coarse = find_binary_from_rels(coarse_prem_rels_amend, ddp_premise)
	if premise_best_rel is None:
		premise_best_rel, premise_part_rels_coarse = find_binary_from_rels(coarse_prem_rels, ddp_premise)

	if premise_best_rel is None:
		worse_rels_fine = premise_part_rels_amend_fine + premise_part_rels_fine
		if len(worse_rels_fine) > 0:
			premise_best_rel = random.choice(worse_rels_fine)

	if premise_best_rel is None:
		worse_rels_coarse = premise_part_rels_amend_coarse + premise_part_rels_coarse
		if len(worse_rels_coarse) > 0:
			premise_best_rel = random.choice(worse_rels_coarse)

	if premise_best_rel is None:
		# premise_best_rel = random_chunk(trans_premise)
		premise_best_rel = PREM_PLACEHOLDER_REL
		match_flag = False

	fine_hypo_rels = rels_bucket['hypo']['fine']
	fine_hypo_rels_amend = rels_bucket['hypo']['fine_amend']
	coarse_hypo_rels = rels_bucket['hypo']['coarse']
	coarse_hypo_rels_amend = rels_bucket['hypo']['coarse_amend']
	crossed_hypo_rels = rels_bucket['hypo']['crossed']
	crossed_hypo_rels_amend = rels_bucket['hypo']['crossed_amend']
	if add_crossed:
		coarse_hypo_rels += crossed_hypo_rels
		coarse_hypo_rels_amend += crossed_hypo_rels_amend

	hypo_best_rel, hypo_part_rels_amend_fine = find_binary_from_rels(fine_hypo_rels_amend, ddp_hypo)
	if hypo_best_rel is None:
		hypo_best_rel, hypo_part_rels_fine = find_binary_from_rels(fine_hypo_rels, ddp_hypo)
	if hypo_best_rel is None:
		worse_rels_fine = hypo_part_rels_amend_fine + hypo_part_rels_fine
		if len(worse_rels_fine) > 0:
			hypo_best_rel = random.choice(worse_rels_fine)

	if hypo_best_rel is None:
		hypo_best_rel, hypo_part_rels_amend_coarse = find_binary_from_rels(coarse_hypo_rels_amend, ddp_hypo)
	if hypo_best_rel is None:
		hypo_best_rel, hypo_part_rels_coarse = find_binary_from_rels(coarse_hypo_rels, ddp_hypo)
	if hypo_best_rel is None:
		worse_rels_coarse = hypo_part_rels_amend_coarse + hypo_part_rels_coarse
		if len(worse_rels_coarse) > 0:
			hypo_best_rel = random.choice(worse_rels_coarse)

	if hypo_best_rel is None:
		# hypo_best_rel = random_chunk(trans_hypo)
		hypo_best_rel = HYPO_PLACEHOLDER_REL
		match_flag = False

	assert premise_best_rel is not None and hypo_best_rel is not None
	premise_masked_sent = mask_predicate(trans_premise, premise_best_rel[1])
	hypo_masked_sent = mask_predicate(trans_hypo, hypo_best_rel[1])

	for item in premise_best_rel:
		assert len(item) > 0
	for item in hypo_best_rel:
		assert len(item) > 0

	line = f"{', '.join(premise_best_rel)}\t{', '.join(hypo_best_rel)}\t{label}"
	raw_line = f"{trans_premise}\t{trans_hypo}\t{premise_masked_sent}\t{hypo_masked_sent}"
	return line, raw_line, match_flag


# Include all "better" fine-grained relations, if none found, go for "better" coarse-grained relations, then "worse"
# fine-grained relations, then "worse" coarse-grained relations
def exhaust_binary_csv_lines_from_sent_and_rels(trans_premise, trans_hypo, rels_bucket, ddp_res, label, add_crossed):
	counts_per_tier = [0,0,0,0,0,0,0]
	flags_per_tier = [False,False,False,False,False,False,False]
	[ddp_premise, ddp_hypo] = ddp_res
	match_found = True

	fine_prem_rels = rels_bucket['premise']['fine']
	fine_prem_rels_amend = rels_bucket['premise']['fine_amend']
	coarse_prem_rels = rels_bucket['premise']['coarse']
	coarse_prem_rels_amend = rels_bucket['premise']['coarse_amend']
	crossed_prem_rels = rels_bucket['premise']['crossed']
	crossed_prem_rels_amend = rels_bucket['premise']['crossed_amend']
	if add_crossed:
		coarse_prem_rels += crossed_prem_rels
		coarse_prem_rels_amend += crossed_prem_rels_amend

	premise_best_rels_amend_fine, premise_part_rels_amend_fine = find_binary_from_rels(fine_prem_rels_amend, ddp_premise, mode='all')
	premise_best_rels = merge_and_unique([premise_best_rels_amend_fine])

	if len(premise_best_rels) == 0:
		premise_best_rels_fine, premise_part_rels_fine = find_binary_from_rels(fine_prem_rels, ddp_premise, mode='all')
		premise_best_rels = merge_and_unique([premise_best_rels_fine])
	else:
		flags_per_tier[0] = True  # match is found in better fine amend binaries

	if len(premise_best_rels) == 0:
		premise_best_rels_amend_coarse, premise_part_rels_amend_coarse = find_binary_from_rels(coarse_prem_rels_amend, ddp_premise, mode='all')
		premise_best_rels = merge_and_unique([premise_best_rels_amend_coarse])
	elif not flags_per_tier[0]:
		flags_per_tier[1] = True  # match is found in better fine original binaries

	if len(premise_best_rels) == 0:
		premise_best_rels_coarse, premise_part_rels_coarse = find_binary_from_rels(coarse_prem_rels, ddp_premise, mode='all')
		premise_best_rels = merge_and_unique([premise_best_rels_coarse])
	elif not flags_per_tier[1] and not flags_per_tier[0]:
		flags_per_tier[2] = True  # match is found in better coarse amend binaries

	if len(premise_best_rels) == 0:
		worse_rels_fine = premise_part_rels_amend_fine + premise_part_rels_fine
		premise_best_rels = worse_rels_fine
	elif not flags_per_tier[2] and not flags_per_tier[1] and not flags_per_tier[0]:
		flags_per_tier[3] = True  # match is found in better coarse original binaries

	if len(premise_best_rels) == 0:
		worse_rels_coarse = premise_part_rels_amend_coarse + premise_part_rels_coarse
		premise_best_rels = worse_rels_coarse
	elif not flags_per_tier[3] and not flags_per_tier[2] and not flags_per_tier[1] and not flags_per_tier[0]:
		flags_per_tier[4] = True  # match is found in worse-fine-binaries

	if len(premise_best_rels) == 0:
		# premise_best_rel = random_chunk(trans_premise)
		premise_best_rels = [PREM_PLACEHOLDER_REL]
		match_found = False
	elif not flags_per_tier[4] and not flags_per_tier[3] and not flags_per_tier[2] and not flags_per_tier[1] and not flags_per_tier[0]:
		flags_per_tier[5] = True  # match is found in worse-coarse-binaries

	if not flags_per_tier[5] and not flags_per_tier[4] and not flags_per_tier[3] and not flags_per_tier[2] and not flags_per_tier[1] and not flags_per_tier[0]:
		flags_per_tier[6] = True  # match is not found, ended up with placeholders.
	for tier_id in range(len(flags_per_tier)):
		if flags_per_tier[tier_id]:
			counts_per_tier[tier_id] += 1
	flags_per_tier = [False, False, False, False, False, False, False]


	fine_hypo_rels = rels_bucket['hypo']['fine']
	fine_hypo_rels_amend = rels_bucket['hypo']['fine_amend']
	coarse_hypo_rels = rels_bucket['hypo']['coarse']
	coarse_hypo_rels_amend = rels_bucket['hypo']['coarse_amend']
	crossed_hypo_rels = rels_bucket['hypo']['crossed']
	crossed_hypo_rels_amend = rels_bucket['hypo']['crossed_amend']
	if add_crossed:
		coarse_hypo_rels += crossed_hypo_rels
		coarse_hypo_rels_amend += crossed_hypo_rels_amend

	hypo_best_rels_amend_fine, hypo_part_rels_amend_fine = find_binary_from_rels(fine_hypo_rels_amend,
																					   ddp_hypo, mode='all')
	hypo_best_rels = merge_and_unique([hypo_best_rels_amend_fine])

	if len(hypo_best_rels) == 0:
		hypo_best_rels_fine, hypo_part_rels_fine = find_binary_from_rels(fine_hypo_rels, ddp_hypo, mode='all')
		hypo_best_rels = merge_and_unique([hypo_best_rels_fine])
	else:
		flags_per_tier[0] = True

	if len(hypo_best_rels) == 0:
		hypo_best_rels_amend_coarse, hypo_part_rels_amend_coarse = find_binary_from_rels(coarse_hypo_rels_amend,
																							   ddp_hypo, mode='all')
		hypo_best_rels = merge_and_unique([hypo_part_rels_amend_coarse])
	elif not flags_per_tier[0]:
		flags_per_tier[1] = True

	if len(hypo_best_rels) == 0:
		hypo_best_rels_coarse, hypo_part_rels_coarse = find_binary_from_rels(coarse_hypo_rels, ddp_hypo, mode='all')
		hypo_best_rels = merge_and_unique([hypo_best_rels_coarse])
	elif not flags_per_tier[1] and not flags_per_tier[0]:
		flags_per_tier[2] = True

	if len(hypo_best_rels) == 0:
		worse_rels_fine = hypo_part_rels_amend_fine + hypo_part_rels_fine
		hypo_best_rels = worse_rels_fine
	elif not flags_per_tier[2] and not flags_per_tier[1] and not flags_per_tier[0]:
		flags_per_tier[3] = True

	if len(hypo_best_rels) == 0:
		worse_rels_coarse = hypo_part_rels_amend_coarse + hypo_part_rels_coarse
		hypo_best_rels = worse_rels_coarse
	elif not flags_per_tier[3] and not flags_per_tier[2] and not flags_per_tier[1] and not flags_per_tier[0]:
		flags_per_tier[4] = True

	if len(hypo_best_rels) == 0:
		hypo_best_rels = [HYPO_PLACEHOLDER_REL]
		match_found = False
	elif not flags_per_tier[4] and not flags_per_tier[3] and not flags_per_tier[2] and not flags_per_tier[1] and not flags_per_tier[0]:
		flags_per_tier[5] = True

	if not flags_per_tier[5] and flags_per_tier[4] and not flags_per_tier[3] and not flags_per_tier[2] and not flags_per_tier[1] and not flags_per_tier[0]:
		flags_per_tier[6] = True
	for tier_id in range(len(flags_per_tier)):
		if flags_per_tier[tier_id]:
			counts_per_tier[tier_id] += 1

	assert len(premise_best_rels) > 0 and len(hypo_best_rels) > 0
	premise_masked_sents = [mask_predicate(trans_premise, prem_bin[1]) for prem_bin in premise_best_rels]
	hypo_masked_sents = [mask_predicate(trans_hypo, hypo_bin[1]) for hypo_bin in hypo_best_rels]

	lines = []
	raw_lines = []

	for i in range(len(premise_best_rels)):
		for j in range(len(hypo_best_rels)):
			prem_r = premise_best_rels[i]
			hypo_r = hypo_best_rels[j]
			prem_msk = premise_masked_sents[i]
			hypo_msk = hypo_masked_sents[j]
			for item in prem_r:
				assert len(item) > 0
			for item in hypo_r:
				assert len(item) > 0
			line = f"{', '.join(prem_r)}\t{', '.join(hypo_r)}\t{label}"
			raw_line = f"{trans_premise}\t{trans_hypo}\t{prem_msk}\t{hypo_msk}"
			lines.append(line)
			raw_lines.append(raw_line)

	return lines, raw_lines, counts_per_tier, match_found


# In Levy/Holt's dataset, there are 62 sentences with multiple rels for Dev set, and 151 for test set
def construct_binary_csv_line_from_sent_and_rels_jia(trans_premise, trans_hypo, rels_bucket, label, exhaust):
	match_flag = True  # whether or not a usable relation in found in the relation extraction pipeline

	fine_prem_rels = rels_bucket['premise']['fine'][0]
	fine_prem_rels_amend = rels_bucket['premise']['fine_amend'][0]
	coarse_prem_rels = rels_bucket['premise']['coarse'][0]
	coarse_prem_rels_amend = rels_bucket['premise']['coarse_amend'][0]
	assert len(fine_prem_rels_amend) == 0 and len(coarse_prem_rels) == 0 and len(coarse_prem_rels_amend) == 0
	binaries = []
	for rel in fine_prem_rels:
		if rel[0] is not None and rel[1] is not None and rel[2] is not None and len(rel) == 3:
			binaries.append(rel)
	if len(binaries) > 0:
		if exhaust:
			premise_best_rel = binaries
		else:
			premise_best_rel = random.choice(binaries)
	else:
		if exhaust:
			premise_best_rel = [PREM_PLACEHOLDER_REL]
		else:
			premise_best_rel = PREM_PLACEHOLDER_REL
		match_flag = False

	fine_hypo_rels = rels_bucket['hypo']['fine'][0]
	fine_hypo_rels_amend = rels_bucket['hypo']['fine_amend'][0]
	coarse_hypo_rels = rels_bucket['hypo']['coarse'][0]
	coarse_hypo_rels_amend = rels_bucket['hypo']['coarse_amend'][0]
	assert len(fine_hypo_rels_amend) == 0 and len(coarse_hypo_rels) == 0 and len(coarse_hypo_rels_amend) == 0
	binaries = []
	for rel in fine_hypo_rels:
		if rel[0] is not None and rel[1] is not None and rel[2] is not None and len(rel) == 3:
			binaries.append(rel)
	if len(binaries) > 0:
		if exhaust:
			hypo_best_rel = binaries
		else:
			hypo_best_rel = random.choice(binaries)
	else:
		# hypo_best_rel = random_chunk(trans_hypo)
		if exhaust:
			hypo_best_rel = [HYPO_PLACEHOLDER_REL]
		else:
			hypo_best_rel = HYPO_PLACEHOLDER_REL
		match_flag = False

	assert premise_best_rel is not None and hypo_best_rel is not None
	if exhaust:
		out_lines = []
		out_raw_lines = []
		for p_rel in premise_best_rel:
			for h_rel in hypo_best_rel:
				premise_masked_sent = mask_predicate(trans_premise, p_rel[1])
				hypo_masked_sent = mask_predicate(trans_hypo, h_rel[1])

				# assert that subject, object and verb are all non-empty
				for wrd in p_rel:
					assert len(wrd) > 0
				for wrd in h_rel:
					assert len(wrd) > 0

				line = f"{', '.join(p_rel)}\t{', '.join(h_rel)}\t{label}"
				raw_line = f"{trans_premise}\t{trans_hypo}\t{premise_masked_sent}\t{hypo_masked_sent}"
				out_lines.append(line)
				out_raw_lines.append(raw_line)

		return out_lines, out_raw_lines, match_flag
	else:
		premise_masked_sent = mask_predicate(trans_premise, premise_best_rel[1])
		hypo_masked_sent = mask_predicate(trans_hypo, hypo_best_rel[1])

		# assert that subject, object and verb are all non-empty
		for item in premise_best_rel:
			assert len(item) > 0
		for item in hypo_best_rel:
			assert len(item) > 0

		line = f"{', '.join(premise_best_rel)}\t{', '.join(hypo_best_rel)}\t{label}"
		raw_line = f"{trans_premise}\t{trans_hypo}\t{premise_masked_sent}\t{hypo_masked_sent}"
		return line, raw_line, match_flag


def contruct_translated_doc_for_postag(mapped, input_fn, out_fn, nosame_out_fn):
	input_fp = open(input_fn, 'r', encoding='utf8')
	tsv_file = csv.reader(input_fp, delimiter="\t")
	out_fp = open(out_fn, 'w', encoding='utf8')
	nosame_out_fp = open(nosame_out_fn, 'w', encoding='utf8')
	iid = 0
	for instance in tsv_file:
		if iid % 1000 == 0 and iid > 0:
			print(iid)
		assert len(instance) == 3
		en_premise = instance[0]
		en_hypo = instance[1]
		trans_premise = mapped[en_premise]
		trans_hypo = mapped[en_hypo]
		doc = {'splitted_text': [trans_premise, trans_hypo], 'value': instance[2]}
		out_line = json.dumps(doc, ensure_ascii=False)
		out_fp.write(out_line+'\n')
		if trans_premise != trans_hypo:
			nosame_out_fp.write(out_line+'\n')
	input_fp.close()
	out_fp.close()
	nosame_out_fp.close()


def construct_tsv_backtranslation(mapped, input_fn, out_fn, outrel_fn, refrel_fn):
	input_fp = open(input_fn, 'r', encoding='utf8')
	refrel_fp = open(refrel_fn, 'r', encoding='utf8')
	tsv_file = csv.reader(input_fp, delimiter="\t")
	out_fp = open(out_fn, 'w', encoding='utf8')
	outrel_fp = open(outrel_fn, 'w', encoding='utf8')
	ref_rels = []
	for line in refrel_fp:
		prr, hrr, lbl = line.split('\t')
		prr = prr.split(' ')
		hrr = hrr.split(' ')
		ref_rels.append((prr, hrr))
	iid = 0
	for instance in tsv_file:
		if iid % 1000 == 0 and iid > 0:
			print(iid)
		assert len(instance) == 3
		orig_premise = instance[0]
		orig_hypo = instance[1]
		backtransed_premise, btp_rel = mapped[orig_premise]
		backtransed_hypo, bth_rel = mapped[orig_hypo]
		out_line = backtransed_premise + '\t' + backtransed_hypo + '\t' + instance[2]
		out_fp.write(out_line+'\n')

		prr, hrr = ref_rels[iid]
		try:
			pst = prr[1].split('::')[1]  # premise subject type
			pot = prr[2].split('::')[1]
		except Exception as e:
			print(e)
			print("P Error!")
			pst = 'thing'
			pot = 'thing'
		try:
			hst = hrr[1].split('::')[1]
			hot = hrr[2].split('::')[1]
		except Exception as e:
			print(e)
			print("H Error!")
			hst = 'thing'
			hot = 'thing'
		outrel_line = f'{btp_rel%(pst, pot)}\t{bth_rel%(hst, hot)}\t{instance[2]}'
		outrel_fp.write(outrel_line+'\n')
		iid += 1

	input_fp.close()
	out_fp.close()
	outrel_fp.close()


def construct_translated_tsv(input_postagged_fn, out_fn, raw_out_fn, rel_levy_mapping_fn, amend, fine_only, exhaust=False, add_crossed=False):
	input_fp = open(input_postagged_fn, 'r', encoding='utf8')
	raw_out_fp = open(raw_out_fn, 'w', encoding='utf8')
	out_fp = open(out_fn, 'w', encoding='utf8')
	rel_levy_mapping_fp = open(rel_levy_mapping_fn, 'w', encoding='utf8')
	ddp = ddparser.DDParser(encoding_model='transformer')
	iid = 0
	st = time.time()
	matched_instances_cnt = 0
	total_matches_per_tier = [0, 0, 0, 0, 0, 0, 0]
	num_lines_bucket = {}
	for line in input_fp:
		if iid % 500 == 0 and iid > 0:
			ct = time.time()
			dur = ct - st
			dur_h = int(dur) / 3600
			dur_m = (int(dur) % 3600) / 60
			dur_s = int(dur) % 60
			print(iid, 'time lapsed: %d hours %d minutes %d seconds' % (dur_h, dur_m, dur_s))
		doc = json.loads(line)
		trans_premise = doc['splitted_text'][0]
		trans_hypo = doc['splitted_text'][1]
		truth_value = doc['value']

		if trans_premise == '布什总统被比尔·克林顿击败。':
			print("!")
		elif trans_hypo == '布什总统被比尔·克林顿击败。':
			print("!")

		ddp_premise, ddp_hypo = ddp.parse([trans_premise, trans_hypo])
		token_normalizer = Token_Normalizer(remove_from_args=False)

		coarse_info_premise = CoarseGrainedInfo(ddp_premise)
		fine_info_premise = FineGrainedInfo(ddp_premise)
		coarse_premise_rels = coarse_info_premise.parse()
		fine_premise_rels = fine_info_premise.parse()

		coarse_info_hypo = CoarseGrainedInfo(ddp_hypo)
		fine_info_hypo = FineGrainedInfo(ddp_hypo)
		coarse_hypo_rels = coarse_info_hypo.parse()
		fine_hypo_rels = fine_info_hypo.parse()

		# below the "possible_rels" are modification structures and are durchaus not true propositions for the whole
		# sentences. Thus, they can be ignored.
		[fine_premise_rels, fine_hypo_rels], [coarse_premise_rels, coarse_hypo_rels],\
		[fine_premise_rels_amend, fine_hypo_rels_amend], \
		[coarse_premise_rels_amend, coarse_hypo_rels_amend], \
		[crossed_premise_rels, crossed_hypo_rels], \
		[crossed_premise_rels_amend, crossed_hypo_rels_amend], \
		[possible_rels_premise, possible_rels_hypo], \
		_, _, _, _, _, _ = post_processing([fine_premise_rels, fine_hypo_rels], [coarse_premise_rels, coarse_hypo_rels],
											[ddp_premise, ddp_hypo], doc['corenlp_pos_tags'], token_normalizer, [coarse_info_premise, coarse_info_hypo],
											vcmp_bucket=None, fine_stop_word_count_bucket=None, fine_digit_excluded_count=0,
											MUST_INCLUDE_CHINESE_flag=False, coarse_stop_word_count_bucket=None,
											coarse_digit_excluded_count=0, KEEP_ONLY_SVO=True, DEBUG=False)

		if fine_only and not amend:
			rels_bucket = {'premise': {'fine': fine_premise_rels, 'fine_amend': [],
									   'coarse': [], 'coarse_amend': [],
									   'crossed': [], 'crossed_amend': []},
						   'hypo': {'fine': fine_hypo_rels, 'fine_amend': [],
									'coarse': [], 'coarse_amend': [],
									'crossed': [], 'crossed_amend': []}}
		elif fine_only and amend:
			rels_bucket = {'premise': {'fine': fine_premise_rels, 'fine_amend': fine_premise_rels_amend,
									   'coarse': [], 'coarse_amend': [],
									   'crossed': [], 'crossed_amend': []},
						   'hypo': {'fine': fine_hypo_rels, 'fine_amend': fine_hypo_rels_amend,
									'coarse': [], 'coarse_amend': [],
									'crossed': [], 'crossed_amend': []}}
		elif not fine_only and not amend:
			rels_bucket = {'premise': {'fine': fine_premise_rels, 'fine_amend': [],
									   'coarse': coarse_premise_rels, 'coarse_amend': [],
									   'crossed': crossed_premise_rels, 'crossed_amend': []},
						   'hypo': {'fine': fine_hypo_rels, 'fine_amend': [],
									'coarse': coarse_hypo_rels, 'coarse_amend': [],
									'crossed': crossed_hypo_rels, 'crossed_amend': []}}
		elif not fine_only and amend:
			rels_bucket = {'premise': {'fine': fine_premise_rels, 'fine_amend': fine_premise_rels_amend,
									   'coarse': coarse_premise_rels, 'coarse_amend': coarse_premise_rels_amend,
									   'crossed': crossed_premise_rels, 'crossed_amend': crossed_premise_rels_amend},
						   'hypo': {'fine': fine_hypo_rels, 'fine_amend': fine_hypo_rels_amend,
									'coarse': coarse_hypo_rels, 'coarse_amend': coarse_hypo_rels_amend,
									'crossed': crossed_hypo_rels, 'crossed_amend': crossed_hypo_rels_amend}}
		else:
			raise AssertionError

		if not exhaust:
			instance_line, raw_line, match_flag = construct_binary_csv_line_from_sent_and_rels(trans_premise, trans_hypo,
																						   rels_bucket, [ddp_premise, ddp_hypo],
																						   truth_value, add_crossed=add_crossed)
			p, h, v = instance_line.split('\t')
			for p_c in p.split(','):
				assert len(p_c.strip()) > 0
			for h_c in h.split(','):
				assert len(h_c.strip()) > 0
			out_fp.write(instance_line+'\n')
			raw_out_fp.write(raw_line+'\n')  # stores raw output into corresponding file
			rel_levy_mapping_fp.write(str(iid)+'\n')
			if match_flag:
				matched_instances_cnt += 1
		else:  # if exhaust
			instance_lines, raw_lines, cur_matches_per_tier, match_found = exhaust_binary_csv_lines_from_sent_and_rels(trans_premise, trans_hypo,
																						   rels_bucket, [ddp_premise, ddp_hypo],
																						   truth_value, add_crossed=add_crossed)
			assert len(instance_lines) == len(raw_lines)
			if len(instance_lines) not in num_lines_bucket:
				num_lines_bucket[len(instance_lines)] = 0
			num_lines_bucket[len(instance_lines)] += 1
			for tier_id in range(5):
				total_matches_per_tier[tier_id] += cur_matches_per_tier[tier_id]
			assert len(instance_lines) == len(raw_lines)
			for ins_i in range(len(instance_lines)):
				out_fp.write(instance_lines[ins_i]+'\n')
				raw_out_fp.write(raw_lines[ins_i]+'\n')
				rel_levy_mapping_fp.write(str(iid)+'\n')
			if match_found:
				matched_instances_cnt += 1
		iid += 1

	print(f"Out of the {iid} instances in dataset, {matched_instances_cnt} are successfully translated and parsed!")
	print(f"Total matches per tier: {total_matches_per_tier}")
	num_lines_bucket = {k: v for k, v in sorted(num_lines_bucket.items(), key=lambda item: item[1], reverse=True)}
	print("Number of lines distribution: ")
	print(num_lines_bucket)
	input_fp.close()
	raw_out_fp.close()
	out_fp.close()
	rel_levy_mapping_fp.close()


def construct_translated_tsv_jia(mapped_entries, input_fn, out_fn, raw_out_fn, rel_levy_mapping_fn, exhaust):
	input_fp = open(input_fn, 'r', encoding='utf8')
	tsv_file = csv.reader(input_fp, delimiter="\t")
	out_fp = open(out_fn, 'w', encoding='utf8')
	raw_out_fp = open(raw_out_fn, 'w', encoding='utf8')
	rel_levy_mapping_fp = open(rel_levy_mapping_fn, 'w', encoding='utf8')
	iid = 0
	matched_instances_cnt = 0
	for instance in tsv_file:
		if iid % 1000 == 0 and iid > 0:
			print(iid)
		assert len(instance) == 3
		en_premise = instance[0]
		en_hypo = instance[1]
		trans_premise_entry = mapped_entries[en_premise]
		trans_hypo_entry = mapped_entries[en_hypo]
		truth_value = instance[2]
		trans_premise = trans_premise_entry['splitted_text'][0]
		trans_hypo = trans_hypo_entry['splitted_text'][0]

		rels_bucket = {'premise': {'fine': trans_premise_entry['fine_rels'], 'fine_amend': trans_premise_entry['amend_fine_rels'],
								   'coarse': trans_premise_entry['coarse_rels'], 'coarse_amend': trans_premise_entry['amend_coarse_rels']},
					   'hypo': {'fine': trans_hypo_entry['fine_rels'], 'fine_amend': trans_hypo_entry['amend_fine_rels'],
								'coarse': trans_hypo_entry['coarse_rels'], 'coarse_amend': trans_hypo_entry['amend_coarse_rels']}}

		instance_line, raw_line, match_flag = construct_binary_csv_line_from_sent_and_rels_jia(trans_premise, trans_hypo,
																						   		rels_bucket, truth_value,
																							   	exhaust)
		if exhaust:
			for i_line, r_line in zip(instance_line, raw_line):
				p, h, v = i_line.split('\t')
				for p_c in p.split(','):
					assert len(p_c.strip()) > 0
				for h_c in h.split(','):
					assert len(h_c.strip()) > 0
				out_fp.write(i_line+'\n')
				raw_out_fp.write(r_line+'\n')
				rel_levy_mapping_fp.write(str(iid)+'\n')
		else:
			p, h, v = instance_line.split('\t')
			for p_c in p.split(','):
				assert len(p_c.strip()) > 0
			for h_c in h.split(','):
				assert len(h_c.strip()) > 0
			out_fp.write(instance_line+'\n')
			raw_out_fp.write(raw_line+'\n')  # stores raw output into corresponding file
			rel_levy_mapping_fp.write(str(iid) + '\n')
		if match_flag:
			matched_instances_cnt += 1
		iid += 1
	print(f"Out of the {iid} instances in dataset, {matched_instances_cnt} are successfully translated and parsed!")
	input_fp.close()
	raw_out_fp.close()
	out_fp.close()
	rel_levy_mapping_fp.close()


def split_eval_to_devtest(trans_fn, raw_fn, rels_fn, rellevy_mapping_fn, trans_ofn, raw_ofn, rels_ofn, rellevy_mapping_ofn, shuffle=False, split_mapping_fn=None):
	mapping_preexists_flag = None
	if os.path.isfile(split_mapping_fn):
		with open(split_mapping_fn, 'r', encoding='utf8') as fp:
			mapping = json.load(fp)
		mapping_preexists_flag = True
	else:
		mapping = None
		mapping_preexists_flag = False
	assert mapping_preexists_flag is not None

	trans_fp = open(trans_fn, 'r', encoding='utf8')
	raw_fp = open(raw_fn, 'r', encoding='utf8')
	rels_fp = open(rels_fn, 'r', encoding='utf8')

	print(f"Reading from: {trans_fn}; {raw_fn}; {rels_fn}!")

	length = 0

	for b,c,d in zip(trans_fp, raw_fp, rels_fp):
		length += 1

	print(f"Length aligned! Number of entries: {length}!")

	trans_fp.close()
	raw_fp.close()
	rels_fp.close()
	trans_fp = open(trans_fn, 'r', encoding='utf8')
	raw_fp = open(raw_fn, 'r', encoding='utf8')
	rels_fp = open(rels_fn, 'r', encoding='utf8')
	rellevy_mapping_fp = open(rellevy_mapping_fn, 'r', encoding='utf8')

	trans_dev_fp = open(trans_ofn % 'dev', 'w', encoding='utf8')
	raw_dev_fp = open(raw_ofn % 'dev', 'w', encoding='utf8')
	rels_dev_fp = open(rels_ofn % 'dev', 'w', encoding='utf8')
	rellevy_mapping_dev_fp = open(rellevy_mapping_ofn % 'dev', 'w', encoding='utf8')
	trans_test_fp = open(trans_ofn % 'test', 'w', encoding='utf8')
	raw_test_fp = open(raw_ofn % 'test', 'w', encoding='utf8')
	rels_test_fp = open(rels_ofn % 'test', 'w', encoding='utf8')
	rellevy_mapping_test_fp = open(rellevy_mapping_ofn % 'test', 'w', encoding='utf8')

	print(f"Saving to: {trans_ofn}; {raw_ofn}; {rels_ofn}!")

	if shuffle:
		if mapping is None:
			assert mapping_preexists_flag is False
			dev_idxs = random.sample(list(range(length)), k=length // 2)
			test_idxs = []
			for idx in range(length):
				if idx not in dev_idxs:
					test_idxs.append(idx)
			mapping = {"dev_idxs": dev_idxs, "test_idxs": test_idxs}
		print("Shuffle before splitting!")
	else:
		dev_idxs = list(range(length//2))
		test_idxs = list(range(length//2, length))
		mapping = {"dev_idxs": dev_idxs, "test_idxs": test_idxs}
		print("NO shuffle before splitting!")

	truth_value_dict = {"dev": {"True": 0, "False": 0}, "test": {"True": 0, "False": 0}}
	for idx, (trans, raw, rels, levyid) in enumerate(zip(trans_fp, raw_fp, rels_fp, rellevy_mapping_fp)):
		truth_value = trans.strip().split('\t')[-1]
		if idx in mapping["dev_idxs"]:
			trans_dev_fp.write(trans.strip('\n')+'\n')
			raw_dev_fp.write(raw.strip('\n')+'\n')
			rels_dev_fp.write(rels.strip('\n')+'\n')
			rellevy_mapping_dev_fp.write(levyid.strip('\n')+'\n')
			if truth_value == 'True':
				truth_value_dict['dev']['True'] += 1
			elif truth_value == 'False':
				truth_value_dict['dev']['False'] += 1
			else:
				raise AssertionError
		else:
			assert idx in mapping["test_idxs"]
			trans_test_fp.write(trans.strip('\n') + '\n')
			raw_test_fp.write(raw.strip('\n') + '\n')
			rels_test_fp.write(rels.strip('\n') + '\n')
			rellevy_mapping_test_fp.write(levyid.strip('\n')+'\n')
			if truth_value == 'True':
				truth_value_dict['test']['True'] += 1
			elif truth_value == 'False':
				truth_value_dict['test']['False'] += 1
			else:
				raise AssertionError

	if mapping_preexists_flag is False:
		assert mapping is not None
		with open(split_mapping_fn, 'w', encoding='utf8') as fp:
			json.dump(mapping, fp, ensure_ascii=False)

	trans_fp.close()
	raw_fp.close()
	rels_fp.close()
	trans_dev_fp.close()
	raw_dev_fp.close()
	rels_dev_fp.close()
	rellevy_mapping_dev_fp.close()
	trans_test_fp.close()
	raw_test_fp.close()
	rels_test_fp.close()
	rellevy_mapping_test_fp.close()

	print("Truth value dict: ")
	print(truth_value_dict)

	print("Finished!")

