import os
import pickle

import torch
from pytorch_pretrained_bert import BertTokenizer
from torch.autograd import Variable

import config_ssdm_xlm
from utils import data_utils_SP
from model import models_xlm_SP
from utils import train_ssdm_helper
from utils import data_utils_xlm_SP

from utils.ud_to_list import load_conll_dataset

from model.models_xlm_SP import vgvae
from torch.utils.tensorboard import SummaryWriter
from config_ssdm_xlm import EVAL_YEAR

os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"

best_dev_res = test_bm_res = test_avg_res = best_distance = 0


def run(e):
    global best_dev_res, test_bm_res, test_avg_res, best_distance

    if not os.path.exists("data/data_DP_xlm.pkl"):
        train_dp_1 = load_conll_dataset("data/DP/sentence1_tree.txt")
        train_dp_2 = load_conll_dataset("data/DP/sentence2_tree.txt")
        dp = data_utils_xlm_SP.data_processor(
            train_path=e.config.train_file,
            eval_path=e.config.eval_file,
            dp_1=train_dp_1,
            dp_2=train_dp_2,
            experiment=e)
        data, tokenizer, W = dp.process()
        output_hal = open("data/data_DP_xlm.pkl", 'wb')
        str = pickle.dumps(data)
        output_hal.write(str)
        output_hal.close()
    else:
        #  data = data_utils.data_holder()
        with open("data/data_DP_xlm.pkl", 'rb') as file:
            data = pickle.loads(file.read())
        W = "lm"

    e.log.info("*" * 25 + " DATA PREPARATION " + "*" * 25)
    e.log.info("*" * 25 + " MODEL INITIALIZATION " + "*" * 25)

    model = vgvae(
        vocab_size=len(data.vocab),
        embed_dim=e.config.edim,
        embed_init=W,
        experiment=e)

    start_epoch = true_it = 0
    if e.config.resume:
        try:
            start_epoch, _, best_dev_res, test_avg_res = \
                model.load(name="latest")
        except:
            start_epoch = 0
        if e.config.use_cuda:
            model.cuda()
            e.log.info("transferred model to gpu")
        e.log.info(
            "resumed from previous checkpoint: start epoch: {}, "
            "iteration: {}, best dev res: {:.3f}, test avg res: {:.3f}"
                .format(start_epoch, true_it, best_dev_res, test_avg_res))

    e.log.info(model)
    e.log.info("*" * 25 + " MODEL INITIALIZATION " + "*" * 25)

    if e.config.summarize:
        writer = SummaryWriter(e.experiment_dir)

    if e.config.decoder_type.startswith("bag"):
        # minibatcher = data_utils_POS.pos_bow_minibatcher
        minibatcher = data_utils_xlm_SP.tree_bow_minibatcher
        e.log.info("using BOW batcher")
    else:
        minibatcher = data_utils_xlm_SP.minibatcher
        e.log.info("using sequential batcher")

    train_batch = minibatcher(
        data1=data.train_data[0],
        data2=data.train_data[1],
        dp1=data.train_data[2],
        dp2=data.train_data[3],
        vocab_size=len(data.vocab),
        batch_size=e.config.batch_size,
        score_func=model.score,
        shuffle=False,
        mega_batch=e.config.mb,
        p_scramble=e.config.ps)

    evaluator = train_ssdm_helper.evaluator(model, e)

    e.log.info("Training start ...")
    train_stats = train_ssdm_helper.tracker(["loss", "vmf_kl", "gauss_kl",
                                        "rec_logloss", "para_logloss",
                                        "wploss", "dp_loss"])
    no_new = 0
    stop_early = False
    for epoch in range(start_epoch, e.config.n_epoch):
        if stop_early:
            break
        if epoch > 1 and train_batch.mega_batch != e.config.mb:
            train_batch.mega_batch = e.config.mb
            train_batch._reset()
        e.log.info("current mega batch: {}".format(train_batch.mega_batch))

        for it, (s1, m1, s2, m2, t1, tm1, t2, tm2,
                 n1, nm1, nt1, ntm1, n2, nm2, nt2, ntm2, y, _) in \
                enumerate(train_batch):
            true_it = it + 1 + epoch * len(train_batch)
            loss, vkl, gkl, rec_logloss, para_logloss, wploss, dploss = \
                model(s1, m1, s2, m2, t1, tm1, t2, tm2,
                      n1, nm1, nt1, ntm1, n2, nm2, nt2, ntm2,
                      e.config.vmkl, e.config.gmkl, y,
                      e.config.mb > 1, true_it=true_it)
            """l1_regularization, l2_regularization = torch.tensor(0).float().cuda(), torch.tensor(0).float().cuda()
            l1_regularization, l2_regularization = Variable(l1_regularization), Variable(l2_regularization)
            for param in model.parameters():
                l1_regularization += torch.norm(param, 1)
                # l2_regularization += torch.norm(param, 2)
            loss+= l1_regularization + l2_regularization"""

            model.optimize(loss)

            train_stats.update(
                {"loss": loss, "vmf_kl": vkl, "gauss_kl": gkl,
                 "para_logloss": para_logloss, "rec_logloss": rec_logloss,
                 "wploss": wploss, "dp_loss": dploss},
                len(s1))

            if (true_it + 1) % e.config.print_every == 0 or \
                    (true_it + 1) % len(train_batch) == 0:
                summarization = train_stats.summarize(
                    "epoch: {}, it: {} (max: {}), kl_temp: {:.2E}|{:.2E}"
                        .format(epoch, it, len(train_batch),
                                e.config.vmkl, e.config.gmkl))
                e.log.info(summarization)
                if e.config.summarize:
                    for name, value in train_stats.stats.items():
                        writer.add_scalar(
                            "train/" + name, value, true_it)
                train_stats.reset()

            if (true_it + 1) % e.config.eval_every == 0 or \
                    (true_it + 1) % len(train_batch) == 0:

                e.log.info("*" * 25 + " DEV SET EVALUATION " + "*" * 25)

                dev_stats, _, dev_res, _ = evaluator.evaluate(
                    data.dev_data, 'pred')

                e.log.info("*" * 25 + " DEV SET EVALUATION " + "*" * 25)

                if e.config.summarize:
                    writer.add_scalar(
                        "dev/pearson", dev_stats[EVAL_YEAR][1], true_it)
                    writer.add_scalar(
                        "dev/spearman", dev_stats[EVAL_YEAR][2], true_it)

                e.log.info("*" * 25 + " TEST EVAL: SEMANTICS " + "*" * 25)

                test_stats, test_bm_res, test_avg_res, test_avg_s = \
                    evaluator.evaluate(data.test_data, 'pred')

                e.log.info("*" * 25 + " TEST EVAL: SEMANTICS " + "*" * 25)
                e.log.info("*" * 25 + " TEST EVAL: SYNTAX " + "*" * 25)

                tz_stats, tz_bm_res, tz_avg_res, tz_avg_s = \
                    evaluator.evaluate(data.test_data, 'predz')
                e.log.info("Summary - benchmark: {:.4f}, test avg: {:.4f}"
                           .format(tz_bm_res, tz_avg_res))
                e.log.info("*" * 25 + " TEST EVAL: SYNTAX " + "*" * 25)
                distance = abs(test_avg_res - tz_avg_res)

                if best_dev_res < dev_res:
                    no_new = 0
                    best_dev_res = dev_res

                    model.save(
                        dev_avg=best_dev_res,
                        dev_perf=dev_stats,
                        test_avg=test_avg_res,
                        test_perf=test_stats,
                        iteration=true_it,
                        epoch=epoch)

                    if e.config.summarize:
                        for year, stats in test_stats.items():
                            writer.add_scalar(
                                "test/{}_pearson".format(year),
                                stats[1], true_it)
                            writer.add_scalar(
                                "test/{}_spearman".format(year),
                                stats[2], true_it)

                        writer.add_scalar(
                            "test/avg_pearson", test_avg_res, true_it)
                        writer.add_scalar(
                            "test/avg_spearman", test_avg_s, true_it)
                        writer.add_scalar(
                            "test/STSBenchmark_pearson", test_bm_res, true_it)
                        writer.add_scalar(
                            "dev/best_pearson", best_dev_res, true_it)

                        writer.add_scalar(
                            "testz/avg_pearson", tz_avg_res, true_it)
                        writer.add_scalar(
                            "testz/avg_spearman", tz_avg_s, true_it)
                        writer.add_scalar(
                            "testz/STSBenchmark_pearson", tz_bm_res, true_it)
                    if distance > best_distance:
                        best_distance = distance
                elif distance > best_distance:
                    best_distance = distance
                    model.save(
                        dev_avg=best_dev_res,
                        dev_perf=dev_stats,
                        test_avg=test_avg_res,
                        test_perf=test_stats,
                        iteration=true_it,
                        epoch=epoch,
                        name="distant")
                    continue
                else:
                    no_new += 1

                train_stats.reset()
                e.log.info("best dev result: {:.4f}, "
                           "STSBenchmark result: {:.4f}, "
                           "test average result: {:.4f}"
                           .format(best_dev_res, test_bm_res, test_avg_res))
            if no_new == 15:
                if best_dev_res:
                    e.log.info("stop early!")
                    stop_early = True
                    break
                else:
                    no_new = 0
            it += 1

        model.save(
            dev_avg=best_dev_res,
            dev_perf=dev_stats,
            test_avg=test_avg_res,
            test_perf=test_stats,
            iteration=true_it,
            epoch=epoch + 1,
            name="latest")

    e.log.info("*" * 25 + " TEST EVAL: SEMANTICS " + "*" * 25)

    test_stats, test_bm_res, test_avg_res, test_avg_s = \
        evaluator.evaluate(data.test_data, 'pred')

    e.log.info("*" * 25 + " TEST EVAL: SEMANTICS " + "*" * 25)
    e.log.info("*" * 25 + " TEST EVAL: SYNTAX " + "*" * 25)

    tz_stats, tz_bm_res, tz_avg_res, tz_avg_s = \
        evaluator.evaluate(data.test_data, 'predz')
    e.log.info("Summary - benchmark: {:.4f}, test avg: {:.4f}"
               .format(tz_bm_res, tz_avg_res))
    e.log.info("*" * 25 + " TEST EVAL: SYNTAX " + "*" * 25)


def eval(e):
    global best_dev_res, test_bm_res, test_avg_res, best_distance

    if not os.path.exists("data/data_DP_xlm_test.pkl"):
        train_dp_1 = load_conll_dataset("data/DP/sentence1_tree.txt")
        train_dp_2 = load_conll_dataset("data/DP/sentence2_tree.txt")
        dp = data_utils_xlm_SP.data_processor(
            train_path=e.config.train_file,
            eval_path=e.config.eval_file,
            dp_1=train_dp_1,
            dp_2=train_dp_2,
            experiment=e)
        data, tokenizer, W = dp.process()
        output_hal = open("data/data_DP_xlm_test.pkl", 'wb')
        str = pickle.dumps(data)
        output_hal.write(str)
        output_hal.close()
    else:
        #  data = data_utils.data_holder()
        with open("data/data_DP_xlm_test.pkl", 'rb') as file:
            data = pickle.loads(file.read())
        # W = "lm"

    e.log.info("*" * 25 + " DATA PREPARATION " + "*" * 25)
    e.log.info("*" * 25 + " MODEL INITIALIZATION " + "*" * 25)

    save_dict = torch.load(
        "/home/user3/wlj/LLDS/result/xlm-0902_DP/dpratio1dratio1edim1280lr3e-05mb20posratio0/distant.ckpt",
        map_location=lambda storage,
                            loc: storage)
    config = save_dict['config']
    checkpoint = save_dict['state_dict']
    config.debug = True
    config.embed_type = "lm"
    with open("/home/user3/wlj/LLDS/data/para/para-vocab-xlm/pre_vocab_50000", "rb") as fp:
        W, vocab = pickle.load(fp)
    with train_ssdm_helper.experiment(config, config.save_prefix) as e:
        model_syntax = models_xlm_SP.vgvae(
            vocab_size=len(vocab),
            embed_dim=e.config.edim,
            embed_init=W,
            experiment=e)
        model_syntax.eval()
        model_syntax.load(checkpointed_state_dict=checkpoint)

    # e.log.info(model_syntax)
    e.log.info("*" * 25 + " MODEL INITIALIZATION " + "*" * 25)

    evaluator = train_ssdm_helper.evaluator(model_syntax, e)
    e.log.info("*" * 25 + " TEST EVAL: SEMANTICS " + "*" * 25)

    test_stats, test_bm_res, test_avg_res, test_avg_s = \
        evaluator.evaluate(data.test_data, 'pred')

    e.log.info("*" * 25 + " TEST EVAL: SEMANTICS " + "*" * 25)
    e.log.info("*" * 25 + " TEST EVAL: SYNTAX " + "*" * 25)

    tz_stats, tz_bm_res, tz_avg_res, tz_avg_s = \
        evaluator.evaluate(data.test_data, 'predz')
    e.log.info("Summary - benchmark: {:.4f}, test avg: {:.4f}"
               .format(tz_bm_res, tz_avg_res))
    e.log.info("*" * 25 + " TEST EVAL: SYNTAX " + "*" * 25)


if __name__ == '__main__':
    args = config_ssdm_xlm.get_base_parser().parse_args()
    args.use_cuda = torch.cuda.is_available()
    # args.use_cuda = True


    def exit_handler(*args):
        print(args)
        print("best dev result: {:.4f}, "
              "STSBenchmark result: {:.4f}, "
              "test average result: {:.4f}"
              .format(best_dev_res, test_bm_res, test_avg_res))
        exit()


    train_ssdm_helper.register_exit_handler(exit_handler)

    with train_ssdm_helper.experiment(args, args.save_prefix) as e:
        e.log.info("*" * 25 + " ARGS " + "*" * 25)
        e.log.info(args)
        e.log.info("*" * 25 + " ARGS " + "*" * 25)

        #run(e)
        eval(e)
