import json
import numpy as np
import gensim
from copy import deepcopy
import operator

# does retrieval using word2vec

word2vec = gensim.models.KeyedVectors.load_word2vec_format('path_to_word2vec', binary=True) # TODO: download word2vec model

dataset = json.load(open("path_to_data", "r")) # TODO: load dataset by giving proper path
all_quotes = set()


for d in dataset:
    q = str(d["fields"]["quote"]).strip("\n").strip().lower()
    all_quotes.add(q)

all_quotes = list(all_quotes)

print(len(all_quotes))

label_map = {}
count_map = {}

quote_vec = np.zeros((250, 300))

for q in all_quotes:
    if q not in label_map.keys():
        label_map[q.lower()] = len(label_map)
reverse_label_map = {int(v):str(k) for k,v in label_map.items()}


for q in label_map.keys():
    ind = label_map[q]
    q_list = q.split()
    q_vec = []
    for _q in q_list:
        try:
            q_vec.append(word2vec[_q])
        except:
            pass
    
    q_vec = np.mean(np.array(q_vec), axis=0)
    quote_vec[ind] = q_vec


matched = 0
reciprocal_rank = 0

for d in dataset:
    q = str(d["fields"]["quote"]).strip("\n").strip().lower()
    n = str(d["fields"]["narrative"]).strip("\n").strip().lower()
    
    correct_quote_index = label_map[q]
    n_vec = []
    n_list = n.split()
    for _n in n_list:
        try:
            n_vec.append(word2vec[_n])
        except:
            pass
    
    # print(n_vec)
    n_vec = np.array(n_vec)
    n_vec = np.mean(n_vec, axis=0)
    n_vec = np.resize(n_vec, (300,1))
    score = np.matmul(quote_vec, n_vec)
    
    score = np.reshape(score, (250,))

    pred_label = np.argmax(score)
    if pred_label == correct_quote_index:
        matched += 1
    correct_quote_score = score[correct_quote_index]
    
    reverse_score_rank = np.argsort(score)[::-1]

    for i in range(len(reverse_score_rank)):
        if int(reverse_score_rank[i]) == int(correct_quote_index):
            rank = i+1
            break

    reciprocal_rank += float(1)/rank

    


print("Acc = ", float(matched)/ float(len(dataset)) )
print("MRR = ", float(reciprocal_rank)/ float(len(dataset)) )
