from sklearn.metrics import mean_squared_error

import CONFIG as C
import umap
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd

question_bert = C.load_reduced_test("BERT.csv", targets=C.OriginalDataset.DFGN_148_reduced.onePL_var_irt_diff)
answer_bert = C.load_reduced_test("BERT_answer.csv", targets=C.OriginalDataset.DFGN_148_reduced.onePL_var_irt_diff)
# question_bert_2 = C.load_reduced_data("BERT.csv", targets=C.OriginalDataset.DFGN_148_reduced.onePL_var_irt_diff)
# question_bert = question_bert.append(question_bert_2)
question_reducer = umap.UMAP(random_state=225530)
answer_reducer = umap.UMAP(random_state=225530)

question_reduced = question_reducer.fit_transform(question_bert[[i for i in question_bert.columns if i != "target"]])
answer_reduced = answer_reducer.fit_transform(answer_bert[[i for i in answer_bert.columns if i != "target"]])
# question_reduced_2 = question_reducer.fit_transform(question_bert_2[[i for i in question_bert_2.columns if i != "target"]])

# plt.scatter(
# 	question_reduced[:, 0],
# 	question_reduced[:, 1],
# 	c=[sns.color_palette("flare", as_cmap=True)(x) for x in question_bert["target"]]
# )
# plt.show()
fig, ax = plt.subplots()
annot = ax.annotate("", xy=(0, 0), xytext=(20, 20), textcoords="offset points",
                    bbox=dict(boxstyle="round", fc="w"),
                    arrowprops=dict(arrowstyle="->"))
annot.set_visible(False)

## TARGETS can be obtained in the counterfactual analysis files
#targets = pd.read_csv("")["diff"]
#targets.rename(index={idx: i for idx, i in enumerate(question_bert.index)}, inplace=True)

mse = abs(targets - question_bert["target"])
sc = ax.scatter(question_reduced[:, 0], question_reduced[:, 1], c=question_bert["target"], s=25, cmap=sns.color_palette("flare", as_cmap=True))
# sc = ax.scatter(question_reduced_2[:, 0], question_reduced_2[:, 1], c=question_bert_2["target"], s=25, cmap=sns.color_palette("flare", as_cmap=True))

# sc = sns.scatterplot(
# 	answer_reduced[:, 0], answer_reduced[:, 1],
# 	hue=answer_bert["target"],
# 	label="full",
# 	palette=sns.color_palette("flare", as_cmap=True)
# )
cbar = plt.colorbar(sc)
# cbar.ax.get_yaxis().set_ticks([])
# cbar.ax.get_yaxis().label_pad = 0
cbar.ax.set_ylabel('Predicted Discrim.', rotation=270, labelpad=15)

def update_annot(ind):
	pos = sc.get_offsets()[ind["ind"][0]]
	annot.xy = pos
	text = "{}".format("".join(str([f"{question_bert.index[n]}, {n}" for n in ind["ind"]])))
	# text = "{}".format("".join(str([n for n in ind["ind"]])))
	annot.set_text(text)
	annot.get_bbox_patch().set_alpha(0.4)

def hover(event):
	vis = annot.get_visible()
	if event.inaxes == ax:
		cont, ind = sc.contains(event)
		if cont:
			update_annot(ind)
			annot.set_visible(True)
			fig.canvas.draw_idle()
		else:
			if vis:
				annot.set_visible(False)
				fig.canvas.draw_idle()

fig.canvas.mpl_connect("motion_notify_event", hover)
plt.xticks([])
plt.yticks([])

# plt.title("BERT (Question) UMAP Reduction")
plt.show()
# plt.savefig("BERT_Question_UMAP_Discrim_pred_test.png", dpi=300)
