import json
import macropodus
import sys
import random


# returns all mentions of this argument within 5 tokens to the arg_span specification.
def construct_typing_instance(argument, sent, global_id, arg_span, prefix, window_size=5, refs=None, suppress_error=False):
	prior_global_id = global_id
	start = 0
	instances = []
	captured_refs = {}
	global_id_lst = []

	while True:
		begin_index = sent.find(argument, start)
		if begin_index == -1:
			if start == 0:
				randnum = random.random()
				if not suppress_error or abs(randnum) < 0.0001:
					print(f"Argument not found! {argument}; {sent}", file=sys.stderr)
			break
		start = begin_index + len(argument)
		if arg_span is not None and abs(arg_span[0]-begin_index) > window_size:
			continue
		span = [begin_index, start]
		serialized_instance = f"{argument}::{begin_index}::{start}"
		if refs is None or serialized_instance not in refs:
			instance = {'figer_types_first_dict': {},
						'figer_types_first_list': [],
						'figer_types_both_dict': {},
						'figer_types_both_list': [],
						'label_types': [],
						'general_type': [],
						'mention': argument,
						'span': span,
						'sentence': sent,
						'mention_id': prefix+str(global_id)}
			instances.append(instance)
			global_id_lst.append(prefix+str(global_id))
		else:
			captured_refs[prefix+str(global_id)] = refs[serialized_instance]
		global_id += 1
	assert len(instances) + len(captured_refs) == global_id - prior_global_id
	if refs is None:
		return instances, global_id_lst, global_id
	else:
		return instances, global_id_lst, global_id, captured_refs


def format_cfet_instance_for_hier(instance):
	def check_length_overflow(item):
		if len(item['sentence']) < 500:
			return
		else:
			print("Length of sentence exceeding 500! Stopping......")
			raise AssertionError

	check_length_overflow(instance)
	sentence = instance['sentence']
	sentence = sentence.replace('\t', ' ')
	span = instance['span']
	type_list = ['person']
	for type_idx in range(len(type_list)):
		t_str = type_list[type_idx]
		if 'None' in t_str:
			if t_str[-5:] != '/None':
				print(t_str)
				raise AssertionError
			t_str = t_str[:-5]
		t_str = '/' + t_str
		type_list[type_idx] = t_str
	out_line = f'{sentence}\t{span[0]}:{span[1]}\t' + ' '.join(type_list)
	return out_line


# DEPRECATED
def format_net_input_for_hier(net_to_convert_path, net_converted_path, net_tsv_mapping_path, is_predict):

	def check_length_overflow(item):
		if len(item['sentence']) < 500:
			return item
		elif item['span'][0] < 500 and item['span'][1] < 500:
			item['sentence'] = item['sentence'][:500]
			return item
		else:
			print("Found entry to skip!")
			return "SKIP"

	input_fn = net_to_convert_path

	with open(input_fn, 'r', encoding='utf8') as fp:
		net_entries = []
		lidx = 0
		for line in fp:
			if lidx % 10000 == 0 and lidx > 0:
				print(lidx)
			entry = json.loads(line)
			net_entries.append(entry)
			lidx += 1

	line_mapping = []

	skipped_count = 0

	out_fp = open(net_converted_path, 'w', encoding='utf8')
	for iid, item in enumerate(net_entries):
		if iid % 1000 == 0 and iid > 0:
			print(f"Processing {iid}/{len(net_entries)}")

		item = check_length_overflow(item)
		if item == "SKIP":
			skipped_count += 1
			continue
		sentence = item['sentence']
		sentence = sentence.replace('\t', ' ')
		span = item['span']
		type_list = item['figer_types_both_list']

		for type_idx in range(len(type_list)):
			t_str = type_list[type_idx]
			if 'None' in t_str:
				if t_str[-5:] != '/None':
					print(t_str)
					raise AssertionError
				t_str = t_str[:-5]
			t_str = '/'+t_str
			type_list[type_idx] = t_str

		line = f'{sentence}\t{span[0]}:{span[1]}\t' + ' '.join(type_list) + '\n'
		out_fp.write(line)
		out_fp.flush()
		out_fp.close()
		out_fp = open(net_converted_path, 'a', encoding='utf8')
		line_mapping.append(item['mention_id'])
	out_fp.close()
	print(f"Entries from {net_to_convert_path} formatted and stored in {net_converted_path}")

	if is_predict:
		print(f"Storing entry mention_id to line number mapping file to {net_tsv_mapping_path}")
		with open(net_tsv_mapping_path, 'w', encoding='utf8') as fp:
			json.dump(line_mapping, fp, ensure_ascii=False)
		print("Done.")

	print(f"{skipped_count} entries were skipped because of max_length overflow (span index out of range).")
	print("Finished.")


class EvalPruner(object):
	def __init__(self):
		self.same_pred_count = 0
		self.similar_count = 0
		self.dissimilar_count = 0

	def eval_pruning(self, premise, hypo, threshold):
		# prune the instances where predicates in premise and hypothesis are the same, or arguments in premise and hypothesis are not similar enough
		if premise[1] == hypo[1]:
			self.same_pred_count += 1
			return False
		if premise[0] != hypo[0]:
			subj_sim = macropodus.sim(premise[0], hypo[0])
			if subj_sim < threshold:
				self.dissimilar_count += 1
				return False
			else:
				self.similar_count += 1
		if premise[2] != hypo[2]:
			obj_sim = macropodus.sim(premise[2], hypo[2])
			if obj_sim < threshold:
				self.dissimilar_count += 1
				return False
			else:
				self.similar_count += 1
		return True