import os
import argparse
import sys
import time
import torch

from transformers import BertTokenizer

from few_shot_ner.dataset import MetaDataSet
from few_shot_ner.eval import evaluate, sample_eval_tasks, EarlyStopping
from few_shot_ner.logger import init_root_logger
from few_shot_ner.io import print_metrics_per_epochs, save_hyperparam, log_fine_tuning_results, first_file_in_dir, \
    create_directories
from few_shot_ner.losses import ProtoNetNerLoss
from few_shot_ner.models import load, save, build_protonet_model, init_seed
from few_shot_ner.optimizer import create_optimizer_protonet, create_meta_optimizer_protonet
from few_shot_ner.parser import *
from few_shot_ner.predictors import ProtoNetPredictor
from few_shot_ner.trainer import proto_ner_step, proto_ft_ner_step, train


def main(args):
    init_seed(args.seed)
    if os.environ.get('SM_OUTPUT_DATA_DIR') is not None:
        args.log_file = os.path.join(os.environ.get('SM_OUTPUT_DATA_DIR'), 'log.txt')
    logger = init_root_logger(args.log_file, level=args.verbose)
    argsdict = {}
    for name, value in vars(args).items():
        logger.info("%s : %s" % (name, str(value)))
        argsdict[name] = value
    cli_string = 'python ' + ' '.join(sys.argv)
    for name, value in vars(args).items():
        logger.info("%s : %s" % (name, str(value)))
    # Find device
    device = torch.device('cpu') if args.cpu else torch.device('cuda:{}'.format(torch.cuda.current_device()))
    logger.info("Device used: %s" % device)
    # Create output directory
    create_directories(args.model_folder, args.output_folder)
    # Load pre-trained BERT tokenizer
    tokenizer = BertTokenizer.from_pretrained(args.bert_model_path)
    # Load data
    data_train = MetaDataSet(args.alphabets_folder, args.train, MetaDataSet.train, tokenizer, args.max_token_length,
                             {MetaDataSet.ner, MetaDataSet.ic})
    data_dev = MetaDataSet(args.alphabets_folder, args.dev, MetaDataSet.dev, tokenizer, args.max_token_length,
                           {MetaDataSet.ner, MetaDataSet.ic})
    eval_tasks = sample_eval_tasks(data_dev, args)
    # Build model
    proto_net = build_protonet_model(args, device)
    # Early stopping initialization
    early_stopping = EarlyStopping(min_delta=args.delta, patience=args.patience,
                                   max_epochs=args.meta_max_epochs, mode='max')
    # Optimizer creation
    optimizer = create_meta_optimizer_protonet(proto_net, args)
    # Parameter-less modules
    ner_loss = ProtoNetNerLoss(proto_net, args.decoder == "softmax").to(device)
    ner_predictor = ProtoNetPredictor(proto_net, args.decoder == "softmax", args.max_token_length).to(device)
    # Training loop
    dev_f1 = 0
    best_metrics = None
    best_epoch = 0
    tot_start_time = time.time()
    outer_step_func = lambda: proto_ner_step(data_train, ner_loss, optimizer, args, device)
    while early_stopping.step(dev_f1):
        epoch = early_stopping.epoch()
        # loop on batches
        ave_loss = train(outer_step_func, args.meta_num_updates, log=False)
        # Epoch done
        logger.info("epoch={:d}; loss={:f};".format(epoch, ave_loss))
        # Eval performances on dev set
        # Save a copy of the initialization weights
        state_dict_proto_net = save(proto_net)
        metrics = []
        # Train on all tasks and evaluate
        for task, dataset in enumerate(eval_tasks):
            with torch.no_grad():
                # Optimizer creation
                task_optimizer = create_optimizer_protonet(proto_net, args)
                # Metrics placeholder
                metrics_on_task = []
                # Evaluate without finetuning
                ner_predictor.set_prototypes(dataset.get_supports(device))
                micro_f1 = evaluate(dataset, ner_predictor, device)
                metrics_on_task.append((micro_f1, 0))
            # Step function closure
            inner_step_func = lambda: proto_ft_ner_step(dataset, ner_loss, task_optimizer, device)
            for valid_epoch in range(args.max_epochs):
                ave_loss = train(inner_step_func, args.num_updates)
                with torch.no_grad():
                    ner_predictor.set_prototypes(dataset.get_supports(device))
                    micro_f1 = evaluate(dataset, ner_predictor, device)
                    metrics_on_task.append((micro_f1, ave_loss))
            # Reset parameters
            load(proto_net, state_dict_proto_net)
            # Display metrics
            metrics.append(metrics_on_task)
            if args.log_validation:
                logger.info("Num tasks : {:d}".format(task + 1))
                msg = print_metrics_per_epochs(metrics)
                logger.info("\n" + msg)

        mean_metrics, std_metrics, best_valid_num_epoch = log_fine_tuning_results(metrics, logger)
        logger.info("micro_f1={}; loss_valid={}; best_valid_num_epoch={};".format(
            mean_metrics[0][best_valid_num_epoch], mean_metrics[1][best_valid_num_epoch], best_valid_num_epoch))
        if dev_f1 < mean_metrics[0][best_valid_num_epoch]:
            dev_f1 = mean_metrics[0][best_valid_num_epoch]
            best_metrics = metrics
            best_epoch = epoch
            torch.save(proto_net.state_dict(), os.path.join(args.model_folder, 'proto_net_best.pth'))
            save_hyperparam(cli_string, argsdict, mean_metrics, std_metrics, best_valid_num_epoch, args.output_folder)
            logger.info("saved better model")

    logger.info("Training done")
    mean_metrics, std_metrics, best_valid_num_epoch = log_fine_tuning_results(best_metrics, logger)
    logger.info("BEST - num_epochs={}; micro_f1={}; loss_valid={}; best_valid_num_epoch={};".format(
        best_epoch, mean_metrics[0][best_valid_num_epoch],
        mean_metrics[1][best_valid_num_epoch], best_valid_num_epoch))
    tot_train_time = time.time() - tot_start_time
    logger.info("Total training time: {}s, i.e., {}h {}min".format(
        tot_train_time, tot_train_time // 3600, tot_train_time % 3600 // 60))


if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='Train a ProtoNet Model for NER',
                                     formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    parser.add_argument("--train", type=str, help='Path to the train data',
                        default=first_file_in_dir(os.environ.get('SM_CHANNEL_TRAIN')))
    parser.add_argument("--dev", type=str, help='Path to the dev data',
                        default=first_file_in_dir(os.environ.get('SM_CHANNEL_DEV')))
    add_decoder_parameters_group(parser)
    add_bert_encoder_parameters_group(parser)
    add_fsl_parameters_group(parser)
    add_meta_learning_optimizer_parameters_group(parser)
    add_optimizer_parameters_group(parser)
    add_early_stopping_parameters_group(parser)
    add_computation_parameters_group(parser)
    add_logging_parameters_group(parser)
    add_read_parameters_group(parser)
    add_eval_parameters_group(parser)
    add_write_parameters_group(parser)
    main(parser.parse_args())
