
import os
os.environ["CUDA_VISIBLE_DEVICES"]="2"
import random
import numpy as np
import pickle as pkl
import scipy.sparse as sp
import sys
from tqdm import tqdm
import pickle
import string
#from PyTsetlinMachineCUDA.tm import MultiClassTsetlinMachine
from sklearn.metrics import f1_score
import numpy as np
np.random.seed(400) 

path_train= 'train.txt'
path_test='test.txt'

f = open(path_train, "r", encoding="utf-8")
lines = f.readlines()
f.close()

doc_name_list_train = []
doc_content_list_train = []
for line in lines:
    line = line.strip()
    label = line[:line.find('\t')]
    content = line[line.find('\t') + 1:]
    #string = str(doc_id) + '\t' + 'train' + '\t' + label
    doc_name_list_train.append(label)
    doc_content_list_train.append(content)

    
f = open(path_test, "r", encoding="utf-8")
lines = f.readlines()
f.close()
doc_name_list_test = []
doc_content_list_test = []
for line in lines:
    line = line.strip()
    label = line[:line.find('\t')]
    content = line[line.find('\t') + 1:]
    #string = str(doc_id) + '\t' + 'test' + '\t' + label
    doc_name_list_test.append(label)
    doc_content_list_test.append(content)
    
from sklearn import preprocessing
label_encoder = preprocessing.LabelEncoder()

label_encoder = preprocessing.LabelEncoder()
label= label_encoder.fit_transform(doc_name_list_train+doc_name_list_test)
from tensorflow.keras.utils import to_categorical
label = to_categorical(label, num_classes=8)
y_train = np.array(label[:5485])
y_val= np.array(label[5485:])
import re
def clean_str(string):
    """
    Tokenization/string cleaning for all datasets except for SST.
    Original taken from https://github.com/yoonkim/CNN_sentence/blob/master/process_data.py
    """
    string = re.sub(r"[^A-Za-z0-9(),!?\'\`]", " ", string)
    string = re.sub(r"\'s", " \'s", string)
    string = re.sub(r"\'ve", " \'ve", string)
    string = re.sub(r"n\'t", " n\'t", string)
    string = re.sub(r"\'re", " \'re", string)
    string = re.sub(r"\'d", " \'d", string)
    string = re.sub(r"\'ll", " \'ll", string)
    string = re.sub(r",", " , ", string)
    string = re.sub(r"!", " ! ", string)
    string = re.sub(r"\(", " \( ", string)
    string = re.sub(r"\)", " \) ", string)
    string = re.sub(r"\?", " \? ", string)
    string = re.sub(r"\s{2,}", " ", string)
    return string.strip().lower()

import string
punct = string.punctuation
import nltk
from nltk.stem import WordNetLemmatizer
lemmatizer = WordNetLemmatizer()
from nltk.corpus import stopwords
nltk.download('stopwords')
stopwords = set(stopwords.words('english'))
total_input= doc_content_list_train + doc_content_list_test


word_freq = {}  # to remove rare words
least_freq = 5
for doc_content in total_input:
    temp = clean_str(doc_content)
    words = temp.split()
    for word in words:
        if word in word_freq:
            word_freq[word] += 1
        else:
            word_freq[word] = 1

clean_docs = []
for doc_content in total_input:
    temp = clean_str(doc_content)
    words = temp.split()
    doc_words = []
    for word in words:
        if word not in stopwords and word_freq[word] >= least_freq:
            doc_words.append(word)
    doc_str = ' '.join(doc_words).strip()
    clean_docs.append(doc_str)

from keras.preprocessing import text, sequence
max_features = 5000
maxlen = 300
embed_size =700
tokenizer1 = text.Tokenizer(num_words=max_features)
tokenizer1.fit_on_texts(clean_docs)
list_tokenized_train = tokenizer1.texts_to_sequences(clean_docs)
X_total = sequence.pad_sequences(list_tokenized_train, maxlen=maxlen, padding='post')

X_train= X_total[:5485]
X_val= X_total[5485:]

word_index = tokenizer1.word_index

with open('TM_embedding.pkl', 'rb') as f:
    embedding_weights = pickle.load(f)

embedding_vector_new={}
for word in embedding_weights.keys():
    vector= embedding_weights.get(word)
    new_vector= np.where(vector > 1, 1, 0)
    embedding_vector_new[word]= new_vector

def make_glovevec_tm(embeddings_index, max_features, embed_size, word_index, glovepath):
    embeddings_glove = {}
    f = open(glovepath)
    for line in f:
        values = line.split()
        word = ' '.join(values[:-300])
        coefs = np.asarray(values[-300:], dtype='float32')
        embeddings_glove[word] = coefs.reshape(-1)
    f.close()
    nb_words = min(max_features, len(word_index))
    embedding_matrix = np.zeros((nb_words, embed_size))
    for word, i in word_index.items():
        if word in embeddings_index.keys():
            if i >= max_features:
                continue
            embedding_vector = embeddings_index.get(word)
            if embedding_vector is not None:
                embedding_matrix[i] = embedding_vector

        else:
            if i >= max_features:
                continue
            embedding_vector = embeddings_glove.get(word)
            if embedding_vector is not None:
                #embedding_matrix[i] = np.concatenate((embedding_vector, np.random.choice([0,1], size=(400)).astype(np.float32)), axis=0)
                embedding_matrix[i] = np.concatenate((embedding_vector, np.random.rand(400).astype(np.float32)), axis=0)
    return embedding_matrix

#embedding_vector = make_glovevec("glove.6B.300d.txt", max_features, embed_size, word_index, embedding_weights)
embedding_vector = make_glovevec_tm(embedding_vector_new, max_features, embed_size, word_index, "glove.6B.300d.txt")
from sklearn.feature_extraction.text import CountVectorizer, TfidfVectorizer
from sklearn.metrics import roc_auc_score
from sklearn.metrics.pairwise import cosine_similarity
from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import KFold, cross_val_score, train_test_split
from sklearn.metrics import classification_report, accuracy_score
from tensorflow.keras.utils import to_categorical

from keras.layers import Layer, Input, Embedding, Bidirectional, LSTM, Concatenate, Dense
import keras.backend as K
from keras.layers import *
from keras import regularizers
from keras import initializers
from keras.models import Model
from tensorflow.keras.optimizers import Adam
K.clear_session()
class Attention(Layer):
    def __init__(self,**kwargs):
        self.supports_masking = True
        self.attention_dim = 100
        super(Attention,self).__init__(**kwargs)

    def build(self,input_shape):
        self.W=self.add_weight(name="att_weight",shape=(input_shape[-1],1),initializer="normal", trainable = True)
        self.b=self.add_weight(name="att_bias",shape=(self.attention_dim,),initializer="normal", trainable = True)
        self.u=self.add_weight(name="u_bias",shape=(self.attention_dim,1),initializer="normal", trainable = True)        
        super(Attention, self).build(input_shape)

    def compute_mask(self, inputs, mask=None):
        return mask


    def call(self,x, mask=None):
        et= K.tanh(K.dot(x,self.W)+self.b)
        ait = K.dot(et, self.u)
        ait = K.squeeze(ait, -1)		
        ait = K.exp(ait)

        if mask is not None:
            ait *= K.cast(mask, K.floatx())
            
        ait /= K.cast(K.sum(ait, axis=1, keepdims=True) + K.epsilon(), K.floatx())
        ait = K.expand_dims(ait)
        weighted_input = x * ait
        output = K.sum(weighted_input, axis=1)
        return output

    def compute_output_shape(self,input_shape):
        return (input_shape[0],input_shape[-1])

    def get_config(self):
        return super(Attention,self).get_config()


def build_model(VOCABULARY_SIZE, MAX_SENTENCE_LENGTH, embedding_weights, n_classes=8, embedding_dim=700):
    l2_reg = regularizers.l2(0.001)
    sentence_in = Input(shape=(MAX_SENTENCE_LENGTH,), name="input_1")
    embedding_trainable = False
    embedded_word_seq = Embedding(VOCABULARY_SIZE,embedding_dim,input_length=MAX_SENTENCE_LENGTH,weights=[embedding_weights],trainable=False)(sentence_in)
    word_encoder = Bidirectional(LSTM(512,return_sequences=True, dropout=0.25))(embedded_word_seq)
    #dense_transform_word = Dense(64, activation='relu', name='dense_transform_word', kernel_regularizer=l2_reg)(word_encoder)
    attention_weighted_text = Attention(name="sentence_attention")(word_encoder)
    #dense_transform_word = Dense(128, activation='relu', name='dense_transform_word', kernel_regularizer=l2_reg)(attention_weighted_text)
    prediction = Dense(n_classes, activation='softmax')(attention_weighted_text)
    model = Model(sentence_in, prediction)
    optimizer=Adam(lr=0.001)
    model.compile(optimizer=optimizer,loss='categorical_crossentropy',metrics=['accuracy'])
    model.summary()
    return model

model = build_model(max_features, maxlen, embedding_vector, 8, embed_size)

from keras.callbacks import EarlyStopping, ModelCheckpoint, CSVLogger
filepath="weights_r8_tm.best.hdf5"
checkpoint = ModelCheckpoint(filepath, monitor='val_accuracy', verbose=1, save_best_only=True, mode='max')
model.fit(X_train, y_train, validation_data=(X_val, y_val), batch_size=64, epochs=50,callbacks=[checkpoint], verbose=1)
csv_logger = CSVLogger("r8_log.csv", append=True)

from sklearn.metrics import classification_report
pred = model.predict(X_val,  verbose=1)
predicted = np.argmax(pred, axis=1)
report = classification_report(np.argmax(y_val, axis=1), predicted)
print(report)