import tensorflow as tf
import model.HBiLSTM
import model.Attention
import model.FullyConnectLayer_attbilstm
import os

def EAttHBiLSTMFrame(wordModel, numClasses):
    lstmUnitNum = 150
    attn_length = -1
    attention_size = 256
    layer_num = 1
    forget_bias = 0.5

    tf.reset_default_graph()
    input_data = tf.placeholder(tf.int32, [None, None])
    emoji_data = tf.placeholder(tf.int32, [None, None])
    labels = tf.placeholder(tf.float32, [None, numClasses])
    input_keep_prob = tf.placeholder(tf.float32)
    output_keep_prob = tf.placeholder(tf.float32)
    learning_rate = tf.placeholder(tf.float32)

    with tf.device("/cpu:0"):
        W_input = tf.Variable(tf.constant(0.0, shape=[wordModel.vocab_size - 1, wordModel.numDimensions]), trainable=False)
        embedding_placeholder = tf.placeholder(tf.float32, [wordModel.vocab_size - 1, wordModel.numDimensions])
        embedding_init = W_input.assign(embedding_placeholder)
        W0 = tf.constant(0.0, shape=[1, wordModel.numDimensions])
        W_input = tf.concat([W0, W_input], axis=0)
        embedding_input = tf.nn.embedding_lookup(W_input, input_data)

        W_emoji = tf.Variable(tf.constant(0.0, shape=[wordModel.emoji_vocab_size - 1, wordModel.emoji_numDimensions]), trainable=False)
        emoji_embedding_placeholder = tf.placeholder(tf.float32, [wordModel.emoji_vocab_size - 1,  wordModel.emoji_numDimensions])
        emoji_embedding_init = W_emoji.assign(emoji_embedding_placeholder)
        W_emoji = tf.concat([W0, W_emoji], axis=0)
        emoji_embedding_input = tf.nn.embedding_lookup(W_emoji, emoji_data)

    output_word = model.HBiLSTM.model_HBiLSTM(embedding_input, lstmUnitNum, layer_num, forget_bias, input_keep_prob, output_keep_prob, attn_length)
    attout_word, _ = model.Attention.attention(output_word, attention_size)
    emoji_embedding_input = tf.transpose(emoji_embedding_input, [1, 0, 2])
    attout_emoji, _ = model.Attention.attention(emoji_embedding_input, attention_size)
    output = tf.concat([attout_word, attout_emoji], axis=1)
    prediction = model.FullyConnectLayer_attbilstm.fullyConnectLayer(output, output.shape[1].value, numClasses)
    return learning_rate, embedding_init, embedding_placeholder, input_data, emoji_embedding_init, emoji_embedding_placeholder, emoji_data, labels, input_keep_prob, output_keep_prob, prediction

def EAttHBiLSTMRun(wordModel, numClasses, input_keep_prob_, output_keep_prob_, cross_deviation, iterations, cross_multiple=5):
    learning_rate, embedding_init, embedding_placeholder, input_data, emoji_embedding_init, emoji_embedding_placeholder, emoji_data, labels, input_keep_prob, output_keep_prob, prediction = EAttHBiLSTMFrame(wordModel, numClasses)
    tv = tf.trainable_variables()
    regularization_cost = 0.00001 * tf.reduce_sum([tf.nn.l2_loss(v) for v in tv])
    loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=prediction, labels=labels)) + regularization_cost
    optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate).minimize(loss)

    prediction_class = tf.argmax(prediction, 1)

    os.environ["CUDA_VISIBLE_DEVICES"] = "0"
    config = tf.ConfigProto()
    config.gpu_options.per_process_gpu_memory_fraction = 0.4
    sess = tf.InteractiveSession(config=config)
    sess.run(tf.global_variables_initializer())

    sess.run(embedding_init, feed_dict={embedding_placeholder: wordModel.embedding[1:]})
    sess.run(emoji_embedding_init, feed_dict={emoji_embedding_placeholder: wordModel.emoji_embedding[1:]})

    saver = tf.train.Saver(max_to_keep=1)

    max_macro_f1 = 0.
    learning_rate_ = 1e-4
    for i in range(iterations):
        nextBatchInputs, nextBatchEmojis, nextBatchLabels = wordModel.getTrainBatch_eab(verify=False, cross_deviation=cross_deviation)
        sess.run(optimizer, {input_data: nextBatchInputs, emoji_data: nextBatchEmojis, labels: nextBatchLabels, input_keep_prob: input_keep_prob_, output_keep_prob: output_keep_prob_, learning_rate: learning_rate_})
        if (i+1) % 10 == 0:
            learning_rate_ = learning_rate_ * 1.0
            Tnot, Fnot, Toff, Foff = 0., 0., 0., 0.
            nextBatchInputs, nextBatchEmojis, nextBatchLabels = wordModel.getTrainBatch_eab(verify=True, cross_deviation=cross_deviation)
            pc = sess.run(prediction_class, {input_data: nextBatchInputs, emoji_data: nextBatchEmojis, labels: nextBatchLabels, input_keep_prob: 1.0, output_keep_prob: 1.0})
            for k in range(len(pc)):
                if nextBatchLabels[k][0] == 0 and pc[k] == 1:
                    Tnot += 1
                elif nextBatchLabels[k][0] == 0 and pc[k] == 0:
                    Fnot += 1
                elif nextBatchLabels[k][0] == 1 and pc[k] == 0:
                    Toff += 1
                elif nextBatchLabels[k][0] == 1 and pc[k] == 1:
                    Foff += 1
            macro_precision_not = Tnot / (Tnot + Fnot)
            macro_precision_off = Toff / (Toff + Foff)
            macro_recall_not = Tnot / (Tnot + Foff)
            macro_recall_off = Toff / (Toff + Fnot)
            macro_f1_not = (2 * macro_precision_not * macro_recall_not) / (macro_precision_not + macro_recall_not)
            macro_f1_off = (2 * macro_precision_off * macro_recall_off) / (macro_precision_off + macro_recall_off)
            mean_macro_accuracy = (Tnot + Toff) / (Tnot + Fnot + Toff + Foff)
            mean_macro_precision = (macro_precision_not + macro_precision_off) / 2
            mean_macro_recall = (macro_recall_not + macro_recall_off) / 2
            mean_macro_f1 = (macro_f1_not + macro_f1_off) / 2

            print("deviation {}/{}   ".format(cross_deviation + 1, cross_multiple),
                  "iterations {}/{}   ".format(i + 1, iterations),
                  "F1 {}   ".format(mean_macro_f1),
                  "accuracy {}   ".format(mean_macro_accuracy),
                  "precision {}   ".format(mean_macro_precision),
                  "recall {}   ".format(mean_macro_recall))
            with open("../trained_model/log_eatthbilstm.txt", "a", encoding="utf-8") as f:
                f.write("deviation " + str(cross_deviation + 1) + "/" + str(cross_multiple) +
                        "   iterations " + str(i + 1) + "/" + str(iterations) +
                        "   F1 " + str(mean_macro_f1) +
                        "   accuracy " + str(mean_macro_accuracy) +
                        "   precision " + str(mean_macro_precision) +
                        "   recall " + str(mean_macro_recall) + "\r\n")
                f.close()

            if mean_macro_f1 > max_macro_f1:
                print("update")
                with open("../trained_model/log_eatthbilstm.txt", "a", encoding="utf-8") as f:
                    f.write("update\r\n")
                    f.close()
                # saver.save(sess, "../trained_model/eatthbilstm-" + str(cross_deviation + 1) + ".ckpt", global_step=i+1)
                max_macro_f1 = mean_macro_f1
