# -*- coding: utf-8 -*-

import os
import glob
import random
random.seed(42)
from time import time
from sklearn.feature_extraction.text import CountVectorizer
from sklearn.model_selection import train_test_split
import pickle
import pandas as pd
import numpy as np
from collections import defaultdict
from tmu.tsetlin_machine import TMAutoEncoder
from sklearn.metrics.pairwise import cosine_similarity
target_word_weight=defaultdict(list)
target_similarity=defaultdict(list)
import re
import nltk
from nltk.corpus import brown
nltk.download('brown')
nltk.download('punkt')

with open('word1_mturk.pkl', 'rb') as f:
    word1 = pickle.load(f)
with open('word2_mturk.pkl', 'rb') as f:
    word2 = pickle.load(f)

num=0
word_total= list(set(word1 + word2))

clause_weight_threshold = 0

number_of_examples = 2000
accumulation = 25

type_i_ii_ratio = 1.0

clause_drop_p = 0.0

factor = 30
clauses = int(factor*20/(1.0 - clause_drop_p))
T = factor*40
s = 5.0

print("Loading Vectorizer")
f_vectorizer_X = open("/home/bimalb/copied_from_cair-gpu05/bimalb/Tsetlin/word_profile/vectorizer_X.pickle", "rb")
vectorizer_X = pickle.load(f_vectorizer_X)
f_vectorizer_X.close()

# f_X = open("X.pickle", "wb")
# pickle.dump(X, f_X, protocol=4)
# f_X.close()

print("Loading Data")
f_X = open("/home/bimalb/copied_from_cair-gpu05/bimalb/Tsetlin/word_profile/X.pickle", "rb")
X_csr = pickle.load(f_X)
f_X.close()

X_train = X_csr

feature_names = vectorizer_X.get_feature_names_out()

number_of_features = vectorizer_X.get_feature_names_out().shape[0]

target_words=[]
for i in word_total:
    if i in vectorizer_X.vocabulary_:
        target_words.append(i)

print("target word length", len(target_words))
output= open('./billion_mturk/target_mturk_' + str(num) + '.pkl', "wb")
pickle.dump(target_words, output)
output.close()


print(len(target_words))
print("feature name", feature_names)

output_active = np.empty(len(target_words), dtype=np.uint32)
for i in range(len(target_words)):
	target_word = target_words[i]
	target_id = vectorizer_X.vocabulary_[target_word]
	output_active[i] = target_id

tm = TMAutoEncoder(clauses, T, s, output_active, max_included_literals=3, type_i_ii_ratio=type_i_ii_ratio, accumulation=accumulation, feature_negation=False, clause_drop_p = clause_drop_p, platform='CPU', output_balancing=True)


print("\nAccuracy Over 40 Epochs:")
for e in range(40):
    start_training = time()
    tm.fit(X_train, number_of_examples=number_of_examples)
    stop_training = time()
    total_time= stop_training-start_training
    print("\nEpoch #%d\n" % (e+1))
    print(f'epoch per time: {total_time}')

profile = np.empty((len(target_words), clauses))
for i in range(len(target_words)):
    weights = tm.get_weights(i)
    profile[i,:] = np.where(weights >= clause_weight_threshold, weights, 0)

output= open('./billion_mturk/tm_weights_mturk_' + str(num) + '.pkl', "wb")
pickle.dump(profile, output)
output.close()

similarity = cosine_similarity(profile)
print("\nWord Similarity\n")

for i in range(len(target_words)):
    print(target_words[i], end=': ')
    sorted_index = np.argsort(-1*similarity[i,:])
    for j in range(1, len(target_words)):
        print("%s(%.2f) " % (target_words[sorted_index[j]], similarity[i,sorted_index[j]]), end=' ')
        target_similarity[(target_words[i], target_words[sorted_index[j]])]  = similarity[i,sorted_index[j]]
    print()

output= open('./billion_mturk/profile_dict_mturk_' + str(num) + '.pkl', "wb")
pickle.dump(target_similarity, output)
output.close()

