import argparse, _pickle, math, os, random, sys, time, logging
random.seed(666)
import numpy as np
np.random.seed(666)
from collections import Counter
from antu.io.configurators.ini_configurator import IniConfigurator
from antu.io.vocabulary import Vocabulary
from antu.io.ext_embedding_readers import fasttext_reader
from utils.PTB_dataset import DatasetSetting, PTBDataset
from utils.conllu_reader import PTBReader
from antu.utils.dual_channel_logger import dual_channel_logger


def main():
    # Configuration file processing
    argparser = argparse.ArgumentParser()
    argparser.add_argument('--config_file', default='../configs/debug.cfg')
    argparser.add_argument('--continue_training', action='store_true', help='Load model Continue Training')
    argparser.add_argument('--name', default='experiment', help='The name of the experiment.')
    argparser.add_argument(
        '--model', default='RNN',
        help='RNN+Gx:    RNN encoder with x-layer Graph Decoder'
             'OnlyGraph: Only use Graph Decoder'
             'WinGraph:  n-gram windows + Graph Decoder'
             'LPGraph:   Link prediction + Graph Decoder')
    argparser.add_argument('--gpu', default='0', help='GPU ID (-1 to cpu)')
    args, extra_args = argparser.parse_known_args()
    cfg = IniConfigurator(args.config_file, extra_args)

    # Logger setting
    logger = dual_channel_logger(
        __name__,
        file_path=cfg.LOG_FILE,
        file_model='w',
        formatter='%(asctime)s - %(levelname)s - %(message)s',
        time_formatter='%m-%d %H:%M')
    # from eval.ptb_evaluator import PTBEvaluator
    from eval.script_evaluator import ScriptEvaluator

    # DyNet setting
    os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
    import dynet_config
    dynet_config.set(mem=cfg.DYNET_MEM, random_seed=cfg.DYNET_SEED)
    dynet_config.set_gpu()
    import dynet as dy
    from models.token_representation import TokenRepresentation
    # from models.win_representation import WinRepresentation
    # from models.graph_representation import GraphRepresentation
    from antu.nn.dynet.seq2seq_encoders import DeepBiRNNBuilder, orthonormal_VanillaLSTMBuilder
    from models.graph_nn_decoder import GraphNNDecoder
    # from models.label_gnn_decoder import LabelGNNDecoder
    # from models.graph_rel_decoder import GraphRELDecoder

    # Build the dataset of the training process
    ## Build data reader
    data_reader = PTBReader(
        field_list=['word', 'tag', 'head', 'rel'],
        root='0\t**root**\t_\t**rcpos**\t**rpos**\t_\t0\t**rrel**\t_\t_',
        spacer=r'[\t]',)
    ## Build vocabulary with pretrained glove
    vocabulary = Vocabulary()
    g_word = fasttext_reader(cfg.GLOVE, True)
    pretrained_vocabs = {'glove': g_word}
    vocabulary.extend_from_pretrained_vocab(pretrained_vocabs)
    ## Setup datasets
    datasets_settings = {
        'train': DatasetSetting(cfg.TRAIN, True),
        'dev': DatasetSetting(cfg.DEV, False),
        'test': DatasetSetting(cfg.TEST, False),}
    datasets = PTBDataset(vocabulary, datasets_settings, data_reader)
    counters = {'word': Counter(), 'tag': Counter(), 'rel': Counter()}
    datasets.build_dataset(counters, no_pad_namespace={'rel'}, no_unk_namespace={'rel'})

    # Build model
    pc = dy.ParameterCollection()
    trainer = dy.AdamTrainer(pc, cfg.LR, cfg.ADAM_BETA1, cfg.ADAM_BETA2, cfg.EPS)

    if args.model[:3] == 'RNN' or args.model == 'OnlyGraph':
        token_repre = TokenRepresentation(pc, cfg, datasets.vocabulary)
    elif args.model == 'WinGraph':
        token_repre = WinRepresentation(pc, cfg, datasets.vocabulary)
    elif args.model == 'LPGraph':
        token_repre = GraphRepresentation(pc, cfg, datasets.vocabulary)
    if args.model[:3] == 'RNN':
        encoder = DeepBiRNNBuilder(
            pc, cfg.ENC_LAYERS, token_repre.token_dim, 
            cfg.ENC_H_DIM, orthonormal_VanillaLSTMBuilder)
    decoder = GraphNNDecoder(pc, cfg, datasets.vocabulary)
    # Train model
    BEST_DEV_LAS = BEST_DEV_UAS = BEST_ITER = 0
    cnt_iter = -cfg.WARM * cfg.GRAPH_LAYERS
    def cmp(ins): return len(ins['word'])
    train_batch = datasets.get_batches('train', cfg.TRAIN_BATCH_SIZE, True, cmp, True)
    valid_batch = list(datasets.get_batches('dev', cfg.TEST_BATCH_SIZE, False, cmp, False))
    test_batch = list(datasets.get_batches('test', cfg.TEST_BATCH_SIZE, False, cmp, False))

    my_eval = ScriptEvaluator(['Valid', 'Test'], datasets.vocabulary)
    if args.model == 'LPGraph':
        valid_loss = [[] for i in range(1)] 
    else:
        valid_loss = [[] for i in range(cfg.GRAPH_LAYERS+3)] 
    logger.info("Experiment name: %s, Model name: %s" % (args.name, args.model))
    SHA = os.popen('git log -1 | head -n 1 | cut -c 8-13').readline().rstrip()
    logger.info('Git SHA: %s' % SHA)
    while cnt_iter < cfg.MAX_ITER:
        dy.renew_cg()
        cnt_iter += 1
        indexes, masks, truth = train_batch.__next__()
        if args.model == 'LPGraph':
            vectors, loss_LP = token_repre(indexes, masks, truth, True)
        else:
            vectors = token_repre(indexes, True)
        if args.model[:3] == 'RNN':
            vectors = encoder(vectors, None, cfg.RNN_DROP, cfg.RNN_DROP, np.array(masks['1D']).T, False, True)
        loss, part_loss = decoder(vectors, masks, truth, cnt_iter, True, True)
        if args.model == 'LPGraph':
            for i, l in enumerate([loss_LP,]):
                valid_loss[i].append(l.value())
            # for i, l in enumerate([loss, loss_LP]+part_loss):
            #     valid_loss[i].append(l.value())
            # loss += loss_LP
        else:
            for i, l in enumerate([loss]+part_loss):
                valid_loss[i].append(l.value())
        loss.backward()
        trainer.learning_rate = cfg.LR*cfg.LR_DECAY**(max(0, cnt_iter)/cfg.LR_ANNEAL)
        trainer.update()

        # Validation
        if cnt_iter % cfg.VALID_ITER: continue
        for i in range(len(valid_loss)): 
            valid_loss[i] = str(round(np.mean(valid_loss[i]), 2))
        avg_loss = ', '.join(valid_loss)
        logger.info("")
        logger.info("Iter: %d-%d, Avg_loss: %s, LR (%f), Best (%d)" % 
                    (cnt_iter/cfg.VALID_ITER, cnt_iter, avg_loss,
                     trainer.learning_rate, BEST_ITER))

        if args.model == 'LPGraph':
            valid_loss = [[] for i in range(1)] 
        else:
            valid_loss = [[] for i in range(cfg.GRAPH_LAYERS+3)] 

        my_eval.clear('Valid')
        for indexes, masks, truth in valid_batch:
            dy.renew_cg()
            if args.model == 'LPGraph':
                vectors = token_repre(indexes, masks, truth, False)
            else:
                vectors = token_repre(indexes, False)
            if args.model[:3] == 'RNN':
                vectors = encoder(vectors, None, cfg.RNN_DROP, cfg.RNN_DROP, np.array(masks['1D']).T, False, False)
            pred = decoder(vectors, masks, None, cnt_iter, False, True)
            my_eval.add_truth('Valid', truth)
            my_eval.add_pred('Valid', pred)
        if args.model[:3] == 'RNN':
            dy.save(cfg.LAST_FILE, [token_repre, encoder, decoder])
        else:
            dy.save(cfg.LAST_FILE, [token_repre, decoder])
        if my_eval.evaluation('Valid', cfg.PRED_DEV, cfg.DEV):
            BEST_ITER = cnt_iter/cfg.VALID_ITER
            os.system('cp %s.data %s.data' % (cfg.LAST_FILE, cfg.BEST_FILE))
            os.system('cp %s.meta %s.meta' % (cfg.LAST_FILE, cfg.BEST_FILE))

        # Just record test result
        my_eval.clear('Test')
        for indexes, masks, truth in test_batch:
            dy.renew_cg()
            # vectors = token_repre(indexes, False)
            if args.model == 'LPGraph':
                vectors = token_repre(indexes, masks, truth, False)
            else:
                vectors = token_repre(indexes, False)
            if args.model[:3] == 'RNN':
                vectors = encoder(vectors, None, cfg.RNN_DROP, cfg.RNN_DROP, np.array(masks['1D']).T, False, False)
            pred = decoder(vectors, masks, None, cnt_iter, False, True)
            my_eval.add_truth('Test', truth)
            my_eval.add_pred('Test', pred)
        my_eval.evaluation('Test', cfg.PRED_TEST, cfg.TEST)
    my_eval.print_best_result('Valid')

    # Final Test
    test_pc = dy.ParameterCollection()
    if args.model[:3] == 'RNN':
        token_repre, encoder, decoder = dy.load(cfg.BEST_FILE, test_pc)
    else:
        token_repre, decoder = dy.load(cfg.BEST_FILE, test_pc)
    my_eval.clear('Test')
    for indexes, masks, truth in test_batch:
        dy.renew_cg()
        if args.model == 'LPGraph':
            vectors = token_repre(indexes, masks, truth, False)
        else:
            vectors = token_repre(indexes, False)
        if args.model[:3] == 'RNN':
            vectors = encoder(vectors, None, cfg.RNN_DROP, cfg.RNN_DROP, np.array(masks['1D']).T, False, False)
        pred = decoder(vectors, masks, None, 0, False, True)
        my_eval.add_truth('Test', truth)
        my_eval.add_pred('Test', pred)
    my_eval.evaluation('Test', cfg.PRED_TEST, cfg.TEST)


if __name__ == '__main__':
    main()
