import pandas as pd
from joblib import load
import IRT.CONFIG as C
from pycaret.regression import *
import numpy as np

ID = 0
org_idx = 3619
org_score = 2.27

java_features = pd.read_csv("questions_for_feature_generation/out.csv", sep="\t")
java_features.drop(["ItemID", "Meantime"], axis=1, inplace=True)
java_features.drop(java_features.columns[-1], axis=1, inplace=True)
answer_features = [f"{i}_answer" for i in java_features.columns]
java_features.rename(columns={i: f"{i}_question" for i in java_features.columns}, inplace=True)

bert = pd.read_csv("generated_features/0/bert.csv")
bert.drop([bert.columns[0]], inplace=True, axis=1)
bert.rename(columns={i: f"{i}_question" for i in bert.columns}, inplace=True)
manual = pd.read_csv("generated_features/0/manual.csv").drop(["index"], axis=1)Pe

org_answer_features = C.get_textual_features("all_features_fixed_manual.csv").iloc[org_idx, :]
answer_bert = [f"{i}_answer" for i in range(768)]
answer_features = [i for i in answer_features if i in pd.DataFrame(org_answer_features).transpose().columns]
answer_bert = pd.DataFrame(org_answer_features[answer_bert]).transpose().rename(index={org_idx: 0})
answer_java = pd.DataFrame(org_answer_features[answer_features]).transpose().rename(index={org_idx: 0})

final = manual.join([bert, java_features, answer_java, answer_bert])

# reduced_dev = C.get_textual_features(file_name="all_features_fixed_manual.csv")
reduced_dev = C.load_reduced_test(file_name="all_features_fixed_manual.csv", targets=C.OriginalDataset.DFGN_148_reduced.onePL_var_irt_diff).drop(["target"], axis=1)
reduced_dev = reduced_dev.loc[reduced_dev.index == 3619]
# for col in reduced_dev.columns:
# 	exact_col = [i for i in final.columns if col.startswith(i)]
# 	try:
# 		exact_col = exact_col[0]
# 	except IndexError:
# 		continue
# 	print(reduced_dev[col])
# 	if not (-10E-4 <= final[exact_col][0] - reduced_dev[col].iloc[0, :] <= 10E-4).all():
# 		print("ERROR AT")
# 		print(final[exact_col])
# 		print(reduced_dev[col][75])
# 		# print(reduced_dev.iloc[126, idx])
# 		import sys
# 		sys.exit(0)


#
model = load_model("best_model_diff.joblib")

reduced_dev_for_col_names = C.load_reduced_data(file_name="all_features_fixed_manual.csv",
                                  targets=C.OriginalDataset.DFGN_148_reduced.onePL_var_irt_diff).drop(["target"], axis=1)

reduced_dev_rename = {}
for i in reduced_dev_for_col_names.columns:
	try:
		value = [j for j in reduced_dev.columns if i.startswith(j)][0]
		reduced_dev_rename[value] = i
	except IndexError:
		continue
reduced_dev.rename(columns = reduced_dev_rename, inplace=True)

rename = {}
for i in final.columns:
	finder = [j for j in reduced_dev.columns if i == "_".join(j.split("_")[0:-1])]
	try:
		finder = finder[0]
		rename[i] = finder
	except IndexError:
		continue

not_in_out = [i for i in reduced_dev.columns if i not in final.rename(columns=rename)]
not_in_in = [i for i in final.rename(columns=rename) if i not in reduced_dev.columns]
print(len(not_in_in))

print(final.shape)
if final.shape[1] != model["trained_model"].n_features_in_:
	for key in not_in_out:
		final[key] = [0 for _ in range(final.shape[0])]
print(final.shape)

final.rename(columns=rename, inplace=True)
original_data_used = pd.read_csv("Feature_set_with_col.csv", index_col="ItemID")
rename = {}
for i in original_data_used.columns:
	finder = [j for j in final.columns if i.split("_")[-1] == j.split("_")[-1]]
	try:
		finder = finder[0]
		rename[i] = finder
	except IndexError:
		continue

original_data_used.rename(columns=rename, inplace=True)
final.drop(not_in_in, inplace=True, axis=1)
final = final.reindex(sorted(final.columns,key=lambda x: int(x.split("_")[-1])), axis=1)
# final = final[original_data_used.columns]


print(final.shape)
# print("here")
#
# breack = reduced_test = C.load_reduced_test(file_name="all_features_fixed_manual.csv",
#                                    targets=C.OriginalDataset.DFGN_148_reduced.onePL_var_irt_diff).drop(["target"], axis=1)
#[75]
print(model.predict(final))
