import json
import pickle
import numpy as np
import pandas as pd
from sklearn.preprocessing import StandardScaler
from enum import Enum
import random as rand
import os

DATA_PATH = ""
EMBEDDING_PATH = f"{DATA_PATH}/hotpot_embeddings"
GLOVE_PATH = f"{EMBEDDING_PATH}/glove_embeddings/glove_6B"
USABLE_VOCAB_SIZE = 14919
START_TOKEN_IDX = USABLE_VOCAB_SIZE + 1
END_TOKEN_IDX = USABLE_VOCAB_SIZE + 2
PADD_TOKEN_IDX = USABLE_VOCAB_SIZE + 3
UNK_TOKEN_IDX = USABLE_VOCAB_SIZE + 4
MAX_VOCAB_SIZE = USABLE_VOCAB_SIZE + 5


def get_textual_features(file_name):
	try:
		return pd.read_csv(f"{DATA_PATH}/hotpot_textual_features/{file_name}", index_col="ItemID")
	except ValueError:
		print("Seems like this might be space delimited. Trying that now.")
		return pd.read_csv(f"{DATA_PATH}/hotpot_textual_features/{file_name}", index_col="ItemID",
		                   delim_whitespace=True)


class TaskTypes(Enum):
	REGRESSION = 1
	BOW_REGRESSION = 2
	IMBALANCED_CLASSIFICATION = 3
	BALANCED_CLASSIFICATION = 4
	BOW_IMBALANCED_CLASSIFICATION = 5
	BOW_BALANCED_CLASSIFICATION = 6


class Fit:
	def __init__(self, path):
		self._path = path
		self.onePL_var_irt_diff = f"{path}/1PL/VAR_IRT/difficulties.csv"
		self.twoPL_var_irt_diff = f"{path}/2PL/VAR_IRT/difficulties.csv"
		self.twoPL_var_irt_discrim = f"{path}/2PL/VAR_IRT/discriminations.csv"
		self.twoPL_var_irt_abilities = f"{path}/2PL/VAR_IRT/abilities.csv"


class OriginalDataset:
	DFGN_148 = Fit(f"{DATA_PATH}/hotpot_model_response_patterns_data/148_DFGN")
	DFGN_148_reduced = Fit(f"{DATA_PATH}/hotpot_model_response_patterns_data/148_DFGN_Reduced")
	DFGN_294 = Fit(f"{DATA_PATH}/hotpot_model_response_patterns_data/294_DFGN")
	TEMP = f"{DATA_PATH}/hotpot_model_response_patterns_data/response_patterns_large.fixed.diffs"
	HGN = "HGN"


def get_questions():
	with open(f"{DATA_PATH}/hotpot_dev_distractor_v1.json", "r") as f:
		questions = json.load(f)
	return questions


def get_questions_train():
	with open(f"{DATA_PATH}/hotpot_train_v1.1.json", "r") as f:
		questions = json.load(f)
	return questions

def get_nltk_questions():
	with open(f"{DATA_PATH}/hotpot_dev_distractor_v1_nltk.json", "r") as f:
		questions = json.load(f)
	return questions


def get_targets(targets: str):
	"""Obtained from OriginalDataset.dataset.onePL_ etc etc etc
	:param targets: path to the specific targets
	:return:
	"""
	try:
		with open(targets, "r") as f:
			return pd.read_csv(f, index_col="buff_1")
	except (FileExistsError, FileNotFoundError):
		raise Exception("Target has not yet been created!")


def _get_classification_indexed_features(file_name, easy_indicies, hard_indicies, normalize):
	# load and index textual features
	textual_features = get_textual_features(file_name)
	if normalize:
		has_bert = "BERT" in file_name
		textual_features = _normalize_data(textual_features, has_bert)
	textual_features_no_correct = textual_features.loc[hard_indicies, :]
	textual_features_no_correct["target"] = [0 for _ in hard_indicies]
	textual_features_middle = textual_features.loc[easy_indicies, :]
	textual_features_middle["target"] = [1 for _ in easy_indicies]
	two_class_classification = textual_features_middle.append(textual_features_no_correct)
	return two_class_classification


def _normalize_data(data, has_bert):
	scalar = StandardScaler()
	if has_bert:
		bert_start_index = data.columns.tolist().index('0')
		data[data.columns[0:bert_start_index]] = scalar.fit_transform(data[data.columns[0: bert_start_index]])
		return data
	else:
		data[data.columns] = scalar.fit_transform(data[data.columns])
		return data


def get_labeled_dataset(task=TaskTypes.REGRESSION, **kwargs):
	"""
    kwargs - needs textual_features parameter when using the BOW_REGRESSION
    :param file_name:
    :param task:
    :param kwargs:
    :return:
    """
	if not kwargs.get("targets"):
		raise Exception("targets keyword must be specified! Must be a string from OriginalDataset. etc etc")
	normalize = kwargs.get("normalize", False)
	if task in {TaskTypes.REGRESSION, TaskTypes.BOW_REGRESSION}:
		targets = get_targets(kwargs.get("targets"))
		targets.rename(columns={"buff_2": "target"}, inplace=True)

		if task == TaskTypes.REGRESSION:
			textual_features = get_textual_features(kwargs["file_name"])
			if normalize:
				has_bert = "BERT" in kwargs["file_name"]
				textual_features = _normalize_data(textual_features, has_bert)
			textual_features = textual_features.join(targets["target"])
			return textual_features

		elif task == TaskTypes.BOW_REGRESSION:
			bag_focus = kwargs.get("bag_focus", "question")
			questions = get_questions()
			corpus = []
			for item in questions:
				if bag_focus == "question_answer":
					corpus.append(f"{item['question']} {item['answer']}")
				else:
					corpus.append(item[bag_focus])
			features_to_use = kwargs.get("textual_features", False)

			corpus = pd.DataFrame(corpus)
			if features_to_use:
				features = get_textual_features(features_to_use)
				corpus = corpus.join(features)
			corpus = corpus.join(targets["target"])
			return corpus

	elif task in {TaskTypes.IMBALANCED_CLASSIFICATION,
	              TaskTypes.BALANCED_CLASSIFICATION,
	              TaskTypes.BOW_IMBALANCED_CLASSIFICATION,
	              TaskTypes.BOW_BALANCED_CLASSIFICATION}:
		# Find Questions that no model/all models got correct via response patterns
		with open(
				f"{DATA_PATH}/hotpot_model_response_patterns_data/raw_data/original_48/hotpot_qa_response_patterns.csv") as f:
			response_patterns = pd.read_csv(f)
		copy = response_patterns.reset_index()
		copy.drop(['index'], axis=1, inplace=True)

		if task in {TaskTypes.IMBALANCED_CLASSIFICATION, TaskTypes.BOW_IMBALANCED_CLASSIFICATION}:
			special_hard_index = copy.sum(axis=0) == 0
			indicies_of_hard = special_hard_index[special_hard_index].index.values
			indicies_of_hard = [int(i) for i in indicies_of_hard]

			special_easy_index = copy.sum(axis=0) == 47
			indicies_of_easy = special_easy_index[
				special_easy_index].index.values
			indicies_of_easy = [int(i) for i in indicies_of_easy]
			if task == TaskTypes.IMBALANCED_CLASSIFICATION:
				return _get_classification_indexed_features(kwargs["file_name"],
				                                            indicies_of_easy, indicies_of_hard, normalize)
			else:
				# TODO: Support feature grabbing along with bow model
				bag_focus = kwargs.get("bag_focus", "question")
				full_dataset = get_labeled_dataset(task=TaskTypes.BOW_REGRESSION, bag_focus=bag_focus)
				easy_questions = full_dataset.loc[indicies_of_easy, :]
				hard_questions = full_dataset.loc[indicies_of_hard, :]
				easy_questions.loc[:, "target"] = [1 for _ in range(len(easy_questions))]
				hard_questions.loc[:, "target"] = [0 for _ in range(len(hard_questions))]
				easy_questions = easy_questions.append(hard_questions)
				return easy_questions

		elif task in {TaskTypes.BALANCED_CLASSIFICATION, TaskTypes.BOW_BALANCED_CLASSIFICATION}:
			question_difficulties = get_targets()

			copy = response_patterns.reset_index()
			copy.drop(['index'], axis=1, inplace=True)
			special_hard_index = copy.sum(axis=0) == 0
			indicies_of_hard = special_hard_index[special_hard_index].index.values
			indicies_of_hard = [int(i) for i in indicies_of_hard]

			# - 2 to -4
			indicies_of_easy = question_difficulties.loc[(question_difficulties["buff_2"] > -10)
			                                             & (question_difficulties["buff_2"] < -4)]["buff"]
			if task == TaskTypes.BALANCED_CLASSIFICATION:
				return _get_classification_indexed_features(kwargs["file_name"], indicies_of_easy, indicies_of_hard,
				                                            normalize)
			else:
				bag_focus = kwargs.get("bag_focus", "question")
				full_dataset = get_labeled_dataset(task=TaskTypes.BOW_REGRESSION, bag_focus=bag_focus)
				easy_questions = full_dataset.loc[indicies_of_easy, :]
				hard_questions = full_dataset.loc[indicies_of_hard, :]
				easy_questions.loc[:, "target"] = [1 for _ in range(len(easy_questions))]
				hard_questions.loc[:, "target"] = [0 for _ in range(len(hard_questions))]
				easy_questions = easy_questions.append(hard_questions)
				return easy_questions


def generate_word2idx_and_idx_to_vector():
	words = []
	idx = 0
	word2idx = {}
	vectors = []
	with open(f'{GLOVE_PATH}/glove.6B.300d.txt', 'rb') as f:
		for l in f:
			line = l.decode().split()
			word = line[0]
			words.append(word)
			word2idx[word] = idx
			idx += 1
			vect = np.array(line[1:]).astype(np.float)
			vectors.append(vect)

	vectors = pd.DataFrame(vectors).transpose().to_hdf(f'{GLOVE_PATH}/6B.300.h5', key="ic")
	pickle.dump(words, open(f'{GLOVE_PATH}/6B.300_words.pkl', 'wb'))
	pickle.dump(word2idx, open(f'{GLOVE_PATH}/6B.300_idx.pkl', 'wb'))


try:
	GLoVE_WRD_TO_IDX = pickle.load(open(f'{GLOVE_PATH}/6B.300_idx.pkl', 'rb'))
except Exception:
	print("Could not load GLoVE files, generating")
	generate_word2idx_and_idx_to_vector()
	GLoVE_WRD_TO_IDX = pickle.load(open(f'{GLOVE_PATH}/6B.300_idx.pkl', 'rb'))


def load_reduced_data(file_name, targets, rename=True):
	"""
	Indicies to map : Large (7405) - > Small (4000)
	"""
	indicies = json.load(open(f"{DATA_PATH}/reduced_data_indicies_train_dev.txt"))
	indicies_to_map = [idx for idx, i in
	                   enumerate(json.load(open(f"{DATA_PATH}/reduced_data_indicies.txt"))["indicies"]) if not i]
	indicies_to_map = {i: idx for idx, i in enumerate(indicies_to_map)}
	target = get_targets(targets=targets)
	target = [target["buff_2"][indicies_to_map[i]] for i in indicies]
	data_set = get_textual_features(file_name=file_name)
	data_set = data_set.loc[indicies, :]
	if rename:
		data_set.rename(columns={i: f"{i}_{idx}" for idx, i in enumerate(data_set.columns)}, inplace=True)
	data_set["target"] = target
	return data_set

def load_reduced_test(file_name, targets):
	indicies = json.load(open(f"{DATA_PATH}/reduced_data_indicies_test.txt"))
	indicies_to_map = [idx for idx, i in
	                   enumerate(json.load(open(f"{DATA_PATH}/reduced_data_indicies.txt"))["indicies"]) if not i]
	indicies_to_map = {i: idx for idx, i in enumerate(indicies_to_map)}
	target = get_targets(targets=targets)
	target = [target["buff_2"][indicies_to_map[i]] for i in indicies]
	data_set = get_textual_features(file_name=file_name)
	data_set = data_set.loc[indicies, :]
	data_set.rename(columns={i: f"{i}_{idx}" for idx, i in enumerate(data_set.columns)}, inplace=True)
	data_set["target"] = target
	return data_set
