import json

import numpy as np
import nltk
from nltk.util import ngrams
from nltk import ngrams
import operator
import random
from random import shuffle
import copy

random.seed(42)

def jaccard(list1, list2):
    intersection = len(list(set(list1).intersection(list2)))
    union = (len(list1) + len(list2)) - intersection
    return intersection, float(intersection) / union

def extract_ngrams(data, num):
    n_grams = ngrams(nltk.word_tokenize(data), num)
    return [ ' '.join(grams) for grams in n_grams]

dataset = json.load(open("path_to_data","r")) # TODO : load the dataset accordingly

quote_list = [] 

for i,d in enumerate(dataset):
    quote = d["fields"]["quote"].strip("\n").strip().lower()
    quote_list.append(quote)

randomized_quote_list = copy.deepcopy(quote_list)
assert len(randomized_quote_list) == len(quote_list)
shuffle(randomized_quote_list)

for mode in ["original", "randomized_assignment"]:
    print("mode = ", mode)
    for ngram_size in [1,2, 3,4]:
        print("ngram_size = ", ngram_size)
        
        js_all = 0
        intersection_all = 0
        js_quote_map = {}
        js_list = []
        intersection_list = []

        for i,d in enumerate(dataset):
            quote = d["fields"]["quote"].strip("\n").strip().lower()
            narr = d["fields"]["narrative"].strip("\n").strip().lower()

            # shuffle(quote_set)
            # quote = quote_set[0]

            if mode == "randomized_assignment":
                quote = randomized_quote_list[i]

            words_quote = nltk.word_tokenize(quote)
            words_quote = [word for word in words_quote if word.isalnum()]

            ngrams_quote = ngrams(words_quote, ngram_size)

            words_narr = nltk.word_tokenize(narr)
            words_narr = [word for word in words_narr if word.isalnum()]

            ngrams_narr = ngrams(words_narr, ngram_size)

            ngrams_quote = set(ngrams_quote)
            ngrams_narr = set(ngrams_narr)

            intersection, js = jaccard(ngrams_quote, ngrams_narr)

            js_list.append(js)
            js_all += js
            
            intersection_list.append(intersection)
            intersection_all += intersection
            if quote not in js_quote_map.keys():
                js_quote_map[quote] = []
            js_quote_map[quote].append(js)


        print("Avg js = ", float(js_all)/float(len(dataset)))
        print("Avg intersection = ", float(intersection_all)/float(len(dataset)))

        # if mode = original and ngram = 1 print top few quotes with max ngram overlap
        if mode == "original" and ngram_size == 1:
            for _key in js_quote_map.keys():
                js_quote_map[_key] = np.average(np.array(js_quote_map[_key]))

            js_quote_map_reverse_sort = dict( sorted(js_quote_map.items(), key=operator.itemgetter(1),reverse=True))
            js_quote_map_reverse_sort = list(js_quote_map_reverse_sort.items())
            for _index in range(5):
                print(js_quote_map_reverse_sort[_index])