import argparse
import sys
import time

import torch
from transformers import BertTokenizer

from few_shot_ner.logger import init_root_logger
from few_shot_ner.dataset import MetaDataSet
from few_shot_ner.eval import evaluate, sample_eval_tasks
from few_shot_ner.io import first_file_in_dir, print_metrics_per_epochs, save_hyperparam, log_fine_tuning_results, \
    create_directories
from few_shot_ner.losses import BaselineNerLoss
from few_shot_ner.models import load, save, build_baseline_model, init_seed
from few_shot_ner.optimizer import create_optimizer
from few_shot_ner.parser import *
from few_shot_ner.predictors import NERPredictor
from few_shot_ner.trainer import bert_ner_step, train


def main(args):
    init_seed(args.seed)
    if os.environ.get('SM_OUTPUT_DATA_DIR') is not None:
        args.log_file = os.path.join(os.environ.get('SM_OUTPUT_DATA_DIR'), 'log.txt')
    logger = init_root_logger(args.log_file, level=args.verbose)
    argsdict = {}
    for name, value in vars(args).items():
        logger.info("%s : %s" % (name, str(value)))
        argsdict[name] = value
    cli_string = 'python ' + ' '.join(sys.argv)
    for name, value in vars(args).items():
        logger.info("%s : %s" % (name, str(value)))
    # Find device
    device = torch.device('cpu') if args.cpu else torch.device('cuda:{}'.format(torch.cuda.current_device()))
    logger.info("Device used: %s" % device)
    # Create directories
    create_directories(args.model_folder, args.output_folder)
    # Load pre-trained BERT tokenizer
    tokenizer = BertTokenizer.from_pretrained(args.bert_model_path)
    # Load data
    data_dev = MetaDataSet(args.alphabets_folder, args.dev, MetaDataSet.dev, tokenizer, args.max_token_length,
                           {MetaDataSet.ner, MetaDataSet.ic})
    eval_tasks = sample_eval_tasks(data_dev, args)
    # Build model
    base_net = build_baseline_model(args, device)
    # Parameter-less modules
    ner_loss = BaselineNerLoss(base_net, args.decoder == "softmax").to(device=device)
    ner_predictor = NERPredictor(base_net, args.decoder == "softmax", args.max_token_length).to(device=device)
    # Save a copy of the initialization weights
    base_net_state = save(base_net)
    metrics = []
    tot_start_time = time.time()
    # Train on all tasks and evaluate
    for task, dataset in enumerate(eval_tasks):
        with torch.no_grad():
            ner_predictor.set_mapping(dataset.tgt_slots)
            # Optimizer creation
            optimizer = create_optimizer(base_net, args)
            # Metrics placeholder
            metrics_on_task = []
            # Evaluate without finetuning
            micro_f1 = evaluate(dataset, ner_predictor, device)
            metrics_on_task.append((micro_f1, 0))
        # Step function closure
        step_func = lambda: bert_ner_step(dataset, ner_loss, optimizer, device)
        for epoch in range(args.max_epochs):
            ave_loss = train(step_func, args.num_updates)
            micro_f1 = evaluate(dataset, ner_predictor, device)
            metrics_on_task.append((micro_f1, ave_loss))
        load(base_net, base_net_state)
        metrics.append(metrics_on_task)
        if args.log_validation:
            logger.info("Num tasks : {:d}".format(task + 1))
            msg = print_metrics_per_epochs(metrics)
            logger.info("\n" + msg)

    mean_metrics, std_metrics, best_num_epoch = log_fine_tuning_results(metrics, logger)
    logger.info("micro_f1={}; loss={}; best_num_epoch={};".format(
        mean_metrics[0][best_num_epoch], mean_metrics[1][best_num_epoch], best_num_epoch))
    save_hyperparam(cli_string, argsdict, mean_metrics, std_metrics, best_num_epoch, args.output_folder)
    torch.save(base_net.state_dict(), os.path.join(args.model_folder, 'base_net_best.pth'))
    logger.info("Training done")
    tot_train_time = time.time() - tot_start_time
    logger.info("Total training time: {}s, i.e., {}h {}min".format(
        tot_train_time, tot_train_time // 3600, tot_train_time % 3600 // 60))


if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='Finetune a BERT Model for NER and find the best number of epochs',
                                     formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    parser.add_argument("--dev", type=str, help='Path to the dev data',
                        default=first_file_in_dir(os.environ.get('SM_CHANNEL_DEV')))
    add_decoder_parameters_group(parser)
    add_bert_encoder_parameters_group(parser)
    add_fsl_parameters_group(parser)
    add_optimizer_parameters_group(parser)
    add_computation_parameters_group(parser)
    add_logging_parameters_group(parser)
    add_read_parameters_group(parser)
    add_eval_parameters_group(parser)
    add_write_parameters_group(parser)
    main(parser.parse_args())
