
import pandas as pd
import random
import sys
import numpy as np
import Levenshtein
from matplotlib import pyplot as plt

LEV_DISTANCE_THRESH = 10
COGNATES_FILE_FORMAT = "../cognates_files/cognates_%s_%s.csv"

def get_data(lang1, lang2):
	filename1 = f'../{lang1}.csv'
	filename2 = f'../{lang2}.csv'

	df1 = pd.read_csv(filename1, dtype={'word': str, 'source_language': str, 'etymon': str, 'raw_etymon': str})
	words1 = list(df1.to_dict()['word'].values())
	words1 = [str(w) for w in words1 if str(w)!='nan']
	df2 = pd.read_csv(filename2, dtype={'word': str, 'source_language': str, 'etymon': str, 'raw_etymon': str})
	words2 = list(df2.to_dict()['word'].values())
	words2 = [str(w) for w in words2 if str(w)!='nan']

	try:
		cognates_file = COGNATES_FILE_FORMAT % (lang1, lang2)
		df_cognates = pd.read_csv(cognates_file) 
	except:
		cognates_file = COGNATES_FILE_FORMAT % (lang2, lang1)
		df_cognates = pd.read_csv(cognates_file)

	cognates = df_cognates[[f'word_{lang1}', f'word_{lang2}']].values
	cognates = set([tuple(p) for p in cognates])

	return words1, words2, cognates

def get_headwords(lang):
	with open(f'../{lang}-headwords-clean.txt') as f:
		words = f.read().split('\n')
	return words

def generate_random(lang1, lang2, condition_f=None):
	_, _, cognates = get_data(lang1, lang2)
	words1 = get_headwords(lang1)
	words2 = get_headwords(lang2)
	negative_pairs = set()
	negative_distances = []
	n = len(cognates)
	unsuccessful = 0

	LEV_DISTANCE_THRESH = get_cognate_distance_mode(cognates, 0)


	while len(negative_pairs) < n*1.2: # generate n + 20% negative examples
		w1 = random.choice(words1)
		w2 = random.choice(words2)
		if unsuccessful > n*20: # stop if generated n*10 wrong ones in a row: maybe there are no more
			print(f'Could only generate {len(negative_pairs)} pairs out of {n}. Stopping...')
			break
		if (w1, w2) in cognates or (w1, w2) in negative_pairs:
			unsuccessful += 1
			continue
		if condition_f is not None and not condition_f(w1, w2, LEV_DISTANCE_THRESH):
			unsuccessful += 1
			continue
		negative_pairs.add((w1,w2))
		negative_distances.append(levenstein_distance(w1,w2))
		unsuccessful = 0


	print(f'All theoretical pairs: {len(words1)*len(words2)}; nr cognates: {n}')
	print(f'nr negative pairs: {len(negative_pairs)}, average Levenstein: {sum(negative_distances)/len(negative_pairs)}')
	
	return negative_pairs

def levenstein_threshold(w1, w2, thresh=None):
	if not thresh:
		thresh = LEV_DISTANCE_THRESH
	distance = levenstein_distance(w1, w2)
	return distance <= thresh

def levenstein_distance(w1, w2):
	if type(w1) != str:
		return None
	if type(w2) != str:
		return None
	return Levenshtein.distance(w1, w2)

def get_cognate_distance_mode(cognates, n=50):
	"""Get quantile of distribution of distances.
	Default is 50% (median).
	If n=0, will return mean.
	"""
	distances = [levenstein_distance(w1, w2) for (w1, w2) in cognates]
	distances = [l for l in distances if l is not None]
	avg_distance = sum(distances)/len(distances)
	median_distance = np.median(np.array(distances, dtype=np.float))
	if n in [25, 50, 75]:
		mode_distance = pd.Series(distances).describe()[f'{str(n)}%']
	else:
		mode_distance = None
	print(f'Avg cognate distance is {avg_distance}; median cognate distance is {median_distance}; cognate distance quantile {n} is {mode_distance}.')
	if n==0:
		return avg_distance
	return mode_distance

def plot_all_distance_distributions():
	languages = ['romanian', 'italian', 'spanish', 'portuguese', 'french']
	all_cognate_distances_df = pd.DataFrame()
	for i1, lang1 in enumerate(languages):
		for i2 in range(i1+1, len(languages)):
			lang2 = languages[i2]
			if lang1==lang2:
				continue
			_, _, cognates = get_data(lang1, lang2)
			distances = [levenstein_distance(w1, w2) for (w1, w2) in cognates]
			all_cognate_distances_df[f'{lang1}-{lang2}'] = pd.Series(distances)
	all_cognate_distances_df.to_csv("../cognate_distances.csv")
	all_cognate_distances_df.hist(figsize=(15,10), bins=10, range=[0,10])
	print(all_cognate_distances_df.describe())
	# plt.show()
	# title = f'Cognate pairs Levenstein distances')

if __name__=="__main__":

	lang1 = sys.argv[1]
	lang2 = sys.argv[2]

	# plot_all_distance_distributions()


	# negative_pairs = generate_random(lang1, lang2)
	# negative_dict = [{f'word_{lang1}': w1, f'word_{lang2}': w2} for (w1, w2) in negative_pairs]
	# df = pd.DataFrame.from_dict(negative_dict)
	# df.to_csv(f'../data_negative/random_{lang1}_{lang2}_extra.csv')
	# # print(len(negative_pairs))
	
	negative_pairs = generate_random(lang1, lang2, levenstein_threshold)
	negative_dict = [{f'word_{lang1}': w1, f'word_{lang2}': w2} for (w1, w2) in negative_pairs]
	df = pd.DataFrame.from_dict(negative_dict)
	df.to_csv(f'../data_negative/levenshtein_{lang1}_{lang2}_extra.csv')
	

