import os
from collections import defaultdict

import CONFIG as C
import matplotlib.pyplot as plt
import pickle
from scipy.stats import pearsonr
from sklearn.metrics import mean_squared_error, r2_score
from pycaret.regression import *
from sklearn.inspection import permutation_importance
from scipy.stats import spearmanr
from scipy.cluster import hierarchy

np.random.seed(225530)

FEATURES_ALL_0_WHEN_REDUCED = [
	'PassiveActiveRatio_answer_', 'ProportionActiveVPs_answer_',
	'AgentlessPassiveCount_answer_', 'WDT_answer_',
	'WRB_answer_', 'RBS_answer_',
	'JJR_answer_', 'JJS_answer_',
	'EX_answer_', '#_answer_',
	'LS_answer_', 'SYM_question_', 'ProportionPassiveVPs_answer_',
	"TO_answer_",
	"ConditionalClausesCount_answer_", "ConditionalClausesIncidence_answer_"
]


def feature_map(data):
	out = []
	for j in FEATURES_ALL_0_WHEN_REDUCED:
		for i in data.columns:
			if i.startswith(j):
				out.append(i)
	return out


def corrfunc(x, y, ax=None, **kws):
	"""Plot the correlation coefficient in the top left hand corner of a plot."""
	r, _ = pearsonr(x, y)
	ax = ax or plt.gca()
	# Unicode for lowercase rho (ρ)
	rho = '\u03C1'
	ax.annotate(f'{rho} = {r:.2f}', xy=(.7, .1), xycoords=ax.transAxes)


def plot_corr_feature(preds, difficulties, axis):
	corrfunc(preds, difficulties, ax=axis)


def fit_models(data_set, include=[]):
	exp_reg = setup(data_set, target="target", normalize=False,
	                numeric_features=[i for i in data_set.columns if i != "target"], session_id=225530, verbose=False,
	                html=False, silent=True)
	for i in exp_reg:
		try:
			if i[0][0] == "Setup Config":
				for j in i:
					if j[0] == "X_test Set":
						X_val = j[1]
					if j[0] == "y_test Set":
						y_val = j[1]
					if j[0] == "X_training Set":
						X_train = j[1]
					if j[0] == "y_training Set":
						y_train = j[1]
		except (ValueError, IndexError, KeyError, TypeError):
			pass

	mean = sum(y_val) / len(y_val)
	BASE_MSE = [mean for _ in y_val]

	BASE_MSE = mean_squared_error(y_val, BASE_MSE)
	BASE_RMSE = np.sqrt(BASE_MSE)
	print("==== BASELINE ====")
	print(f"MSE: {BASE_MSE}")
	print(f"RMSE: {BASE_RMSE}")
	print("==== END BASELINE ====")

	if include:
		best_model = compare_models(include=include)
		table = pull()
	else:
		best_model = compare_models(exclude=["tr", "catboost"])
		table = pull()
	# table.drop(["Object"], axis=1, inplace=True)

	MODEL_MSE = mean_squared_error(y_val, best_model.predict(X_val))
	MODEL_RMSE = np.sqrt(MODEL_MSE)

	return best_model, (BASE_MSE, BASE_RMSE, MODEL_MSE, MODEL_RMSE), (X_train, y_train, X_val, y_val), table


def reduced_1pl_diff_vs_reduced_2pl_discrim():
	"""
	Experiment Purpose:
	    It seems that I am having a hard time predicting the two pl difficulty on the reduced 2PL fit. However, I can
	    predict the 1PL for both non-reduced and reduced dataset - why? Well I wonder if somehow the trends that we see
	    in the 1PL is somehow captured in the discrimination parameters rather than the difficulty parameters. I need to
	    show this then explore the meaning of this.

	    This experiment aborts when the base_mse is larger than the model_mse. This happens as this experiment is not
	    exactly useful as there is no clear discrepancy between predictability of 1PL and 2PL diff.

	Results:
		The features that predict the difficulties for the 1PL difficulties and 2PL discrims are by and large disjoint.
	:return:
	"""
	print("Starting Feature Importance between reduced 1pl diff and reduced 2pl discrim")
	working_dir = "1pl_diff_reduced_vs_2pl_discrim_clustered_features"
	if not os.path.exists(working_dir):
		os.mkdir(working_dir)

	two_pl_reduced_discrim = C.load_reduced_data("all_features_fixed_manual.csv",
	                                             C.OriginalDataset.DFGN_148_reduced.twoPL_var_irt_discrim)
	two_pl_reduced_discrim.drop(feature_map(two_pl_reduced_discrim), inplace=True, axis='columns')

	two_pl_reduced_discrim_features = two_pl_reduced_discrim[
		[i for i in two_pl_reduced_discrim.columns if i != "target"]]
	corr = spearmanr(two_pl_reduced_discrim_features).correlation
	corr_linkage = hierarchy.ward(corr)

	cluster_ids = hierarchy.fcluster(corr_linkage, 1, criterion='distance')
	cluster_id_to_feature_ids = defaultdict(list)
	for idx, cluster_id in enumerate(cluster_ids):
		cluster_id_to_feature_ids[cluster_id].append(idx)
	two_pl_reduced_discrim_features = [v[0] for v in cluster_id_to_feature_ids.values()]
	_targets = two_pl_reduced_discrim["target"]
	two_pl_reduced_discrim = two_pl_reduced_discrim.iloc[:, two_pl_reduced_discrim_features]
	two_pl_reduced_discrim["target"] = _targets

	one_pl_reduced_diff = C.load_reduced_data("all_features_fixed_manual.csv",
	                                          C.OriginalDataset.DFGN_148_reduced.onePL_var_irt_diff)
	one_pl_reduced_diff.drop(feature_map(one_pl_reduced_diff), inplace=True, axis='columns')

	one_pl_reduced_diff_features = one_pl_reduced_diff[
		[i for i in one_pl_reduced_diff.columns if i != "target"]]
	corr = spearmanr(one_pl_reduced_diff_features).correlation
	corr_linkage = hierarchy.ward(corr)

	cluster_ids = hierarchy.fcluster(corr_linkage, 1, criterion='distance')
	cluster_id_to_feature_ids = defaultdict(list)
	for idx, cluster_id in enumerate(cluster_ids):
		cluster_id_to_feature_ids[cluster_id].append(idx)
	one_pl_reduced_diff_features = [v[0] for v in cluster_id_to_feature_ids.values()]
	_targets = one_pl_reduced_diff["target"]
	one_pl_reduced_diff = one_pl_reduced_diff.iloc[:, one_pl_reduced_diff_features]
	one_pl_reduced_diff["target"] = _targets

	print(f"ONE PL SHAPE: {one_pl_reduced_diff.shape}, TWO PL SHAPE: {two_pl_reduced_discrim.shape}")

	model_info_discrim = fit_models(two_pl_reduced_discrim)
	best_model_discrim = model_info_discrim[0]
	X_val_discrim, y_val_discrim = model_info_discrim[2][2], model_info_discrim[2][3]
	table_diff = model_info_discrim[-1]
	with open(f"{working_dir}/baseline_discrim.txt", "w") as f:
		mean = sum(y_val_discrim) / len(y_val_discrim)
		BASE_MSE = [mean for _ in y_val_discrim]

		BASE_MSE = mean_squared_error(y_val_discrim, BASE_MSE)
		BASE_RMSE = np.sqrt(BASE_MSE)
		f.write("==== BASELINE ====\n")
		f.write(f"MSE: {BASE_MSE}\n")
		f.write(f"RMSE: {BASE_RMSE}\n")
		f.write("==== END BASELINE ====")
	table_diff.to_csv(f"{working_dir}/model_res_discrim.csv")

	model_info_diff = fit_models(one_pl_reduced_diff)
	best_model_diff = model_info_diff[0]
	X_val_diff, y_val_diff = model_info_diff[2][2], model_info_diff[2][3]
	table_diff = model_info_diff[-1]
	with open(f"{working_dir}/baseline_diff.txt", "w") as f:
		mean = sum(y_val_diff) / len(y_val_diff)
		BASE_MSE = [mean for _ in y_val_diff]

		BASE_MSE = mean_squared_error(y_val_diff, BASE_MSE)
		BASE_RMSE = np.sqrt(BASE_MSE)
		f.write("==== BASELINE ====\n")
		f.write(f"MSE: {BASE_MSE}\n")
		f.write(f"RMSE: {BASE_RMSE}\n")
		f.write("==== END BASELINE ====")
	table_diff.to_csv(f"{working_dir}/model_res_diff.csv")

	r = permutation_importance(best_model_discrim, X_val_discrim, y_val_discrim, random_state=225530,scoring='neg_mean_squared_error')
	                           # scoring='neg_mean_squared_error')
	with open(f"{working_dir}/output_discrim.txt", "w") as f:
		for i in r.importances_mean.argsort()[::-1]:
			if r.importances_mean[i] - 2 * r.importances_std[i] > 0:
				# try:
				# 	int(one_pl_reduced_diff.colums[i].split("_")[0])
				# 	continue
				# except (IndexError, ValueError):
				f.write(f"{two_pl_reduced_discrim.columns[i]:<8}\n")
				f.write(f"{r.importances_mean[i]:.3f}\n")
				f.write(f" +/- {r.importances_std[i]:.3f}\n")

	r = permutation_importance(best_model_diff, X_val_diff, y_val_diff, random_state=225530,scoring="neg_mean_squared_error")
	                           # scoring="neg_mean_squared_error")
	with open(f"{working_dir}/output_diff.txt", "w") as f:
		for i in r.importances_mean.argsort()[::-1]:
			if r.importances_mean[i] - 2 * r.importances_std[i] > 0:
				# try:
				# 	int(two_pl_reduced_discrim.colums[i].split("_")[0])
				# 	continue
				# except (IndexError, ValueError):
				f.write(f"{one_pl_reduced_diff.columns[i]:<8}\n")
				f.write(f"{r.importances_mean[i]:.3f}\n")
				f.write(f" +/- {r.importances_std[i]:.3f}\n")


reduced_1pl_diff_vs_reduced_2pl_discrim()
