from sklearn.metrics import mean_squared_error

import IRT.CONFIG as C
import pandas as pd
from pycaret.regression import *
import json
import matplotlib.pyplot as plt

FEATURES_ALL_0_WHEN_REDUCED = [
	'PassiveActiveRatio_answer_',
	'AgentlessPassiveCount_answer_', 'WDT_answer_',
	'WRB_answer_', 'RBS_answer_',
	'JJR_answer_', 'JJS_answer_',
	'EX_answer_', '#_answer_',
	'LS_answer_', 'SYM_question_', 'ProportionPassiveVPs_answer_',
	"TO_answer_",
	"ConditionalClausesCount_answer_", "ConditionalClausesIncidence_answer_"
]


def feature_map(data):
	return_val = []
	for j in FEATURES_ALL_0_WHEN_REDUCED:
		for i in data.columns:
			if i.startswith(j):
				return_val.append(i)
	return return_val
# manual_features = pd.read_csv("manual_features_fixed.csv", index_col="ItemID")
# java_features_question = C.get_textual_features("hotpot_java_extracted_features_fixed.csv")
# java_features_answer = C.get_textual_features("hotpot_java_extracted_features_answer_fixed.csv")
# bert_question = C.get_textual_features("BERT.csv")
# bert_answer = C.get_textual_features("BERT_answer.csv")
# java_features_question.rename(columns={i:f"{i}_question" for i in java_features_question}, inplace=True)
# bert_question.rename(columns={i:f"{i}_question" for i in bert_question}, inplace=True)
# bert_answer.rename(columns={i:f"{i}_answer" for i in bert_answer}, inplace=True)
# java_features_answer.rename(columns={i:f"{i}_answer" for i in java_features_answer}, inplace=True)
# full = pd.concat([manual_features, java_features_answer, java_features_question, bert_question, bert_answer], axis=1)
# full.to_csv(f"{C.DATA_PATH}/hotpot_textual_features/all_features_fixed_manual.csv")


indicies = json.load(open(f"{C.DATA_PATH}/reduced_data_indicies_test.txt"))
indicies_to_map = [idx for idx, i in
                   enumerate(json.load(open(f"{C.DATA_PATH}/reduced_data_indicies.txt"))["indicies"]) if not i]
indicies_to_map = {idx: i for idx, i in enumerate(indicies_to_map)}

reduced_dev = C.load_reduced_data(file_name="all_features_fixed_manual.csv",
                                  targets=C.OriginalDataset.DFGN_148_reduced.onePL_var_irt_diff)
reduced_dev.drop(feature_map(reduced_dev), inplace=True, axis='columns')
exp_reg = setup(reduced_dev, target="target", normalize=False,
                numeric_features=[i for i in reduced_dev.columns if i != "target"], session_id=225530, verbose=False,
                html=False, silent=True)
for i in exp_reg:
	try:
		if i[0][0] == "Setup Config":
			for j in i:
				if j[0] == "X_test Set":
					X_val = j[1]
				if j[0] == "y_test Set":
					y_val = j[1]
				if j[0] == "X_training Set":
					X_train = j[1]
				if j[0] == "y_training Set":
					y_train = j[1]
	except (ValueError, IndexError, KeyError, TypeError):
		pass

mean = sum(y_train) / len(y_train)
BASE_MSE = [mean for _ in y_val]

BASE_MSE = mean_squared_error(y_val, np.array(BASE_MSE))
BASE_RMSE = np.sqrt(BASE_MSE)
reduced_test = C.load_reduced_test(file_name="all_features_fixed_manual.csv",
                                            targets=C.OriginalDataset.DFGN_148_reduced.onePL_var_irt_diff)
reduced_test.drop(feature_map(reduced_test), inplace=True, axis='columns')

print("==== BASELINE ====")
print(f"MSE: {BASE_MSE}")
print(f"RMSE: {BASE_RMSE}")
print("==== END BASELINE ====")
print("==== BASELINE ====")
print(f"MSE: {mean_squared_error(y_train, np.array([mean for _ in y_train]))}")
print(f"RMSE: {np.sqrt(mean_squared_error(y_train, np.array([mean for _ in y_train])))}")
print("==== END BASELINE ====")
print("==== BASELINE ====")
print(f"MSE: {mean_squared_error(reduced_test['target'], np.array([mean for _ in reduced_test['target']]))}")
print(f"RMSE: {np.sqrt(mean_squared_error(reduced_test['target'], np.array([mean for _ in reduced_test['target']])))}")
print("==== END BASELINE ====")
best_model, table = compare_models(exclude=["tr", "catboost"])
# best_model, table = compare_models(include=['lr'])
table.to_csv("out.csv")
print(table)
save_model(best_model, "best_model.joblib")
fig, ax = plt.subplots()
annot = ax.annotate("", xy=(0, 0), xytext=(20, 20), textcoords="offset points",
                    bbox=dict(boxstyle="round", fc="w"),
                    arrowprops=dict(arrowstyle="->"))
annot.set_visible(False)
reduced_test = C.load_reduced_test(file_name="all_features_fixed_manual.csv",
                                   targets=C.OriginalDataset.DFGN_148_reduced.onePL_var_irt_diff)
reduced_test.drop(feature_map(reduced_test), inplace=True, axis='columns')
data = [i for i in exp_reg if isinstance(i, pd.DataFrame)][1]
col = []
col_map = {}
data.to_csv("Feature_set_with_col.csv")
for j in data.columns:
	col_name = [i for i in reduced_test.columns if i.split("_")[-1] == j.split("_")[-1]][0]
	col.append(col_name)
	col_map[col_name] = j
best_model = load_model("best_model.joblib")
# X_test = reduced_test[col]
X_test = reduced_test
X_test.to_csv("X_TEST")
# X_test.rename(columns=col_map, inplace=True)
y_test = reduced_test["target"]
sc = plt.scatter(x=best_model.predict(X_test), y=y_test)
out_final = pd.DataFrame(best_model.predict(X_test))
out_final.to_csv("out_preds_final.csv")
def update_annot(ind):
	pos = sc.get_offsets()[ind["ind"][0]]
	annot.xy = pos
	text = "{}".format("".join(str([f"{reduced_test.iloc[n, :].name}:{n} & {y_test.iloc[n]}" for n in ind["ind"]])))
	# text = "{}".format("".join(str([n for n in ind["ind"]])))
	annot.set_text(text)
	annot.get_bbox_patch().set_alpha(0.4)

def hover(event):
	vis = annot.get_visible()
	if event.inaxes == ax:
		cont, ind = sc.contains(event)
		if cont:
			update_annot(ind)
			annot.set_visible(True)
			fig.canvas.draw_idle()
		else:
			if vis:
				annot.set_visible(False)
				fig.canvas.draw_idle()

fig.canvas.mpl_connect("motion_notify_event", hover)
plt.title("1PL Difficulty Prediction VS Ground Truth (Test)")
plt.xlabel("Predictions")
plt.ylabel("Ground Truth")
plt.show()
