import json
import time
import random
from utils.data_preperation_utils import format_cfet_instance_for_hier, construct_typing_instance
import sys
import os


def prepare_parsed_arguments_for_typing(ddparser_path, arg_typing_input_path, arg_typing_input_sliced_path,
										arg_hier_input_path, arg_hier_input_sliced_path,
										argnet_mapping_path, arg_entity_pool_path, arg_prev_res_path,
										found_from_ref_path, num_slices, jia_baseline=0, get_bucket_flag=True):
	THRESHOLD = 3
	entry_count = 0
	input_fp = open(ddparser_path, 'r', encoding='utf8')
	for line in input_fp:
		entry_count += 1

	print("entry count: ", entry_count)
	chunk_size = entry_count // num_slices

	refent_count = 0
	if os.path.exists(arg_prev_res_path):
		ref_fp = open(arg_prev_res_path, 'r', encoding='utf8')
		for line in ref_fp:
			refent_count += 1

		print("Ref entry count: ", refent_count)
		if refent_count != entry_count:
			arg_prev_res_path = None
		ref_fp.close()
	else:
		arg_prev_res_path = None

	global_id = 0
	all_mappings = {}
	entity_bucket = None
	doc_idx = 0
	instance_count = 0
	refs_total_count = 0
	refs_not_used_count = 0

	#cfet_output_fps = [open(arg_typing_input_sliced_path % i, 'w', encoding='utf8') for i in range(num_slices)]
	cfet_output_fns = [arg_typing_input_sliced_path % i for i in range(num_slices)]
	cfet_full_out_fp = open(arg_typing_input_path, 'w', encoding='utf8')
	#hier_output_fps = [open(arg_hier_input_sliced_path % i, 'w', encoding='utf8') for i in range(num_slices)]
	hier_output_fns = [arg_hier_input_sliced_path % i for i in range(num_slices)]
	hier_full_out_fp = open(arg_hier_input_path, 'w', encoding='utf8')

	input_fp.close()
	input_fp = open(ddparser_path, 'r', encoding='utf8')


	# Find out how many times each entity appears in the triples, rule out the entities appearing less than THRESHOLD
	# times from the entity_bucket.
	if get_bucket_flag:
		entity_bucket = {}
		st = time.time()
		for input_line in input_fp:
			if doc_idx % 1000 == 0 and doc_idx > 0:
				ct = time.time()
				dur = ct-st
				dur_h = int(dur) // 3600
				dur_m = (int(dur) % 3600) // 60
				dur_s = int(dur) % 60
				print(f"{doc_idx}; time lapsed: {dur_h} hours {dur_m} minutes {dur_s} seconds; found {len(entity_bucket)} entities!")
			doc = json.loads(input_line)
			for sent_id in range(len(doc['splitted_text'])):
				rels = doc['coarse_rels'][sent_id] + doc['fine_rels'][sent_id] + doc['amend_coarse_rels'][sent_id] \
					   + doc['amend_fine_rels'][sent_id] + doc['crossed_rels'][sent_id] + doc['amend_crossed_rels'][sent_id] \
					   + doc['possible_rels'][sent_id]
				for rel in rels:
					subj = rel[0][0]
					obj = rel[0][2]
					if subj is not None:
						if subj not in entity_bucket:
							entity_bucket[subj] = 1
						else:
							entity_bucket[subj] += 1
					if obj is not None:
						if obj not in entity_bucket:
							entity_bucket[obj] = 1
						else:
							entity_bucket[obj] += 1
			doc_idx += 1
		print(f"Found {len(entity_bucket)} unique arguments in total!")

		# the following code-block has no effect outside its scope!
		pruned_entity_bucket = {}
		for key in entity_bucket:
			if entity_bucket[key] >= THRESHOLD:
				pruned_entity_bucket[key] = entity_bucket[key]
		print(f"{len(pruned_entity_bucket)} unique arguments passed the argument occurance test! (nothing is pruned!)")


	input_fp.close()
	input_fp = open(ddparser_path, 'r', encoding='utf8')
	doc_idx = 0
	ref_fp = open(arg_prev_res_path, 'r', encoding='utf8') if arg_prev_res_path is not None else None
	intermediate_fp = open(found_from_ref_path, 'w', encoding='utf8') if found_from_ref_path is not None else None
	assert arg_prev_res_path is None or (found_from_ref_path is not None)

	st = time.time()
	last_cfet_out_fn = ''
	last_hier_out_fn = ''
	cfet_out_fp = None
	hier_out_fp = None
	for input_line in input_fp:
		if doc_idx % 1000 == 0 or doc_idx == entry_count-1:
			assert instance_count + refs_total_count == global_id
			ct = time.time()
			dur = ct-st
			dur_h = int(dur) / 3600
			dur_m = (int(dur) % 3600) / 60
			dur_s = int(dur) % 60
			print('time lapsed: %d hours %d minutes %d seconds' % (dur_h, dur_m, dur_s))
			print(f"doc_idx: {doc_idx}/{entry_count}; num of instances: {global_id}!")
			print(f"Number of argument types found from reference: {refs_total_count}!")
			print(f"Number of reference argument type instances unused: {refs_not_used_count}!")
		cfet_out_fn = cfet_output_fns[min(doc_idx//chunk_size, (num_slices-1))]
		hier_out_fn = hier_output_fns[min(doc_idx//chunk_size, (num_slices-1))]

		if cfet_out_fn != last_cfet_out_fn:
			assert hier_out_fn != last_hier_out_fn
			if cfet_out_fp is not None:
				cfet_out_fp.close()
				hier_out_fp.close()
			cfet_out_fp = open(cfet_out_fn, 'w', encoding='utf8')
			hier_out_fp = open(hier_out_fn, 'w', encoding='utf8')
			last_cfet_out_fn = cfet_out_fn
			last_hier_out_fn = hier_out_fn
			print(last_cfet_out_fn)
			print(last_hier_out_fn)
		else:
			assert hier_out_fn == last_hier_out_fn

		ref_line = ref_fp.readline() if ref_fp is not None else None
		ref_json = json.loads(ref_line) if ref_line is not None else None
		if ref_json is not None:
			assert len(ref_json) == 1
			ref_list = []
			for ref_key in ref_json:
				ref_list = ref_json[ref_key]
		else:
			ref_list = None
		ref_found_perdoc = []

		cur_doc_mapping = []
		doc = json.loads(input_line)

		assert ref_list is None or len(ref_list) == len(doc['splitted_text'])
		for sent_id in range(len(doc['splitted_text'])):
			sent = doc['splitted_text'][sent_id]
			sent_ref_dict = ref_list[sent_id] if ref_list is not None else {}
			sent_ref_map = {}
			for ref_token in sent_ref_dict:
				for ref_mention in sent_ref_dict[ref_token]:
					ser = f"{ref_token}::{ref_mention[0][0]}::{ref_mention[0][1]}"
					sent_ref_map[ser] = [ref_token, ref_mention]

			if 'ddp_lbls' in doc:
				ddp_tkn = doc['ddp_lbls'][sent_id]['word']
				# map dependency tokens to their spans in the text.
				ddp_tkn_idxs = []
				_tail = 0
				for tkn in ddp_tkn:
					ddp_tkn_idxs.append(_tail)
					_tail += len(tkn)
				ddp_tkn_idxs.append(_tail)  # add a pseudo token representing EOF.
			else:
				ddp_tkn_idxs = None
				if jia_baseline <= 0:
					print("Careful! Are you using the Jia et al baseline? If not, serious error!", file=sys.stderr)

			c_rels = doc['coarse_rels'][sent_id] + doc['amend_coarse_rels'][sent_id] + doc['crossed_rels'][sent_id] + doc['amend_crossed_rels'][sent_id]
			f_rels = doc['fine_rels'][sent_id] + doc['amend_fine_rels'][sent_id] + doc['possible_rels'][sent_id]
			sent_args = []
			sent_arg_idxs = []
			sent_mapping = []

			for rel in c_rels:
				if rel[1] == 'SVO':
					subj = rel[0][0]
					obj = rel[0][2]
					predicate = rel[0][1]
					# subj/obj being not in entity_bucket means they are not frequent enough to be considered.
					if subj is not None and (entity_bucket is None or subj in entity_bucket) and subj not in sent_args:
						sent_args.append(subj)
						sent_arg_idxs.append(None)
					if obj is not None and (entity_bucket is None or obj in entity_bucket) and obj not in sent_args:
						sent_args.append(obj)
						sent_arg_idxs.append(None)

			# for entity pool, only consider arguments of fine-grained relations, because the coarse-grained ones
			# contain too many long mentions that are irrelevant.
			for rel in f_rels:
				if rel[1] == 'SVO':
					subj = rel[0][0]
					obj = rel[0][2]
					if ddp_tkn_idxs is not None:
						subj_span = [ddp_tkn_idxs[rel[2][0]], ddp_tkn_idxs[rel[2][0] + 1]] if rel[2][0] is not None else None
						obj_span = [ddp_tkn_idxs[rel[2][2]], ddp_tkn_idxs[rel[2][2] + 1]] if rel[2][2] is not None else None
					else:
						subj_span = None
						obj_span = None
					predicate = rel[0][1]
					if subj is not None and (entity_bucket is None or subj in entity_bucket) and subj not in sent_args:
						sent_args.append(subj)
						sent_arg_idxs.append(subj_span)
					if obj is not None and (entity_bucket is None or obj in entity_bucket) and obj not in sent_args:
						sent_args.append(obj)
						sent_arg_idxs.append(obj_span)

			assert len(sent_args) == len(sent_arg_idxs)
			ref_found_persent = {}  # {'Txxxx': [ref_token, [[ref_span_st, ref_span_ed], [type_0, ...]]]}
			for arg_id, argument_ in enumerate(sent_args):
				arg_span = sent_arg_idxs[arg_id]
				# return a list of instances and global-ids, meanwhile updates the global-id to current end-of-list
				instances, global_id_lst, global_id, ref_found_dct = construct_typing_instance(argument_, sent, global_id,
																						   arg_span, 'T', refs=sent_ref_map,
																							   suppress_error=True)
				instance_count += len(instances)
				refs_total_count += len(ref_found_dct)
				for ins in instances:
					cfet_line = json.dumps(ins, ensure_ascii=False)
					cfet_out_fp.write(cfet_line+'\n')
					cfet_full_out_fp.write(cfet_line+'\n')
					hier_line = format_cfet_instance_for_hier(ins)
					hier_out_fp.write(hier_line+'\n')
					hier_full_out_fp.write(hier_line+'\n')
				sent_mapping += global_id_lst

				for key in ref_found_dct:
					assert key not in ref_found_persent  # "refs_total_count += len(ref_found_dct)" depends on this!
					ref_found_persent[key] = ref_found_dct[key]

			ref_found_perdoc.append(ref_found_persent)
			cur_doc_mapping.append(sent_mapping)
			if len(ref_found_persent) != len(sent_ref_map):
				refs_not_used_count += 1
				'''
				print("Reference types not all used!")
				print(f"len ref_found_persent: {len(ref_found_persent)}")
				print(f"len sent_ref_map: {len(sent_ref_map)}")
				print("ref_found_persent: ")
				print(ref_found_persent)
				print("sent_ref_map: ")
				print(sent_ref_map)
				print("sent_args: ")
				print(sent_args)
				print("f_rels: ")
				print(f_rels)
				print("c_rels: ")
				print(c_rels)
				print("sent: ")
				print(sent)
				print("")
				'''
		all_mappings[doc_idx] = cur_doc_mapping
		if intermediate_fp is not None:
			ref_found_perdoc_line = json.dumps({doc_idx: ref_found_perdoc})
			intermediate_fp.write(ref_found_perdoc_line+'\n')
		doc_idx += 1

	if get_bucket_flag:
		unique_entities = list(entity_bucket.keys())
		pruned_unique_entities = list(entity_bucket.keys())
		with open(arg_entity_pool_path, 'w', encoding='utf8') as fp:
			json.dump({'ent_list': unique_entities, 'pruned_ent_list': pruned_unique_entities}, fp, indent=4, ensure_ascii=False)
			fp.flush()

	with open(argnet_mapping_path, 'w', encoding='utf8') as fp:
		json.dump(all_mappings, fp, ensure_ascii=False)
		fp.flush()

	cfet_out_fp.close()
	cfet_full_out_fp.close()
	hier_out_fp.close()
	hier_full_out_fp.close()
	input_fp.close()
	if ref_fp is not None:
		ref_fp.close()
	if intermediate_fp is not None:
		intermediate_fp.close()

	print("Done.")
	print("prepare_parsed_arguments_for_typing function finished successfully!")


# Extracts the ## first-layer ## named entity typing results of the arguments from the NET_result files, and (optionally)
# store them in one file.
def merge_argnet_results(argnet_thresholded_instances_path, arg_typing_result_sliced_path,
						 threshold, num_slices):

	random.seed(time.time())

	# full_results_instances = []
	sample_results_instances = []
	num_pred_bucket = {}
	# one entity mention might have up to 5 positive type predictions.
	for i in range(5):
		num_pred_bucket[i] = 0

	global_thresholded_out_count = 0

	print(f"Dumping full_results_instances to {argnet_thresholded_instances_path}......")
	if argnet_thresholded_instances_path is not None:
		merged_fp = open(argnet_thresholded_instances_path, 'w', encoding='utf8')
	else:
		merged_fp = None

	st = time.time()
	for i in range(num_slices):
		print("Reading arg_net result number %s" % str(i))
		with open(arg_typing_result_sliced_path % str(i), 'r', encoding='utf8') as fp:
			lidx = 0
			for line in fp:
				if lidx % 100000 == 0 and lidx > 0:
					ct = time.time()
					dur = ct - st
					dur_h = int(dur) // 3600
					dur_m = (int(dur) % 3600) // 60
					dur_s = int(dur) % 60
					print(f"{lidx}; time lapsed: {dur_h} hours {dur_m} minutes {dur_s} seconds")
				instance = json.loads(line)
				instance_predictions = instance['ins_types']
				instance_pred_idxs = instance['ins_tids']
				instance_scores = instance['ins_score']
				instance_first_layer_idxs = []
				for pred_id in instance_pred_idxs:
					if pred_id < 50:
						instance_first_layer_idxs.append(pred_id)
				if len(instance_first_layer_idxs) == 1 and instance_scores[instance_first_layer_idxs[0]] < threshold:
					instance_predictions = []
					instance_pred_idxs = []
					instance['ins_types'] = []
					instance['ins_tids'] = []
					global_thresholded_out_count += 1
				del instance['ins_score']  # also commented out the del command after the random.random() line
				if merged_fp is not None:
					out_line = json.dumps(instance)
					merged_fp.write(out_line+'\n')
				if random.random() < 0.000001:
					# del instance['ins_score']
					sample_results_instances.append(instance)
				# del instance['ins_score']
				# instance['sent_len'] = len(instance['ins_sent'])
				# del instance['ins_sent']
				# full_results_instances.append(instance)
				num_pred_bucket[len(instance_predictions)] += 1
				lidx += 1
	print(f"A total number of {global_thresholded_out_count} instances were thresholded out!")
	if merged_fp is not None:
		merged_fp.close()
	print("Done.")

	print("Number of predictions bucket: ")
	print(num_pred_bucket)

	for ins in sample_results_instances:
		print(ins)

	# print("Loading ddparser_entries......")
	# ddparser_entries = []
	# with open(ddparser_path, 'r', encoding='utf8') as fp:
	# 	lidx = 0
	# 	for line in fp:
	# 		if lidx % 10000 == 0 and lidx > 0:
	# 			print(lidx)
	# 		item = json.loads(line)
	# 		ddparser_entries.append({'splitted_text': item['splitted_text']})
	# 		lidx += 1
	# print("DDParser results loaded!")


def create_indoc_file_argnet_results(argnet_thresholded_instances_path, argnet_mapping_path, ddparser_path,
									 argnet_intermediate_path, argnet_documented_path):
	global_thresholded_out_count = 0
	with open(argnet_mapping_path, 'r', encoding='utf8') as fp:
		argnet_mapping = json.load(fp)
	# argnet_mapping = dict(sorted(argnet_mapping.items(), key=lambda item: int(item[0])))
	print("ARGNET mapping loaded!")

	print("Creating per-document file......")
	# {doc_i: [{arg_i:[(span, predictions), ...]}, {}, ...]}
	mapping_cnt = 0
	len_mismatch_trashed_cnt = 0
	tok_mismatch_trashed_cnt = 0
	len_fp = open('len_mismatch_log.txt', 'w', encoding='utf8')
	tok_fp = open('tok_mismatch_log.txt', 'w', encoding='utf8')
	doced_fp = open(argnet_documented_path, 'w', encoding='utf8')
	intermediate_fp = open(argnet_intermediate_path, 'r', encoding='utf8') if argnet_intermediate_path is not None else None

	ddp_fp = open(ddparser_path, 'r', encoding='utf8')
	merged_fet_fp = open(argnet_thresholded_instances_path, 'r', encoding='utf8')
	ddp_line_idx = 0
	fet_line_idx = -1
	cur_instance = None
	copied_ent_ids = []

	for doc_key in argnet_mapping:
		if mapping_cnt % 1000 == 0:
			print(mapping_cnt, ' / ', len(argnet_mapping))
		doc_mapping = argnet_mapping[doc_key]
		doc_result = []
		doc_key = int(doc_key)
		assert doc_key == ddp_line_idx
		ddp_entry = ddp_fp.readline()
		ddp_entry = json.loads(ddp_entry)
		intermediate_entry = intermediate_fp.readline() if intermediate_fp is not None else None
		intermediate_entry = json.loads(intermediate_entry) if intermediate_entry is not None else None
		assert intermediate_entry is None or len(intermediate_entry) == 1
		if intermediate_entry is not None:
			for key in intermediate_entry:
				assert int(key) == int(doc_key)
				intermediate_entry = intermediate_entry[key]  # [{'Txxxx': [ref_token, [[ref_span_st, ref_span_ed], [type_0, ...]]], ...}, ...]

		splitted_text = ddp_entry['splitted_text']
		assert intermediate_entry is None or len(intermediate_entry) == len(splitted_text)
		for sent_id, sent_mapping in enumerate(doc_mapping):
			sent_result = {}
			sent_refs = intermediate_entry[sent_id] if intermediate_entry is not None else {}
			for ref_key in sent_refs:
				copied_ent_ids.append(int(ref_key.strip('T')))
				ref_tok = sent_refs[ref_key][0]
				ref_span = sent_refs[ref_key][1][0]
				ref_types = sent_refs[ref_key][1][1]
				if ref_tok not in sent_result:
					sent_result[ref_tok] = [(ref_span, ref_types)]
				else:
					sent_result[ref_tok].append((ref_span, ref_types))

			canonical_sent = splitted_text[sent_id]
			for mention_id in sent_mapping:
				res_id = int(mention_id.strip('T'))
				assert res_id > fet_line_idx
				first_increment = True
				while res_id > fet_line_idx:
					if fet_line_idx not in copied_ent_ids:
						cur_instance = merged_fet_fp.readline()
						if not first_increment:
							print("!!")
						else:
							first_increment = False
					else:
						copied_ent_ids.remove(fet_line_idx)
					fet_line_idx += 1
				cur_instance = json.loads(cur_instance)
				mention_name = cur_instance['ins_span_text']
				mention_name = ''.join(mention_name)
				mention_preds = cur_instance['ins_types']
				mention_sent = ''.join(cur_instance['ins_sent'])
				mention_span = cur_instance['ins_span']
				if mention_sent != canonical_sent:
					print("sentence mismatch: ", file=len_fp)
					print(f"mention: {mention_name}", file=len_fp)
					print(f"mention span: {mention_span}", file=len_fp)
					print(f"canonical sentence: {canonical_sent}", file=len_fp)
					print(f"mention sent: {mention_sent}", file=len_fp)
					print("", file=len_fp)
					# time.sleep(10)
					len_mismatch_trashed_cnt += 1
					continue
				elif mention_name != canonical_sent[mention_span[0]:mention_span[1]]:
					print("sentence mismatch: ", file=tok_fp)
					print(f"mention: {mention_name}", file=tok_fp)
					print(f"mention span: {mention_span}", file=tok_fp)
					print(f"canonical sentence: {canonical_sent}", file=tok_fp)
					print("", file=tok_fp)
					# time.sleep(10)
					tok_mismatch_trashed_cnt += 1
					continue
				if len(mention_preds) > 0:  # currently all mentions are predicted at least one type, unless thresholded out (since it's [top, other])
					if mention_name not in sent_result:
						sent_result[mention_name] = [(mention_span, mention_preds)]
					else:
						sent_result[mention_name].append((mention_span, mention_preds))
				else:
					global_thresholded_out_count += 1
			doc_result.append(sent_result)
		dump_item = {doc_key: doc_result}
		dump_line = json.dumps(dump_item, ensure_ascii=False)
		doced_fp.write(dump_line+'\n')
		#if debug > 0:
		#	for sent_id in range(len(doc_result)):
		#		ori_sent = ddparser_results[int(doc_key)]['splitted_text'][sent_id]
		#		print(ori_sent)
		#		print(doc_result[sent_id])
		#		time.sleep(3)
		mapping_cnt += 1
		ddp_line_idx += 1

	print(f"global_thresholded_out_count at final: {global_thresholded_out_count}")
	print(f"len_mismatch_trashed_cnt: {len_mismatch_trashed_cnt}")
	print(f"tok_mismatch_trashed_cnt: {tok_mismatch_trashed_cnt}")
	len_fp.close()
	tok_fp.close()
	doced_fp.close()
	ddp_fp.close()
	merged_fet_fp.close()
	print("Done!")


# DEPRECATED
def merge_nernet_results(nernet_sliced_result_path, nernet_mapping_path, ddparser_path, debug,
						 nernet_documented_path, num_slices):

	full_results_instances = []
	num_pred_bucket = {}
	for i in range(num_slices):
		num_pred_bucket[i] = 0

	for i in range(num_slices):
		with open(nernet_sliced_result_path % str(i), 'r', encoding='utf8') as fp:
			lines = fp.readlines()

		for lidx, line in enumerate(lines):
			instance = json.loads(line)
			num_pred_bucket[len(instance['preds'])] += 1
			full_results_instances.append(instance)

	with open(nernet_mapping_path, 'r', encoding='utf8') as fp:
		nernet_mapping = json.load(fp)

	with open(ddparser_path, 'r', encoding='utf8') as fp:
		ddparser_results = []
		lidx = 0
		for line in fp:
			if lidx % 10000 == 0 and lidx > 0:
				print(lidx)
			item = json.loads(line)
			ddparser_results.append(item)
			lidx += 1

	nernet_documented = {}
	mapping_cnt = 0
	for doc_key in nernet_mapping:
		if mapping_cnt % 1000 == 0 and mapping_cnt > 0:
			print("%d / %d" % (mapping_cnt, len(nernet_mapping)))
		doc_mapping = nernet_mapping[str(doc_key)]
		doc_result = []
		for sent_mapping in doc_mapping:
			sent_result = {}
			for mention_id in sent_mapping:
				assert mention_id[0] == 'R'
				res_id = int(mention_id.strip('R'))
				mention_name = full_results_instances[res_id]['mention']
				mention_preds = full_results_instances[res_id]['preds']
				if len(mention_preds) > 0:
					sent_result[mention_name] = mention_preds
			doc_result.append(sent_result)
		nernet_documented[doc_key] = doc_result
		if debug > 0:
			for sent_id in range(len(doc_result)):
				ori_sent = ddparser_results[int(doc_key)]['splitted_text'][sent_id]
				print(ori_sent)
				print(doc_result[sent_id])
				time.sleep(3)
		mapping_cnt += 1

	with open(nernet_documented_path, 'w', encoding='utf8') as fp:
		# this is a standalone file without texts or triples, not big enough to be dumped in lines
		json.dump(nernet_documented, fp, ensure_ascii=False)
		fp.flush()
	print("Done!")