#    Copyright (C) 2017 Tiancheng Zhao, Carnegie Mellon University

import os
import time

import numpy as np
import tensorflow as tf
from beeprint import pp

from config import KgCVAEConfig as Config
from corpus import SWDADialogCorpus
from data_utils import SWDADataLoader
from model import CVAE

# constants
tf.app.flags.DEFINE_string("word2vec_path", "data/vocab_embeddings", "data/vocab_embeddings")
tf.app.flags.DEFINE_string("data_dir", "data/full_swda_clean_42da_sentiment_dialog_corpus.p", "Raw data directory.")
tf.app.flags.DEFINE_string("work_dir", "working",
                           "Experiment results directory.")
tf.app.flags.DEFINE_bool("equal_batch", True, "Make each batch has similar length.")
tf.app.flags.DEFINE_bool("resume", False, "Resume from previous")
tf.app.flags.DEFINE_bool("forward_only", False, "Only do decoding")
tf.app.flags.DEFINE_bool("save_model", False, "Create checkpoints")
tf.app.flags.DEFINE_string("test_path", "run1500783422", "the dir to load checkpoint for forward only")
FLAGS = tf.app.flags.FLAGS


def main():
    # config for training
    config = Config()

    if not os.path.exists(config.logsdir):
        os.makedirs(config.logsdir)
    if not os.path.exists(config.samples_dir):
        os.makedirs(config.samples_dir)
    if not os.path.exists(config.test_samples_dir):
        os.makedirs(config.test_samples_dir)

    exp_time = len(os.listdir(config.logsdir))
    try:
        os.mkdir(os.path.join(config.samples_dir, 'exp_time_{}'.format(exp_time)))
        os.mkdir(os.path.join(config.test_samples_dir, 'exp_time_{}'.format(exp_time)))
    except FileExistsError:
        pass

    # config for validation
    valid_config = Config()
    valid_config.keep_prob = 1.0
    valid_config.dec_keep_prob = 1.0
    valid_config.batch_size = 64

    # configuration for testing
    test_config = Config()
    test_config.keep_prob = 1.0
    test_config.dec_keep_prob = 1.0
    test_config.batch_size = 64

    pp(config)

    # get data set
    api = SWDADialogCorpus(config.data_name, max_sent_len=config.max_utt_len, word2vec=FLAGS.word2vec_path,
                           word2vec_dim=config.embed_size)
    dial_corpus = api.get_dialog_corpus()
    meta_corpus = api.get_meta_corpus()

    train_meta, valid_meta, test_meta = meta_corpus.get("train"), meta_corpus.get("valid"), meta_corpus.get("test")
    train_dial, valid_dial, test_dial = dial_corpus.get("train"), dial_corpus.get("valid"), dial_corpus.get("test")

    # convert to numeric input outputs that fits into TF models
    train_feed = SWDADataLoader("Train", train_dial, train_meta, config)
    valid_feed = SWDADataLoader("Valid", valid_dial, valid_meta, config)
    test_feed = SWDADataLoader("Test", test_dial, test_meta, config)

    if FLAGS.forward_only or FLAGS.resume:
        log_dir = os.path.join(config.work_dir, FLAGS.test_path)
    else:
        log_dir = os.path.join(config.work_dir, "run" + str(int(time.time())))

    #    word2idx = api.rev_vocab

    # begin training

    Graph_cvae = tf.get_default_graph()
    Graph_cvae.seed = config.Graphseed

    os.environ['CUDA_VISIBLE_DEVICES'] = config.gpuSet
    sess_config = tf.ConfigProto(allow_soft_placement=True)
    sess_config.gpu_options.allow_growth = True

    with tf.Session(graph=Graph_cvae, config=sess_config) as sess:
        initializer = tf.random_uniform_initializer(-1.0 * config.init_w, config.init_w)
        scope = "model"
        with tf.variable_scope(scope, reuse=None, initializer=initializer):
            model = CVAE(sess, config, api, log_dir=None if FLAGS.forward_only else log_dir, forward=False,
                         reu_conv=False, scope=scope)
        with tf.variable_scope(scope, reuse=True, initializer=initializer):
            valid_model = CVAE(sess, valid_config, api, log_dir=None, forward=False, reu_conv=True, scope=scope)
        with tf.variable_scope(scope, reuse=True, initializer=initializer):
            test_model = CVAE(sess, test_config, api, log_dir=None, forward=True, reu_conv=True, scope=scope)

        #        scope = "model"
        #        model = KgRnnCVAE(sess, config, api, log_dir=None if FLAGS.forward_only else log_dir, forward=False, scope=scope)
        #        valid_model = KgRnnCVAE(sess, valid_config, api, log_dir=None, forward=False, scope=scope)
        #        test_model = KgRnnCVAE(sess, test_config, api, log_dir=None, forward=True, scope=scope)

        # sess.run(tf.global_variables_initializer())

        # print("Created computation graphs")
        # if api.word2vec is not None and not FLAGS.forward_only:
        #    print("Loaded word2vec")
        #    sess.run(model.embedding.assign(np.array(api.word2vec)))

        # write config to a file for logging
        if not FLAGS.forward_only:
            with open(os.path.join(log_dir, "run.log"), "w") as f:
                f.write(pp(config, output=False))

        # create a folder by force
        ckp_dir = os.path.join(log_dir, "checkpoints")
        if not os.path.exists(ckp_dir):
            os.mkdir(ckp_dir)

        ckpt = tf.train.get_checkpoint_state(ckp_dir)
        print("Created models with fresh parameters.")
        sess.run(tf.global_variables_initializer())

        if ckpt:
            print("Reading dm models parameters from %s" % ckpt.model_checkpoint_path)
            model.saver.restore(sess, ckpt.model_checkpoint_path)

        if not FLAGS.forward_only:
            dm_checkpoint_path = os.path.join(ckp_dir, model.__class__.__name__ + ".ckpt")
            global_t = 0
            patience = 10  # wait for at least 10 epoch before stop
            dev_loss_threshold = np.inf
            best_dev_loss = np.inf
            epoch = 0
            #            for epoch in range(config.max_epoch):
            flag_valid_batches = config.update_limit
            while (epoch != config.max_epoch):
                print(">> Epoch %d with lr %f" % (epoch, model.learning_rate.eval()))

                # begin training
                if train_feed.num_batch is None or train_feed.ptr >= train_feed.num_batch:
                    train_feed.epoch_init(config.batch_size, config.backward_size,
                                          config.step_size, shuffle=True)
                    epoch += 1
                    done_epoch = epoch

                global_t, train_loss = model.train(global_t, sess, train_feed, update_limit=flag_valid_batches)

                if global_t % 100 != 0:
                    flag_valid_batches = config.update_limit - (global_t % 100)
                else:
                    flag_valid_batches = config.update_limit

                    # begin validation
                    print('\n')
                    valid_feed.epoch_init(valid_config.batch_size, valid_config.backward_size,
                                          valid_config.step_size, shuffle=False, intra_shuffle=False)

                    start_valid_time = time.time()
                    if config.if_multi_direction:
                        ppl, sample_masks = valid_model.valid("ELBO_VALID", sess, True, valid_feed)
                    else:
                        ppl = valid_model.valid("ELBO_VALID", sess, False, valid_feed)
                    end_valid_time = time.time()

                    with open(os.path.join(config.logsdir, 'ppl_loss_{}.txt'.format(exp_time)), 'a',
                              encoding='utf-8') as f:
                        f.write('{}\t{}\t{}\t{:.4f}\n'.format(
                            epoch,
                            global_t,
                            ppl,
                            end_valid_time - start_valid_time))
                    #                    print(len(sample_masks))
                    #                    print(len(sample_masks[0]))

                    valid_feed.epoch_init(valid_config.batch_size, valid_config.backward_size,
                                          valid_config.step_size, shuffle=False, intra_shuffle=False)

                    test_feed.epoch_init(test_config.batch_size, test_config.backward_size,
                                         test_config.step_size, shuffle=False, intra_shuffle=False)

                    # only save a models if the dev loss is smaller
                    # Decrease learning rate if no improvement was seen over last 3 times.
                    if config.op == "sgd" and done_epoch > config.lr_hold:
                        sess.run(model.learning_rate_decay_op)

                    if ppl < best_dev_loss:
                        with open(os.path.join(config.samples_dir, 'exp_time_{}'.format(exp_time),
                                               'samples_epoch_{:0>4d}_batches_{:0>6d}_ppl_{}_result'.format(epoch,
                                                                                                            global_t,
                                                                                                            ppl)), 'w',
                                  encoding='utf-8') as f:
                            valid_samples = test_model.valid_for_sample("ELBO_VALID", sess, True, sample_masks,
                                                                        valid_feed)
                            for true_sent, t2, re_sent, dir_id, focuses in zip(valid_samples[0], valid_samples[1],
                                                                               valid_samples[2], valid_samples[3],
                                                                               valid_samples[4]):
                                f.write('True A, True B: \n')
                                f.write(true_sent + '\n')
                                f.write(t2 + '\n')
                                f.write('Generate B: \n')
                                f.write(re_sent + '\n')
                                f.write(str(dir_id) + '\n')
                                f.write(str(focuses) + '\n\n')

                        with open(os.path.join(config.test_samples_dir, 'exp_time_{}'.format(exp_time),
                                               'samples_epoch_{:0>4d}_batches_{:0>6d}_ppl_{}_result'.format(epoch,
                                                                                                            global_t,
                                                                                                            ppl)), 'w',
                                  encoding='utf-8') as f:
                            test_samples = test_model.test_for_sample("TEST", sess, True, test_feed)
                            for true_sent, t2, re_sent, dir_id, focuses in zip(test_samples[0], test_samples[1],
                                                                               test_samples[2], test_samples[3],
                                                                               test_samples[4]):
                                f.write('True A, True B: \n')
                                f.write(true_sent + '\n')
                                f.write(t2 + '\n')
                                f.write('Generate B: \n')
                                f.write(re_sent + '\n')
                                f.write(str(dir_id) + '\n')
                                f.write(str(focuses) + '\n\n')
                        # still save the best train model
                        if FLAGS.save_model:
                            print("Save model!!")
                            model.saver.save(sess, dm_checkpoint_path, global_step=epoch)
                        best_dev_loss = ppl

            #                if config.early_stop and patience <= done_epoch:
            #                    print("!!Early stop due to run out of patience!!")
            #                    break
            print("Best validation loss %f" % best_dev_loss)
            print("Done training")
        else:
            # begin validation
            # begin validation
            valid_feed.epoch_init(valid_config.batch_size, valid_config.backward_size,
                                  valid_config.step_size, shuffle=False, intra_shuffle=False)
            valid_model.valid("ELBO_VALID", sess, valid_feed)

            # test_feed.epoch_init(valid_config.batch_size, valid_config.backward_size,
            #                      valid_config.step_size, shuffle=False, intra_shuffle=False)
            # valid_model.valid("ELBO_TEST", sess, test_feed)

            # dest_f = open(os.path.join(log_dir, "test.txt"), "wb")
            # test_feed.epoch_init(test_config.batch_size, test_config.backward_size,
            #                     test_config.step_size, shuffle=False, intra_shuffle=False)
            # test_model.test(sess, test_feed, num_batch=None, repeat=10, dest=dest_f)
            # dest_f.close()


if __name__ == "__main__":
    if FLAGS.forward_only:
        if FLAGS.test_path is None:
            print("Set test_path before forward only")
            exit(1)
    main()
