import spacy
import pandas as pd
import nltk
import string
import IRT.CONFIG as C
import pickle
import numpy as np
from sklearn.metrics.pairwise import cosine_similarity
from IRT.deep_irt.models.embedders import BERTEmbedder
import torch

from IRT.deep_irt.pre_processing import build_idx_list

nlp = spacy.load('en_core_web_sm')
punctuation = set(string.punctuation)
PUNC_FILTER = str.maketrans('', '', string.punctuation)
stop_words = set(nltk.corpus.stopwords.words('english'))


def get_subject_noun_phrases(text):
	context = nlp(text)
	n_subjs = []
	for np in context.noun_chunks:
		# if [tok for tok in np if (tok.dep_ == "nsubj")]:
		n_subjs.append(np)
	return n_subjs


def create_count_feature(name):
	return pd.Series(name=f"{name}_in_gold_contexts",
	                 index=range(0, 1), dtype=int), pd.Series(name=f"{name}_in_distractor_contexts",
	                                                          index=range(0, 1), dtype=int)


def get_counts(counts, golds, distractors):
	gold_count, distractor_count = 0, 0
	for context in golds:
		for sentence in context[1]:
			for counter in counts:
				count = sentence.lower().count(" " + str(counter))
				gold_count += count
	for context in distractors:
		for sentence in context[1]:
			for counter in counts:
				count = sentence.lower().count(" " + str(counter))
				distractor_count += count
	return gold_count, distractor_count


def get_x_snp(x, x_name, golds, distractors):
	no_snp_gold, no_snp_distractor = create_count_feature(f"no_snp_{x_name}")
	noun_phrases = [str(i).lower() for i in set(get_subject_noun_phrases(x))]
	no_snp_gold[0], no_snp_distractor[0] = get_counts(noun_phrases, golds, distractors)
	return no_snp_gold, no_snp_distractor


def get_question_snp(question, golds, distractors):
	return get_x_snp(question, "question", golds, distractors)


def get_answer_snp(answer, golds, distractors):
	return get_x_snp(answer, "answer", golds, distractors)


def get_x_ngram(x, x_name, golds, distractors, ngram, ngram_name):
	no_unigram_question_gold, no_unigram_question_distractor = create_count_feature(f"no_{ngram_name}_{x_name}")
	tokenized = nltk.word_tokenize(x)
	tokenized = [i.lower() for i in tokenized if i not in punctuation and i.lower() not in stop_words]
	tokenized = nltk.ngrams(tokenized, n=ngram)
	tokenized = [" ".join([f'{j}' for j in i]) for i in tokenized]  # Convert ngrams to strings with spaces
	no_unigram_question_gold[0], no_unigram_question_distractor[0] = get_counts(tokenized, golds, distractors)
	return no_unigram_question_gold, no_unigram_question_distractor


def get_question_ngram(question, golds, distractors, ngram, ngram_name):
	return get_x_ngram(question, "question", golds, distractors, ngram, ngram_name)


def get_answer_ngram(answer, golds, distractors, ngram, ngram_name):
	return get_x_ngram(answer, "answer", golds, distractors, ngram, ngram_name)


vectors = pd.read_hdf(f'{C.GLOVE_PATH}/6B.300.h5', key="ic")
words = pickle.load(open(f'{C.GLOVE_PATH}/6B.300_words.pkl', 'rb'))
word2idx = pickle.load(open(f'{C.GLOVE_PATH}/6B.300_idx.pkl', 'rb'))
glove = [vectors[word2idx[w]] for w in words]
PUNC_FILTER = str.maketrans('', '', string.punctuation)


def get_sentence_embedding(sentence):
	avg_vector = []
	for word in sentence.split(" "):
		try:
			avg_vector.append(glove[word2idx[word.translate(PUNC_FILTER).lower()]])
		except KeyError:
			continue
	avg_vector = np.array(avg_vector)
	if len(avg_vector) == 0:
		return np.array([np.nan])
	avg_vector = np.mean(avg_vector, axis=0)
	return avg_vector


def get_cosin_sim(question, golds, distractors):
	max_golds = pd.Series(name="max_cosine_similarity_gold", index=range(0, 1), dtype=float)
	max_distractors = pd.Series(name="max_cosine_similarity_distractors", index=range(0, 1), dtype=float)
	question_embedding = get_sentence_embedding(question)
	focused_context_embeddings = []
	distractor_context_embeddings = []
	for context in golds:
		for sentence in context[1]:
			embedding = get_sentence_embedding(sentence)
			if (embedding != embedding).all():  # There are rare cases where nan values can occur
				continue
			focused_context_embeddings.append(embedding)
	for context in distractors:
		for sentence in context[1]:
			embedding = get_sentence_embedding(sentence)
			if (embedding != embedding).all():  # There are rare cases where nan values can occur
				continue
			distractor_context_embeddings.append(embedding)
	question_embedding = question_embedding.reshape(1, -1)
	max_golds[0] = max(
		*[cosine_similarity(question_embedding, focused_embedding.reshape(1, -1)).item(0, 0) for focused_embedding
		  in focused_context_embeddings])
	if len(distractor_context_embeddings) == 0:
		max_distractors[0] = 0
	else:
		max_distractors[0] = max(
			[cosine_similarity(question_embedding, distractor_embedding.reshape(1, -1)).item(0, 0) for
			 distractor_embedding in distractor_context_embeddings])
	return max_golds, max_distractors


def get_gold_ngram(golds, distractors, ngram, ngram_name):
	feature = pd.Series(name=f"no_{ngram_name}_gold_ctx_distractor_ctx",
	                    index=range(0, 1), dtype=int)
	context_1 = golds[0][1]
	context_1.extend(golds[1][1])
	contexts = "".join(context_1)
	gold_ctx = [i.lower() for i in nltk.word_tokenize(contexts) if i.lower() not in stop_words and i not in punctuation]
	gold_ctx = nltk.ngrams(gold_ctx, n=ngram)
	gold_ctx = [" ".join([f'{j}' for j in i]) for i in gold_ctx]  # Convert ngrams to strings with spaces
	gold_ctx = [i for i in gold_ctx if i not in punctuation]
	for context in distractors:
		for sentence in context[1]:
			for grams in gold_ctx:
				count = sentence.lower().count(" " + str(grams))
				feature[0] += count
	return feature


def answer_question_features(answer, question):
	no_snp_question_in_answer = pd.Series(name="no_snp_question_in_answer",
	                                      index=range(0, 1), dtype=int)
	no_snp_question_in_answer[0] = 0
	no_unigram_question_in_answer = pd.Series(name="no_unigram_question_in_answer",
	                                          index=range(0, 1), dtype=int)
	no_unigram_question_in_answer[0] = 0
	no_bigram_question_in_answer = pd.Series(name="no_bigram_question_in_answer",
	                                         index=range(0, 1), dtype=int)
	no_bigram_question_in_answer[0] = 0
	no_trigram_question_in_answer = pd.Series(name="no_trigram_question_in_answer",
	                                          index=range(0, 1), dtype=int)
	no_trigram_question_in_answer[0] = 0

	nps = get_subject_noun_phrases(question)
	for np in nps:
		no_snp_question_in_answer[0] += answer.count(" " + str(np))

	def gram_getter(ngram):
		question_grams = nltk.word_tokenize(question)
		question_grams = [i.lower() for i in question_grams if i.lower() not in stop_words and i not in punctuation]
		question_grams = nltk.ngrams(question_grams, n=ngram)
		question_grams = [" ".join([f'{j}' for j in i]) for i in
		                  question_grams]  # Convert ngrams to strings with spaces
		return question_grams

	unigrams = gram_getter(1)
	bigrams = gram_getter(2)
	trigrams = gram_getter(3)

	for unigram in unigrams:
		no_unigram_question_in_answer[0] += answer.lower().count(" " + str(unigram))
	for bigram in bigrams:
		no_bigram_question_in_answer[0] += answer.lower().count(" " + str(bigram))
	for trigram in trigrams:
		no_trigram_question_in_answer[0] += answer.lower().count(" " + str(trigram))

	return no_snp_question_in_answer, no_unigram_question_in_answer, no_bigram_question_in_answer, no_trigram_question_in_answer


def get_pos_tags_for_text(text, name):
	feature_names = []
	current_poss = []
	current_pos = {}

	pos_tags = nltk.pos_tag(text)
	for word, tag in pos_tags:
		feature_names.append(tag)
		try:
			current_pos[tag] += 1
		except KeyError:
			current_pos[tag] = 1
	current_poss.append(current_pos)

	feature_names = set(feature_names)
	features = []
	for feature_name in feature_names:
		feature = pd.Series(name=f"{feature_name}_{name}", index=(0,), dtype=int)
		for idx, item in enumerate(current_poss):
			try:
				feature[idx] = item[feature_name]
			except KeyError:
				feature[idx] = 0
		features.append(feature)

	pos_features = pd.DataFrame(features)
	pos_features.index.names = ["ItemID"]
	pos_features = pos_features.transpose()
	return pos_features


def pos_counts(question_nltk, answer_nltk):
	# Naive way: Count the number of parts of speech
	assert question_nltk == [i for i in question_nltk if i not in stop_words]
	return get_pos_tags_for_text(question_nltk, "question"), get_pos_tags_for_text(answer_nltk, "answer")


def get_features_manual(questions, questions_nltk):
	full_frame = pd.DataFrame()
	for idx, question_info in enumerate(questions):
		question = question_info["question"]
		question_nltk = questions_nltk[idx]["question"]

		answer = question_info["answer"]
		answer_nltk = questions_nltk[idx]["answer"]

		gold_titles = set(i[0] for i in question_info["supporting_facts"])
		golds = [i for i in question_info["context"] if i[0] in gold_titles]
		distractors = [i for i in question_info["context"] if i[0] not in gold_titles]

		no_snp_q_gold, no_snp_q_dis = get_question_snp(question, golds, distractors)

		no_unigram_q_gold, no_unigram_q_dis = get_question_ngram(question, golds, distractors, 1, "unigram")
		no_bigram_q_gold, no_bigram_q_dis = get_question_ngram(question, golds, distractors, 2, "bigram")
		no_trigram_q_gold, no_trigram_q_dis = get_question_ngram(question, golds, distractors, 3, "trigram")

		no_unigram_gold_dis = get_gold_ngram(golds, distractors, 1, "unigram")
		no_bigram_gold_dis = get_gold_ngram(golds, distractors, 2, "bigram")
		no_trigram_gold_dis = get_gold_ngram(golds, distractors, 3, "trigram")

		no_snp_a_gold, no_snp_a_dis = get_answer_snp(answer, golds, distractors)

		no_unigram_a_gold, no_unigram_a_dis = get_answer_ngram(answer, golds, distractors, 1, "unigram")
		no_bigram_a_gold, no_bigram_a_dis = get_answer_ngram(answer, golds, distractors, 2, "bigram")
		no_trigram_a_gold, no_trigram_a_dis = get_answer_ngram(answer, golds, distractors, 3, "trigram")

		no_snp_a_q, no_uni_a_q, no_bi_a_q, no_tri_a_q = answer_question_features(answer, question)

		# if idx == 3:
		# 	question += ","

		q_pos, a_pos = pos_counts(question_nltk, answer_nltk)

		indexer = pd.Series(name="index", index=range(0, 1), dtype=int)
		indexer[0] = idx
		all_features = [
			indexer,
			no_snp_q_gold, no_snp_q_dis,
			no_unigram_q_gold, no_unigram_q_dis,
			no_bigram_q_gold, no_bigram_q_dis,
			no_trigram_q_gold, no_trigram_q_dis,
			no_unigram_gold_dis, no_bigram_gold_dis,
			no_trigram_gold_dis, no_unigram_a_gold, no_unigram_a_dis,
			no_bigram_a_gold, no_bigram_a_dis,
			no_trigram_a_gold, no_trigram_a_dis,
			no_snp_a_q, no_uni_a_q, no_bi_a_q, no_tri_a_q,
			no_snp_a_gold, no_snp_a_dis
		]

		local_frame = pd.DataFrame(all_features).transpose()
		local_frame = local_frame.join(q_pos)
		local_frame = local_frame.join(a_pos)
		local_frame.set_index(["index"], inplace=True)
		if idx == 0:
			full_frame = local_frame
		else:
			full_frame = pd.concat([full_frame, local_frame])
	return full_frame.fillna(0)


def build_bert_embeddings(question, id):
	BERT_MODEL = torch.hub.load('huggingface/pytorch-transformers', 'model', 'bert-base-cased')
	BERT_ENCODER = torch.hub.load('huggingface/pytorch-transformers', 'tokenizer', 'bert-base-cased')

	embedder = BERTEmbedder(use_cuda=False)
	idx_2_question = []
	idx_2_question.append(
		BERT_MODEL(torch.tensor(BERT_ENCODER.encode(BERT_ENCODER.tokenize(question))).unsqueeze(0))[0][0, 0, :])

	for idx, items in enumerate(idx_2_question):
		with open(f"", 'w') as f:
			df = pd.DataFrame(pd.Series(items.cpu().detach().numpy())).transpose()
			df.to_csv(f)
