import tensorflow as tf
import model.RCNN
import model.FullyConnectLayer_cnn
import os

def RCNNFrame(wordModel, batchSize, numClasses):
    lstmUnitNum = 256
    filterNum = 100
    forget_bias = 0.5

    tf.reset_default_graph()
    input_data = tf.placeholder(tf.int32, [batchSize, None])
    labels = tf.placeholder(tf.float32, [batchSize, numClasses])
    keepratio_cnn = tf.placeholder(tf.float32)
    input_keep_prob = tf.placeholder(tf.float32)
    output_keep_prob = tf.placeholder(tf.float32)

    with tf.device("/cpu:0"):
        W = tf.Variable(tf.constant(0.0, shape=[wordModel.vocab_size, wordModel.numDimensions]), trainable=False, name="W")
        embedding_placeholder = tf.placeholder(tf.float32, [wordModel.vocab_size, wordModel.numDimensions])
        embedding_init = W.assign(embedding_placeholder)
        embedding_input = tf.nn.embedding_lookup(W, input_data)

    value = model.RCNN.model_RCNN(embedding_input, lstmUnitNum, filterNum, forget_bias, input_keep_prob, output_keep_prob, keepratio_cnn)
    prediction = model.FullyConnectLayer_cnn.fullyConnectLayer(value, filterNum, numClasses)
    return embedding_init, embedding_placeholder, input_data, labels, input_keep_prob, output_keep_prob, keepratio_cnn, prediction

def RCNNRun(wordModel, batchSize, numClasses, input_keep_prob_, output_keep_prob_, keepratio_cnn_, cross_deviation, epoch, cross_multiple=5, batchTotal=32, verTotal=8):
    learning_rate = 1e-4
    embedding_init, embedding_placeholder, input_data, labels, input_keep_prob, output_keep_prob, keepratio_cnn, prediction = RCNNFrame(wordModel, batchSize, numClasses)
    loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=prediction, labels=labels))
    optimizer = tf.train.AdamOptimizer(learning_rate).minimize(loss)

    correctPred = tf.equal(tf.argmax(prediction, 1), tf.argmax(labels, 1))
    accuracy = tf.reduce_mean(tf.cast(correctPred, tf.float32))

    os.environ["CUDA_VISIBLE_DEVICES"] = "1"
    config = tf.ConfigProto()
    config.gpu_options.per_process_gpu_memory_fraction = 0.4
    sess = tf.InteractiveSession(config=config)
    sess.run(tf.global_variables_initializer())

    saver = tf.train.Saver(max_to_keep=1)

    max_acc = 0.
    for i in range(epoch):
        for j in range(batchTotal):
            nextBatchInputs, nextBatchLabels = wordModel.getTrainBatch(verify=False, batch_index=j, cross_deviation=cross_deviation)
            sess.run(optimizer, {input_data: nextBatchInputs, labels: nextBatchLabels, input_keep_prob: input_keep_prob_, output_keep_prob: output_keep_prob_, keepratio_cnn: keepratio_cnn_})
        mean_acc = 0.
        for j in range(verTotal):
            nextBatchInputs, nextBatchLabels = wordModel.getTrainBatch(verify=True, batch_index=j, cross_deviation=cross_deviation)
            mean_acc += sess.run(accuracy, {input_data: nextBatchInputs, labels: nextBatchLabels, input_keep_prob: 1.0, output_keep_prob: 1.0, keepratio_cnn: 1.0})
        mean_acc /= verTotal

        print("deviation {}/{}   ".format(cross_deviation + 1, cross_multiple),
              "epoch {}/{}   ".format(i + 1, epoch),
              "accuracy {}   ".format(mean_acc))
        with open("../trained_model/log_rcnn.txt", "a", encoding="utf-8") as f:
            f.write("deviation " + str(cross_deviation + 1) + "/" + str(cross_multiple) +
                    "   epoch " + str(i + 1) + "/" + str(epoch) +
                    "   accuracy " + str(mean_acc) + "...\r\n")
            f.close()

        # if mean_acc > max_acc:
        #     save_path = saver.save(sess, "../trained_model/rcnn-" + str(cross_deviation + 1) + ".ckpt", global_step=i + 1)
        #     print("saved to %s" % save_path)
        #     with open("../trained_model/log_rcnn.txt", "a", encoding="utf-8") as f:
        #         f.write("saved to../trained_model/rcnn-" + str(cross_deviation + 1) + ".ckpt-" + str(i + 1) + "\r\n")
        #         f.close()
        #     max_acc = mean_acc
