from sys import meta_path, exit
import numpy as np
import os
import math
import csv
from sklearn.metrics import f1_score
from sklearn.neighbors import KNeighborsClassifier
import sklearn
import argparse

def get_demb(mat_file):
    f = open(mat_file, 'r')
    tmp = f.readlines()
    contents = tmp[1:]
    dimension = [int(x) for x in tmp[0].split(' ')]
    doc_emb = np.zeros((dimension[0], dimension[2], dimension[1]))
    print(doc_emb.shape)
    for i, content in enumerate(contents):
        content = content.strip()
        tokens = content.split(' ')
        vec = tokens[1:]
        vec = np.array([float(ele) for ele in vec])
        mat = np.reshape(vec,(dimension[2], dimension[1]))
        #mat = mat.T
        doc_emb[i] = mat #np.reshape(sym_mat,-1)
    return doc_emb, dimension[1], dimension[2]

def get_wemb(vec_file):
    f = open(vec_file, 'r', errors='ignore')
    tmp = f.readlines()
    dimension = [int(x) for x in tmp[0].split(' ')]
    contents = tmp[1:]
    word_emb = {}
    vocabulary = {}
    vocabulary_inv = {}
    W = np.zeros((dimension[0], dimension[1]))
    for i, content in enumerate(contents):
        content = content.strip()
        tokens = content.split(' ')
        word = tokens[0]
        vec = tokens[1:]
        vec = np.array([float(ele) for ele in vec])
        W[i] = np.array(vec)
        vocabulary[word] = i
        vocabulary_inv[i] = word

    return W, vocabulary, vocabulary_inv

def print_topics(demb, W, vocab, ivocab):
    for i in range(demb.shape[0]):
        vec_result = demb[i]
        print(vec_result)
        vec_norm = np.zeros(vec_result.shape)
        d = (np.sum(vec_result ** 2,) ** (0.5))
        vec_norm = (vec_result.T / d).T

        dist = np.dot(W, vec_norm.T)

        #for term in input_term.split(' '):
        #    index = vocab[term]
        #     dist[index] = -np.Inf

        a = np.argsort(-dist)[:20]

        print("\n                               Word       Cosine distance\n")
        print("---------------------------------------------------------\n")
        for x in a:
            print("%35s\t\t%f\n" % (ivocab[x], dist[x]))

def get_corpus(corpus_file):
    f = open(corpus_file, 'r', errors='ignore')
    tmp = f.readlines()
    print(len(tmp))
    return tmp    
if __name__ == "__main__":

    parser = argparse.ArgumentParser(description='main',
                                     formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    parser.add_argument('--dataset', default='20news')
    parser.add_argument('--word_file', default='jose.txt')
    parser.add_argument('--doc_file', default='doc.txt')
    parser.add_argument('--corpus', default='text.txt')

    args = parser.parse_args()
    print(args)
    corpus = get_corpus(os.path.join('datasets', args.dataset, args.corpus))
    W, vocab, ivocab = get_wemb(os.path.join('datasets', args.dataset, args.word_file))
    doc_emb, rows, cols = get_demb(os.path.join("datasets", args.dataset, args.doc_file))
    k = 0
    for i in range(doc_emb.shape[0]):
        mat = doc_emb[i] 
        comb = mat@mat.T
        for index, elem in np.ndenumerate(comb):
            n_i = np.linalg.norm(mat[index[0]])
            n_j = np.linalg.norm(mat[index[1]])
            comb[index[0]][index[1]] = elem/(n_i*n_j)
        test = np.sum(comb)/(cols*cols)
        if test < 0.8:
            print(i, corpus[i])
            print(mat)
            print_topics(mat, W, vocab, ivocab)
            k += 1 
        if k >= 1:
           break 
