import argparse
import json
import spacy
from translate_utils import merge_txt_to_dict
import csv

parser = argparse.ArgumentParser()
parser.add_argument('--root', type=str, default='implications')
parser.add_argument('--num_slices', type=int, default=16)
parser.add_argument('--service', type=str, default='google')
args = parser.parse_args()
assert args.service in ['google', 'baidu']

orig_fn = './'+args.root+'/'+args.root+'_in_lines_%d.txt'
if args.service == 'google':
	trans_fn = './'+args.root+'/'+args.root+'_in_lines_translated_%d.txt'
elif args.service == 'baidu':
	trans_fn = './'+args.root+'/'+args.root+'_in_lines_translated_baidu_%d.txt'
else:
	raise AssertionError
back_fn = './'+args.root+'/'+args.root+'_in_lines_backtranslated_'+args.service+'_%d.txt'
mapping_fn = './'+args.root+'/'+args.root+'_mapping.json'
input_fn = './'+args.root+'/'+args.root+'.tsv'
filtered_pos_in_fn = './'+args.root+'/'+args.root+'_pos_input_%s_%s.json'


def prune_sent(sent):
	stopwords = ['the ', 'The ', 'A ', 'a ']
	for w in stopwords:
		sent = sent.replace(w, '')
	return sent


def acquire_accepted_pairs(input_fn, mapped, en_mapped, pos_in_fn, label, acceptable_sents):
	input_fp = open(input_fn, 'r', encoding='utf8')
	tsv_file = csv.reader(input_fp, delimiter="\t")
	out_fp = open(pos_in_fn, 'w', encoding='utf8')
	accepted_count = 0
	iid = 0
	for instance in tsv_file:
		assert len(instance) == 3
		en_premise = instance[0]
		en_hypo = instance[1]
		en_prem_in_sent = en_mapped[en_premise]
		en_hypo_in_sent = en_mapped[en_hypo]
		if en_prem_in_sent in acceptable_sents and en_hypo_in_sent in acceptable_sents:
			trans_premise = mapped[en_premise]
			trans_hypo = mapped[en_hypo]
			doc = {'splitted_text': [trans_premise, trans_hypo], 'value': instance[2]}
			out_line = json.dumps(doc, ensure_ascii=False)
			out_fp.write(out_line+'\n')
			accepted_count += 1
		iid += 1
	input_fp.close()
	out_fp.close()
	print(f"Number of accepted proposition pairs for {label}: {accepted_count}/{iid}!")


exact_match_count = 0
exact_match_ls = []
lemma_match_count = 0
lemma_match_ls = []
pruned_exact_match_count = 0
pruned_exact_match_ls = []
pruned_lemma_match_count = 0
pruned_lemma_match_ls = []
nlp = spacy.load('en_core_web_sm')
orig_lines = []
trans_lines = []
back_lines = []

for i in range(args.num_slices):
	orig_fp = open(orig_fn%i, 'r', encoding='utf8')
	trans_fp = open(trans_fn%i, 'r', encoding='utf8')
	back_fp = open(back_fn%i, 'r', encoding='utf8')
	orig_lines += [item.strip() for item in orig_fp.readlines() if len(item.strip())>0]
	trans_lines += [item.strip() for item in trans_fp.readlines() if len(item.strip())>0]
	back_lines += [item.strip() for item in back_fp.readlines() if len(item.strip())>0]
	orig_fp.close()
	trans_fp.close()
	back_fp.close()

assert len(orig_lines) == len(back_lines) and len(orig_lines) == len(trans_lines)

print("Begins!")

for lid, (ol, bl) in enumerate(zip(orig_lines, back_lines)):
	if lid % 1000 == 0 and lid > 0:
		print(lid)
	if ol == bl:
		exact_match_count += 1
		lemma_match_count += 1
		exact_match_ls.append(ol)
		lemma_match_ls.append(ol)
		continue
	od = nlp(ol)
	bd = nlp(bl)
	lemma_matched = True
	for ot, bt in zip(od, bd):
		if ot.lemma_ != bt.lemma_:
			lemma_matched = False
			break
	if lemma_matched:
		lemma_match_count += 1
		lemma_match_ls.append(ol)

for lid, (ol, bl) in enumerate(zip(orig_lines, back_lines)):
	if lid % 1000 == 0 and lid > 0:
		print(lid)
	pruned_ol = prune_sent(ol)
	pruned_bl = prune_sent(bl)
	if pruned_ol == pruned_bl:
		pruned_exact_match_count += 1
		pruned_lemma_match_count += 1
		pruned_exact_match_ls.append(ol)
		pruned_lemma_match_ls.append(ol)
		continue
	pruned_od = nlp(pruned_ol)
	pruned_bd = nlp(pruned_bl)
	pruned_lemma_matched = True
	for pot, pbt in zip(pruned_od, pruned_bd):
		if pot.lemma_ != pbt.lemma_:
			pruned_lemma_matched = False
			break
	if pruned_lemma_matched:
		pruned_lemma_match_count += 1
		pruned_lemma_match_ls.append(ol)


print(f"Stats for #{args.root}#: ")
print(f"Out of the {len(orig_lines)} propositions: ")
print(f"{exact_match_count} propositions have exact string match after back translation;")
print(f"{lemma_match_count} propositions have lemma match after back translation;")
print(f"{pruned_exact_match_count} propositions have pruned exact string match after back translation;")
print(f"{pruned_lemma_match_count} propositions have pruned lemma match after back translation!")

mapped = merge_txt_to_dict(trans_fn, mapping_fn, args.num_slices)
en_mapped = merge_txt_to_dict(orig_fn, mapping_fn, args.num_slices)
acquire_accepted_pairs(input_fn, mapped, en_mapped, filtered_pos_in_fn%(args.service, 'EM'), 'exact match', exact_match_ls)
acquire_accepted_pairs(input_fn, mapped, en_mapped, filtered_pos_in_fn%(args.service, 'LM'), 'lemma match', lemma_match_ls)
acquire_accepted_pairs(input_fn, mapped, en_mapped, filtered_pos_in_fn%(args.service, 'PEM'), 'pruned exact match', pruned_exact_match_ls)
acquire_accepted_pairs(input_fn, mapped, en_mapped, filtered_pos_in_fn%(args.service, 'PLM'), 'pruned lemma match', pruned_lemma_match_ls)

print("Finished!")
