import numpy as np
import pickle
import codecs
import pandas as pd
from collections import defaultdict

import numpy as np
import pickle
import codecs
from collections import defaultdict
from scipy.stats import spearmanr
from sklearn import preprocessing
import pickle
from collections import defaultdict
final_dict=defaultdict(list)
from gensim.models import Word2Vec
import gensim.downloader
from gensim.models import FastText
from sklearn.metrics.pairwise import cosine_similarity
from sklearn.cluster import AgglomerativeClustering, KMeans
model =  Word2Vec.load("../word2vec_billion.model")
model_fast= FastText.load("../billion_fasttext.model")
from glove import Corpus, Glove
model_glove= Glove.load("../billion_glove_updated.model")
seed=42

data= pd.read_csv('ap.csv')
category= data.category.tolist()
words= data.word.tolist()

categories=np.array(category)
print("cat", len(set(categories)))
with open('category_ap.pkl', 'rb') as f:
    labels = pickle.load(f)

with open('target_ap_0.pkl', 'rb') as f:
    total_word = pickle.load(f)

with open('tm_weights_ap_0.pkl', 'rb') as f:
    weights = pickle.load(f)


total_weights_TM=[]
total_weights_w2v=[]
total_weights_fast=[]
total_weights_glove=[]
for word in words:
    if word in total_word:
        total_weights_TM.append(weights[total_word.index(word)])
        total_weights_w2v.append(model.wv.get_vector(word))
        total_weights_fast.append(model_fast.wv.get_vector(word))
        total_weights_glove.append(model_glove.word_vectors[model_glove.dictionary[word]])
total_weights_TM= np.array(total_weights_TM)
total_weights_w2v= np.array(total_weights_w2v)
total_weights_fast= np.array(total_weights_fast)
total_weights_glove= np.array(total_weights_glove)

def calculate_purity(y_true, y_pred):
    """
    Calculate purity for given true and predicted cluster labels.
    Parameters
    ----------
    y_true: array, shape: (n_samples, 1)
      True cluster labels
    y_pred: array, shape: (n_samples, 1)
      Cluster assingment.
    Returns
    -------
    purity: float
      Calculated purity.
    """
    assert len(y_true) == len(y_pred)
    true_clusters = np.zeros(shape=(len(set(y_true)), len(y_true)))
    pred_clusters = np.zeros_like(true_clusters)
    for id, cl in enumerate(set(y_true)):
        true_clusters[id] = (y_true == cl).astype("int")
    for id, cl in enumerate(set(y_pred)):
        pred_clusters[id] = (y_pred == cl).astype("int")

    M = pred_clusters.dot(true_clusters.T)
    return 1. / len(y_true) * np.sum(np.max(M, axis=1))

def purity(total_weights, categories):
    ids = np.random.RandomState(seed).choice(range(len(total_weights)), len(total_weights), replace=False)
    best_purity = 0
    #print("weights", weights[ids])
    print("cat", len(set(categories[ids])))
    #AgglomerativeClustering(n_clusters=len(set(categories)), affinity="euclidean",linkage="ward").fit_predict(total_weights[ids])

    best_purity = calculate_purity(categories[ids], AgglomerativeClustering(n_clusters=9, affinity="euclidean", linkage="ward").fit_predict(total_weights[ids]))

    for affinity in ["cosine", "euclidean"]:
        for linkage in ["average", "complete"]:
            purity = calculate_purity(categories[ids], AgglomerativeClustering(n_clusters=9,
                                                                      affinity=affinity,
                                                                      linkage=linkage).fit_predict(total_weights[ids]))
            best_purity = max(best_purity, purity)

    purity = calculate_purity(categories[ids], KMeans(random_state=seed, n_init=10, n_clusters=9).
                              fit_predict(total_weights[ids]))
    best_purity = max(purity, best_purity)
    return best_purity

print("purity TM", purity(total_weights_TM, categories))
print("purity w2v", purity(total_weights_w2v, categories))
print("purity fast", purity(total_weights_fast, categories))
print("purity glove", purity(total_weights_glove, categories))
