import tensorflow as tf
import numpy as np
import model.BiLSTM
import model.Attention
import model.FullyConnectLayer_attbilstm
import os

def AttBiLSTMFrame(wordModel, numClasses):
    lstmUnitNum = 60
    attn_length = -1
    attention_size = 100
    layer_num = 1
    forget_bias = 0.5

    tf.reset_default_graph()
    input_data = tf.placeholder(tf.int32, [None, None])
    labels = tf.placeholder(tf.float32, [None, numClasses])
    input_keep_prob = tf.placeholder(tf.float32)
    output_keep_prob = tf.placeholder(tf.float32)

    with tf.device("/cpu:0"):
        W = tf.Variable(tf.constant(0.0, shape=[wordModel.vocab_size, wordModel.numDimensions]), trainable=False, name="W")
        embedding_placeholder = tf.placeholder(tf.float32, [wordModel.vocab_size, wordModel.numDimensions])
        embedding_init = W.assign(embedding_placeholder)
        embedding_input = tf.nn.embedding_lookup(W, input_data)

    value = model.BiLSTM.model_BiLSTM_stack(embedding_input, lstmUnitNum, layer_num, forget_bias, input_keep_prob, output_keep_prob, attn_length)
    attout, _ = model.Attention.attention(value, attention_size)
    prediction = model.FullyConnectLayer_attbilstm.fullyConnectLayer(attout, attout.shape[1].value, numClasses)
    return embedding_init, embedding_placeholder, input_data, labels, input_keep_prob, output_keep_prob, prediction

def AttBiLSTMRun(wordModel, numClasses, input_keep_prob_, output_keep_prob_, cross_deviation, iterations, cross_multiple=5):
    embedding_init, embedding_placeholder, input_data, labels, input_keep_prob, output_keep_prob, prediction = AttBiLSTMFrame(wordModel, numClasses)
    loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=prediction, labels=labels))

    learning_rate = 1e-4
    # global_step = tf.Variable(0)
    # learning_rate = tf.train.exponential_decay(0.1, global_step, 100, 0.98, staircase=True)
    optimizer = tf.train.AdamOptimizer(learning_rate).minimize(loss)
    # optimizer = tf.train.GradientDescentOptimizer(learning_rate=learning_rate).minimize(loss, global_step=global_step)

    os.environ["CUDA_VISIBLE_DEVICES"] = "2"
    config = tf.ConfigProto()
    config.gpu_options.per_process_gpu_memory_fraction = 0.3
    sess = tf.InteractiveSession(config=config)
    sess.run(tf.global_variables_initializer())

    sess.run(embedding_init, feed_dict={embedding_placeholder: wordModel.embedding})

    saver = tf.train.Saver(max_to_keep=1)

    max_macro_f1 = 0.
    bestret = 0
    for i in range(iterations):
        nextBatchInputs, nextBatchLabels = wordModel.getTrainBatch2(verify=False, cross_deviation=cross_deviation)
        sess.run(optimizer, {input_data: nextBatchInputs, labels: nextBatchLabels, input_keep_prob: input_keep_prob_, output_keep_prob: output_keep_prob_})
        if i % 10 == 0:
            Tnot, Fnot, Toff, Foff = 0., 0., 0., 0.
            testBatchInputs, testBatchLabels = wordModel.getTrainBatch2(verify=True, cross_deviation=cross_deviation)
            retpc = sess.run(prediction, {input_data: testBatchInputs, labels: testBatchLabels, input_keep_prob: 1.0, output_keep_prob: 1.0})
            pc = np.argmax(retpc, 1)
            for k in range(len(pc)):
                if testBatchLabels[k][0] == 0 and pc[k] == 1:
                    Tnot += 1
                elif testBatchLabels[k][0] == 0 and pc[k] == 0:
                    Fnot += 1
                elif testBatchLabels[k][0] == 1 and pc[k] == 0:
                    Toff += 1
                elif testBatchLabels[k][0] == 1 and pc[k] == 1:
                    Foff += 1

            if Tnot == 0 or Fnot == 0 or Toff == 0 or Foff == 0:
                print("Tnot:" + str(Tnot) + "  Fnot:" + str(Fnot) + "  Toff:" + str(Toff) + "   Foff:" + str(Foff))
            else:
                macro_precision_not = Tnot / (Tnot + Fnot)
                macro_precision_off = Toff / (Toff + Foff)
                macro_recall_not = Tnot / (Tnot + Foff)
                macro_recall_off = Toff / (Toff + Fnot)
                macro_f1_not = (2 * macro_precision_not * macro_recall_not) / (macro_precision_not + macro_recall_not)
                macro_f1_off = (2 * macro_precision_off * macro_recall_off) / (macro_precision_off + macro_recall_off)
                mean_macro_accuracy = (Tnot + Toff) / (Tnot + Fnot + Toff + Foff)
                mean_macro_precision = (macro_precision_not + macro_precision_off) / 2
                mean_macro_recall = (macro_recall_not + macro_recall_off) / 2
                mean_macro_f1 = (macro_f1_not + macro_f1_off) / 2

                print("deviation {}/{}   ".format(cross_deviation + 1, cross_multiple),
                      "iterations {}/{}   ".format(i + 1, iterations),
                      "F1 {}   ".format(mean_macro_f1),
                      "accuracy {}   ".format(mean_macro_accuracy),
                      "precision {}   ".format(mean_macro_precision),
                      "recall {}   ".format(mean_macro_recall))
                with open("../trained_model/log_attbilstm.txt", "a", encoding="utf-8") as f:
                    f.write("deviation " + str(cross_deviation + 1) + "/" + str(cross_multiple) +
                            "   iterations " + str(i + 1) + "/" + str(iterations) +
                            "   F1 " + str(mean_macro_f1) +
                            "   accuracy " + str(mean_macro_accuracy) +
                            "   precision " + str(mean_macro_precision) +
                            "   recall " + str(mean_macro_recall) + "\r\n")
                    f.close()

                if mean_macro_f1 > max_macro_f1:
                    print("update")
                    with open("../trained_model/log_attbilstm.txt", "a", encoding="utf-8") as f:
                        f.write("update\r\n")
                        f.close()
                    max_macro_f1 = mean_macro_f1
                    bestret = retpc
    return bestret, max_macro_f1
        # if mean_acc > max_acc:
        #     save_path = saver.save(sess, "../trained_model/attbilstm-" + str(cross_deviation + 1) + ".ckpt", global_step=i+1)
        #     print("saved to %s" % save_path)
        #     with open("../trained_model/log_attbilstm.txt", "a", encoding="utf-8") as f:
        #         f.write("saved to../trained_model/attbilstm-" + str(cross_deviation + 1) + ".ckpt-" + str(i + 1) + "\r\n")
        #         f.close()
        #     max_acc = mean_acc
