# -*- coding: utf-8 -*-

# This code shows an example of text translation from English to Simplified-Chinese.
# This code runs on Python 2.7.x and Python 3.x.
# You may install `requests` to run this code: pip install requests
# Please refer to `https://api.fanyi.baidu.com/doc/21` for complete api document

import requests
import random
import json
from hashlib import md5
import time
import argparse
import sys
import os


endpoint = 'http://api.fanyi.baidu.com'
path = '/api/trans/vip/translate'
url = endpoint + path


# Generate salt and sign
def make_md5(s, encoding='utf-8'):
	return md5(s.encode(encoding)).hexdigest()


def query_block(appid, appkey, block_txt, back=False):
	salt = random.randint(32768, 65536)
	sign = make_md5(appid + block_txt + str(salt) + appkey)
	headers = {'Content-Type': 'application/x-www-form-urlencoded'}
	if back:
		from_lang = 'zh'
		to_lang = 'en'
	else:
		from_lang = 'en'
		to_lang = 'zh'
	payload = {'appid': appid, 'q': block_txt, 'from': from_lang, 'to': to_lang, 'salt': salt, 'sign': sign}
	cur_slept_time = 0
	res = {'error_code': 88888}

	timeout = 120.0
	for i in range(3):
		try:
			r = requests.post(url, params=payload, headers=headers, timeout=timeout)
			res = r.json()
			if 'error_code' in res and res['error_code'] == 52001:
				timeout *= 2
				continue
			else:
				break
		except Exception as e:
			print("posting exception while handling: ")
			print(block_txt)
			print("Exception: ")
			print(e)
			print("Retrying......")
			time.sleep(4)
			cur_slept_time += 4

	return res


def translate_levy_holts():
	parser = argparse.ArgumentParser()
	parser.add_argument('--root', type=str, default='implications')
	parser.add_argument('--slice_id', type=int, default=-1, help='the slice id to attend to, when set to -1 means translate all.')
	parser.add_argument('--num_slices', type=int, default=16)
	parser.add_argument('--back', type=int, default=0, help='whether or not to back-translate: 0 indicates false, 1 indicates true.')
	parser.add_argument('--block_size', type=int, default=100)
	args = parser.parse_args()

	# Set your own appid/appkey.
	appid = '20200507000442561'
	appkey = 'LL7vumyns4bGkZgNEyAI'
	path = './'+args.root+'/'+args.root+'_in_lines_'
	for sid in range(args.num_slices):
		first_line = True
		if 0 <= args.slice_id != sid:
			continue
		print(f'slice id: {sid}')
		if args.back == 0:
			in_path = path+f'{sid}.txt'
			out_path = path+f'translated_baidu_{sid}.txt'
		else:
			in_path = path+f'translated_baidu_{sid}.txt'
			out_path = path+f'backtranslated_baidu_{sid}.txt'
		back = True if args.back > 0 else False
		with open(in_path, 'r', encoding='utf8') as in_fp:
			in_lines = in_fp.readlines()
		out_fp = open(out_path, 'w', encoding='utf8')
		start_i = 0
		while start_i < len(in_lines):
			print(f"start id: {start_i}")
			end_i = min(start_i+args.block_size, len(in_lines))
			block = in_lines[start_i:end_i]
			block_txt = '\n'.join([item.strip() for item in block])
			res = query_block(appid, appkey, block_txt, back)
			if 'error_code' in res:
				print(res)
			for sent_res in res['trans_result']:
				out_line = sent_res['dst']
				if first_line:
					out_fp.write(out_line)
					first_line = False
				else:
					out_fp.write('\n'+out_line)
			start_i += args.block_size
			time.sleep(1)
		out_fp.close()
		print('')


def translate_webhose(input_fn, output_fn, last_out_fn, max_len, blankfilling=False):
	#parser = argparse.ArgumentParser()
	#parser.add_argument('--input', type=str, default='/Users/teddy/Files/Potential Corpus/WebHoses_Chinese_News_Articles/webhose_data_entries_with_corenlp_ner_and_parse.json')
	#parser.add_argument('--output', type=str, default='./webhose_data_entries_with_translations.jsonl')
	#parser.add_argument('--zh2en', type=int, default=1, help='whether or not to back-translate: 0 indicates false, 1 indicates true.')
	#parser.add_argument('--max_len', type=int, default=2000)
	#args = parser.parse_args()

	# Set your own appid/appkey.
	appid = '20200507000442561'
	appkey = 'LL7vumyns4bGkZgNEyAI'
	zh2en = True

	request_cnt = 0

	if zh2en:
		print("Translating from Chinese to English!")
	else:
		print("Translating from English to Chinese!")

	input_fp = open(input_fn, 'r', encoding='utf8')

	if os.path.exists(last_out_fn):
		ref_fp = open(last_out_fn, 'r', encoding='utf8')
	else:
		ref_fp = None

	open(output_fn, 'w', encoding='utf8').close()

	st = time.time()
	last_t = None
	total_sleep_time = 0
	total_resubmit_count = 0
	error_count = 0
	total_char_count = 0
	for iid, in_line in enumerate(input_fp):
		# if iid < 2543:
		# 	continue
		if iid % 200 == 0:
			ct = time.time()
			dur = ct - st
			dur_h = int(dur) // 3600
			dur_m = (int(dur) % 3600) // 60
			dur_s = int(dur) % 60
			print(f"{iid}; {request_cnt} requests; slept %.2f seconds; {total_resubmit_count} resubmissions; {error_count} errors; {total_char_count} chars; time lapsed: {dur_h} hours {dur_m} minutes {dur_s} seconds." % total_sleep_time)
		in_doc = json.loads(in_line)

		if ref_fp is not None:
			ref_line = ref_fp.readline()
			if len(ref_line) == 0:
				ref_fp.close()
				ref_fp = None
			else:
				ref_item = json.loads(ref_line)
				assert len(ref_item['splitted_text']) == len(in_doc['splitted_text'])
				assert len(ref_item['english_splitted_text']) == len(in_doc['splitted_text']) or ref_item['translation_mismatch']
				if (not ref_item['translation_mismatch']) or (not blankfilling):
					with open(output_fn, 'a', encoding='utf8') as ofp:
						ofp.write(ref_line.strip('\n') + '\n')
					continue

		in_doc['english_splitted_text'] = []
		in_doc['translation_mismatch'] = False
		sent_curid = 0
		while sent_curid < len(in_doc['splitted_text']):
			cur_accumulated_len = 0
			cur_sents = []
			# so that the total length of a query never exceeds max_len
			while sent_curid < len(in_doc['splitted_text']) and cur_accumulated_len+len(in_doc['splitted_text'][sent_curid]) < max_len:
				cur_sents.append(''.join(in_doc['splitted_text'][sent_curid]).replace('\n', ''))
				cur_accumulated_len += len(in_doc['splitted_text'][sent_curid])
				sent_curid += 1
			total_char_count += cur_accumulated_len
			assert len(cur_sents) > 0
			cur_block = '。 \n '.join(cur_sents)
			assert len(cur_block.split('\n')) == len(cur_sents)
			if last_t is not None:
				while time.time() - last_t < 1.3:
					time.sleep(0.2)
					total_sleep_time += 0.2
			last_t = time.time()
			res = query_block(appid, appkey, cur_block, zh2en)
			request_cnt += 1
			if 'error_code' in res:
				if int(res['error_code']) == 54003:
					time.sleep(2)
					total_sleep_time += 2
					total_resubmit_count += 1
					res = query_block(appid, appkey, cur_block, zh2en)
				if 'error_code' in res:
					print("Error code found!")
					print(res)
					in_doc['translation_mismatch'] = True
					error_count += 1
					continue
			cur_res_sents = []
			for sent_res in res['trans_result']:
				cur_res_sents += sent_res['dst'].strip('\n').split('\n')
			if len(cur_sents) != len(cur_res_sents):
				print(f"Length mismatch: {len(cur_sents)}; {len(cur_res_sents)}")
				print(cur_sents)
				print(cur_res_sents)
				in_doc['translation_mismatch'] = True
			in_doc['english_splitted_text'] += cur_res_sents
		if len(in_doc['english_splitted_text']) != len(in_doc['splitted_text']) and in_doc['translation_mismatch'] is False:
			print("ERROR!", file=sys.stderr)
			print(in_doc)
		with open(output_fn, 'a', encoding='utf8') as ofp:
			out_line = json.dumps(in_doc, ensure_ascii=False)
			ofp.write(out_line+'\n')

	input_fp.close()
	print("Finished.")


def translate_newsspike(input_fn, output_fn, ref_fn, max_len, blankfilling=False, placeholding=False):
	#parser = argparse.ArgumentParser()
	#parser.add_argument('--input', type=str,
	#					default='/Users/teddy/eclipse-workspace/entGraph_mod/downloaded/new_allRels.txt')
	#parser.add_argument('--output', type=str, default='./newsspike_all_rels_with_translations.jsonl')
	#parser.add_argument('--max_len', type=int, default=5400)
	#args = parser.parse_args()

	# Set your own appid/appkey.
	appid = '20210602000850735'
	appkey = 'Cbvpqfg49TfGBmruz9dS'
	zh2en = False

	total_num_lines = 0
	# total number of sentences in NewsSpike: 11955986!
	with open(input_fn, 'r', encoding='utf8') as fp:
		for vain_line in fp:
			total_num_lines += 1

	print(f"Translating a total of {total_num_lines} sentences!")

	input_fp = open(input_fn, 'r', encoding='utf8')

	open(output_fn, 'w', encoding='utf8').close()

	already_translated_ids = set()
	try:
		with open(ref_fn, 'r', encoding='utf8') as rfp:
			for lidx, line in enumerate(rfp):
				if lidx % 10000 == 0:
					print(lidx)
				already_translated_item = json.loads(line)
				# if we're filling in the blanks, those that have mismatches should not be considered 'already-translated',
				# thus should not be skipped.
				if already_translated_item['translation_mismatch'] is True and blankfilling:
					continue
				cur_id = f"{already_translated_item['articleId']}-{already_translated_item['lineId']}"
				assert cur_id not in already_translated_ids
				already_translated_ids.add(cur_id)
				with open(output_fn, 'a', encoding='utf8') as ofp:
					ofp.write(line.strip('\n') + '\n')
		already_translated_len = len(already_translated_ids)
		print(f"{already_translated_len} entries have already been translated and are here inherited!")
	except FileNotFoundError:
		print(f"No existing output file found!")

	st = time.time()
	input_lines_buffer = []
	empty_sent_ids_buffer = []
	accumulated_len = 0
	request_cnt = 0
	total_sleep_time = 0
	total_resubmit_count = 0
	error_count = 0
	total_char_count = 0
	inherited_count = 0
	handled_count = 0
	last_t = None
	for lidx in range(total_num_lines+1):
		if lidx % 10000 == 0:
			ct = time.time()
			dur = ct - st
			dur_h = int(dur) // 3600
			dur_m = (int(dur) % 3600) // 60
			dur_s = int(dur) % 60
			print(f"{lidx}/{total_num_lines} ({inherited_count} inherited); {handled_count} sents handled; {request_cnt} requests; slept %.2f seconds; {total_resubmit_count} resubmissions; {error_count} errors; {total_char_count} chars; time lapsed: {dur_h} hours {dur_m} minutes {dur_s} seconds." % total_sleep_time)

		if lidx < total_num_lines:
			in_line = input_fp.readline()
			item = json.loads(in_line)
			cur_id = f"{item['articleId']}-{item['lineId']}"
			if cur_id in already_translated_ids:
				inherited_count += 1
				continue  # the item would have already been in the output file, and thus need not to be translated again!
			elif placeholding:
				item['trans_s'] = '【占位符】'
				item['translation_mismatch'] = False
				with open(output_fn, 'a', encoding='utf8') as ofp:
					out_line = json.dumps(item, ensure_ascii=False)
					ofp.write(out_line + '\n')
				accumulated_len = 0
				continue
			else:
				handled_count += 1
			item['s'] = item['s'].replace('\n', '')
			item['trans_s'] = ''
			item['translation_mismatch'] = False

		if (lidx == total_num_lines and not placeholding) or (accumulated_len > 0 and accumulated_len+len(item['s']) >= max_len):
			if len(input_lines_buffer) != len(empty_sent_ids_buffer):
				block_txt = ' \n '.join([p['s'] for pid, p in enumerate(input_lines_buffer) if pid not in empty_sent_ids_buffer])
				assert len(block_txt.split('\n')) == len(input_lines_buffer)-len(empty_sent_ids_buffer)
				if last_t is not None:
					while time.time() - last_t < 1.3:
						time.sleep(0.2)
						total_sleep_time += 0.2
				last_t = time.time()
				res = query_block(appid, appkey, block_txt, zh2en)
				if 'error_code' in res:
					if int(res['error_code']) == 54003:
						time.sleep(2)
						total_sleep_time += 2
						total_resubmit_count += 1
						res = query_block(appid, appkey, block_txt, zh2en)
					if 'error_code' in res:
						print("Error code found!")
						print(res)
						error_count += 1
						for p in input_lines_buffer:
							p['translation_mismatch'] = True
				else:
					cur_res_sents = []
					for sent_res in res['trans_result']:
						cur_res_sents += sent_res['dst'].strip('\n').split('\n')
					if len(input_lines_buffer) != len(cur_res_sents)+len(empty_sent_ids_buffer):
						print(f"Length mismatch: {len(input_lines_buffer)}; {len(cur_res_sents)}")
						print(block_txt)
						print(cur_res_sents)
						for p in input_lines_buffer:
							p['translation_mismatch'] = True
					else:
						new_cur_res_sents = []
						while len(new_cur_res_sents) in empty_sent_ids_buffer:
							new_cur_res_sents.append('')
						for r in cur_res_sents:
							new_cur_res_sents.append(r)
							while len(new_cur_res_sents) in empty_sent_ids_buffer:
								new_cur_res_sents.append('')
						for p, t in zip(input_lines_buffer, new_cur_res_sents):
							p['trans_s'] = t.strip('\n')
				with open(output_fn, 'a', encoding='utf8') as ofp:
					for p in input_lines_buffer:
						out_line = json.dumps(p, ensure_ascii=False)
						ofp.write(out_line+'\n')

			input_lines_buffer = []
			empty_sent_ids_buffer = []
			total_char_count += accumulated_len
			accumulated_len = 0
			request_cnt += 1

		if lidx < total_num_lines:
			accumulated_len += len(item['s'])
			if item['s'].strip() == '':
				empty_sent_ids_buffer.append(len(input_lines_buffer))
			input_lines_buffer.append(item)
	print("Finished!")
	input_fp.close()


def translate_webhose_blankfilling(raw_fn, transed_fn, output_fn, max_len):
	#parser = argparse.ArgumentParser()
	#parser.add_argument('--input', type=str, default='/Users/teddy/Files/Potential Corpus/WebHoses_Chinese_News_Articles/webhose_data_entries_with_corenlp_ner_and_parse.json')
	#parser.add_argument('--output', type=str, default='./webhose_data_entries_with_translations.jsonl')
	#parser.add_argument('--zh2en', type=int, default=1, help='whether or not to back-translate: 0 indicates false, 1 indicates true.')
	#parser.add_argument('--max_len', type=int, default=2000)
	#args = parser.parse_args()

	# Set your own appid/appkey.
	appid = '20200507000442561'
	appkey = 'LL7vumyns4bGkZgNEyAI'
	zh2en = True

	request_cnt = 0

	if zh2en:
		print("Translating from Chinese to English!")
	else:
		print("Translating from English to Chinese!")

	raw_fp = open(raw_fn, 'r', encoding='utf8')
	transed_fp = open(transed_fn, 'r', encoding='utf8')
	open(output_fn, 'w', encoding='utf8').close()

	st = time.time()
	last_t = None
	total_sleep_time = 0
	total_resubmit_count = 0
	error_count = 0
	total_char_count = 0
	for iid, in_line in enumerate(raw_fp):
		# if iid < 2543:
		# 	continue
		if iid >= 2543:
			if iid % 1000 == 0:
				print(iid)
			transed_line = transed_fp.readline()
			transed_item = json.loads(transed_line)
			if transed_item['translation_mismatch'] is False:
				in_doc = json.loads(in_line)
				assert len(transed_item['english_splitted_text']) == len(in_doc['splitted_text'])
				with open(output_fn, 'a', encoding='utf8') as ofp:
					out_line = json.dumps(transed_item, ensure_ascii=False)
					ofp.write(out_line + '\n')
					continue
		if iid % 100 == 0:
			ct = time.time()
			dur = ct - st
			dur_h = int(dur) // 3600
			dur_m = (int(dur) % 3600) // 60
			dur_s = int(dur) % 60
			print(f"{iid}; {request_cnt} requests; slept %.2f seconds; {total_resubmit_count} resubmissions; {error_count} errors; {total_char_count} chars; time lapsed: {dur_h} hours {dur_m} minutes {dur_s} seconds." % total_sleep_time)
		in_doc = json.loads(in_line)

		if 'english_splitted_text' in in_doc and len(in_doc['english_splitted_text']) == len(in_doc['splitted_text']):
			with open(output_fn, 'a', encoding='utf8') as ofp:
				out_line = json.dumps(in_doc, ensure_ascii=False)
				ofp.write(out_line + '\n')
				continue
		in_doc['english_splitted_text'] = []
		in_doc['translation_mismatch'] = False
		sent_curid = 0
		while sent_curid < len(in_doc['splitted_text']):
			cur_accumulated_len = 0
			cur_sents = []
			# so that the total length of a query never exceeds max_len
			while sent_curid < len(in_doc['splitted_text']) and cur_accumulated_len+len(in_doc['splitted_text'][sent_curid]) < max_len:
				cur_sents.append(''.join(in_doc['splitted_text'][sent_curid]).replace('\n', ''))
				cur_accumulated_len += len(in_doc['splitted_text'][sent_curid])
				sent_curid += 1
			total_char_count += cur_accumulated_len
			assert len(cur_sents) > 0
			cur_block = '。 \n '.join(cur_sents)
			assert len(cur_block.split('\n')) == len(cur_sents)
			if last_t is not None:
				while time.time() - last_t < 1.3:
					time.sleep(0.2)
					total_sleep_time += 0.2
			last_t = time.time()
			res = query_block(appid, appkey, cur_block, zh2en)
			request_cnt += 1
			if 'error_code' in res:
				if int(res['error_code']) == 54003:
					time.sleep(2)
					total_sleep_time += 2
					total_resubmit_count += 1
					res = query_block(appid, appkey, cur_block, zh2en)
				if 'error_code' in res:
					print("Error code found!")
					print(res)
					in_doc['translation_mismatch'] = True
					error_count += 1
					continue
			cur_res_sents = []
			for sent_res in res['trans_result']:
				cur_res_sents += sent_res['dst'].strip('\n').split('\n')
			if len(cur_sents) != len(cur_res_sents):
				print(f"Length mismatch: {len(cur_sents)}; {len(cur_res_sents)}")
				print(cur_sents)
				print(cur_res_sents)
				in_doc['translation_mismatch'] = True
			in_doc['english_splitted_text'] += cur_res_sents
		if len(in_doc['english_splitted_text']) != len(in_doc['splitted_text']) and in_doc['translation_mismatch'] is False:
			print("ERROR!", file=sys.stderr)
			print(in_doc)
		with open(output_fn, 'a', encoding='utf8') as ofp:
			out_line = json.dumps(in_doc, ensure_ascii=False)
			ofp.write(out_line+'\n')

	raw_fp.close()
	transed_fp.close()
	print("Finished.")



if __name__ == '__main__':
	parser = argparse.ArgumentParser()
	parser.add_argument('--mode', type=str, default='both')
	parser.add_argument('--host', type=str, default='server')
	args = parser.parse_args()
	if args.mode == 'webhose':
		if args.host == 'local':
			input_fn = '/Users/teddy/Files/Potential Corpus/WebHoses_Chinese_News_Articles/webhose_data_entries_no_corenlp.jsonl'
		elif args.host == 'server':
			input_fn = 'webhose_data_entries_no_corenlp.jsonl'
		else:
			raise AssertionError
		output_fn = './webhose_data_entries_with_translations.jsonl'
		ref_fn = './webhose_data_entries_with_translations_last.jsonl'
		max_len = 1750
		translate_webhose(input_fn, output_fn, ref_fn, max_len, blankfilling=False)

	elif args.mode == 'newsspike':
		if args.host == 'local':
			input_fn = '/Users/teddy/eclipse-workspace/entGraph_mod/downloaded/news_gen8_p.json'
		elif args.host == 'server':
			input_fn = './news_gen8_p.json'
		else:
			raise AssertionError
		output_fn = './newsspike_gen8_with_translations.jsonl'
		max_len = 4000
		translate_newsspike(input_fn, output_fn, max_len)

	elif args.mode == 'both':
		if args.host == 'local':
			input_fn = '/Users/teddy/Files/Potential Corpus/WebHoses_Chinese_News_Articles/webhose_data_entries_no_corenlp.jsonl'
		elif args.host == 'server':
			input_fn = 'webhose_data_entries_no_corenlp.jsonl'
		else:
			raise AssertionError
		output_fn = './webhose_data_entries_with_translations.jsonl'
		ref_fn = './webhose_data_entries_with_translations_last.jsonl'
		max_len = 1750
		translate_webhose(input_fn, output_fn, ref_fn, max_len, blankfilling=False)

		print("Webhose Entries Translated!")

		if args.host == 'local':
			input_fn = '/Users/teddy/eclipse-workspace/entGraph_mod/downloaded/news_gen8_p.json'
		elif args.host == 'server':
			input_fn = './news_gen8_p.json'
		else:
			raise AssertionError
		output_fn = './newsspike_gen8_with_translations.jsonl'
		ref_fn = './newsspike_gen8_with_translations_last.jsonl'
		max_len = 4000
		translate_newsspike(input_fn, output_fn, ref_fn, max_len)

	elif args.mode == 'webhose_blanks':
		if args.host == 'local':
			raw_fn = '/Users/teddy/Files/Potential Corpus/WebHoses_Chinese_News_Articles/webhose_data_entries_no_corenlp.jsonl'
		elif args.host == 'server':
			raw_fn = 'webhose_data_entries_no_corenlp.jsonl'
		else:
			raise AssertionError
		transed_fn = './webhose_data_entries_with_translations.jsonl'
		output_fn = './webhose_data_entries_with_translations_blank_filled.jsonl'
		max_len = 1000
		translate_webhose(raw_fn, output_fn, transed_fn, max_len, blankfilling=True)

	elif args.mode == 'newsspike_blanks':
		if args.host == 'local':
			input_fn = '/Users/teddy/eclipse-workspace/entGraph_mod/downloaded/news_gen8_p.json'
		elif args.host == 'server':
			input_fn = './news_gen8_p.json'
		else:
			raise AssertionError
		output_fn = './newsspike_gen8_with_translations.jsonl'
		ref_fn = './newsspike_gen8_with_translations_before_blankfilling.jsonl'
		max_len = 0
		translate_newsspike(input_fn, output_fn, ref_fn, max_len, blankfilling=True, placeholding=False)

	elif args.mode == 'clue':
		if args.host == 'local':
			raw_fn = '/Users/teddy/Files/Potential Corpus/WebHoses_Chinese_News_Articles/clue_data_entries_no_corenlp.jsonl'
		elif args.host == 'server':
			raw_fn = './clue_data_entries_no_corenlp.jsonl'
		else:
			raise AssertionError
		output_fn = './clue_data_entries_with_translations.jsonl'
		ref_fn = './clue_data_entries_with_translations_last.jsonl'
		max_len = 1750
		translate_webhose(raw_fn, output_fn, ref_fn, max_len, blankfilling=False)

	elif args.mode == 'clue_blanks':
		if args.host == 'local':
			raw_fn = '/Users/teddy/Files/Potential Corpus/WebHoses_Chinese_News_Articles/clue_data_entries_no_corenlp.jsonl'
		elif args.host == 'server':
			raw_fn = './clue_data_entries_no_corenlp.jsonl'
		else:
			raise AssertionError
		output_fn = './clue_data_entries_with_translations_blank_filled.jsonl'
		ref_fn = './clue_data_entries_with_translations.jsonl'
		max_len = 1000
		translate_webhose(raw_fn, output_fn, ref_fn, max_len, blankfilling=False)

	elif args.mode == 'newscrawl':
		pass

	else:
		raise NotImplementedError
	# translate_levy_holts()
