import tensorflow as tf
import model.BiLSTM
import model.Attention
import model.FullyConnectLayer_attbilstm
import os

def AttBiLSTMFrame(wordModel, numClasses):
    lstmUnitNum = 100
    attn_length = -1
    attention_size = 256
    layer_num = 2
    forget_bias = 0.5

    tf.reset_default_graph()
    input_data = tf.placeholder(tf.int32, [None, None])
    labels = tf.placeholder(tf.float32, [None, numClasses])
    input_keep_prob = tf.placeholder(tf.float32)
    output_keep_prob = tf.placeholder(tf.float32)

    with tf.device("/cpu:0"):
        W = tf.Variable(tf.constant(0.0, shape=[wordModel.vocab_size, wordModel.numDimensions]), trainable=False, name="W")
        embedding_placeholder = tf.placeholder(tf.float32, [wordModel.vocab_size, wordModel.numDimensions])
        embedding_init = W.assign(embedding_placeholder)
        embedding_input = tf.nn.embedding_lookup(W, input_data)

    value = model.BiLSTM.model_BiLSTM_stack(embedding_input, lstmUnitNum, layer_num, forget_bias, input_keep_prob, output_keep_prob, attn_length)
    attout, _ = model.Attention.attention(value, attention_size)
    print(attout.shape)
    prediction = model.FullyConnectLayer_attbilstm.fullyConnectLayer(attout, lstmUnitNum, numClasses)
    return embedding_init, embedding_placeholder, input_data, labels, input_keep_prob, output_keep_prob, prediction

def AttBiLSTMRun(wordModel, numClasses, input_keep_prob_, output_keep_prob_, cross_deviation, epoch, cross_multiple=5, batchTotal=32, verTotal=8):
    learning_rate = 1e-4
    embedding_init, embedding_placeholder, input_data, labels, input_keep_prob, output_keep_prob, prediction = AttBiLSTMFrame(wordModel, numClasses)
    loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=prediction, labels=labels))
    optimizer = tf.train.AdamOptimizer(learning_rate).minimize(loss)

    prediction_class = tf.argmax(prediction, 1)

    os.environ["CUDA_VISIBLE_DEVICES"] = "0"
    config = tf.ConfigProto()
    config.gpu_options.per_process_gpu_memory_fraction = 0.4
    sess = tf.InteractiveSession(config=config)
    sess.run(tf.global_variables_initializer())

    sess.run(embedding_init, feed_dict={embedding_placeholder: wordModel.embedding})

    saver = tf.train.Saver(max_to_keep=1)

    max_macro_f1 = 0.
    for i in range(epoch):
        for j in range(batchTotal):
            nextBatchInputs, nextBatchLabels = wordModel.getTrainBatch(verify=False, batch_index=j, cross_deviation=cross_deviation)
            sess.run(optimizer, {input_data: nextBatchInputs, labels: nextBatchLabels, input_keep_prob: input_keep_prob_, output_keep_prob: output_keep_prob_})
        Tnot, Fnot, Toff, Foff = 0., 0., 0., 0.
        for j in range(verTotal):
            nextBatchInputs, nextBatchLabels = wordModel.getTrainBatch(verify=True, batch_index=j, cross_deviation=cross_deviation)
            pc = sess.run(prediction_class, {input_data: nextBatchInputs, labels: nextBatchLabels, input_keep_prob: 1.0, output_keep_prob: 1.0})
            for k in range(len(pc)):
                if nextBatchLabels[k][0] == 0 and pc[k] == 1:
                    Tnot += 1
                elif nextBatchLabels[k][0] == 0 and pc[k] == 0:
                    Fnot += 1
                elif nextBatchLabels[k][0] == 1 and pc[k] == 0:
                    Toff += 1
                elif nextBatchLabels[k][0] == 1 and pc[k] == 1:
                    Foff += 1
        macro_precision_not = Tnot / (Tnot + Fnot)
        macro_precision_off = Toff / (Toff + Foff)
        macro_recall_not = Tnot / (Tnot + Foff)
        macro_recall_off = Toff / (Toff + Fnot)
        macro_f1_not = (2 * macro_precision_not * macro_recall_not) / (macro_precision_not + macro_recall_not)
        macro_f1_off = (2 * macro_precision_off * macro_recall_off) / (macro_precision_off + macro_recall_off)
        mean_macro_accuracy = (Tnot + Toff) / (Tnot + Fnot + Toff + Foff)
        mean_macro_precision = (macro_precision_not + macro_precision_off) / 2
        mean_macro_recall = (macro_recall_not + macro_recall_off) / 2
        mean_macro_f1 = (macro_f1_not + macro_f1_off) / 2

        print("deviation {}/{}   ".format(cross_deviation + 1, cross_multiple),
              "epoch {}/{}   ".format(i + 1, epoch),
              "F1 {}   ".format(mean_macro_f1),
              "accuracy {}   ".format(mean_macro_accuracy),
              "precision {}   ".format(mean_macro_precision),
              "recall {}   ".format(mean_macro_recall))
        with open("../trained_model/log_attbilstm.txt", "a", encoding="utf-8") as f:
            f.write("deviation " + str(cross_deviation + 1) + "/" + str(cross_multiple) +
                    "   epoch " + str(i + 1) + "/" + str(epoch) +
                    "   F1 " + str(mean_macro_f1) +
                    "   accuracy " + str(mean_macro_accuracy) +
                    "   precision " + str(mean_macro_precision) +
                    "   recall " + str(mean_macro_recall) + "\r\n")
            f.close()

        if mean_macro_f1 > max_macro_f1:
            print("update")
            with open("../trained_model/log_attbilstm.txt", "a", encoding="utf-8") as f:
                f.write("update\r\n")
                f.close()
            max_macro_f1 = mean_macro_f1
        # if mean_acc > max_acc:
        #     save_path = saver.save(sess, "../trained_model/attbilstm-" + str(cross_deviation + 1) + ".ckpt", global_step=i+1)
        #     print("saved to %s" % save_path)
        #     with open("../trained_model/log_attbilstm.txt", "a", encoding="utf-8") as f:
        #         f.write("saved to../trained_model/attbilstm-" + str(cross_deviation + 1) + ".ckpt-" + str(i + 1) + "\r\n")
        #         f.close()
        #     max_acc = mean_acc
