import argparse
import os
import time
import random
import logging
from collections import Counter

import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torch.optim import AdamW
from torch.optim.lr_scheduler import ExponentialLR
from transformers import get_linear_schedule_with_warmup, get_cosine_schedule_with_warmup
import numpy as np
from antu.io import Vocabulary
from antu.io import glove_reader
from antu.utils.dual_channel_logger import dual_channel_logger
from module.exp_scheduler import ExponentialLRwithWarmUp

import warnings
warnings.filterwarnings("ignore")

from parser import Parser
from eval.PTB_evaluator import ptb_evaluation
from utils.data_loader_mp import DataLoaderX
from utils.conllu_reader import PTBReader
from utils.conllu_dataset import CoNLLUDataset, conllu_fn


def parse_args():
    """parse model configuration
    
    :return: cfg
    """
    parser = argparse.ArgumentParser(description="Training for multi-source dependency parsing.")
    # Data IO
    parser.add_argument('--TRAIN', type=str, help="Train set path.")
    parser.add_argument('--DEV', type=str, help="Development set path.")
    parser.add_argument('--TEST', type=str, help="Test set path.")
    parser.add_argument('--GLOVE', type=str, help="Glove path.")
    parser.add_argument('--PRED_DEV', type=str, help="Predicted development set path.")
    parser.add_argument('--PRED_TEST', type=str, help="Predicted test set path.")
    parser.add_argument('--LOG', type=str, help="Log path.")
    parser.add_argument('--MIN_COUNT', type=int, help="Minimum occurrences of word vocabulary.")
    parser.add_argument('--LAST', type=str, help="Last checkpoint path.")
    parser.add_argument('--BEST', type=str, help="Best checkpoint path.")
    # Training setup
    parser.add_argument('--SEED', type=int, help="Set random seed.")
    parser.add_argument('--N_EPOCH', type=int, help="#Epoch for training & testing.")
    parser.add_argument('--N_BATCH', type=int, help="True Batch size for training & testing.")
    parser.add_argument('--STEP_UPDATE', type=int, help="Step of update for training.")
    parser.add_argument('--STEP_VALID', type=int, help="Step of validate for training.")
    parser.add_argument('--N_WORKER', type=int, help="#Worker for data loader.")
    parser.add_argument('--IS_RESUME', default=False, action='store_true', help="Continue training.")
    # Optimizer
    parser.add_argument('--LR', type=float, help="Learning rate in Adam.")
    parser.add_argument('--BETAS', type=float, nargs=2, help="Beta1 and Beta2 in Adam.")
    parser.add_argument('--EPS', type=float, help="EPS in Adam.")
    parser.add_argument('--LR_DECAY', type=float, help="Decay rate of LR.")
    parser.add_argument('--LR_WARM', type=int, help="Warm up step of LR.")
    parser.add_argument('--LR_ANNEAL', type=int, help="Anneal step of LR.")
    parser.add_argument('--CLIP', type=float, help="Gradient clipping.")
    # Network setup
    parser.add_argument('--IS_FIX_GLOVE', default=False, action='store_true', help="Fix GLOVE emb.")
    parser.add_argument('--D_TAG', type=int, help="Dimension of tag embedding.")
    parser.add_argument('--EMB_DROP', type=float, help="Dropout rate of embedding representation.")
    parser.add_argument('--D_ARC', type=int, help="Dimension of ARC_MLP vector.")
    parser.add_argument('--D_REL', type=int, help="Dimension of REL_MLP vector.")
    parser.add_argument('--MLP_DROP', type=float, help="Dropout rate of MLP representation.")
    parser.add_argument('--MODEL_TYPE', type=str, help="Type of P_EMB: glove/sskip, ENC: xformer/rnn.")    
    # ENC=RNN
    parser.add_argument('--D_RNN_HID', type=int, help="Dimension of RNN hidden vector.")
    parser.add_argument('--N_RNN_LAYER', type=int, help="Number of RNN layers.")
    parser.add_argument('--RNN_DROP', type=float, help="Dropout rate of RNN representation.")
    # ENC=Xformer
    parser.add_argument('--D_MODEL', type=int, help="Dimension of xformer hidden vector.")
    parser.add_argument('--N_HEAD', type=int, help="Dropout rate of RNN representation.")
    parser.add_argument('--D_FF', type=int, help="Dimension of xformer Feed Forward.")
    parser.add_argument('--XFMR_DROP', type=float, help="Dropout rate of Xformer representation.")
    parser.add_argument('--N_XFMR_LAYER', type=int, help="Number of Xformer layers.")
    parser.add_argument('--PE_TYPE', type=str, help="Type of PE (sincos/random, cat/add).")
    parser.add_argument('--PE_DROP', type=float, help="Dropout rate of PE.")
    parser.add_argument('--N_PE', type=int, help="Maximum number of PE.")
    
    return parser.parse_args() 


def main():
    # Configuration file processing
    cfg = parse_args()

    # Set seeds
    random.seed(cfg.SEED)
    np.random.seed(cfg.SEED)
    torch.manual_seed(cfg.SEED)
    torch.cuda.manual_seed(cfg.SEED)

    # Logger setting
    logger = dual_channel_logger(
        __name__,
        file_path=cfg.LOG,
        file_model='w',
        formatter="%(asctime)s - %(levelname)s - %(message)s",
        time_formatter="%m-%d %H:%M")

    # Build data reader
    field_list=['word', 'tag', 'head', 'rel']
    if 'sskip' in cfg.GLOVE: field_list.append('word_cased')
    data_reader = PTBReader(
        field_list=field_list,
        root='0\t**root**\t_\t**rcpos**\t**rpos**\t_\t0\t**rrel**\t_\t_',
        spacer=r'[\t]',)
    # Build vocabulary with pretrained glove
    vocabulary = Vocabulary()
    g_word, _ = glove_reader(cfg.GLOVE)
    vocabulary.extend_from_pretrained_vocab({'glove': g_word,})
    counters = {'word': Counter(), 'tag': Counter(), 'rel': Counter()}
    # Build the dataset
    train_set = CoNLLUDataset(
        cfg.TRAIN, data_reader, vocabulary, counters, {'word': cfg.MIN_COUNT}, 
        no_pad_namespace={'rel'}, no_unk_namespace={'rel'})
    dev_set  = CoNLLUDataset(cfg.DEV,  data_reader, vocabulary)
    test_set = CoNLLUDataset(cfg.TEST, data_reader, vocabulary)
    # Build the data-loader
    train = DataLoader(train_set, cfg.N_BATCH, True,  num_workers=cfg.N_WORKER, pin_memory=cfg.N_WORKER>0, collate_fn=conllu_fn)
    dev   = DataLoader(dev_set,   cfg.N_BATCH, False, num_workers=cfg.N_WORKER, pin_memory=cfg.N_WORKER>0, collate_fn=conllu_fn)
    test  = DataLoader(test_set,  cfg.N_BATCH, False, num_workers=cfg.N_WORKER, pin_memory=cfg.N_WORKER>0, collate_fn=conllu_fn)

    # Build parser model
    parser = Parser(vocabulary, cfg)

    # create losses
    CELoss = nn.CrossEntropyLoss()

    # if running on GPU
    if torch.cuda.is_available(): parser = parser.cuda()

    # build optimizers
    optim = AdamW(parser.parameters(), cfg.LR, cfg.BETAS, cfg.EPS)
    # sched = ExponentialLR(optim, cfg.LR_DECAY**(1/cfg.LR_ANNEAL))
    max_step = cfg.N_EPOCH * 40000 / cfg.N_BATCH / cfg.STEP_UPDATE
    sched = ExponentialLRwithWarmUp(optim, cfg.LR_DECAY**(1/cfg.LR_ANNEAL), cfg.LR_WARM)
    # sched = get_linear_schedule_with_warmup(optim, cfg.LR_WARM, max_step)
    # sched = get_cosine_schedule_with_warmup(optim, cfg.LR_WARM, max_step)
    
    # load checkpoint if wanted
    start_epoch = best_uas = best_las = best_epoch = 0    
    def load_ckpt(ckpt_path: str):
        ckpt = torch.load(ckpt_path)
        start_epoch = ckpt['epoch']+1
        best_uas, best_las, best_epoch = ckpt['best']
        parser.load_state_dict(ckpt['parser'])
        optim.load_state_dict(ckpt['optim'])
        sched.load_state_dict(ckpt['sched'])
        return start_epoch, best_uas, best_las, best_epoch
    
    if cfg.IS_RESUME: 
        start_epoch, best_uas, best_las, best_epoch = load_ckpt(cfg.LAST)
    
    @torch.no_grad()
    def validation(data_loader: DataLoader, pred_path: str, gold_path: str):
        pred = {'arcs': [], 'rels': []}
        for data in data_loader:
            if cfg.N_WORKER:
                for x in data.keys(): data[x] = data[x].cuda()
            arcs, rels = parser(data)
            pred['arcs'].extend(arcs)
            pred['rels'].extend(rels)
        uas, las = ptb_evaluation(vocabulary, pred, pred_path, gold_path)
        return uas, las

    # Train model
    prepare_time_tot = process_time_tot = 0
    for epoch in range(start_epoch, cfg.N_EPOCH):
        parser.train()        
        arc_losses, rel_losses = [], []
        start_time = time.time()
        for n_iter, data in enumerate(train):
            if cfg.N_WORKER:
                for x in data.keys(): data[x] = data[x].cuda()
            prepare_time = time.time()-start_time
            prepare_time_tot += prepare_time
            pred_arc, gold_arc, pred_rel, gold_rel = parser(data)
            arc_loss = CELoss(pred_arc, gold_arc)
            rel_loss = CELoss(pred_rel, gold_rel)
            arc_losses.append(arc_loss.item())
            rel_losses.append(rel_loss.item())
            ((arc_loss+rel_loss)/cfg.STEP_UPDATE).backward()
            # Actual update
            if n_iter % cfg.STEP_UPDATE == cfg.STEP_UPDATE-1:
                nn.utils.clip_grad_norm_(parser.parameters(), cfg.CLIP)
                optim.step()
                optim.zero_grad()
                sched.step()
            process_time = time.time()-start_time-prepare_time
            process_time_tot += process_time
            start_time = time.time()
        if epoch % cfg.STEP_VALID != cfg.STEP_VALID-1: continue
        # save current parser
        torch.save({
            'epoch': epoch,
            'best': (best_uas, best_las, best_epoch),
            'parser': parser.state_dict(),
            'optim': optim.state_dict(),
            'sched': sched.state_dict(),
        }, cfg.LAST)

        # validate parer on dev set
        parser.eval()
        uas, las = validation(dev, cfg.PRED_DEV, cfg.DEV)
        if uas > best_uas and las > best_las or uas+las > best_uas+best_las:
            best_uas, best_las, best_epoch = uas, las, epoch
            os.popen(f'cp {cfg.LAST} {cfg.BEST}')
        logger.info(
            f'|{epoch:4}| Arc({float(np.mean(arc_losses)):.2f}) '
            f'Rel({float(np.mean(rel_losses)):.2f}) Best({best_epoch})')
        logger.info(f'| Dev| UAS:{uas:6.2f}, LAS:{las:6.2f}')
        # view performance on test set
        uas, las = validation(test, cfg.PRED_TEST, cfg.TEST)
        logger.info(f'|Test| UAS:{uas:6.2f}, LAS:{las:6.2f}\n')

    print(f"prepare_time_tot:{prepare_time_tot:.2f} process_time_tot:{process_time_tot:.2f}")
    logger.info(f'*Best Dev Result* UAS:{best_uas:6.2f}, LAS:{best_las:6.2f}, Epoch({best_epoch})')
    load_ckpt(cfg.BEST)
    parser.eval()
    uas, las = validation(test, cfg.PRED_TEST, cfg.TEST)
    logger.info(f'*Final Test Result* UAS:{uas:6.2f}, LAS:{las:6.2f}')

if __name__ == '__main__':
    main()
