
from src.config import *
from src.utils import *
from src.dataloader import *
from src.trainer import *
from src.model import *

import torch
import numpy as np
from tqdm import tqdm
import random


def set_random_seed():
    seed = int(random.random()*1000000%1000)
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True


def train(params):
    # initialize experiment
    logger = init_experiment(params, logger_filename=params.logger_filename)

    source_mapping = get_label_mapping(params.src_dm)
    target_dataloader_train, target_dataloader_dev, target_dataloader_test, target_mapping = get_dataloader(dataset_path,
                                                                                            params.tgt_dm, 
                                                                                            params.batch_size_target)

    # Source domain NER Tagger
    if params.model_name in ['bert-base-cased']:
        model = BertTagger(params.src_dm,
                            params.tgt_dm,
                            params.hidden_dim,
                            params.model_name,
                            params.ckpt)
    else:
        raise Exception('model name %s is invalid'%params.model_name)
    model.cuda()
    trainer = BaseTrainer(params, model, label_mapping=(source_mapping, target_mapping))

    # load the source model
    trainer.load_model("source_model.pth", path=params.path_pretrain, is_source_model=True)
    # build the graph in the source label space
    if params.load_source_graph:
        trainer.load_source_graph()
    else:
        trainer.set_source_graph(target_dataloader_train,
                                is_save=True)

    # load the target model
    logger.info("Training on target domain : %s ..."%(params.tgt_dm))

    trainer.set_source(False)
    trainer.set_optimizer()

    for e in range(params.epoch_target):

        logger.info("============== epoch %d ==============" % e)
        
        pbar = tqdm(enumerate(target_dataloader_train), total=len(target_dataloader_train))
        
        for i, (X, y) in pbar:
            X, y = X.cuda(), y.cuda()
            trainer.train_step(X, y)

        if trainer.scheduler != None:
            trainer.scheduler.step()

        if not params.debug:
            f1_dev = trainer.evaluate(target_dataloader_dev)
            logger.info('f1_dev: %.4f'%f1_dev)

    logger.info("Finish training on target domain : %s ..."%(params.tgt_dm))

    if params.test_finetune:

        logger.info("Testing on target domain : %s ..."%(params.tgt_dm))
        f1_test = trainer.evaluate(target_dataloader_test)
        logger.info('f1_test: %.4f'%f1_test)


if __name__ == "__main__":
    params = get_params()
    set_random_seed()
    train(params)
