import json
import time
import sys
import copy
import random


# convert format as 2016-10-07T08:00:00.000+03:00 into format as Feb 12, 2013 12:00:00 AM
def convert_datetime_format(input_str):
	MONTH_MAP = [None, 'Jan', 'Feb', 'Mar', 'Apr', 'May', 'Jun', 'Jul', 'Aug', 'Sep', 'Oct', 'Nov', 'Dec']
	_date, _time = input_str.split('T')
	_year, _month, _day = _date.split('-')
	_month = MONTH_MAP[int(_month)]
	_day = str(int(_day))  # remove the beginning 0 if applicable
	_time = _time.split('+')[0]  # remove the +03:00 whose meaning is unclear. (we don't care about hours that much anyway)
	_hour, _min, _sec = _time.split(':')
	_suffix = 'PM' if int(_hour) > 11 else 'AM'
	_hour = (int(_hour)-12) if int(_hour) > 12 else int(_hour)
	_sec = _sec.split('.')[0]
	output_str = f"{_month} {_day}, {_year} {_hour}:{_min}:{_sec} {_suffix}"
	return output_str


# {doc_id: {'rels': doc_typed_rels, 'splitted_text': data_entry['splitted_text']}}
# rel: (names, 'SVO', dep_idxs, types, features, [GG/GE/EG/EE])
# TO
# {“s”: sentence, “date”: "Feb 12, 2013 12:00:00 AM", “articleId”: doc_id, “lineId”: line_id,
# "rels":[{"r":"((pred.1,pred.2)::arg1::arg2::[EE/GE/EG/GG]::0::?::type1::type2)"}]}
def format_parsed_triples_to_tacl(stct_fname, tacl_fname, retrieve_unary, reformat_time):
	stct_fp = open(stct_fname, 'r', encoding='utf8')
	tacl_fp = open(tacl_fname, 'w', encoding='utf8')
	lidx = 0
	sharps_deleted_count = 0
	colons_deleted_count = 0
	typepair_triple_bucket = {}
	type_argument_bucket = {}

	for line in stct_fp:
		if lidx % 10000 == 0:
			print(lidx)
		stct_rels = None
		stct_sents = None
		articleId = None
		articleDate = None
		item = json.loads(line)
		assert len(item) == 1
		for doc_key in item:
			articleId = doc_key
			articleDate = item[doc_key]['published'] if 'published' in item[doc_key] else item[doc_key]['time']  # Example: 2016-10-07T08:00:00.000+03:00
			if reformat_time:
				articleDate = convert_datetime_format(articleDate)
			stct_rels = item[doc_key]['rels']
			stct_sents = item[doc_key]['splitted_text']
		for sent_id in range(len(stct_sents)):
			tacl_dict = {"s": stct_sents[sent_id],
						 "date": articleDate,
						 "articleId": articleId,
						 "lineId": sent_id,
						 "rels": []}
			for rel in stct_rels[sent_id]:

				# skip if we don't want to retrieve the unaries
				bu_status = rel[4][1]  # BNY/UNY
				if bu_status == 'UNY' and retrieve_unary == 0:
					continue

				# If we are to retrieve the unaries, we replace the names with "空空空空空" and their types with "thing"
				subj_name = rel[0][0]
				obj_name = rel[0][2]
				pred_name = rel[0][1]
				pred_dep_idx = rel[2][1]
				subj_types = rel[3]['subj']
				obj_types = rel[3]['obj']
				if subj_name is None:
					subj_name = "空空空空空"
					subj_types = ['thing']
				if obj_name is None:
					obj_name = "空空空空空"
					obj_types = ['thing']
				GEEG = rel[5]
				if '#' in pred_name:
					pred_name = pred_name.replace('#', '')
					sharps_deleted_count += 1
				elif ':' in pred_name:
					pred_name = pred_name.replace(':', '')
					colons_deleted_count += 1
				if '#' in subj_name:
					subj_name = subj_name.replace('#', '')
					sharps_deleted_count += 1
				elif ':' in subj_name:
					subj_name = subj_name.replace(':', '')
					colons_deleted_count += 1
				if '#' in obj_name:
					obj_name = obj_name.replace('#', '')
					sharps_deleted_count += 1
				elif ':' in obj_name:
					obj_name = obj_name.replace(':', '')
					colons_deleted_count += 1
				assert '#' not in GEEG
				assert ':' not in GEEG

				# construct a relation record for each type pair.
				for t_subj in subj_types:
					for t_obj in obj_types:
						assert '#' not in t_subj and ':' not in t_subj
						assert '#' not in t_obj and ':' not in t_obj
						rel_str = f"(({pred_name}.1,{pred_name}.2)::{subj_name}::{obj_name}::{GEEG}::0::{pred_dep_idx}::{t_subj}::{t_obj})"
						tacl_dict['rels'].append({"r": rel_str})
						if (t_subj+'#'+t_obj) not in typepair_triple_bucket:
							typepair_triple_bucket[(t_subj+'#'+t_obj)] = 1
						else:
							typepair_triple_bucket[(t_subj+'#'+t_obj)] += 1
						if t_subj not in type_argument_bucket:
							type_argument_bucket[t_subj] = 1
						else:
							type_argument_bucket[t_subj] += 1
						if t_obj not in type_argument_bucket:
							type_argument_bucket[t_obj] = 1
						else:
							type_argument_bucket[t_obj] += 1

			tacl_line = json.dumps(tacl_dict, ensure_ascii=False)
			tacl_fp.write(tacl_line+'\n')
		lidx += 1
	stct_fp.close()
	tacl_fp.close()
	print(f"sharps_deleted_count: {sharps_deleted_count};")
	print(f"colons_deleted_count: {colons_deleted_count}.")

	typepair_triple_bucket = {k: v for k, v in sorted(typepair_triple_bucket.items(), key=lambda item: item[1], reverse=True)}
	type_argument_bucket = {k: v for k, v in sorted(type_argument_bucket.items(), key=lambda item: item[1], reverse=True)}

	print(f"typepair_triple_bucket: {typepair_triple_bucket}")
	print(f"type_argument_bucket: {type_argument_bucket}")

	print("Finished!")


def _unique(rels):  # [["重庆日报报业集团", "授权", "华龙网"], "SVO", [1, 2, 3]]
	rels_dct = {}
	for rel in rels:
		rel_ser = f"{rel[0][0]}::{rel[0][1]}::{rel[0][2]}::::{rel[1]}::::{rel[2][0]}::{rel[2][1]}::{rel[2][2]}"
		if rel_ser not in rels_dct:
			rels_dct[rel_ser] = rel
	rels_uniq = [copy.deepcopy(rels_dct[_key]) for _key in rels_dct]
	return rels_uniq


# final routine
def type_acquired_relations(both_ner_result_path, argnet_documented_path, ddparser_path, typed_triples_documented_path,
							typed_triples_documented_excluded_path, typed_triples_stats_path, amend=True, debug=False,
							is_jia_baseline=False, add_crossed=False, add_coarse=True):

	log_filename = '../TPR_logfile_V3_2.log'
	log_fp = open(log_filename, 'w', encoding='utf8')
	LOGFILE = sys.stdout if debug else log_fp

	have_to_randomly_pick_arg_mention_count = 0
	arg_type_not_found_count = 0

	def fetch_arg_types(arg, arg_span, argnets_sent, have_to_randomly_pick_arg_mention_count, arg_type_not_found_count):
		# global arg_type_not_found_count
		# {arg_1: [(span, predictions), ...], arg_2: [(span, predictions), ...]}
		if arg is None or arg not in argnets_sent:
			if arg is not None and arg not in argnets_sent:
				arg_type_not_found_count += 1
			return [], have_to_randomly_pick_arg_mention_count, arg_type_not_found_count
		mentions = argnets_sent[arg]
		assert len(mentions) > 0
		if len(mentions) == 1:
			types = mentions[0][1]
		elif arg_span is None:
			ment = random.choice(mentions)
			types = ment[1]
			have_to_randomly_pick_arg_mention_count += 1
		else:
			types = None
			min_dist = 10000000
			for cand_span, preds in mentions:
				cur_dist = abs(arg_span[0]-cand_span[0]) + abs(arg_span[1]-cand_span[1])
				if cur_dist < min_dist:
					types = preds
					min_dist = cur_dist
			assert types is not None
		new_types = []
		for t in types:
			if len(t) == 0:
				continue
			t = t.split('/')
			t_first_layer = t[1]  # record only the first layer types
			if t_first_layer not in new_types:
				new_types.append(t_first_layer)
		return new_types, have_to_randomly_pick_arg_mention_count, arg_type_not_found_count

	def match_NE_strict(arg, arg_span, ners_sent, raw_sent):
		if arg_span is None:
			return match_NE_loose(arg, arg_span, ners_sent, raw_sent)
		else:
			for cand_span in ners_sent:
				if arg_span[0] == cand_span[0] and arg_span[1] == cand_span[1]:
					return 'E'
			return 'G'

	def match_NE_loose(arg, arg_span, ners_sent, raw_sent):
		cand_names = []
		for cand_span in ners_sent:
			cand_names.append(raw_sent[cand_span[0]:cand_span[1]])
		if arg in cand_names:
			return 'E'
		else:
			return 'G'

	# at least one argument is not None, and all arguments that are not None have at least one type.
	def eligible(subj, subj_types, obj, obj_types):
		if subj is not None and len(subj_types) == 0:
			return False
		elif obj is not None and len(obj_types) == 0:
			return False
		elif subj is None and obj is None:
			return False
		else:
			return True

	# this functions, and the other checkings, have been superseded by those in dudepparse.py, as part of the predicate
	# lexicon, rather than features. Therefore, the checking functions below and their corresponding checking features
	# are here all just placeholders.
	def check_negation(cur_rel, sent_rels):
		negation_indicator_advs = ['不', '未能', '', '', '', '', '', '']
		negation_indicator_cmps = ['失败',]
		assert cur_rel[1] == 'SVO'

		#predicate = cur_rel[0][1]
		#if len(predicate) < 4 and ('不' == predicate[0] or '未' == predicate[0]):
		#	print(f"Inherent negation: {predicate}", file=LOGFILE)
		#	return True
		#elif predicate == '从未':
		#	print(f"Inherent negation: {predicate}", file=LOGFILE)
		#for cand_rel in sent_rels:
		#	if cand_rel[1] == 'ADV_V' and cand_rel[0][1] == predicate:
		#		if cand_rel[0][0] in negation_indicator_advs:
		#			return True
		#	elif cand_rel[1] == 'V_CMP' and cand_rel[0][0] == predicate:
		#		if cand_rel[0][1] in negation_indicator_cmps:
		#			return True
		return False

	def check_unary(cur_rel):
		if cur_rel[0][0] is None or cur_rel[0][2] is None:
			return True
		else:
			return False

	def check_tense(cur_rel, sent_rels, dep_structure):
		if debug:
			print("Check tense not yet implemented!", file=sys.stderr)
		return ''

	def check_modal(cur_rel, sent_rels, dep_structure):
		if debug:
			print("Check modal not yet implemented!", file=sys.stderr)
		return False

	def check_conjunction(cur_rel, sent_rels, dep_structure):
		if debug:
			print("Check conjunction not yet implemented!", file=sys.stderr)
		return False

	def check_copular(cur_rel):
		# print("Should already have been implemented in amendments!", file=sys.stderr)
		#new_rel = []
		if cur_rel[0][1] in ['是', '也是', '就是', '而是', '正是', '才是', '都是', '仍是', '既是', '又是', '却是', '只是', '算是', '竟是',
							 '便是', '无疑是', '乃是', '并且是']:
			if cur_rel[0][2] is not None:
				# new_rel.append([cur_rel[0][0], '是·' + cur_rel[0][2], None])
				# new_rel.append(cur_rel[1])
				# new_rel.append(cur_rel[2])
				# assert len(cur_rel) == len(new_rel)
				return 'is'#, new_rel
			else:
				return False#, None
		elif cur_rel[0][1] == '不是':
			if cur_rel[0][2] is not None:
				#new_rel.append([cur_rel[0][0], '否·是·' + cur_rel[0][2], None])
				#new_rel.append(cur_rel[1])
				#new_rel.append(cur_rel[2])
				#assert len(cur_rel) == len(new_rel)
				return 'isnot'#, new_rel
			else:
				return False#, None
		elif '是' in cur_rel[0][1] and cur_rel[0][1] not in ['是不是', '']:
			print(f"Possibly Copular construction: {cur_rel[0][1]}", file=LOGFILE)
			return False#, None
		else:
			return False#, None
	'''
	if os.path.exists(nernet_documented_path):
		with open(nernet_documented_path, 'r', encoding='utf8') as fp:
			nernet_documented_results = json.load(fp)
		print("Read in NERNET results per document!")
	else:
		print("NERNET null file, skipping!")
		nernet_documented_results = None
	'''

	match_NE = match_NE_strict

	argnet_line_count = 0
	with open(argnet_documented_path, 'r', encoding='utf8') as fp:
		for line in fp:
			argnet_line_count += 1

	ddp_line_count = 0
	with open(ddparser_path, 'r', encoding='utf8') as fp:
		for line in fp:
			ddp_line_count += 1
	ner_line_count = 0
	with open(both_ner_result_path, 'r', encoding='utf8') as fp:
		for line in fp:
			ner_line_count += 1
	print(f"argnet_line_count: {argnet_line_count}")
	print(f"ddp_line_count: {ddp_line_count}")
	print(f"ner_line_count: {ner_line_count}")
	if argnet_line_count != ddp_line_count or ner_line_count != ddp_line_count:
		raise AssertionError

	argnet_fp = open(argnet_documented_path, 'r', encoding='utf8')
	ddp_fp = open(ddparser_path, 'r', encoding='utf8')
	ner_fp = open(both_ner_result_path, 'r', encoding='utf8')

	global_typed_triple_count = 0
	excluded_global_typed_triple_count = 0
	global_negation_count = 0
	global_unary_count = 0
	global_past_count = 0
	global_future_count = 0
	global_conj_count = 0
	global_copular_count = 0
	global_uneligible_count = 0
	global_GG_triple_count = 0

	poss_rel_accepted_count = 0
	poss_rel_rejected_count = 0

	# file pointers to accepted relations & rejected relations per document
	acc_stct_fp = open(typed_triples_documented_path, 'w', encoding='utf8')
	rej_stct_fp = open(typed_triples_documented_excluded_path, 'w', encoding='utf8')

	st = time.time()
	for doc_id in range(ddp_line_count):
		argnet_line = argnet_fp.readline()
		ddp_line = ddp_fp.readline()
		ner_line = ner_fp.readline()

		argnet_stct = json.loads(argnet_line)  # [{arg_i:[(span, predictions), ...]}, {}, ...]
		argnet_stct = argnet_stct[str(doc_id)]  # implicitly checks whether the doc_key in argnet_struct is aligned with current doc_id
		data_entry = json.loads(ddp_line)
		ner_stct = json.loads(ner_line)

		if doc_id % 1000 == 0 and doc_id > 0:
			ct = time.time()
			dur = ct - st
			dur_h = int(dur) // 3600
			dur_m = (int(dur) % 3600) // 60
			dur_s = int(dur) % 60
			print(
				f"{doc_id} / {ddp_line_count}; time lapsed: {dur_h} hours {dur_m} minutes {dur_s} seconds.")

		assert len(argnet_stct) == len(data_entry['splitted_text']) and len(argnet_stct) == len(ner_stct['both_ner_spans'])

		doc_typed_rels = []
		excluded_doc_typed_rels = []

		for sent_id in range(len(data_entry['splitted_text'])):
			raw_sent = data_entry['splitted_text'][sent_id]
			sent_argnet_res = argnet_stct[sent_id]  # {arg_1:[(span, predictions), ...], arg_2: [(span, predictions), ...]}
			if amend:
				sent_rels = data_entry['fine_rels'][sent_id] + \
							data_entry['amend_fine_rels'][sent_id]  # [["重庆日报报业集团", "授权", "华龙网"], "SVO", [1, 2, 3]]
				if add_coarse:
					sent_rels += data_entry['coarse_rels'][sent_id] + data_entry['amend_coarse_rels'][sent_id]
				if add_crossed:
					sent_rels += data_entry['crossed_rels'][sent_id] + data_entry['amend_crossed_rels'][sent_id]
				sent_possible_rels = data_entry['possible_rels'][sent_id]
			else:
				sent_rels = data_entry['fine_rels'][sent_id]
				if add_coarse:
					sent_rels += data_entry['coarse_rels'][sent_id]
				if add_crossed:
					sent_rels += data_entry['crossed_rels'][sent_id]
				sent_possible_rels = []

			for poss_rel in sent_possible_rels:
				subj = poss_rel[0][0]
				pred_last_chunk = poss_rel[0][1].strip().split('·')[-1]
				subj_types, _, _ = fetch_arg_types(subj, None, sent_argnet_res, 0, 0)
				pred_types, _, _ = fetch_arg_types(pred_last_chunk, None, sent_argnet_res, 0, 0)
				if 'person' in subj_types and ('person' in pred_types or 'title' in pred_types):
					if debug:
						print(poss_rel)
					poss_rel_accepted_count += 1
					sent_rels.append(poss_rel)
				else:
					poss_rel_rejected_count += 1

			sent_rels = _unique(sent_rels)
			if 'ddp_lbls' in data_entry:
				sent_ddps = data_entry['ddp_lbls'][sent_id]
				ddp_tkn = sent_ddps['word']
				# code chunk borrowed from prepare_parsed_arguments_for_typing
				# map dependency tokens to their spans in the text.
				ddp_tkn_idxs = []
				_tail = 0
				for tkn in ddp_tkn:
					ddp_tkn_idxs.append(_tail)
					_tail += len(tkn)
				ddp_tkn_idxs.append(_tail)  # add a pseudo token representing EOF.
			else:
				sent_ddps = None
				ddp_tkn_idxs = None
				if not is_jia_baseline:
					print("DDP labels missing when it's not Jia et al baseline! Serious Error!!!", file=sys.stderr)
			sent_ner_spans = ner_stct['both_ner_spans'][sent_id]  # List[(span_0, span_1)]

			sent_typed_rels = []
			excluded_sent_typed_rels = []
			for rel in sent_rels:
				if rel[1] != 'SVO':
					continue
				subj = rel[0][0]
				obj = rel[0][2]
				if ddp_tkn_idxs is not None:
					subj_span = [ddp_tkn_idxs[rel[2][0]], ddp_tkn_idxs[rel[2][0] + 1]] if rel[2][0] is not None else None
					obj_span = [ddp_tkn_idxs[rel[2][2]], ddp_tkn_idxs[rel[2][2] + 1]] if rel[2][2] is not None else None
				else:
					subj_span = None
					obj_span = None
				subj_types, have_to_randomly_pick_arg_mention_count, arg_type_not_found_count \
					= fetch_arg_types(subj, subj_span, sent_argnet_res, have_to_randomly_pick_arg_mention_count,
									  arg_type_not_found_count)
				obj_types, have_to_randomly_pick_arg_mention_count, arg_type_not_found_count \
					= fetch_arg_types(obj, obj_span, sent_argnet_res, have_to_randomly_pick_arg_mention_count,
									  arg_type_not_found_count)
				subj_GE = match_NE(subj, subj_span, sent_ner_spans, raw_sent)  # either 'G' (general) or 'E' (named entity)
				obj_GE = match_NE(obj, obj_span, sent_ner_spans, raw_sent)  # either 'G' (general) or 'E' (named entity)

				if eligible(subj, subj_types, obj, obj_types):
					feat_list = []  # [negation, unary, tense, modal, conjunction, copular]

					negation_flag = check_negation(rel, sent_rels)
					if negation_flag:
						feat_list.append('NEG')
						global_negation_count += 1
					else:
						feat_list.append('NULL')

					unary_flag = check_unary(rel)
					if unary_flag:
						feat_list.append('UNY')
						global_unary_count += 1
					else:
						feat_list.append('BNY')

					tense_flag = check_tense(rel, sent_rels, sent_ddps)
					feat_list.append(tense_flag)
					if tense_flag == 'PAST':
						global_past_count += 1
					elif tense_flag == 'FUTR':
						global_future_count += 1
					elif tense_flag == 'PRST':
						pass
					elif tense_flag == '':
						pass
					else:
						raise AssertionError

					modal_flag = check_modal(rel, sent_rels, sent_ddps)
					if modal_flag:
						feat_list.append('MODAL')
					else:
						feat_list.append('NULL')

					# relation triples with conjunction are currently deleted because it'd unclear whether it is distributive or not.
					# John and Mary met in the library.
					# John and Mary bought a book.
					conjunction_flag = check_conjunction(rel, sent_rels, sent_ddps)
					if conjunction_flag:
						global_conj_count += 1

					copular_flag = check_copular(rel)
					if copular_flag == 'is':
						# rel = copular_rel
						# feat_list[1] = 'UNY'
						# obj_types = None
						global_copular_count += 1
						# global_unary_count += 1
					elif copular_flag == 'isnot':
						# rel = copular_rel
						feat_list[0] = 'NEG'
						# feat_list[1] = 'UNY'
						# obj_types = None
						global_copular_count += 1
						# global_unary_count += 1
					else:
						pass

					rel.append({'subj': subj_types, 'obj': obj_types})
					rel.append(feat_list)
					rel.append(subj_GE+obj_GE)
					assert len(rel) == 6  # (names, 'SVO', dep_idxs, types, features, [GG/GE/EG/EE])
					if not conjunction_flag:
						sent_typed_rels.append(copy.deepcopy(rel))
						global_typed_triple_count += 1
						if subj_GE == 'G' and obj_GE == 'G':
							global_GG_triple_count += 1
					else:
						excluded_sent_typed_rels.append(copy.deepcopy(rel))
						excluded_global_typed_triple_count += 1
				else:
					print("Filtered out due to being not eligible!", file=LOGFILE)
					global_uneligible_count += 1
			doc_typed_rels.append(sent_typed_rels)
			excluded_doc_typed_rels.append(excluded_sent_typed_rels)

		assert len(doc_typed_rels) == len(data_entry['splitted_text'])

		acc_item = {doc_id: {'rels': doc_typed_rels, 'splitted_text': data_entry['splitted_text'], 'published': data_entry['published'] if 'published' in data_entry else data_entry['time']}}
		rej_item = {doc_id: {'rels': excluded_doc_typed_rels, 'splitted_text': data_entry['splitted_text'], 'published': data_entry['published'] if 'published' in data_entry else data_entry['time']}}
		acc_stct_line = json.dumps(acc_item, ensure_ascii=False)
		rej_stct_line = json.dumps(rej_item, ensure_ascii=False)
		acc_stct_fp.write(acc_stct_line+'\n')
		rej_stct_fp.write(rej_stct_line+'\n')

	print(f"Retrieved {global_typed_triple_count} typed relation triples.")
	print(f"another {excluded_global_typed_triple_count} relation triples are retrieved but excluded.")
	print(f"{global_negation_count} relations were detected as being negations.")
	print(f"{global_unary_count} relations were detected as being unaries.")
	print(f"{global_past_count} relations were detected as being past-tense.")
	print(f"{global_future_count} relations were detected as being future-tense.")
	print(f"{global_conj_count} relations were detected as being conjunctions.")
	print(f"{global_copular_count} relations were detected as being copular constructions.")
	print(f"{global_uneligible_count} relations were dumped because of no feasible types.")
	print(f"{global_GG_triple_count} relations are General-Genaral relations without named entities!")

	print(f"{arg_type_not_found_count} arguments are not found in typing results (most probably due to discontinuous spans!)")
	print(f"Have to randomly pick argument count: {have_to_randomly_pick_arg_mention_count}")

	print(f"{poss_rel_accepted_count} possible rels are accepted.")
	print(f"{poss_rel_rejected_count} possible rels are rejected.")

	global_stats = {
		'global_typed_triple_count': global_typed_triple_count,
		'excluded_global_typed_triple_count': excluded_global_typed_triple_count,
		'global_negation_count': global_negation_count,
		'global_unary_count': global_unary_count,
		'global_past_count': global_past_count,
		'global_future_count': global_future_count,
		'global_conj_count': global_conj_count,
		'global_copular_count': global_copular_count,
		'global_uneligible_count': global_uneligible_count,
		'global_GG_triple_count': global_GG_triple_count,
		'arg_type_not_found_count': arg_type_not_found_count,
		'have_to_randomly_pick_arg_mention_count': have_to_randomly_pick_arg_mention_count,
		'poss_rel_accepted_count': poss_rel_accepted_count,
		'poss_rel_rejected_count': poss_rel_rejected_count
	}

	with open(typed_triples_stats_path, 'w', encoding='utf8') as fp:
		json.dump(global_stats, fp, ensure_ascii=False, indent=4)

	acc_stct_fp.close()
	rej_stct_fp.close()
	log_fp.close()
	print("Done!")
