import pandas as pd
import numpy as np
from scipy import spatial
import matplotlib.pyplot as plt
from sklearn.manifold import TSNE
from flair.embeddings import WordEmbeddings
from flair.data import Sentence
import torch
from scipy import spatial
from sklearn import neighbors
from sklearn.preprocessing import normalize
import nltk
nltk.download('punkt')
from nltk.tokenize import word_tokenize

embeddings_dict = {}

with open('glove.6B.50d.txt', 'r', encoding='utf-8') as f:
    for line in f:
        values = line.split()
        word = values[0]
        vector = np.asarray(values[1:], "float32")
        embeddings_dict[word] = vector

tuple_embeds = []
keys = []
values = []
for key in embeddings_dict.keys():
    keys.append(key)
    values.append(embeddings_dict[key])

all_embeddings = np.asarray(list(values))
word_tree_sci = spatial.KDTree(all_embeddings)
word_tree_sk = neighbors.BallTree(all_embeddings)

neigh = neighbors.NearestNeighbors(n_neighbors=1)#, n_jobs=-1)
neigh.fit(all_embeddings)

def find_nearest_embeddings_fast_sklearn(embedding):
    indices = neigh.kneighbors([embedding], return_distance=False)
    return [keys[index[0]] for index in indices]

glove_embedding = WordEmbeddings('glove50dw2v.gensim')

original = pd.read_csv('train_test_set.csv', encoding='utf-8')

def perturb_embedding(embedding, epsilon):
    v = [np.random.multivariate_normal(np.zeros(50), np.identity(50))]
    v = normalize(v)[0]
    l = np.random.gamma(50, scale=1/epsilon)
    N = v*l

    return embedding + N

def transform_sentence(sentence: str, epsilon: float):
    
    """
    for token in sentence_flair:
        i += 1
        embedding = token.embedding.numpy()
        noisy_embedding = perturb_embedding(embedding, epsilon)
        new_token = find_closest_embeddings_fast_sklearn(noisy_embedding)
        new_sentence += new_token + ' '
    """

    words = word_tokenize(sentence.lower())
    new_sentence = ''
    for i, word in enumerate(words):
        if word in keys:
            emb = embeddings_dict[word]
            perturbed_emb = perturb_embedding(emb, epsilon)
            new_word = find_nearest_embeddings_fast_sklearn(perturbed_emb)
            new_sentence += new_word[0] + ' '
            
        else:
            new_sentence += word + ' '
    
    return new_sentence

    #return new_sentence   
    #return " ".join(new_sentence)

reviews = original['review'].to_list()

for i in range(10000):
    if i % 50 == 0:
        print(i)
    new_review = transform_sentence(reviews[i], 10)
    with open(f'./reviews_epsilon10_d50_imdb_nozeros/{i}.txt', 'w', encoding='utf-8') as f:
        f.write(new_review)
