import csv
import json
import random
from difflib import SequenceMatcher
from utils.data_preperation_utils import construct_typing_instance, format_cfet_instance_for_hier, EvalPruner


def fetch_blurry_span(sent, mention, rev=False):
	if rev:
		start = sent.rfind(mention)
	else:
		start = sent.find(mention)
	if start >= 0:
		return [start, start+len(mention)], mention

	cand_spans = []
	for i in range(len(mention)):
		if i == 0:
			continue
		if rev:
			start = sent.rfind(mention[:i])
			end = sent.rfind(mention[i:])+len(mention[i:])
		else:
			start = sent.find(mention[:i])
			end = sent.rfind(mention[i:])+len(mention[i:])
		if 0 <= start < end and end >= 0:
			cand_spans.append([start, end])

	if len(cand_spans) == 0:
		if rev:
			if '占位符' in mention:
				return [len(sent)-1, len(sent)], sent[-1:]
			else:
				return [len(sent)-len(mention), len(sent)], sent[len(sent)-len(mention):]
		else:
			if '占位符' in mention:
				return [0, 1], sent[:1]
			else:
				return [0, len(mention)], sent[:len(mention)]
	elif len(cand_spans) == 1:
		return cand_spans[0], sent[cand_spans[0][0]:cand_spans[0][1]]
	else:
		max_sp_len = 0
		best_sp = None
		for sp in cand_spans:
			if sp[1]-sp[0] > max_sp_len:
				max_sp_len = sp[1]-sp[0]
				best_sp = sp
		return best_sp, sent[best_sp[0]:best_sp[1]]


def prepare_levy_for_typing(levy_input_path, levy_raw_path, levy_arguments_path_cfet, levy_arguments_path_hier, levy_arguments_mapping_path):
	input_fp = open(levy_input_path, 'r', encoding='utf8')
	input_raw_fp = open(levy_raw_path, 'r', encoding='utf8')
	tsv_file = csv.reader(input_fp, delimiter="\t")
	tsv_raw_file = csv.reader(input_raw_fp, delimiter="\t")

	typings_dcts = []
	typings_hier = []
	mapping = []

	for iid, (instance, instance_raw) in enumerate(zip(tsv_file, tsv_raw_file)):
		if iid % 100 == 0 and iid > 0:
			print(iid)
		assert len(instance) == 3
		premise = instance[0].split(',')
		premise = [x.strip() for x in premise]
		assert len(premise) == 3
		premise_sent = instance_raw[0]
		premise_subj_span, premise[0] = fetch_blurry_span(premise_sent, premise[0], rev=False)
		premise_obj_span, premise[2] = fetch_blurry_span(premise_sent, premise[2], rev=True)
		hypo = instance[1].split(',')
		hypo = [x.strip() for x in hypo]
		assert len(hypo) == 3
		hypo_sent = instance_raw[1]
		hypo_subj_span, hypo[0] = fetch_blurry_span(hypo_sent, hypo[0], rev=False)
		hypo_obj_span, hypo[2] = fetch_blurry_span(hypo_sent, hypo[2], rev=True)
		base_id = len(typings_dcts)
		premise_subj_instance, _, _ = construct_typing_instance(premise[0], premise_sent, base_id+0, premise_subj_span, 'L', 0)
		premise_obj_instance, _, _ = construct_typing_instance(premise[2], premise_sent, base_id+1, premise_obj_span, 'L', 0)
		hypo_subj_instance, _, _ = construct_typing_instance(hypo[0], hypo_sent, base_id+2, hypo_subj_span, 'L', 0)
		hypo_obj_instance, _, _ = construct_typing_instance(hypo[2], hypo_sent, base_id+3, hypo_obj_span, 'L', 0)
		assert len(premise_subj_instance) == 1 and len(premise_obj_instance) == 1 and len(hypo_subj_instance) == 1 and len(hypo_obj_instance) == 1
		premise_subj_instance = premise_subj_instance[0]
		premise_obj_instance = premise_obj_instance[0]
		hypo_subj_instance = hypo_subj_instance[0]
		hypo_obj_instance = hypo_obj_instance[0]
		premise_subj_hier_line = format_cfet_instance_for_hier(premise_subj_instance)
		premise_obj_hier_line = format_cfet_instance_for_hier(premise_obj_instance)
		hypo_subj_hier_line = format_cfet_instance_for_hier(hypo_subj_instance)
		hypo_obj_hier_line = format_cfet_instance_for_hier(hypo_obj_instance)

		typings_dcts.append(premise_subj_instance)
		typings_dcts.append(premise_obj_instance)
		typings_dcts.append(hypo_subj_instance)
		typings_dcts.append(hypo_obj_instance)

		typings_hier.append(premise_subj_hier_line)
		typings_hier.append(premise_obj_hier_line)
		typings_hier.append(hypo_subj_hier_line)
		typings_hier.append(hypo_obj_hier_line)

		mapping.append((base_id, base_id+1, base_id+2, base_id+3))

	with open(levy_arguments_path_cfet, 'w', encoding='utf8') as fp:
		for item in typings_dcts:
			line = json.dumps(item, ensure_ascii=False)
			fp.write(line+'\n')

	with open(levy_arguments_path_hier, 'w', encoding='utf8') as fp:
		for line in typings_hier:
			fp.write(line+'\n')

	with open(levy_arguments_mapping_path, 'w', encoding='utf8') as fp:
		json.dump(mapping, fp, ensure_ascii=False)

	input_fp.close()
	input_raw_fp.close()
	print("Finished!")


def format_levy_for_evaluation(levy_input_path, levy_typing_output_path, levy_arguments_mapping_path, levy_output_path,
							   rellevy_mapping_path, typedrellevy_mapping_path, translation_consistency_threshold,
							   use_all_types):

	typed_arguments = []
	typed_arguments_types_fl = []  # first layer types, excluding the "None" type.
	out_lines = []

	with open(levy_typing_output_path, 'r', encoding='utf8') as fp:
		for line in fp:
			item = json.loads(line)
			typed_arguments.append(item)

	# normalize the typings, find out the first-layer type sets.
	for item in typed_arguments:
		new_ins_ts = []
		new_ins_ts_first_layer = []
		for t in item['ins_types']:
			if len(t) == 0:
				continue
			else:
				new_ins_ts.append(t.strip('/'))
				new_ins_ts_first_layer.append(t.split('/')[1])
		new_ins_ts_first_layer = list(set(new_ins_ts_first_layer))  # remove repeatitive types
		new_ins_ts = list(set(new_ins_ts))  # remove repeatitive types
		item['ins_types'] = new_ins_ts
		typed_arguments_types_fl.append(new_ins_ts_first_layer)
	assert len(typed_arguments) == len(typed_arguments_types_fl)

	with open(levy_arguments_mapping_path, 'r', encoding='utf8') as fp:
		mapping = json.load(fp)

	with open(rellevy_mapping_path, 'r', encoding='utf8') as fp:
		rellevy_mapping = []
		for line in fp:
			rellevy_mapping.append(int(line.strip()))
	input_fp = open(levy_input_path, 'r', encoding='utf8')
	tsv_file = csv.reader(input_fp, delimiter="\t")

	included_raw_instance_count = 0
	reversed_order_count = 0
	pruner = EvalPruner()
	iid = 0
	out_line_id = 0  # each instance could be converted into multiple out_lines, as there could be multiple type ids with which an argument pair is sensible.

	typed_rel_levy_mapping = []

	for instance in tsv_file:
		ins_mapping = mapping[iid]
		premise = instance[0].split(',')
		premise = [x.strip() for x in premise]
		hypo = instance[1].split(',')
		hypo = [x.strip() for x in hypo]
		value = instance[2].strip()
		levy_id = rellevy_mapping[iid]
		# print("!!!!!!!!!!!!!!!!!!!!!!!!!!!")
		if not pruner.eval_pruning(premise, hypo, translation_consistency_threshold):
			pass  # Now the entries are kept even when they fail this examination.
			# continue

		reversed = check_reversed(premise, hypo)
		prem_subj_types = set(typed_arguments_types_fl[ins_mapping[0]])
		prem_obj_types = set(typed_arguments_types_fl[ins_mapping[1]])
		hypo_subj_types = set(typed_arguments_types_fl[ins_mapping[2]])
		hypo_obj_types = set(typed_arguments_types_fl[ins_mapping[3]])

		prem_0 = premise[0]
		prem_2 = premise[2]
		if reversed:
			hypo_0 = hypo[2]
			hypo_2 = hypo[0]
			hypo_t_0 = hypo_obj_types
			hypo_t_2 = hypo_subj_types
			reversed_order_count += 1
		else:
			hypo_0 = hypo[0]
			hypo_2 = hypo[2]
			hypo_t_0 = hypo_subj_types
			hypo_t_2 = hypo_obj_types

		if '占位符' in premise[0]:
			subj_types = list(hypo_t_0)
			obj_types = list(hypo_t_2)
		elif '占位符' in hypo[0]:
			subj_types = list(prem_subj_types)
			obj_types = list(prem_obj_types)
		else:
			subj_types = list(prem_subj_types.intersection(hypo_t_0))
			obj_types = list(prem_obj_types.intersection(hypo_t_2))

		if prem_0 != hypo_0:
			if '占位符' in prem_0:
				subj_name = hypo_0
			elif '占位符' in hypo_0:
				subj_name = prem_0
			else:
				subj_name = random.choice([prem_0, hypo_0])
			if subj_name == prem_0 and len(subj_types) == 0:
				subj_types = list(prem_subj_types)
			elif subj_name == hypo_0 and len(subj_types) == 0:
				subj_types = list(hypo_t_0)
		else:
			subj_name = prem_0
			if len(subj_types) == 0:
				subj_types = list(prem_subj_types.union(hypo_t_0))
		if prem_2 != hypo_2:
			if '占位符' in prem_2:
				obj_name = hypo_2
			elif '占位符' in hypo_2:
				obj_name = prem_2
			else:
				obj_name = random.choice([prem_2, hypo_2])
			if obj_name == prem_2 and len(obj_types) == 0:
				obj_types = list(prem_obj_types)
			elif obj_name == hypo_2 and len(obj_types) == 0:
				obj_types = list(hypo_t_2)
		else:
			obj_name = prem_2
			if len(obj_types) == 0:
				obj_types = list(prem_obj_types.union(hypo_t_2))
		assert len(subj_types) > 0 and len(obj_types) > 0

		subj_t = random.choice(subj_types)
		obj_t = random.choice(obj_types)
		if use_all_types:
			for subj_t in subj_types:
				for obj_t in obj_types:
					if reversed:
						out_line = f"({premise[1]}.1,{premise[1]}.2) {subj_name}::{subj_t} {obj_name}::{obj_t}	" \
									f"({hypo[1]}.1,{hypo[1]}.2) {obj_name}::{obj_t} {subj_name}::{subj_t}	{value}"
					else:
						out_line = f"({premise[1]}.1,{premise[1]}.2) {subj_name}::{subj_t} {obj_name}::{obj_t}	" \
								   f"({hypo[1]}.1,{hypo[1]}.2) {subj_name}::{subj_t} {obj_name}::{obj_t}	{value}"
					out_lines.append(out_line)
					typed_rel_levy_mapping.append(levy_id)
					out_line_id += 1
		else:
			subj_t = random.choice(subj_types)
			obj_t = random.choice(obj_types)
			if reversed:
				out_line = f"({premise[1]}.1,{premise[1]}.2) {subj_name}::{subj_t} {obj_name}::{obj_t}	" \
						   f"({hypo[1]}.1,{hypo[1]}.2) {obj_name}::{obj_t} {subj_name}::{subj_t}	{value}"
			else:
				out_line = f"({premise[1]}.1,{premise[1]}.2) {subj_name}::{subj_t} {obj_name}::{obj_t}	" \
						   f"({hypo[1]}.1,{hypo[1]}.2) {subj_name}::{subj_t} {obj_name}::{obj_t}	{value}"
			out_lines.append(out_line)
			typed_rel_levy_mapping.append(levy_id)
			out_line_id += 1

		# for subj_t in subj_types:
		# 	for obj_t in obj_types:
		#		out_line = f"({premise[1]}.1,{premise[1]}.2) {subj_name}::{subj_t} {obj_name}::{obj_t}	" \
		#				   f"({hypo[1]}.1,{hypo[1]}.2) {subj_name}::{subj_t} {obj_name}::{obj_t}	{value}"
		#		out_lines.append(out_line)

		included_raw_instance_count += 1
		iid += 1

	assert iid == len(rellevy_mapping)

	print("Shuffling is turned off!")
	# random.shuffle(out_lines)
	# dev_out_lines = out_lines[:len(out_lines)//2]
	# test_out_lines = out_lines[len(out_lines)//2:]

	with open(levy_output_path%'all', 'w', encoding='utf8') as fp:
		for line in out_lines:
			fp.write(line+'\n')

	with open(typedrellevy_mapping_path, 'w', encoding='utf8') as fp:
		json.dump(typed_rel_levy_mapping, fp, indent=4, ensure_ascii=False)

	# with open(levy_output_path%'dev', 'w', encoding='utf8') as fp:
	# 	for line in dev_out_lines:
	# 		fp.write(line+'\n')

	# with open(levy_output_path%'test', 'w', encoding='utf8') as fp:
	# 	for line in test_out_lines:
	# 		fp.write(line+'\n')

	print(f"Evaluation data are formatted properly and stored in {levy_output_path}!")

	print(f"{pruner.same_pred_count} entries pruned due to same predicates; {pruner.similar_count+pruner.dissimilar_count} " \
		  f"unmatched argument pairs, of which {pruner.similar_count} are considered similar, {pruner.dissimilar_count} " \
		  f"are considered dissimilar!")
	print(f"Number of raw Evaluation Dataset entries included: {included_raw_instance_count}!")
	print(f"Reversed order count: {reversed_order_count}")
	print(f"{iid} rel pair entries, converted into {out_line_id} typed rel pair entries!")
	print("Finished!")


def align_raw_with_typed_rels(levy_input_path, levy_raw_path, typedrellevy_mapping_path):
	input_fp = open(levy_input_path, 'r', encoding='utf8')
	raw_fp = open(levy_raw_path, 'r', encoding='utf8')
	with open(typedrellevy_mapping_path, 'r', encoding='utf8') as t_mapping_fp:
		mapping_list = json.load(t_mapping_fp)
	assert levy_input_path[-4:] == '.tsv'
	typed_input_path = levy_input_path.strip('.tsv') + '_typed.tsv'
	assert levy_raw_path[-4:] == '.tsv'
	typed_raw_path = levy_raw_path.strip('.tsv') + '_typed.tsv'
	typed_input_fp = open(typed_input_path, 'w', encoding='utf8')
	typed_raw_fp = open(typed_raw_path, 'w', encoding='utf8')

	last_id = -1
	last_inp = None
	last_raw = None
	for item in mapping_list:
		if item > last_id:
			last_id = item
			last_inp = input_fp.readline()
			last_raw = raw_fp.readline()
		else:
			assert last_inp is not None
			assert last_raw is not None

		typed_input_fp.write(last_inp.strip('\n') + '\n')
		typed_raw_fp.write(last_raw.strip('\n') + '\n')

	input_fp.close()
	raw_fp.close()
	typed_input_fp.close()
	typed_raw_fp.close()


def get_placeholder_types(levy_arguments_path_hier, levy_typing_output_path):
	print(f"Building placeholder FET output files for {levy_arguments_path_hier}, storing to {levy_typing_output_path}!")

	hier_input_fp = open(levy_arguments_path_hier, 'r', encoding='utf8')
	levy_output_fp = open(levy_typing_output_path, 'w', encoding='utf8')
	idx = 0

	for line in hier_input_fp:
		sent, span, plchdr = line.split('\t')
		span_st, span_ed = [int(x) for x in span.split(':')]
		ins_types = ["", "/person"]
		ins_tids = [0, 37]
		ins_score = [round(random.uniform(-0.5, 0.5), 5) for i in range(152)]
		for tid in ins_tids:
			ins_score[tid] = 2.33
		ins_span_text = sent[span_st:span_ed]
		entry = {
			"ins_types": ins_types,
			"ins_sent": [tok for tok in sent],
			"ins_span": [span_st, span_ed],
			"ins_span_text": [tok for tok in ins_span_text],
			"ins_score": ins_score,
			"ins_tids": ins_tids
		}
		out_line = json.dumps(entry, ensure_ascii=False)
		levy_output_fp.write(out_line+'\n')
		idx += 1
	hier_input_fp.close()
	levy_output_fp.close()

	print(f"Total number of entries: {idx}")
	print("Finished!")


def check_reversed(premise, hypo):
	match_20 = SequenceMatcher(None, premise[2], hypo[0]).find_longest_match(0, len(premise[2]), 0, len(hypo[0]))  # {a: xx, b: xx, size: xx}
	match_02 = SequenceMatcher(None, premise[0], hypo[2]).find_longest_match(0, len(premise[0]), 0, len(hypo[2]))
	match_00 = SequenceMatcher(None, premise[0], hypo[0]).find_longest_match(0, len(premise[0]), 0, len(hypo[0]))
	match_22 = SequenceMatcher(None, premise[2], hypo[2]).find_longest_match(0, len(premise[2]), 0, len(hypo[2]))
	threshold = 0.75  # if the longest match substring exceeds the threshold proportion of the length of both arguments, then we call it a reversed match.

	if premise[2] == hypo[0] or premise[0] == hypo[2]:
		if premise[0] != hypo[0] and premise[2] != hypo[2]:
			return True
		else:
			return False
	elif float(match_20.size)/float(len(premise[2])) > threshold and float(match_20.size)/float(len(hypo[0])) > threshold:
		if match_20.size > match_22.size and match_20.size > match_00.size:
			print("prem: ", premise)
			print("hypo: ", hypo)
			return True
		else:
			return False
	elif float(match_02.size)/float(len(premise[0])) > threshold and float(match_02.size)/float(len(hypo[2])) > threshold:
		if match_02.size > match_00.size and match_02.size > match_22.size:
			print("prem: ", premise)
			print("hypo: ", hypo)
			return True
		else:
			return False
	else:
		return False

