import CONFIG as C 
import plotly.graph_objects as go
import pandas as pd
import numpy as np
from pycaret.regression import *
import pandas.io.formats.style
def twopl():
	import seaborn as sns
	import matplotlib.pyplot as plt
	from scipy.stats import pearsonr
	from sklearn.metrics import mean_squared_error, r2_score

	FEATURES_ALL_0_WHEN_REDUCED = [
		'PassiveActiveRatio_answer_',
		'AgentlessPassiveCount_answer_', 'WDT_answer_',
		'WRB_answer_', 'RBS_answer_',
		'JJR_answer_', 'JJS_answer_',
		'EX_answer_', '#_answer_',
		'LS_answer_', 'SYM_question_', 'ProportionPassiveVPs_answer_',
		"TO_answer_",
		"ConditionalClausesCount_answer_", "ConditionalClausesIncidence_answer_"
	]


	def feature_map(data):
		return_val = []
		for j in FEATURES_ALL_0_WHEN_REDUCED:
			for i in data.columns:
				if i.startswith(j):
					return_val.append(i)
		return return_val


	def corrfunc(x, y, ax=None, **kws):
		"""Plot the correlation coefficient in the top left hand corner of a plot."""
		r, _ = pearsonr(x, y)
		ax = ax or plt.gca()
		# Unicode for lowercase rho (ρ)
		rho = '\u03C1'
		ax.annotate(f'{rho} = {r:.2f}', xy=(.7, .1), xycoords=ax.transAxes)


	def plot_corr_feature(preds, difficulties, axis):
		corrfunc(preds, difficulties, ax=axis)


	def get_val(setup):
		for i in setup:
			try:
				if i[0][0] == "Setup Config":
					for j in i:
						if j[0] == "X_test Set":
							X_val = j[1]
						if j[0] == "y_test Set":
							y_val = j[1]
						if j[0] == "X_training Set":
							X_train = j[1]
						if j[0] == "y_training Set":
							y_train = j[1]
			except (ValueError, IndexError, KeyError, TypeError):
				pass
		return X_val, y_val


	def get_results(representation, name):
		regresssion_task = setup(representation, target="target",
		                         numeric_features=[i for i in representation.columns if i != "target"],
		                         normalize=False, session_id=225530, silent=True, html=False)

		X_val, y_val = get_val(regresssion_task)
		best_model = compare_models(exclude=['tr', 'catboost'])
		display_data = pull()
		reduced_test = C.load_reduced_test(file_name=name,
		                                            targets=C.OriginalDataset.DFGN_148_reduced.twoPL_var_irt_discrim)
		data = [i for i in regresssion_task if isinstance(i, pd.DataFrame)]
		data = [i for i in data if i.shape[0] == 2099][0]
		col = []
		if data.shape[1] != reduced_test.shape[1]:
			for j in data.columns:
				col.append([i for i in reduced_test.columns if i.split("_")[-1] == j.split("_")[-1]][0])
		else:
			col = reduced_test.columns
		X_test = reduced_test[col]
		y_test = reduced_test["target"]
		# g = sns.regplot(x=best_model.predict(X_test), y=y_test)
		# plot_corr_feature(best_model.predict(X_test), y_test, axis=g)
		with open(f"summary_backups/regression_discrim_2pl/{name}_best_test_set.txt", "w") as f:
			f.write(f"2PL (Dev) MSE: {display_data.sort_values(by='MSE')['MSE'][0]}\n")
			f.write(f"2PL (Dev) R2: {r2_score(y_val, best_model.predict(X_val))}\n")
			f.write(f"2PL (Dev) Pearson: {pearsonr(best_model.predict(X_val), y_val)}\n")
			f.write(f"2PL (Test) MSE: {mean_squared_error(best_model.predict(X_test), y_test)}\n")
			f.write(f"2PL (Test) R2: {r2_score(y_test, best_model.predict(X_test))}\n")
			f.write(f"2PL (Test) Pearson: {pearsonr(best_model.predict(X_test), y_test)}\n")

		# print(f)
		# plt.title("2PL Discrim. Prediction VS Ground Truth (Test)")
		# plt.xlabel("Predictions")
		# plt.ylabel("Ground Truth")
		# plt.savefig("PLots")
		return display_data


	representation_to_name = {
		# Question Features as focus
		"hotpot_java_extracted_features_fixed.csv": "Question Features (Java Only)",
		# "hotpot_java_extracted_features_BERT.csv": "Question Features (Java + Question BERT Embedding)",
		# "hotpot_java_extracted_features_BERT_answer.csv": "Question Features (Java + Answer BERT Embedding)",
		# "hotpot_java_extracted_features_manual_features.csv": "Question Features (Java + Manual)",
		# "hotpot_java_extracted_features_manual_features_BERT.csv": "Question Features (Java + Manual + Question BERT Embedding)",
		# "hotpot_java_extracted_features_manual_features_BERT_answer.csv": "Question Features (Java + Manual + Answer BERT Embedding)",

		# Answer Features as focus
		"hotpot_java_extracted_features_answer_fixed.csv": "Answer Features (Java Only)",
		# "hotpot_java_extracted_features_answer_BERT.csv": "Answer Features (Java + Question BERT Embedding",
		# "hotpot_java_extracted_features_answer_BERT_answer.csv": "Answer Features (Java + Answer BERT Embedding)",
		# "hotpot_java_extracted_features_answer_manual_features.csv": "Answer Features (Java + Manual)",
		# "hotpot_java_extracted_features_answer_manual_features_BERT.csv": "Answer Features (Java + Manual + Question BERT Embedding)",
		# "hotpot_java_extracted_features_answer_manual_features_BERT_answer.csv": "Answer Features (Java + Manual + Answer BERT Embedding)",

		# Neutral Features (Manual Features)
		"fixed_manual_features_only.csv": "Question Features (Manual Only)",
		# "hotpot_manual_features_BERT.csv": "Question Features (Manual + Question BERT Embedding)",
		# "hotpot_manual_features_BERT_answer.csv": "Question Features (Manual + Answer BERT Embedding)",

		# Part of speech features as focus
		"pos_tags_question_only_fixed.csv": "Question Features (P.O.S counts)",
		"pos_tags_answer_only_fixed.csv": "Answer Features (P.O.S counts)",
		# "pos_tags_for_supp_fact.csv": "Focused Context Features (P.O.S counts)",
		# "pos_tags_for_distractor_context.csv": "Distractor Context Features (P.O.S counts)",

		# BERT only,
		"BERT.csv": "Question BERT Embedding Only",
		"BERT_answer.csv": "Answer BERT Embedding Only",

		# Feature importance
		# "answer_features_java_manual_bert_a_ANOVA.csv": "Answer Combined Reduction",
		# "all_features_reduced.csv": "All Features Reduction",

		# Combined Datasets:
		"all_features_fixed_manual.csv": "All Features (Question Java, Answer Java, Manual, "
		                    "BERT Answer, BERT Question, POS Answer, POS Question)"

	}


	def build_dataset():
		representations = list(representation_to_name.keys())
		complete_features = []
		model_results = pd.DataFrame()
		# model_results = pd.read_csv("summary_backups/regression_diff_1PL/model_results_backup_23.csv")
		print(len(representations))
		idx = 0
		for representation in representations[idx:]:
			print(f"Now doing {representation}")
			if idx in {0, 3, 1, 4}:
				data = C.load_reduced_data(file_name=representation,
				                                    targets=C.OriginalDataset.DFGN_148_reduced.twoPL_var_irt_discrim, rename=False)
			else:
				data = C.load_reduced_data(file_name=representation,
				                                    targets=C.OriginalDataset.DFGN_148_reduced.twoPL_var_irt_discrim, rename=True)
			if idx == 0 or idx == 3:
				data.rename(columns={f"{i}": f"{i}_question_{idx}" for idx, i in enumerate(data.columns) if i != "target"}, inplace=True)
			if idx == 1 or idx == 4:
				data.rename(columns={f"{i}": f"{i}_answer_{idx}" for idx, i in enumerate(data.columns) if i != "target"}, inplace=True)

			data.drop(feature_map(data), inplace=True, axis='columns')
			assert (data.sum(axis=0) != 0).all(), f"Some Features were 0 at {idx}"
			# if "pos_tags" in representation or "all_features" in representation:
			# 	data.rename(
			# 		columns={column: f"{column}_{i}" for i, column in
			# 		         enumerate([i for i in data.columns if i != "target"])},
			# 		inplace=True)
			df = get_results(data, representation)
			local_results = pd.DataFrame(
				[[model, r2, mse, rmse, representation_to_name[representation]] for model, r2, mse, rmse in
				 zip(df["Model"], df["R2"], df["MSE"],
				     df["RMSE"])], columns=["Model", "R2", "MSE", "RMSE", "Representation"])
			model_results = model_results.append(local_results)
			complete_features.append(representation_to_name[representation])
			with open(f"summary_backups/regression_discrim_2pl/model_results_backup_{idx}.csv", "w") as f:
				model_results.to_csv(f)
			with open(f"summary_backups/regression_discrim_2pl/completed_at_{idx}.txt", "w") as f:
				for i in complete_features:
					f.write(f"{i}\n")
			idx += 1

		with open(f"summary_backups/regression_discrim_2pl/model_results_complete.csv", "w") as f:
			model_results.to_csv(f)


	build_dataset()

def onepl():
	import seaborn as sns
	import matplotlib.pyplot as plt
	from scipy.stats import pearsonr
	from sklearn.metrics import mean_squared_error, r2_score


	FEATURES_ALL_0_WHEN_REDUCED = [
		'PassiveActiveRatio_answer_',
		'AgentlessPassiveCount_answer_', 'WDT_answer_',
		'WRB_answer_', 'RBS_answer_',
		'JJR_answer_', 'JJS_answer_',
		'EX_answer_', '#_answer_',
		'LS_answer_', 'SYM_question_', 'ProportionPassiveVPs_answer_',
		"TO_answer_",
		"ConditionalClausesCount_answer_", "ConditionalClausesIncidence_answer_"
	]


	def feature_map(data):
		return_val = []
		for j in FEATURES_ALL_0_WHEN_REDUCED:
			for i in data.columns:
				if i.startswith(j):
					return_val.append(i)
		return return_val

	def corrfunc(x, y, ax=None, **kws):
		"""Plot the correlation coefficient in the top left hand corner of a plot."""
		r, _ = pearsonr(x, y)
		ax = ax or plt.gca()
		# Unicode for lowercase rho (ρ)
		rho = '\u03C1'
		ax.annotate(f'{rho} = {r:.2f}', xy=(.7, .1), xycoords=ax.transAxes)

	def plot_corr_feature(preds, difficulties, axis):
		corrfunc(preds, difficulties, ax=axis)

	def get_val(setup):
		for i in setup:
			try:
				if i[0][0] == "Setup Config":
					for j in i:
						if j[0] == "X_test Set":
							X_val = j[1]
						if j[0] == "y_test Set":
							y_val = j[1]
						if j[0] == "X_training Set":
							X_train = j[1]
						if j[0] == "y_training Set":
							y_train = j[1]
			except (ValueError, IndexError, KeyError, TypeError):
				pass
		return X_val, y_val

	def get_results(representation, name):
		regresssion_task = setup(representation, target="target",
		                         numeric_features=[i for i in representation.columns if i != "target"],
		                         normalize=False, session_id=225530, silent=True, html=False)

		X_val, y_val = get_val(regresssion_task)
		best_model = compare_models(exclude=['tr', 'catboost'])
		display_data = pull()
		reduced_test = C.load_reduced_test(file_name=name,
		                                            targets=C.OriginalDataset.DFGN_148_reduced.onePL_var_irt_diff)
		data = [i for i in regresssion_task if isinstance(i, pd.DataFrame)]
		data = [i for i in data if i.shape[0] == 2099][0]
		col = []
		for j in data.columns:
			col.append([i for i in reduced_test.columns if i.split("_")[-1] == j.split("_")[-1]][0])
		X_test = reduced_test[col]
		y_test = reduced_test["target"]
		# g = sns.regplot(x=best_model.predict(X_test), y=y_test)
		# plot_corr_feature(best_model.predict(X_test), y_test, axis=g)
		with open(f"summary_backups/regression_diff_1PL/{name}_best_test_set.txt", "w") as f:
			f.write(f"1PL (Dev) MSE: {display_data.sort_values(by='MSE')['MSE'][0]}\n")
			f.write(f"1PL (Dev) R2: {r2_score(y_val, best_model.predict(X_val))}\n")
			f.write(f"1PL (Dev) Pearson: {pearsonr(best_model.predict(X_val), y_val)}\n")
			f.write(f"1PL (Test) MSE: {mean_squared_error(best_model.predict(X_test), y_test)}\n")
			f.write(f"1PL (Test) R2: {r2_score(y_test, best_model.predict(X_test))}\n")
			f.write(f"1PL (Test) Pearson: {pearsonr(best_model.predict(X_test), y_test)}\n")

		# print(f)
		# plt.title("2PL Discrim. Prediction VS Ground Truth (Test)")
		# plt.xlabel("Predictions")
		# plt.ylabel("Ground Truth")
		# plt.savefig("PLots")
		return display_data

	representation_to_name = {
		# Question Features as focus
		"hotpot_java_extracted_features_fixed.csv": "Question Features (Java Only)",
		# "hotpot_java_extracted_features_BERT.csv": "Question Features (Java + Question BERT Embedding)",
		# "hotpot_java_extracted_features_BERT_answer.csv": "Question Features (Java + Answer BERT Embedding)",
		# "hotpot_java_extracted_features_manual_features.csv": "Question Features (Java + Manual)",
		# "hotpot_java_extracted_features_manual_features_BERT.csv": "Question Features (Java + Manual + Question BERT Embedding)",
		# "hotpot_java_extracted_features_manual_features_BERT_answer.csv": "Question Features (Java + Manual + Answer BERT Embedding)",

		# Answer Features as focus
		"hotpot_java_extracted_features_answer_fixed.csv": "Answer Features (Java Only)",
		# "hotpot_java_extracted_features_answer_BERT.csv": "Answer Features (Java + Question BERT Embedding",
		# "hotpot_java_extracted_features_answer_BERT_answer.csv": "Answer Features (Java + Answer BERT Embedding)",
		# "hotpot_java_extracted_features_answer_manual_features.csv": "Answer Features (Java + Manual)",
		# "hotpot_java_extracted_features_answer_manual_features_BERT.csv": "Answer Features (Java + Manual + Question BERT Embedding)",
		# "hotpot_java_extracted_features_answer_manual_features_BERT_answer.csv": "Answer Features (Java + Manual + Answer BERT Embedding)",

		# Neutral Features (Manual Features)
		"fixed_manual_features_only.csv": "Question Features (Manual Only)",
		# "hotpot_manual_features_BERT.csv": "Question Features (Manual + Question BERT Embedding)",
		# "hotpot_manual_features_BERT_answer.csv": "Question Features (Manual + Answer BERT Embedding)",

		# Part of speech features as focus
		"pos_tags_question_only_fixed.csv": "Question Features (P.O.S counts)",
		"pos_tags_answer_only_fixed.csv": "Answer Features (P.O.S counts)",
		# "pos_tags_for_supp_fact.csv": "Focused Context Features (P.O.S counts)",
		# "pos_tags_for_distractor_context.csv": "Distractor Context Features (P.O.S counts)",

		# BERT only,
		"BERT.csv": "Question BERT Embedding Only",
		"BERT_answer.csv": "Answer BERT Embedding Only",

		# Feature importance
		# "answer_features_java_manual_bert_a_ANOVA.csv": "Answer Combined Reduction",
		# "all_features_reduced.csv": "All Features Reduction",

		# Combined Datasets:
		"all_features_fixed_manual.csv": "All Features (Question Java, Answer Java, Manual, "
		                    "BERT Answer, BERT Question, POS Answer, POS Question)"
	}

	def build_dataset():
		representations = list(representation_to_name.keys())
		complete_features = []
		model_results = pd.DataFrame()
		# model_results = pd.read_csv("summary_backups/regression_diff_1PL/model_results_backup_23.csv")
		print(len(representations))
		idx = 2
		for representation in representations[idx:idx+1]:
			print(f"Now doing {representation}")
			if idx in {0, 3, 1, 4}:
				data = C.load_reduced_data(file_name=representation,
				                                    targets=C.OriginalDataset.DFGN_148_reduced.onePL_var_irt_diff, rename=False)
			else:
				data = C.load_reduced_data(file_name=representation,
				                                    targets=C.OriginalDataset.DFGN_148_reduced.onePL_var_irt_diff, rename=True)
			if idx == 0 or idx == 3:
				data.rename(columns={f"{i}": f"{i}_question_{idx}" for idx, i in enumerate(data.columns) if i != "target"}, inplace=True)
			if idx == 1 or idx == 4:
				data.rename(columns={f"{i}": f"{i}_answer_{idx}" for idx, i in enumerate(data.columns) if i != "target"}, inplace=True)

			data.drop(feature_map(data), inplace=True, axis='columns')
			assert (data.sum(axis=0) != 0).all(), f"Some Features were 0 at {idx}"
			data.drop(feature_map(data), inplace=True, axis='columns')
			# if "pos_tags" in representation or "all_features" in representation:
			# 	data.rename(
			# 		columns={column: f"{column}_{i}" for i, column in
			# 		         enumerate([i for i in data.columns if i != "target"])},
			# 		inplace=True)
			df = get_results(data, representation)
			local_results = pd.DataFrame(
				[[model, r2, mse, rmse, representation_to_name[representation]] for model, r2, mse, rmse in
				 zip(df["Model"], df["R2"], df["MSE"],
				     df["RMSE"])], columns=["Model", "R2", "MSE", "RMSE", "Representation"])
			model_results = model_results.append(local_results)
			complete_features.append(representation_to_name[representation])
			with open(f"summary_backups/regression_diff_1PL/model_results_backup_{idx}.csv", "w") as f:
				model_results.to_csv(f)
			with open(f"summary_backups/regression_diff_1PL/completed_at_{idx}.txt", "w") as f:
				for i in complete_features:
					f.write(f"{i}\n")
			idx += 1

		with open(f"summary_backups/regression_diff_1PL/model_results_complete.csv", "w") as f:
			model_results.to_csv(f)

	build_dataset()
if __name__ == "__main__":
	twopl()
	onepl()
# build_table()
