import json
import time
from utils.data_preperation_utils import format_cfet_instance_for_hier


def prepare_corpus_for_ner(ddparser_path, ner_input_path, raw_ner_mapping_path, num_slices):

	num_entries = 0
	with open(ddparser_path, 'r', encoding='utf8') as fp:
		for line in fp:
			num_entries += 1
	print(f"Num entries: {num_entries}")

	global_sent_id = 0
	sent_id_mapping = {}
	print("Saving input file to %s" % ner_input_path)
	fps = [open(ner_input_path % str(slice_id), 'w', encoding='utf8') for slice_id in range(num_slices)]
	full_fp = open(ner_input_path%'full', 'w', encoding='utf8')
	slice_size = num_entries // num_slices
	input_fp = open(ddparser_path, 'r', encoding='utf8')
	for doc_id, document_line in enumerate(input_fp):
		document = json.loads(document_line)

		slice_id = doc_id // slice_size
		fp = fps[min(slice_id, (num_slices-1))]
		if doc_id % 1000 == 0 and doc_id > 0:
			print("Processed through to document %d / %d" % (doc_id, num_entries))
		sents = document['splitted_text']
		for sent_id, sent in enumerate(sents):
			sent = sent.strip()
			sent_id_mapping[global_sent_id] = {'doc_id': doc_id, 'sent_id': sent_id}
			global_sent_id += 1
			for c in sent:
				if len(c.strip()) == 0:
					continue
				fp.write(c+'	'+'O'+'\n')
				full_fp.write(c+'	'+'O'+'\n')
			fp.write('\n')
			full_fp.write('\n')
			fp.flush()
			full_fp.flush()
	print("Done.")

	for fp in fps:
		fp.close()
	full_fp.close()
	input_fp.close()

	print("Saving mapping file to %s" % raw_ner_mapping_path)
	with open(raw_ner_mapping_path, 'w', encoding='utf8') as fp:
		json.dump(sent_id_mapping, fp, ensure_ascii=False)
		fp.flush()
	print("Done.")
	print("Finished!")


# prepare flatner resuolts into a format suit for NET inference
def merge_flat_ner_results_in_doc(flatner_result_sliced_path, ner_input_path, raw_ner_mapping_path,
								  checkraw, ddparser_path, debug, flat_ner_result_indoc_path,
								  num_slices):

	num_entities_starting_with_I = 0
	num_unaligned_sentences = 0
	sents_in = []
	sents_out = []
	buffer_in = []
	for i in range(num_slices):
		print("Reading in ner results number %d" % i)
		with open(flatner_result_sliced_path%i, 'r', encoding='utf8') as fp_output:
			with open(ner_input_path%str(i), 'r', encoding='utf8') as fp_input:
				input_lines = fp_input.readlines()
				output_entries = json.load(fp_output)
				sents_out += output_entries

				for lidx in range(len(input_lines)):
					if lidx % 100000 == 0 and lidx > 0:
						print("Line number %d / %d" % (lidx, len(input_lines)))
					in_line = input_lines[lidx].strip().split('\t')
					# if EOS encountered
					if len(in_line) == 1 and len(buffer_in) > 0:
						buffer_out = sents_out[len(sents_in)]
						if len(buffer_in) != len(buffer_out):
							new_buffer_out = []
							buffer_out_id = 0
							for item in buffer_in:
								if item != '#':
									new_buffer_out.append(buffer_out[buffer_out_id])
									buffer_out_id += 1
								else:
									new_buffer_out.append('O')
							assert buffer_out_id == len(buffer_out)
							assert len(new_buffer_out) == len(buffer_in)
							sents_out[len(sents_in)] = new_buffer_out
							print(buffer_in)
							print(sents_out[len(sents_in)])
						sents_in.append(buffer_in)
						buffer_in = []
					# if token within a sentence
					elif len(in_line) == 2:
						if in_line[0] not in ['#']:
							buffer_in.append(in_line[0])
						else:
							print("!")
							buffer_in.append(in_line[0])  # don't see why this special token should be avoided.
					elif len(in_line) == 1 and len(buffer_in) == 0:
						continue
					else:
						print("in_line: ", in_line)
						print("buffer in: ", buffer_in)
						print("lidx: ", lidx)
						raise AssertionError

				if len(buffer_in) > 0:
					print(buffer_in)
					print(sents_out[len(sents_in)])
					sents_in.append(buffer_in)
					buffer_in = []

	print("NER input read in!")

	with open(raw_ner_mapping_path, 'r', encoding='utf8') as fp:
		mapping = json.load(fp)

	assert len(sents_in) == len(sents_out)
	print("Sentence mapping read in!")

	if len(mapping) != len(sents_in):
		print("Len sents: ", len(sents_out))
		print("Len mapping: ", len(mapping))
		print("Sents and mapping size mismatch!")
		raise AssertionError

	if checkraw is True:
		print("Checking sentence with raw DDParser output!")
		ddp_fp = open(ddparser_path, 'r', encoding='utf8')
		ddp_lidx = -1
		cur_doc = None
		for in_sent_id, sent_input in enumerate(sents_in):
			if in_sent_id % 20000 == 0:
				print(f"in sent id: {in_sent_id}/{len(sents_in)}")
			doc_id = mapping[str(in_sent_id)]['doc_id']
			local_sent_id = mapping[str(in_sent_id)]['sent_id']
			assert ddp_lidx <= doc_id
			while ddp_lidx < doc_id:
				raw_line = ddp_fp.readline()
				cur_doc = json.loads(raw_line)
				ddp_lidx += 1
			ori_sent = cur_doc['splitted_text'][local_sent_id]
			sent_input = ''.join(sent_input)  # bind the list of characters into a real sentence.
			if len(ori_sent) != len(sent_input):
				print(
					"!!! Length of sentence not equal between raw corpus and NER input at sentence_id %d!" % in_sent_id)
				print("Doc/Sent: %d/%d" % (doc_id, local_sent_id))
				num_unaligned_sentences += 1
				print(ori_sent)
				print(sent_input)
				print("")
		ddp_fp.close()

	print(f"A total number of {num_unaligned_sentences} sentences have lengths that do not match!")
	clean_entries = dict()

	print("Reading in NER results......")
	for global_sent_id, (sent_input, sent_labels) in enumerate(zip(sents_in, sents_out)):
		if global_sent_id % 20000 == 0 and global_sent_id > 0:
			print("Read in to sentence %d / %d" % (global_sent_id, len(sents_in)))

		doc_id = mapping[str(global_sent_id)]['doc_id']
		local_sent_id = mapping[str(global_sent_id)]['sent_id']

		if doc_id not in clean_entries:
			clean_entries[doc_id] = dict()
		clean_entry = clean_entries[doc_id]

		if 'ner_lbls' not in clean_entry:
			clean_entry['ner_lbls'] = []
		if 'flat_ner_spans' not in clean_entry:
			clean_entry['flat_ner_spans'] = []
		if 'splitted_text' not in clean_entry:
			clean_entry['splitted_text'] = []
		clean_entry['ner_lbls'].append(sent_labels)
		assert len(clean_entry['splitted_text']) == local_sent_id
		clean_entry['splitted_text'].append(''.join(sent_input))

		spans_buffer = []
		span_start_id = -1
		for token_id in range(len(sent_input)):
			if sent_labels[token_id][0] == 'B':
				if span_start_id >= 0:
					spans_buffer.append([span_start_id, token_id])
					span_start_id = -1
				assert span_start_id < 0
				span_start_id = token_id
			elif sent_labels[token_id][0] == 'O' and token_id > 0 and sent_labels[token_id-1][0] in ['B', 'I']:
				assert span_start_id >= 0
				spans_buffer.append([span_start_id, token_id])
				span_start_id = -1
			elif sent_labels[token_id][0] == 'O':
				assert span_start_id < 0
			elif sent_labels[token_id][0] == 'I':
				if span_start_id < 0:
					span_start_id = token_id
					print(f"Named entity starting with I! doc_id {doc_id}; sent_id {local_sent_id}")
					num_entities_starting_with_I += 1
			else:
				raise AssertionError

		# deal with edge cases
		if span_start_id >= 0:
			spans_buffer.append([span_start_id, len(sent_input)])
			span_start_id = -1

		clean_entry['flat_ner_spans'].append(spans_buffer)
		spans_buffer = []

		# This is iterating over all the ner result sentences, not iterating over the documents, so we can't write
		# things into the output file for each iteration.

	print(f"A total number of {num_entities_starting_with_I} named entities start with I-XXX!")
	print("Lines Processed! Saving NER results without raw entries to %s" % flat_ner_result_indoc_path)
	fp = open(flat_ner_result_indoc_path, 'w', encoding='utf8')
	for eid, ent_key in enumerate(clean_entries):
		if eid % 10000 == 0 and eid > 0:
			print(f'{eid}/{len(clean_entries)}')
		assert eid == int(ent_key)
		entry = clean_entries[ent_key]
		string = json.dumps(entry, ensure_ascii=False)
		fp.write(string+'\n')
		fp.flush()
		fp.close()
		fp = open(flat_ner_result_indoc_path, 'a', encoding='utf8')
	fp.close()
	print("Done.")


def prepare_indoc_ner_results_for_net(ner_result_indoc_path, nernet_sliced_input_path, nernet_input_path,
									  nernet_hier_sliced_input_path, nernet_hier_input_path, nernet_mapping_path,
									  arg_entity_pool_path, ner_source, num_slices):
	'''
	def check_length_overflow(item):
		if len(item['sentence']) < 500:
			return item
		elif item['span'][0] < 500 and item['span'][1] < 500:
			item['sentence'] = item['sentence'][:500]
			return item
		else:
			print("Found entry to skip!")
			return "SKIP"
	'''

	doc_count = 0
	with open(ner_result_indoc_path, 'r', encoding='utf8') as fp:
		for line in fp:
			doc_count += 1

	print(f"Total number of entries: {doc_count}")

	print("Constructing NET input entries......")
	global_net_entry_id = 0
	raw_net_mapping = {}

	with open(arg_entity_pool_path, 'r', encoding='utf8') as fp:
		unique_entities = json.load(fp)

	unique_entities = set(unique_entities)

	slice_size = doc_count // num_slices

	input_fp = open(ner_result_indoc_path, 'r', encoding='utf8')

	cfet_out_fps = [open(nernet_sliced_input_path % i, 'w', encoding='utf8') for i in range(num_slices)]
	hier_out_fps = [open(nernet_hier_sliced_input_path % i, 'w', encoding='utf8') for i in range(num_slices)]
	cfet_full_fp = open(nernet_input_path, 'w', encoding='utf8')
	hier_full_fp = open(nernet_hier_input_path, 'w', encoding='utf8')
	print(f"Writing to: {cfet_full_fp}; {hier_full_fp}")

	doc_id = 0

	not_in_unique_entities_count = 0

	st = time.time()
	for input_line in input_fp:

		clean_entry = json.loads(input_line)
		cfet_fp = cfet_out_fps[min(doc_id // slice_size, (num_slices-1))]
		hier_fp = hier_out_fps[min(doc_id // slice_size, (num_slices-1))]

		if doc_id % 1000 == 0 and doc_id > 0:
			ct = time.time()
			dur = ct - st
			dur_h = int(dur) / 3600
			dur_m = (int(dur) % 3600) / 60
			dur_s = int(dur) % 60
			print('time lapsed: %d hours %d minutes %d seconds' % (dur_h, dur_m, dur_s))
			print(f"Constructed through to document {doc_id} / {doc_count}; currently found {global_net_entry_id} mentions for NET!")
		if doc_id == 0:
			print(clean_entry)
		doc_mapping = []
		for sent_id in range(len(clean_entry['splitted_text'])):
			sent = clean_entry['splitted_text'][sent_id]
			assert len(sent) < 501
			if ner_source == 'flat':
				spans_buffer = clean_entry['flat_ner_spans'][sent_id]
			elif ner_source == 'corenlp':
				spans_buffer = clean_entry['corenlp_ner_spans'][sent_id]
			else:
				raise AssertionError
			sent_mapping = []
			for span_id, span in enumerate(spans_buffer):
				assert len(span) == 2 and span[0] < span[1]
				if ner_source == 'flat':
					mention_id_str = 'RF'+str(global_net_entry_id)
				elif ner_source == 'corenlp':
					mention_id_str = 'RC'+str(global_net_entry_id)
				else:
					raise AssertionError
				instance = {'figer_types_first_dict': {},
							'figer_types_first_list': [],
							'figer_types_both_dict': {},
							'figer_types_both_list': [],
							'label_types': [],
							'general_type': [],
							'mention': sent[span[0]:span[1]],
							'span': span,
							'sentence': sent,
							'mention_id': mention_id_str}
				# only type the extracted named entities that are also arguments occurring at least 3 times.
				if instance['mention'] not in unique_entities:
					not_in_unique_entities_count += 1
					continue
				cfet_line = json.dumps(instance, ensure_ascii=False)
				hier_line = format_cfet_instance_for_hier(instance)
				cfet_fp.write(cfet_line+'\n')
				cfet_full_fp.write(cfet_line+'\n')
				hier_fp.write(hier_line+'\n')
				hier_full_fp.write(hier_line+'\n')
				global_net_entry_id += 1
				sent_mapping.append(instance['mention_id'])
			doc_mapping.append(sent_mapping)
		raw_net_mapping[doc_id] = doc_mapping
		doc_id += 1

	print(f"Contructed {global_net_entry_id} lines of NET inputs in total!")
	print(f"Another {not_in_unique_entities_count} lines of NET inputs filtered out because they're not frequent enough in triples!")

	# print("Saving NET inputs to %s" % args.nernet_input_path)
	# with open(args.nernet_input_path, 'w', encoding='utf8') as fp:
	#	json.dump(net_input_instances, fp, ensure_ascii=False)
	#	fp.flush()
	# print("Done.")

	input_fp.close()
	for fp in cfet_out_fps:
		fp.close()
	for fp in hier_out_fps:
		fp.close()
	cfet_full_fp.close()
	hier_full_fp.close()

	print("Saving NET input document mappings to %s" % nernet_mapping_path)
	with open(nernet_mapping_path, 'w', encoding='utf8') as fp:
		json.dump(raw_net_mapping, fp, ensure_ascii=False)
		fp.flush()
	print("Done.")

	print("Finished!")


def merge_flatandcorenlp_ner_results(flat_ner_indoc_path, corenlp_ner_indoc_path, both_ner_indoc_path):
	with open(flat_ner_indoc_path, 'r', encoding='utf8') as fp:
		flat_lidx = 0
		for line in fp:
			flat_lidx += 1
	with open(corenlp_ner_indoc_path, 'r', encoding='utf8') as fp:
		core_lidx = 0
		for line in fp:
			core_lidx += 1
	assert flat_lidx == core_lidx

	flat_fp = open(flat_ner_indoc_path, 'r', encoding='utf8')
	core_fp = open(corenlp_ner_indoc_path, 'r', encoding='utf8')
	out_fp = open(both_ner_indoc_path, 'w', encoding='utf8')

	print("Reading in corenlp_entries and merging...")
	for lidx, (core_line, flat_line) in enumerate(zip(core_fp, flat_fp)):
		if lidx % 10000 == 0 and lidx > 0:
			print(lidx)
		core_entry = json.loads(core_line)
		flat_entry = json.loads(flat_line)
		core_entry['flat_ner_spans'] = flat_entry['flat_ner_spans']
		both_spans = []
		for sent_id in range(len(core_entry['splitted_text'])):
			sent_spans = core_entry['corenlp_ner_spans'][sent_id] + core_entry['flat_ner_spans'][sent_id]
			both_spans.append(sent_spans)
		core_entry['both_ner_spans'] = both_spans
		out_line = json.dumps(core_entry, ensure_ascii=False)
		out_fp.write(out_line+'\n')

	flat_fp.close()
	core_fp.close()
	out_fp.close()
	print("Finished!")


def assign_corenlp_ner_as_final(corenlp_ner_indoc_path, both_ner_indoc_path):
	core_fp = open(corenlp_ner_indoc_path, 'r', encoding='utf8')
	out_fp = open(both_ner_indoc_path, 'w', encoding='utf8')

	print("Reading in corenlp_entries and copying to #both_ner_spans# ...")
	for lidx, core_line in enumerate(core_fp):
		if lidx % 10000 == 0 and lidx > 0:
			print(lidx)
		core_entry = json.loads(core_line)
		core_entry['both_ner_spans'] = core_entry['corenlp_ner_spans']
		out_line = json.dumps(core_entry, ensure_ascii=False)
		out_fp.write(out_line+'\n')

	core_fp.close()
	out_fp.close()
	print("Finished!")
