from LAC import LAC
import time
import json
import argparse


en_endpunc_list = ['.', ';', '?', '!', '"']
zh_endpunc_list= ['。', '；', '？', '！', '"']


def create_out_line(line_en, line_zh, lac, args):
	line_en_segged = line_en.split(' ')
	line_zh_segged = lac.run(line_zh)
	assert len(line_en_segged) > 0 and len(line_zh_segged) > 0
	if args.cutoff and (len(line_en_segged) > args.cutoff_thres or len(line_zh_segged) > args.cutoff_thres):
		raise NotImplementedError
	line_en = ' '.join(line_en_segged)
	line_zh = ' '.join(line_zh_segged)
	if len(line_en) > 1 and line_en[-1] in en_endpunc_list and line_en[-2] != ' ':
		line_en = line_en[:-1] + ' ' + line_en[-1]
	if len(line_zh) > 1 and line_zh[-1] in zh_endpunc_list and line_zh[-2] != ' ':
		line_zh = line_zh[:-1] + ' ' + line_zh[-1]
	out_line = f"{line_en} ||| {line_zh}"

	return out_line


def webhose_reformat(input_fn, output_fn, index_fn, args):
	input_fp = open(input_fn, 'r', encoding='utf8')
	output_fp = open(output_fn, 'w', encoding='utf8')
	index_fp = open(index_fn, 'w', encoding='utf8')
	lac = LAC(mode='seg')
	st = time.time()
	for lidx, line in enumerate(input_fp):
		if lidx % 1000 == 0:
			ct = time.time()
			print(f"{lidx}; {int((ct-st)) // 60} minutes; {int((ct-st)) % 60} seconds")
		item = json.loads(line)
		zh_sents = item['splitted_text']
		en_sents = item['english_splitted_text']
		assert item['translation_mismatch'] is False

		for sid, (line_zh, line_en) in enumerate(zip(zh_sents, en_sents)):
			out_line = create_out_line(line_en=line_en, line_zh=line_zh, lac=lac, args=args)
			output_fp.write(out_line + '\n')

			idx_dct = {'doc_id': lidx, 'sent_id': sid}
			idx_line = json.dumps(idx_dct, ensure_ascii=False)
			index_fp.write(idx_line + '\n')

	input_fp.close()
	output_fp.close()
	index_fp.close()


# Reformatting for newsspike does not have an index fp, because each entry here contains only one sentence,
# there should be a line-to-line alignment between the .parallel file and the json file.
def newsspike_reformat(input_fn, output_fn, args):
	input_fp = open(input_fn, 'r', encoding='utf8')
	output_fp = open(output_fn, 'w', encoding='utf8')
	lac = LAC(mode='seg')
	st = time.time()
	empty_skipped_count = 0

	for lidx, line in enumerate(input_fp):
		if lidx % 1000 == 0:
			ct = time.time()
			print(f"{lidx}; {int((ct-st)) // 60} minutes; {int((ct-st)) % 60} seconds")
		item = json.loads(line)
		line_zh = item['trans_s']
		line_en = item['s']
		if len(line_zh) == 0:
			if len(line_en.strip()) > 0:
				print(line)
			empty_skipped_count += 1
			continue
		elif line_zh == '【占位符】':
			empty_skipped_count += 1
			continue
		assert item['translation_mismatch'] is False

		out_line = create_out_line(line_en=line_en, line_zh=line_zh, lac=lac, args=args)
		output_fp.write(out_line + '\n')

	input_fp.close()
	output_fp.close()



if __name__ == '__main__':
	parser = argparse.ArgumentParser()
	parser.add_argument('--mode', type=str, default='webhose', help='[webhose/newsspike/(clue/news_crawl)]')
	parser.add_argument('--cutoff', type=int, default=0, help='whether or not to cut long sentences into short ones.')
	parser.add_argument('--cutoff_thres', type=int, default=200, help='the threshold number of tokens for cutoff.')
	args = parser.parse_args()

	args.cutoff = True if args.cutoff > 0 else False
	if args.mode == 'webhose':
		input_fn = './webhose_data_entries_with_translations_blank_filled.jsonl'
		output_fn = './webhose.parallel'
		index_fn = './webhose_alignment_idxs.json'
		webhose_reformat(input_fn, output_fn, index_fn, args)
	elif args.mode == 'newsspike':
		input_fn = './newsspike_gen8_with_translations.jsonl'
		output_fn = './newsspike.parallel'
		newsspike_reformat(input_fn, output_fn, args)
	elif args.mode == 'clue':
		raise NotImplementedError
	elif args.mode == 'news_crawl':
		raise NotImplementedError
	else:
		raise AssertionError

	print("Done.")
