import json
import argparse

parser = argparse.ArgumentParser()
parser.add_argument('-i', '--input_fn', default='clue_typed_triples_tacl.json', type=str)
parser.add_argument('-o', '--output_suff', default='_reduced_%d_%.1f_%d_%d', type=str)
parser.add_argument('-b', '--bucket_fn', default='clue_predicate_buckets.json', type=str)
parser.add_argument('-a', '--arg_bucket_fn', default='clue_argument_buckets.json', type=str)
parser.add_argument('-p', '--portion', default=0.05, type=float)
parser.add_argument('-m', '--multiples', default=3, type=float)
parser.add_argument('--pred_minimum', default=10, type=int)
parser.add_argument('--args_minimum', default=5, type=int)
parser.add_argument('-f', '--flag', default=1, type=int, help='flag for whether or not to re-calculate the buckets.')
args = parser.parse_args()

assert args.input_fn[-5:] == '.json'
output_fn = args.input_fn[:-5] + args.output_suff + '.json'

if not args.flag:
	with open(args.bucket_fn, 'r', encoding='utf8') as fp:
		pred_buckets = json.load(fp)
	with open(args.arg_bucket_fn, 'r', encoding='utf8') as fp:
		args_buckets = json.load(fp)
else:
	pred_buckets = {}
	args_buckets = {}
	max_num_rels_per_sent = 0
	total_rels = 0
	with open(args.input_fn, 'r', encoding='utf8') as fp:
		for lidx, line in enumerate(fp):
			if lidx % 100000 == 0:
				print(f"lidx: {lidx}; max number of rels per sentence: {max_num_rels_per_sent}; "
					  f"total number of rels: {total_rels}; number of unique predicates: {len(pred_buckets)}; "
					  f"number of unique argument pairs: {len(args_buckets)}")
			item = json.loads(line)
			rels = item["rels"]
			if len(rels) > max_num_rels_per_sent:
				max_num_rels_per_sent = len(rels)
			total_rels += len(rels)
			for rel in rels:
				rel_elements = rel["r"][1:-1].split('::')
				assert len(rel_elements) == 8

				# the splitting scheme below only works for Chinese rels! (has to guarantee that the correct comma is in the middle)
				rel_predicate = rel_elements[0][1:-1]
				rel_args = rel_elements[1] + '::::' + rel_elements[2]
				assert len(rel_predicate) % 2 == 1
				splitting_point = len(rel_predicate)//2
				p1 = rel_predicate[:splitting_point]
				p2 = rel_predicate[splitting_point+1:]
				assert p1[-2:] == '.1' and p2[-2:] == '.2'
				p1 = p1[:-2]
				p2 = p2[:-2]
				assert p1 == p2

				if p1 not in pred_buckets:
					pred_buckets[p1] = 0
				pred_buckets[p1] += 1
				if rel_args not in args_buckets:
					args_buckets[rel_args] = 0
				args_buckets[rel_args] += 1

		print(f"Total number of lines: {lidx+1}")
		print(f"Total number of predicates: {len(pred_buckets)}")
		print(f"Total number of argument pairs: {len(args_buckets)}")
		print(f"Total number of rels: {total_rels}")

	pred_buckets = {k: v for k, v in sorted(pred_buckets.items(), key=lambda x: x[1], reverse=True)}
	args_buckets = {k: v for k, v in sorted(args_buckets.items(), key=lambda x: x[1], reverse=True)}

	with open(args.bucket_fn, 'w', encoding='utf8') as fp:
		json.dump(pred_buckets, fp, ensure_ascii=False)
	print(f"Predicate buckets saved!")
	with open(args.arg_bucket_fn, 'w', encoding='utf8') as fp:
		json.dump(args_buckets, fp, ensure_ascii=False)
	print(f"Arguments buckets saved!")

adaptive_threshold = None
for kidx, k in enumerate(pred_buckets):
	v = pred_buckets[k]
	if kidx > len(pred_buckets) * args.portion:
		adaptive_threshold = v
		break

print(f"Adaptive threshold at top {args.portion*100}\% predicates: {adaptive_threshold}")

output_fn = output_fn % (adaptive_threshold, args.multiples, args.pred_minimum, args.args_minimum)
output_fp = open(output_fn, 'w', encoding='utf8')

sents_involving_reduce = 0
total_rels_reduced = 0
new_total_rels = 0
new_max_num_rels_per_sent = 0
inter_max_rels_per_sent = 0  # relevant to speed
new_pred_buckets = {}

num_absolute_populous = 0
num_relative_populous = 0
num_rels_trimmed_by_minimum = 0


with open(args.input_fn, 'r', encoding='utf8') as fp:
	for lidx, line in enumerate(fp):
		if lidx % 100000 == 0:
			print(f"lidx: {lidx}; sents involving reduce: {sents_involving_reduce}; "
				  f"total number of rels reduced: {total_rels_reduced}; total number of rels left: {new_total_rels}; "
				  f"num absolute populous: {num_absolute_populous}; num relative populous: {num_relative_populous}; "
				  f"number of relations trimmed by minimum: {num_rels_trimmed_by_minimum}")
		item = json.loads(line)
		rels = item["rels"]
		rel_predicates = []
		rel_argpairs = []
		subsumed_mask = []  # if subsumed, the logit is False, if still should be there, the logit is True
		new_rels = []
		for rel in rels:
			rel_elements = rel["r"][1:-1].split('::')
			assert len(rel_elements) == 8

			# the splitting scheme below only works for Chinese rels! (has to guarantee that the correct comma is in the middle)
			rel_predicate = rel_elements[0][1:-1]
			rel_args = rel_elements[1] + '::::' + rel_elements[2]
			assert len(rel_predicate) % 2 == 1
			splitting_point = len(rel_predicate) // 2
			p1 = rel_predicate[:splitting_point]
			p2 = rel_predicate[splitting_point + 1:]
			assert p1[-2:] == '.1' and p2[-2:] == '.2'
			p1 = p1[:-2]
			p2 = p2[:-2]
			assert p1 == p2
			rel_predicates.append(p1)
			rel_argpairs.append(rel_args)

		inter_rels = []
		inter_predicates = []
		for pid, pred in enumerate(rel_predicates):
			if pred_buckets[pred] < args.pred_minimum or args_buckets[rel_argpairs[pid]] < args.args_minimum:
				num_rels_trimmed_by_minimum += 1
			else:
				inter_rels.append(rels[pid])
				inter_predicates.append(pred)

		rels = inter_rels
		rel_predicates = inter_predicates
		if inter_max_rels_per_sent < len(rel_predicates):
			inter_max_rels_per_sent = len(rel_predicates)
			print(f"New max intermediate rels number: {inter_max_rels_per_sent}")

		for pred1_idx, pred1 in enumerate(rel_predicates):
			subsumed = False
			has_absolute_populous = False
			has_relative_populous = False
			for pred2_idx, pred2 in enumerate(rel_predicates):
				if pred1 == pred2:
					continue
				if pred1 in pred2:
					absolute_populous = False
					relative_populous = False
					if pred_buckets[pred2] > adaptive_threshold:
						absolute_populous = True
						has_absolute_populous = True
					if pred_buckets[pred1] / float(pred_buckets[pred2]) < args.multiples:
						relative_populous = True
						has_relative_populous = True
					if absolute_populous and relative_populous:
						subsumed = True
						break
			if has_absolute_populous:
				num_absolute_populous += 1
			if has_relative_populous:
				num_relative_populous += 1
			if subsumed:
				subsumed_mask.append(False)
			else:
				subsumed_mask.append(True)

		assert len(subsumed_mask) == len(rels)
		assert len(subsumed_mask) == len(rel_predicates)
		for rid, rel in enumerate(rels):
			if subsumed_mask[rid] is True:
				new_rels.append(rel)
				if rel_predicates[rid] not in new_pred_buckets:
					new_pred_buckets[rel_predicates[rid]] = 0
				new_pred_buckets[rel_predicates[rid]] += 1

		#if len(rels) > 0:
		#	assert len(new_rels) > 0
		new_total_rels += len(new_rels)
		if len(new_rels) > new_max_num_rels_per_sent:
			new_max_num_rels_per_sent = len(new_rels)

		diff_len = len(rels) - len(new_rels)
		if diff_len > 0:
			sents_involving_reduce += 1
		total_rels_reduced += diff_len

		item["rels"] = new_rels

		out_line = json.dumps(item, ensure_ascii=False)
		output_fp.write(out_line+'\n')

print(f"lidx: {lidx}; sents involving reduce: {sents_involving_reduce}; total number of rels reduced: {total_rels_reduced}; "
	  f"total number of rels left: {new_total_rels}; new maximum number of rels per sentence: {new_max_num_rels_per_sent}; "
	  f"num absolute populous: {num_absolute_populous}; num relative populous: {num_relative_populous}; "
	  f"number of relations trimmed by minimum: {num_rels_trimmed_by_minimum}; number of remaining unique predicates: {len(new_pred_buckets)}")

output_fp.close()


