import tensorflow as tf
import numpy as np
import EAttHBiLSTM_FK
import data_handle.DataProcessing_test
import os

def single_test(wordModel, modelfile, numClasses):
    learning_rate, embedding_init, embedding_placeholder, input_data, emoji_embedding_init, emoji_embedding_placeholder, emoji_data, labels, input_keep_prob, output_keep_prob, prediction = EAttHBiLSTM_FK.EAttHBiLSTMFrame(wordModel, numClasses)

    os.environ["CUDA_VISIBLE_DEVICES"] = "0"
    config = tf.ConfigProto()
    config.gpu_options.per_process_gpu_memory_fraction = 0.4
    sess = tf.InteractiveSession(config=config)
    sess.run(tf.global_variables_initializer())

    saver = tf.train.Saver()
    saver.restore(sess, modelfile)

    ids, inputs, emojis = wordModel.getTestData()
    prediction_ = sess.run(prediction, {input_data: inputs, emoji_data: emojis, input_keep_prob: 1.0, output_keep_prob: 1.0})
    return ids, prediction_

def EAttHBiLSTM_test(wordModel, numClasses):
    id1, prediction1 = single_test(wordModel, "../trained_model/eatthbilstm-1.ckpt-31", numClasses)
    print("The 1 model is finished")
    id2, prediction2 = single_test(wordModel, "../trained_model/eatthbilstm-2.ckpt-22", numClasses)
    print("The 2 model is finished")
    id3, prediction3 = single_test(wordModel, "../trained_model/eatthbilstm-3.ckpt-32", numClasses)
    print("The 3 model is finished")
    id4, prediction4 = single_test(wordModel, "../trained_model/eatthbilstm-4.ckpt-26", numClasses)
    print("The 4 model is finished")
    id5, prediction5 = single_test(wordModel, "../trained_model/eatthbilstm-5.ckpt-25", numClasses)
    print("The 5 model is finished")
    prediction = np.zeros((len(id1), numClasses))
    for i in range(len(id1)):
        if id1[i] == id2[i] and id1[i] == id3[i] and id1[i] == id4[i] and id1[i] == id5[i]:
            prediction[i][0] = prediction1[i][0] + prediction2[i][0] + prediction3[i][0] + prediction4[i][0] + prediction5[i][0]
            prediction[i][1] = prediction1[i][1] + prediction2[i][1] + prediction3[i][1] + prediction4[i][1] + prediction5[i][1]
        else:
            print("id error")
            exit()
    return id1, prediction

def saveToFile(id, prediction):
    prediction_class = np.argmax(prediction, 1)
    with open("../trained_model/Result.txt", "a", encoding="utf-8") as f:
        for i in range(len(prediction_class)):
            if prediction_class[i] == 1:
                f.write(id[i] + ",NOT\n")
            elif prediction_class[i] == 0:
                f.write(id[i] + ",OFF\n")
        f.close()

wordModel = data_handle.DataProcessing_test.DataProcessing(
        'D:/SemEval2019-Task6/workspace/wordModel/glove.twitter.27B.200d.txt',
        'D:/SemEval2019-Task6/workspace/data/test/testset-taska.tsv',
        'D:/SemEval2019-Task6/workspace/wordModel/emoji2vec200.txt',
        200, eab=2)
numClasses = 2
id, prediction = EAttHBiLSTM_test(wordModel, numClasses)
saveToFile(id, prediction)
