# coding: utf-8
import os, sys
# os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
# os.environ["CUDA_VISIBLE_DEVICES"] = "2"

import numpy as np
import time, pickle
import datetime
import data_helpers
import tensorflow as tf
import pickle
from tensorflow.contrib import learn
import logging
from collections import defaultdict

from collections import Counter
from scipy.stats import wasserstein_distance

tf.flags.DEFINE_string("word2vec", 'glove/glove.840B.300d.txt',
                       "Word2vec file with pre-trained embeddings (default: None)")
tf.flags.DEFINE_integer("embedding_dim", 300, "Dimensionality of character embedding")
tf.flags.DEFINE_integer("num_units", 256, "Number of filters per filter size ")
tf.flags.DEFINE_float("dropout_rate", 0.5, "Dropout keep probability (default: 0.5)")
tf.flags.DEFINE_float("l2_reg_lambda", 0.0, "L2 regularizaion lambda (default: 0.0)")
tf.flags.DEFINE_integer("max_document_length", 50, "length of the padded sentence")
tf.flags.DEFINE_integer("dense_units", 128, "the last layer of the CNN encoder")

# Training parameters
tf.flags.DEFINE_integer("batch_size_s", 2300, "Source batch size")
tf.flags.DEFINE_integer("batch_size_t", 100, "Target batch size")
tf.flags.DEFINE_integer("num_epochs", 8, "Number of training epochs")


tf.flags.DEFINE_boolean("allow_soft_placement", True, "Allow device soft device placement")
tf.flags.DEFINE_boolean("log_device_placement", False, "Log placement of ops on devices")
tf.flags.DEFINE_integer("base_model_pretrain", 50, "Pretrain the base model for i iterations")
tf.flags.DEFINE_integer("RL_pretrain", 0, "Pretrain the RL model for i iterations after base model pretrain")
tf.flags.DEFINE_boolean("reinforced_data_selector", True,
                        "whether to use the reinforced source data selector")

tf.flags.DEFINE_integer("wd_num", 5, "train number for domain critic")
tf.flags.DEFINE_integer("critic_units", 128, "units of the hidden layer in domain critic")
tf.flags.DEFINE_integer("rl_units", 128, "units of the hidden layer in data selector")
tf.flags.DEFINE_float("l2_param", 1e-4, "l2_param for domain critic")
tf.flags.DEFINE_float("wd_param", 0.1, "wd_param for domain critic")
tf.flags.DEFINE_integer("gp_param", 10, "gp_param for domain critic")
tf.app.flags.DEFINE_float("reward_decay", 0.8, "reward_decay")
tf.app.flags.DEFINE_float("coef", 1, "reward coefficient")

FLAGS = tf.flags.FLAGS
reinforced_data_selector = True

print("Loading source domain data...")
x1_s_str, x2_s_str, y_s=data_helpers.load_mnli('data/mnli_train.txt')
x1_t_train_str, x2_t_train_str, y_t_train  =data_helpers.load_data_and_labels_tsv('data/scitail_1.0_train.tsv')
x1_t_val, x2_t_val, y_t_val=data_helpers.load_data_and_labels_tsv('data/scitail_1.0_dev.tsv')
x1_t_test, x2_t_test, y_t_test = data_helpers.load_data_and_labels_tsv('data/scitail_1.0_test.tsv')

lent=len(x1_t_train_str)

# x1_s, x2_s, y_s = data_helpers.load_data_and_labels_csv('data/quora_train.csv')
# x1_t_train, x2_t_train, y_t_train = data_helpers.load_data_and_labels_nli('data/cikm_train.txt')
# x1_t_val, x2_t_val, y_t_val = data_helpers.load_data_and_labels_nli('data/cikm_validation.txt')
# x1_t_test, x2_t_test, y_t_test = data_helpers.load_data_and_labels_nli('data/cikm_test.txt')


vocab_processor = learn.preprocessing.VocabularyProcessor(FLAGS.max_document_length)
vocab_processor.fit(x1_s_str + x2_s_str + x1_t_train_str + x2_t_train_str + x1_t_val + x2_t_val + x1_t_test + x2_t_test)
x1_s = np.array(list(vocab_processor.transform(x1_s_str)))
x2_s = np.array(list(vocab_processor.transform(x2_s_str)))
x1_t_train = np.array(list(vocab_processor.transform(x1_t_train_str)))
x2_t_train = np.array(list(vocab_processor.transform(x2_t_train_str)))
x1_t_val = np.array(list(vocab_processor.transform(x1_t_val)))
x2_t_val = np.array(list(vocab_processor.transform(x2_t_val)))
x1_t_test = np.array(list(vocab_processor.transform(x1_t_test)))
x2_t_test = np.array(list(vocab_processor.transform(x2_t_test)))

print(x1_s.shape, y_s.shape)
print(x1_t_train.shape, x1_t_val.shape, x1_t_test.shape, y_t_train.shape, y_t_val.shape, y_t_test.shape)

if FLAGS.word2vec:
    # initial matrix with random uniform
    initW = np.random.uniform(-0.25, 0.25, (len(vocab_processor.vocabulary_), FLAGS.embedding_dim))
    print("Load glove file {}\n".format(FLAGS.word2vec))
    with open('glove/glove.840B.300d.pkl', 'rb') as fp:
        embeddings_index = pickle.load(fp)
    print('Found %s word vectors.' % len(embeddings_index))
    for word, coef in embeddings_index.items():
        idx = vocab_processor.vocabulary_.get(word)
        if idx != 0:
            initW[idx] = coef

from shared_dam_classifier import SharedNN_DAM
from immi_ag import Agent

with tf.Graph().as_default():
    sess = tf.Session()
    with sess.as_default():
        shared_nn = SharedNN_DAM(
            sequence_length=FLAGS.max_document_length,
            num_classes=2,
            vocab_size=len(vocab_processor.vocabulary_),
            embedding_size=FLAGS.embedding_dim,
            num_units=FLAGS.num_units,
            critic_units=FLAGS.critic_units,
            l2_param=FLAGS.l2_param,
            wd_param=FLAGS.wd_param,
            gp_param=FLAGS.gp_param,
            l2_reg_lambda=FLAGS.l2_reg_lambda)

        LR_A = 0.001  
        agent = Agent(sess, n_features=4*FLAGS.num_units+4, n_actions=2, n_units=FLAGS.rl_units,lr=LR_A)
        global_step = tf.Variable(0, name="global_step", trainable=False)
        increment_global_step_op = tf.assign(global_step, global_step + 1)
        optimizer = tf.train.AdamOptimizer(1e-3)
        grads_and_vars1 = optimizer.compute_gradients(shared_nn.loss1)
        train_op1 = optimizer.apply_gradients(grads_and_vars1, global_step=global_step)
        grads_and_vars2 = optimizer.compute_gradients(shared_nn.loss2)
        train_op2 = optimizer.apply_gradients(grads_and_vars2, global_step=global_step)

        timestamp = str(int(time.time()))
        out_dir = os.path.abspath(os.path.join('output/', "runs", timestamp))
        os.makedirs(out_dir)
        print("Writing to {}\n".format(out_dir))

        loss1_summary = tf.summary.scalar("loss1", shared_nn.loss1)
        loss2_summary = tf.summary.scalar("loss2", shared_nn.loss2)
        loss_summary = tf.summary.scalar("total_loss", shared_nn.total_loss)
        acc_summary1 = tf.summary.scalar("accuracy1", shared_nn.accuracy1)
        acc_summary2 = tf.summary.scalar("accuracy2", shared_nn.accuracy2)
        auc_summary2 = tf.summary.scalar("auc2", shared_nn.auc2)
        #theta_D = [v for v in tf.global_variables() if 'critic' in v.name]

        wd_d_op = tf.train.AdamOptimizer(1e-4).minimize(shared_nn.all_wd_loss, var_list=shared_nn.theta_D)
        train_op = tf.train.AdamOptimizer(1e-4).minimize(shared_nn.total_loss,var_list=shared_nn.theta_G)
        # Train Summaries
        train_summary_op1 = tf.summary.merge([loss1_summary, acc_summary1])
        train_summary_op2 = tf.summary.merge([loss2_summary, acc_summary2])
        train_summary_op = tf.summary.merge([loss_summary])
        train_summary_dir = os.path.join(out_dir, "summaries", "train")
        train_summary_writer = tf.summary.FileWriter(train_summary_dir, sess.graph)

        # Dev summaries
        dev_summary_op = tf.summary.merge([loss2_summary, acc_summary2, auc_summary2])
        dev_summary_dir = os.path.join(out_dir, "summaries", "dev")
        dev_summary_writer = tf.summary.FileWriter(dev_summary_dir, sess.graph)

        # Checkpoint directory. Tensorflow assumes this directory already exists so we need to create it
        checkpoint_dir = os.path.abspath(os.path.join(out_dir, "checkpoints"))
        checkpoint_prefix = os.path.join(checkpoint_dir, "model")
        if not os.path.exists(checkpoint_dir):
            os.makedirs(checkpoint_dir)
        saver = tf.train.Saver(tf.global_variables())

        # Write vocabulary
        vocab_processor.save(os.path.join(out_dir, "vocab"))

        # Initialize all variables
        sess.run(tf.global_variables_initializer())
        sess.run(tf.local_variables_initializer())

        if FLAGS.word2vec:
            sess.run(shared_nn.W.assign(initW))

        def train_step(x1_s_batch, x2_s_batch, y_s_batch, x1_t_batch, x2_t_batch, y_t_batch):
            feed_dict = {
                shared_nn.input_x_a: x1_s_batch,
                shared_nn.input_x_b: x2_s_batch,
                shared_nn.input_y1: y_s_batch,
                shared_nn.input_x_c: x1_t_batch,
                shared_nn.input_x_d: x2_t_batch,
                shared_nn.input_y2: y_t_batch,
                shared_nn.dropout_rate: FLAGS.dropout_rate
            }
            _, step, summaries, loss,wd_loss,total_loss,l2_loss,accuracy1, accuracy2,ss,tt= sess.run(
                [train_op, global_step, train_summary_op, shared_nn.loss,shared_nn.wd_loss,shared_nn.total_loss,shared_nn.l2_loss,
                 shared_nn.accuracy1, shared_nn.accuracy2,shared_nn.o1,shared_nn.o2],
                feed_dict)

            mean=abs(np.mean(ss)-np.mean(tt))
            slist=np.mean(ss,axis=1).tolist()
            tlist=np.mean(tt,axis=1).tolist()
            rr,result=[],[]
            index=0
            for i in slist:
                maxmax=np.max(ss)-np.min(tt)
                for j in tlist:
                    if abs(i-j)<maxmax:
                        maxmax=abs(i-j)
                rr.append(mean-maxmax)
                if mean-maxmax>0:
                    result.append(index)
                index+=1

            train_summary_writer.add_summary(summaries, step)
            return np.array(rr),result


        def train_step_src(x1_s_batch, x2_s_batch, y_s_batch):

            feed_dict = {
                shared_nn.input_x_a: x1_s_batch,
                shared_nn.input_x_b: x2_s_batch,
                shared_nn.input_y1: y_s_batch,
                shared_nn.dropout_rate: FLAGS.dropout_rate
            }
            _, step, summaries, loss1, accuracy1, src_rep = sess.run(
                [train_op1, global_step, train_summary_op1, shared_nn.loss1, shared_nn.accuracy1, shared_nn.o1],
                feed_dict)
            time_str = datetime.datetime.now().isoformat()
            #print("{} src: step {}, loss1 {:g}, acc1 {:g}".format(time_str, step, loss1, accuracy1))
            train_summary_writer.add_summary(summaries, step)
            return src_rep


        def train_step_tgt(x1_t_batch, x2_t_batch, y_t_batch):
            """
            A single training step for source model
            """

            feed_dict = {
                shared_nn.input_x_c: x1_t_batch,
                shared_nn.input_x_d: x2_t_batch,
                shared_nn.input_y2: y_t_batch,
                shared_nn.dropout_rate: FLAGS.dropout_rate
            }
            sess.run(shared_nn.update_auc_op, feed_dict)
            _, step, summaries, loss2, accuracy2, auc2 = sess.run(
                [train_op2, global_step, train_summary_op2, shared_nn.loss2, shared_nn.accuracy2, shared_nn.auc2],
                feed_dict)
            time_str = datetime.datetime.now().isoformat()
            #print("{} tgt: step {}, loss2 {:g}, acc2 {:g}, auc2 {:g}".format(time_str, step, loss2, accuracy2, auc2))
            train_summary_writer.add_summary(summaries, step)
            sess.run(running_vars_initializer)


        def train_step_return_train_loss(x1_s_batch, x2_s_batch, y_s_batch):
            feed_dict = {
                shared_nn.input_x_a: x1_s_batch,
                shared_nn.input_x_b: x2_s_batch,
                shared_nn.input_y1: y_s_batch,
                shared_nn.dropout_rate: FLAGS.dropout_rate
            }
            batch_losses1 = sess.run([shared_nn.losses1], feed_dict)
            return batch_losses1


        def dev_step(x1_t_batch, x2_t_batch, y_t_batch, writer=None):
            feed_dict = {
                shared_nn.input_x_c: x1_t_batch,
                shared_nn.input_x_d: x2_t_batch,
                shared_nn.input_y2: y_t_batch,
                shared_nn.dropout_rate: 1.0
            }
            sess.run(shared_nn.update_auc_op, feed_dict)
            time_str = datetime.datetime.now().isoformat()
            step, summaries, loss2, accuracy2, auc2= sess.run(
                [global_step, dev_summary_op, shared_nn.loss2, shared_nn.accuracy2, shared_nn.auc2],
                feed_dict)
            if writer:
                writer.add_summary(summaries, step)

            sess.run(running_vars_initializer)
            return accuracy2,auc2


        def dev_step_return_test_loss(x1_s_batch, x2_s_batch, y_s_batch):
            """
            obtain the prediction loss for a source batch on the target model
            """
            feed_dict = {
                shared_nn.input_x_c: x1_s_batch,
                shared_nn.input_x_d: x2_s_batch,
                shared_nn.input_y2: y_s_batch,
                shared_nn.dropout_rate: 1.0
            }
            # use losses instead of loss because we want the individual loss for each instance in the batch
            batch_losses2 = sess.run([shared_nn.losses2], feed_dict)
            return batch_losses2


        def test_step(x1_t_batch, x2_t_batch, y_t_batch):
            feed_dict = {
                shared_nn.input_x_c: x1_t_batch,
                shared_nn.input_x_d: x2_t_batch,
                shared_nn.input_y2: y_t_batch,
                shared_nn.dropout_rate: 1.0
            }
            sess.run(shared_nn.update_auc_op, feed_dict)
            step, summaries, loss2, accuracy2, auc2, prob2, prediction2, correct_prediction2 = sess.run(
                [global_step, dev_summary_op, shared_nn.loss2, shared_nn.accuracy2, shared_nn.auc2, shared_nn.prob2,
                 shared_nn.predictions2, shared_nn.correct_predictions2],
                feed_dict)
            time_str = datetime.datetime.now().isoformat()
            print("test{}: step {}, loss2 {:g}, acc2 {:g}, auc2 {:g}".format(time_str, step, loss2, accuracy2, auc2))

            sess.run(running_vars_initializer)
            return prob2, prediction2, correct_prediction2, accuracy2, auc2  # return target validation auc/acc/loss to compute reward

        batches_s= data_helpers.batch_iter(
            list(zip(x1_s, x2_s, y_s)), FLAGS.batch_size_s, FLAGS.num_epochs)
        batches_t = data_helpers.batch_iter(
            list(zip(x1_t_train, x2_t_train, y_t_train)), FLAGS.batch_size_t, FLAGS.num_epochs)
        i_episode, i_batch,dev_ep = 0, 0,0
        last_reward_metric2 = 0.6
        num_batches_per_epoch = int((len(x1_s) - 1) / FLAGS.batch_size_s) + 1
        print('num_batches_per_epoch:', num_batches_per_epoch)

        running_vars = tf.get_collection(tf.GraphKeys.LOCAL_VARIABLES, scope='auc2')

        # Define initializer to initialize/reset running variables
        running_vars_initializer = tf.variables_initializer(var_list=running_vars)

        tf.get_default_graph().finalize()
        dev_acc,dev_auc,test_acc,test_auc=0.0,0.0,0.0,0.0
        for batch_s, batch_t in zip(batches_s, batches_t):
            i_batch += 1
            x1_s_batch, x2_s_batch, y_s_batch = zip(*batch_s[0])
            x1_t_batch, x2_t_batch, y_t_batch = zip(*batch_t[0])
            if i_batch <= FLAGS.base_model_pretrain:                    
                    for _ in range(FLAGS.wd_num):
                        wd_loss1, _ = sess.run([shared_nn.all_wd_loss, wd_d_op],
                                               feed_dict={shared_nn.input_x_a: x1_s_batch,
                                                          shared_nn.input_x_b: x2_s_batch,
                                                          shared_nn.input_x_c: x1_t_batch,
                                                          shared_nn.input_x_d: x2_t_batch,
                                                          shared_nn.dropout_rate: FLAGS.dropout_rate})
                    src_rep = train_step_src(x1_s_batch, x2_s_batch, y_s_batch)
                    train_step_tgt(x1_t_batch, x2_t_batch, y_t_batch)
                    reward_metric2, _ = dev_step(x1_t_val, x2_t_val, y_t_val, writer=dev_summary_writer)
            else:
                if reinforced_data_selector:
                    for _ in range(FLAGS.wd_num):
                        wd_loss1, _ = sess.run([shared_nn.all_wd_loss, wd_d_op],
                                              feed_dict={shared_nn.input_x_a: x1_s_batch,
                                                         shared_nn.input_x_b: x2_s_batch,
                                                         shared_nn.input_x_c: x1_t_batch,
                                                         shared_nn.input_x_d: x2_t_batch,
                                                         shared_nn.dropout_rate: FLAGS.dropout_rate})
                    rr,result1=train_step(x1_s_batch, x2_s_batch, y_s_batch, x1_t_batch, x2_t_batch, y_t_batch)

                    feed_dict_src_rep = {
                        shared_nn.input_x_a: x1_s_batch,
                        shared_nn.input_x_b: x2_s_batch,
                        shared_nn.dropout_rate: 1.0
                    }

                    feature1 = sess.run([shared_nn.o1], feed_dict_src_rep)[0]
                    feature2 = train_step_return_train_loss(x1_s_batch, x2_s_batch, y_s_batch)
                    feature3 = dev_step_return_test_loss(x1_s_batch, x2_s_batch, y_s_batch)
                    feature2 = np.array(feature2).reshape(-1, 1)
                    feature3 = np.array(feature3).reshape(-1, 1)

                    feed_dict_test_prob = {
                        shared_nn.input_x_c: x1_s_batch,
                        shared_nn.input_x_d: x2_s_batch,
                        shared_nn.dropout_rate: 1.0
                    }
                    feature7 = sess.run([shared_nn.prob2], feed_dict_test_prob)[0]
                    observation = np.concatenate((feature1,feature2, feature3,feature7), axis=1)
                    action = agent.choose_action(observation)
                    keep_idx = action.reshape(-1, )

                    x1_s_batch = np.array(x1_s_batch)[keep_idx == 1]
                    x2_s_batch = np.array(x2_s_batch)[keep_idx == 1]
                    y_s_batch = np.array(y_s_batch)[keep_idx == 1]
                   
                    if len(y_s_batch) > 0:
                        src_rep = train_step_src(x1_s_batch, x2_s_batch, y_s_batch)
                    if len(y_s_batch) == 0:
                        sess.run(increment_global_step_op)
                    train_step_tgt(x1_t_batch, x2_t_batch, y_t_batch)
                    reward_metric2, _= dev_step(x1_t_val, x2_t_val, y_t_val, writer=dev_summary_writer)
                    if reward_metric2+_ >dev_acc+dev_auc:
                        dev_acc=reward_metric2
                        dev_auc=_
                        dev_ep = i_batch
                    td_error=FLAGS.coef*rr+reward_metric2-last_reward_metric2
                    
                    agent.learn(observation, action, td_error)
                    last_reward_metric2 = reward_metric2
                else:
                    for _ in range(FLAGS.wd_num):
                        sess.run([wd_d_op], feed_dict={shared_nn.input_x_a: x1_s_batch, shared_nn.input_x_b: x2_s_batch,
                                                   shared_nn.input_x_c:x1_t_batch,shared_nn.input_x_d:x2_t_batch,shared_nn.dropout_rate: FLAGS.dropout_rate})

                    src_rep = train_step_src(x1_s_batch, x2_s_batch, y_s_batch)
                    train_step_tgt(x1_t_batch, x2_t_batch, y_t_batch)
                    train_step(x1_s_batch, x2_s_batch, y_s_batch,x1_t_batch, x2_t_batch, y_t_batch)
                    reward_metric2, _ = dev_step(x1_t_val, x2_t_val, y_t_val, writer=dev_summary_writer)

                    if reward_metric2+_ > dev_acc + dev_auc:
                        dev_acc = reward_metric2
                        dev_auc = _
                        dev_ep = i_batch

            if i_batch % num_batches_per_epoch== 0 :
                current_epoch = i_batch / num_batches_per_epoch
                acc_val, auc_val= dev_step(x1_t_val, x2_t_val, y_t_val, writer=dev_summary_writer)
                prob, prediction, correct_prediction, acc_test, auc_test = test_step(x1_t_test, x2_t_test, y_t_test)
                prediction=prediction.tolist()
                if acc_val + auc_val > dev_acc + dev_auc:
                    dev_acc = acc_val
                    dev_auc = auc_val
                    dev_ep = i_batch

                    test_acc = acc_test
                    test_auc = auc_test


