# coding: utf-8
import os, sys
# os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
# os.environ["CUDA_VISIBLE_DEVICES"] = "2"

import numpy as np
import time, pickle
import datetime
import data_helpers
import tensorflow as tf
from tensorflow.contrib import learn
import logging
from collections import defaultdict,Counter
from scipy.stats import wasserstein_distance
tf.flags.DEFINE_string("word2vec", 'glove/glove.twitter.27B.100d.txt',
                       "Word2vec file with pre-trained embeddings (default: None)")
tf.flags.DEFINE_string("target_file", '../Watches.txt.gz', "Data source for the positive data.")
tf.flags.DEFINE_integer("embedding_dim", 100, "Dimensionality of character embedding (default: 128)")
tf.flags.DEFINE_integer("num_units", 256, "Number of filters per filter size")
tf.flags.DEFINE_float("dropout_rate", 0.5, "Dropout keep probability (default: 0.5)")
tf.flags.DEFINE_integer("eval_every", 200, "eval steps")
tf.flags.DEFINE_integer("max_document_length", 100, "length of the padded sentence")
tf.app.flags.DEFINE_float("coef", 1, "reward coefficient")
tf.flags.DEFINE_float("l2_reg_lambda", 1, "L2 regularizaion lambda (default: 0.0)")
# Training parameters
tf.flags.DEFINE_integer("batch_size_s", 64, "Source Batch Size")
tf.flags.DEFINE_integer("batch_size_t", 64, "Target Batch Size")
tf.flags.DEFINE_integer("num_epochs", 200, "Number of training epochs (default: 200)")

tf.flags.DEFINE_boolean("allow_soft_placement", True, "Allow device soft device placement")
tf.flags.DEFINE_boolean("log_device_placement", False, "Log placement of ops on devices")
tf.flags.DEFINE_integer("base_model_pretrain", 50, "Pretrain the base model for i iterations")
tf.flags.DEFINE_integer("RL_pretrain", 0, "Pretrain the RL model for i iterations after base model pretrain")
tf.flags.DEFINE_boolean("reinforced_data_selector", True,
                        "whether to use the reinforced source data selector")

tf.flags.DEFINE_integer("wd_num", 5, "train number for domain critic")
tf.flags.DEFINE_integer("critic_units", 128, "units of the hidden layer in domain critic")
tf.flags.DEFINE_integer("rl_units", 128, "units of the hidden layer in agent")
tf.flags.DEFINE_float("l2_param", 1e-4, "l2_param for domain critic")
tf.flags.DEFINE_float("wd_param", 1, "wd_param for domain critic")
tf.flags.DEFINE_integer("gp_param", 10, "gp_param for domain critic")
tf.flags.DEFINE_integer("num_filters", 128, "Number of filters per filter size (default: 128)")

FLAGS = tf.flags.FLAGS
reinforced_data_selector = True

x_s_str,y_s=data_helpers.load_data_and_labels("../Electronics.txt.gz")
x_t_str, y_t=data_helpers.load_data_and_labels(FLAGS.target_file)
max_document_length=100

vocab_processor = learn.preprocessing.VocabularyProcessor(FLAGS.max_document_length)
vocab_processor.fit(x_s_str+x_t_str)
x_s = np.array(list(vocab_processor.transform(x_s_str)))
x_t = np.array(list(vocab_processor.transform(x_t_str)))
y_s=np.array(y_s)
y_t=np.array(y_t)

np.random.seed(10)
shuffle_indices = np.random.permutation(np.arange(len(y_s)))
x_shuffled_s = x_s[shuffle_indices]
y_shuffled_s= y_s[shuffle_indices]

shuffle_indices = np.random.permutation(np.arange(len(y_t)))
x_shuffled_t= x_t[shuffle_indices]
y_shuffled_t= y_t[shuffle_indices]

dev_sample_index = -1 * int(2*0.1 * float(len(y_t)))
test_sample_index=-1*int(0.1 * float(len(y_t)))
x_train_t, x_dev_t,x_test_t = x_shuffled_t[:dev_sample_index], x_shuffled_t[dev_sample_index:test_sample_index],x_shuffled_t[test_sample_index:]
y_train_t, y_dev_t,y_test_t = y_shuffled_t[:dev_sample_index], y_shuffled_t[dev_sample_index:test_sample_index],y_shuffled_t[test_sample_index:]


if FLAGS.word2vec:
    initW = np.random.uniform(-0.25, 0.25, (len(vocab_processor.vocabulary_), FLAGS.embedding_dim))
    print("Load glove file {}\n".format(FLAGS.word2vec))

    print('Indexing word vectors.')
    with open('glove/glove.twitter.27B.100d.pkl', 'rb') as fp:
        embeddings_index = pickle.load(fp)
        
    print('Found %s word vectors.' % len(embeddings_index))
    num_words = len(vocab_processor.vocabulary_) + 1

    for word, coef in embeddings_index.items():
        idx = vocab_processor.vocabulary_.get(word)
        if idx != 0:
            initW[idx] = coef

from shared_net import Shared_CNN
from ag import Agent
with tf.Graph().as_default():
    sess = tf.Session()
    with sess.as_default():
        shared_nn = Shared_CNN(
            sequence_length=FLAGS.max_document_length,
            num_classes=1,
            vocab_size=len(vocab_processor.vocabulary_),
            embedding_size=FLAGS.embedding_dim,
            filter_sizes=[2,3,4,5],
            num_filters=FLAGS.num_filters,
            critic_units=FLAGS.critic_units,
            l2_param=FLAGS.l2_param,
            wd_param=FLAGS.wd_param,
            gp_param=FLAGS.gp_param,
            l2_reg_lambda=FLAGS.l2_reg_lambda
            )

        LR_A = 0.001  
        agent = Agent(sess, n_features=4*FLAGS.num_units+2, n_actions=2, n_units=FLAGS.rl_units,lr=LR_A)

        global_step = tf.Variable(0, name="global_step", trainable=False)
        increment_global_step_op = tf.assign(global_step, global_step + 1)
        optimizer = tf.train.AdamOptimizer(1e-3)
        grads_and_vars1 = optimizer.compute_gradients(shared_nn.loss1)
        train_op1 = optimizer.apply_gradients(grads_and_vars1, global_step=global_step)
        grads_and_vars2 = optimizer.compute_gradients(shared_nn.loss2)
        train_op2 = optimizer.apply_gradients(grads_and_vars2, global_step=global_step)

        wd_d_op = tf.train.AdamOptimizer(1e-4).minimize(shared_nn.all_wd_loss, var_list=shared_nn.theta_D)
        train_op = tf.train.AdamOptimizer(1e-4).minimize(shared_nn.total_loss,var_list=shared_nn.theta_G)
        
        timestamp = str(int(time.time()))
        out_dir = os.path.abspath(os.path.join('output/', "runs", timestamp))
        os.makedirs(out_dir)
        print("Writing to {}\n".format(out_dir))

        # Summaries for loss and accuracy
        loss1_summary = tf.summary.scalar("loss1", shared_nn.loss1)
        loss2_summary = tf.summary.scalar("loss2", shared_nn.loss2)
        loss_summary = tf.summary.scalar("loss", shared_nn.loss)

        pcc_summary1 = tf.summary.scalar("pcc1", shared_nn.pcc1)
        pcc_summary2 = tf.summary.scalar("pcc2", shared_nn.pcc2)
       
        train_summary_op1 = tf.summary.merge([loss1_summary,pcc_summary1])
        train_summary_op2 = tf.summary.merge([loss2_summary,pcc_summary2])
        train_summary_op = tf.summary.merge([loss_summary])
        train_summary_dir = os.path.join(out_dir, "summaries", "train")
        train_summary_writer = tf.summary.FileWriter(train_summary_dir, sess.graph)

        dev_summary_op = tf.summary.merge([loss2_summary,pcc_summary2])
        dev_summary_dir = os.path.join(out_dir, "summaries", "dev")
        dev_summary_writer = tf.summary.FileWriter(dev_summary_dir, sess.graph)

        checkpoint_dir = os.path.abspath(os.path.join(out_dir, "checkpoints"))
        checkpoint_prefix = os.path.join(checkpoint_dir, "model")
        if not os.path.exists(checkpoint_dir):
            os.makedirs(checkpoint_dir)
        saver = tf.train.Saver(tf.global_variables())

        # Write vocabulary
        vocab_processor.save(os.path.join(out_dir, "vocab"))

        # Initialize all variables
        sess.run(tf.global_variables_initializer())
        sess.run(tf.local_variables_initializer())

        if FLAGS.word2vec:
            sess.run(shared_nn.W.assign(initW))

        def train_step(x_s_batch, y_s_batch, x_t_batch, y_t_batch):
            feed_dict = {
                shared_nn.input_x1: x_s_batch,
                shared_nn.input_y1: y_s_batch,
                shared_nn.input_x2: x_t_batch,
                shared_nn.input_y2: y_t_batch,
                shared_nn.dropout_rate: FLAGS.dropout_rate
            }
            _, step, summaries, clf_loss, wd_loss,l2_loss,total_loss,pcc1, pcc2,ss,tt= sess.run(
                [train_op, global_step, train_summary_op, shared_nn.clf_loss,shared_nn.wd_loss,shared_nn.l2_loss,shared_nn.total_loss,
                 shared_nn.pcc1, shared_nn.pcc2,shared_nn.o1,shared_nn.o2],
                feed_dict)      
            mean=abs(np.mean(ss)-np.mean(tt))
            slist=np.mean(ss,axis=1).tolist()
            tlist=np.mean(tt,axis=1).tolist()
            rr=[]
            for i in slist:
                maxmax=np.max(ss)-np.min(tt)
                for j in tlist:
                    if abs(i-j)<maxmax:
                        maxmax=abs(i-j)
                rr.append(mean-maxmax)

            train_summary_writer.add_summary(summaries, step)
            return np.array(rr)

        def train_step_src(x_s_batch,y_s_batch):
            feed_dict = {
                shared_nn.input_x1: x_s_batch,
                shared_nn.input_y1: y_s_batch,
                shared_nn.dropout_rate: FLAGS.dropout_rate
            }
            _, step, summaries, loss1,pcc1,pred = sess.run(
                [train_op1, global_step, train_summary_op1, shared_nn.loss1,shared_nn.pcc1,shared_nn.scores1],
                feed_dict)
            time_str = datetime.datetime.now().isoformat()
            train_summary_writer.add_summary(summaries, step)
            return 


        def train_step_tgt(x_t_batch, y_t_batch):
            feed_dict = {
                shared_nn.input_x2: x_t_batch,
                shared_nn.input_y2: y_t_batch,
                shared_nn.dropout_rate: FLAGS.dropout_rate
            }
            _, step, summaries, loss2,pcc2,pred = sess.run(
                [train_op2, global_step, train_summary_op2, shared_nn.loss2,shared_nn.pcc2,shared_nn.scores2],
                feed_dict)
            time_str = datetime.datetime.now().isoformat()
            #print("{} tgt: step {}, loss2 {:g},corrcoef{:g}".format(time_str, step, loss2,result))
            train_summary_writer.add_summary(summaries, step)
    
        def dev_step(x_t_batch, y_t_batch, writer=None):
            feed_dict = {
                shared_nn.input_x2: x_t_batch,
                shared_nn.input_y2: y_t_batch,
                shared_nn.dropout_rate: 1.0
            }
            step, summaries, loss2,pcc2,pred= sess.run(
                [global_step, dev_summary_op, shared_nn.loss2,shared_nn.pcc2,shared_nn.scores2],
                feed_dict)
           
            time_str = datetime.datetime.now().isoformat()
            #print("dev{}: step {}, loss2 {:g},corrcoef {:g}".format(time_str, step, loss2,corr))
            if writer:
                writer.add_summary(summaries, step)
            return loss2,pcc2

        def test_step(x1_t_batch,y_t_batch):
            feed_dict = {
                shared_nn.input_x2: x1_t_batch,
                shared_nn.input_y2: y_t_batch,
                shared_nn.dropout_rate: 1.0
            }
            step, loss2,pcc2= sess.run(
                [global_step, shared_nn.loss2,shared_nn.pcc2],
                feed_dict)
            time_str = datetime.datetime.now().isoformat()
            return loss2,pcc2

        batches_s = data_helpers.batch_iter(
            list(zip(x_shuffled_s, y_shuffled_s)), FLAGS.batch_size_s, FLAGS.num_epochs)
        batches_t = data_helpers.batch_iter(
            list(zip(x_train_t, y_train_t)), FLAGS.batch_size_t, FLAGS.num_epochs)

        i_episode, i_batch,best_ep= 0, 0,0
        last_pcc2=0.1
        ans,dev_best=0,0
        num_batches_per_epoch = int((len(x_shuffled_s) - 1) / FLAGS.batch_size_s) + 1

        tf.get_default_graph().finalize()
        for batch_s, batch_t in zip(batches_s, batches_t):
            i_batch += 1
            x_s_batch,y_s_batch = zip(*batch_s[0])
            x_t_batch, y_t_batch = zip(*batch_t[0])
            
            if i_batch <= FLAGS.base_model_pretrain:
                    for _ in range(FLAGS.wd_num):
                        wd_loss1, _ = sess.run([shared_nn.all_wd_loss, wd_d_op],
                                               feed_dict={shared_nn.input_x1: x_s_batch, shared_nn.input_x2: x_t_batch,
                                                        shared_nn.dropout_rate: FLAGS.dropout_rate})
                    train_step_src(x_s_batch,y_s_batch)
                    train_step_tgt(x_t_batch, y_t_batch)
            else:
                if reinforced_data_selector:
                    for _ in range(FLAGS.wd_num):
                        wd_loss1, wd_loss,_ = sess.run([shared_nn.all_wd_loss,shared_nn.wd_loss, wd_d_op],
                                             feed_dict={shared_nn.input_x1: x_s_batch, shared_nn.input_x2: x_t_batch,
                                                        shared_nn.dropout_rate: FLAGS.dropout_rate})

                    rr=train_step(x_s_batch, y_s_batch, x_t_batch, y_t_batch)

                    feed_dict_src_rep = {
                        shared_nn.input_x1: x_s_batch,
                        shared_nn.dropout_rate: 1.0
                    }
                    feed_dict_test_prob = {
                        shared_nn.input_x2: x_s_batch,
                        shared_nn.dropout_rate: 1.0
                    }
 
                    feature1 = sess.run([shared_nn.o1], feed_dict_src_rep)[0]
                    feature2 = sess.run([shared_nn.o2],feed_dict_test_prob)[0]
                    feature3 = sess.run([shared_nn.scores1],feed_dict_src_rep)[0]
                    feature7 = sess.run([shared_nn.scores2], feed_dict_test_prob)[0]
                    observation = np.concatenate((feature1,feature2,feature3,feature7), axis=1)
                    action = agent.choose_action(observation)
                    keep_idx = action.reshape(-1, )

                    x_s_batch = np.array(x_s_batch)[keep_idx == 1]
                    y_s_batch = np.array(y_s_batch)[keep_idx == 1]

                    feed_dict_src_rep = {shared_nn.input_x1: x_s_batch}
                    feed_dict_tgt_rep = {shared_nn.input_x2: x_t_batch}
                    share_o1= sess.run(shared_nn.o1,feed_dict_src_rep)
                    share_o2= sess.run(shared_nn.o2,feed_dict_tgt_rep)

                    if len(y_s_batch) > 0:
                        train_step_src(x_s_batch, y_s_batch)
                    if len(y_s_batch) == 0:
                        sess.run(increment_global_step_op)
                    train_step_tgt(x_t_batch, y_t_batch)
                    loss2,pcc= dev_step(x_dev_t, y_dev_t, writer=dev_summary_writer)
                    
                    if  pcc>ans:
                        ans=pcc
                        best_ep = i_batch
                    
                    feed_dict_src_rep = {
                        shared_nn.input_x1: x_s_batch,
                        shared_nn.dropout_rate: 1.0
                    }
                    feed_dict_tgt_rep = {
                        shared_nn.input_x2: x_t_batch,
                        shared_nn.dropout_rate: 1.0
                    }
                    td_error=FLAGS.coef*rr+pcc-last_pcc2
                    agent.learn(observation, action, td_error)
                    if i_batch % num_batches_per_epoch == 0:  
                        i_episode += 1
                    last_loss2=loss2
                    last_pcc2=pcc
                else:
                    for _ in range(FLAGS.wd_num):
                        wd_loss1, wd_loss,_ = sess.run([shared_nn.all_wd_loss,shared_nn.wd_loss, wd_d_op],
                                             feed_dict={shared_nn.input_x1: x_s_batch, shared_nn.input_x2: x_t_batch,
                                                        shared_nn.dropout_rate: FLAGS.dropout_rate})

                    train_step_src(x_s_batch,y_s_batch)
                    train_step_tgt(x_t_batch, y_t_batch)
                    loss2,pcc= dev_step(x_dev_t , y_dev_t, writer=dev_summary_writer)


            if i_batch % FLAGS.eval_every== 0 :
                loss2,pcc= dev_step(x_dev_t , y_dev_t, writer=dev_summary_writer)
                if pcc> dev_best:
                    dev_best = pcc
                    loss_test,pcc_test = test_step(x_test_t,y_test_t)
                    ans=pcc_test
 

