import argparse, _pickle, math, os, random, sys, time, logging
random.seed(666)
import numpy as np
np.random.seed(666)
from collections import Counter
from antu.io.configurators.ini_configurator import IniConfigurator
from antu.io.vocabulary import Vocabulary
from antu.io.ext_embedding_readers import fasttext_reader
from utils.ULP_dataset import DatasetSetting, ULPDataset
from utils.ULP_reader import ULPReader
from antu.utils.dual_channel_logger import dual_channel_logger


def main():
    # Configuration file processing
    argparser = argparse.ArgumentParser()
    argparser.add_argument('--config_file', default='../configs/debug.cfg')
    argparser.add_argument('--continue_training', action='store_true', help='Load model Continue Training')
    argparser.add_argument('--name', default='experiment', help='The name of the experiment.')
    argparser.add_argument(
        '--model', default='RNN',
        help='RNN+Gx:    RNN encoder with x-layer Graph Decoder'
             'OnlyGraph: Only use Graph Decoder'
             'WinGraph:  n-gram windows + Graph Decoder'
             'LPGraph:   Link prediction + Graph Decoder')
    argparser.add_argument('--gpu', default='0', help='GPU ID (-1 to cpu)')
    args, extra_args = argparser.parse_known_args()
    cfg = IniConfigurator(args.config_file, extra_args)

    # Logger setting
    logger = dual_channel_logger(
        __name__,
        file_path=cfg.LOG_FILE,
        file_model='w',
        formatter='%(asctime)s - %(levelname)s - %(message)s',
        time_formatter='%m-%d %H:%M')

    # DyNet setting
    os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
    import dynet_config
    dynet_config.set(mem=cfg.DYNET_MEM, random_seed=cfg.DYNET_SEED)
    dynet_config.set_gpu()
    import dynet as dy
    from models.dual_role_graph_representation import GraphRepresentation

    # Build the dataset of the training process
    ## Build data reader
    data_reader = ULPReader(
        field_list=['word', 'tag', 'head', 'rel'],
        beg='0\t**beg**\t_\t**rcpos**\t**rpos**\t_\t0\t**rrel**\t_\t_',
        end='0\t**end**\t_\t**rcpos**\t**rpos**\t_\t0\t**rrel**\t_\t_',
        spacer=r'[\t]',)
    ## Build vocabulary with pretrained glove
    vocabulary = Vocabulary()
    g_word = fasttext_reader(cfg.GLOVE, True)
    pretrained_vocabs = {'glove': g_word}
    vocabulary.extend_from_pretrained_vocab(pretrained_vocabs)
    ## Setup datasets
    datasets_settings = {
        'train': DatasetSetting(cfg.TRAIN, True),
        'dev': DatasetSetting(cfg.DEV, False),
        'test': DatasetSetting(cfg.TEST, False),}
    datasets = ULPDataset(vocabulary, datasets_settings, data_reader)
    counters = {'word': Counter(), 'tag': Counter(), 'rel': Counter()}
    datasets.build_dataset(
        counters, no_pad_namespace={'rel'}, no_unk_namespace={'rel'})

    # Build model
    pc = dy.ParameterCollection()
    trainer = dy.AdamTrainer(
        pc,
        alpha=cfg.LEARNING_RATE,
        beta_1=cfg.ADAM_BETA1,
        beta_2=cfg.ADAM_BETA2,
        eps=cfg.EPS)
    BEST_DEV = 0

    token_repre = GraphRepresentation(pc, cfg, datasets.vocabulary)
    # Train model
    cnt_iter = 0
    def cmp(ins): return len(ins['word'])
    train_batch = datasets.get_batches('train', cfg.TRAIN_BATCH_SIZE, True, cmp, True)
    valid_batch = list(datasets.get_batches('dev', cfg.TEST_BATCH_SIZE, False, cmp, False))
    test_batch = list(datasets.get_batches('test', cfg.TEST_BATCH_SIZE, False, cmp, False))

    valid_loss = [[] for i in range(1)] 
    logger.info("Experiment name: %s, Model name: %s" % (args.name, args.model))
    logger.info('Git SHA: %s' % os.popen('git log -1 | head -n 1 | cut -c 8-13').readline().rstrip())
    while cnt_iter < cfg.MAX_ITER:
        dy.renew_cg()
        cnt_iter += 1
        indexes, masks, truth = train_batch.__next__()
        loss_LP = token_repre(indexes, masks, truth, True)
        for i, l in enumerate([loss_LP,]):
            valid_loss[i].append(l.value())
        loss_LP.backward()
        trainer.learning_rate = cfg.LEARNING_RATE*cfg.LR_DECAY**(cnt_iter / cfg.LR_ANNEAL)
        trainer.update()

        if cnt_iter % cfg.VALID_ITER: continue
        for i in range(len(valid_loss)): valid_loss[i] = str(round(np.mean(valid_loss[i]), 2))
        avg_loss = ', '.join(valid_loss)
        logger.info("")
        logger.info("Iter: %d-%d, Avg_loss: %s" % (cnt_iter / cfg.VALID_ITER, cnt_iter, avg_loss))
        valid_loss = [[] for i in range(1)] 
        right = total = r1 = tot1 = 0
        for indexes, masks, truth in valid_batch:
            dy.renew_cg()
            right_, total_, r1_, tot1_ = token_repre(indexes, masks, truth, False)
            right += right_
            total += total_
            r1 += r1_
            tot1 += tot1_
        logger.info("Dev:  %f, %f" % (round(right*1.0/total, 4), r1*1.0/tot1))
        dy.save(cfg.LAST_FILE, [token_repre,])
        if right*1.0/total > BEST_DEV:
            BEST_DEV = right*1.0/total
            os.system('cp %s.data %s.data' % (cfg.LAST_FILE, cfg.BEST_FILE))
            os.system('cp %s.meta %s.meta' % (cfg.LAST_FILE, cfg.BEST_FILE))
        right = total = 0
        for indexes, masks, truth in test_batch:
            dy.renew_cg()
            right_, total_, _, _ = token_repre(indexes, masks, truth, False)
            right += right_
            total += total_
        logger.info("Test: %f" % round(right*1.0/total, 4))
    logger.info("BEST_DEV: %f" % round(BEST_DEV, 4))
    test_pc = dy.ParameterCollection()
    token_repre, = dy.load(cfg.BEST_FILE, test_pc)
    right = total = 0
    for indexes, masks, truth in test_batch:
        dy.renew_cg()
        right_, total_ = token_repre(indexes, masks, truth, False)
        right += right_
        total += total_
    logger.info("FINAL_TEST: %f" % round(right*1.0/total, 4))


if __name__ == '__main__':
    main()
