import tarfile
import argparse
import sys

import torch
from transformers import BertTokenizer

from few_shot_ner.io import create_directories
from few_shot_ner.logger import init_root_logger
from few_shot_ner.dataset import MetaDataSet
from few_shot_ner.eval import evaluate
from few_shot_ner.io import first_file_in_dir, log_fine_tuning_results, save_hyperparam, print_metrics_per_epochs
from few_shot_ner.losses import BaselineNerLoss
from few_shot_ner.models import save, load, build_baseline_model, init_seed
from few_shot_ner.optimizer import create_optimizer
from few_shot_ner.parser import *
from few_shot_ner.predictors import NERPredictor
from few_shot_ner.trainer import bert_ner_step, train


def main(args):
    init_seed(args.seed)
    if os.environ.get('SM_OUTPUT_DATA_DIR') is not None:
        args.log_file = os.path.join(os.environ.get('SM_OUTPUT_DATA_DIR'), 'log.txt')
    logger = init_root_logger(args.log_file, level=args.verbose)
    argsdict = {}
    for name, value in vars(args).items():
        logger.info("%s : %s" % (name, str(value)))
        argsdict[name] = value
    cli_string = 'python ' + ' '.join(sys.argv)
    for name, value in vars(args).items():
        logger.info("%s : %s" % (name, str(value)))

    # Find device
    device = torch.device('cpu') if args.cpu else torch.device('cuda:{}'.format(torch.cuda.current_device()))
    logger.info("Device used: %s" % device)

    # Create directories
    create_directories(args.output_folder)

    # Load pre-trained BERT tokenizer
    tokenizer = BertTokenizer.from_pretrained(args.bert_model_path)

    # Load test data
    data_test = MetaDataSet(args.alphabets_folder, args.test, MetaDataSet.dev, tokenizer, args.max_token_length,
                           {MetaDataSet.ner, MetaDataSet.ic})

    # Build model
    base_net = build_baseline_model(args, device)

    if 'opt/ml/input' in args.trained_model:
      args.trained_model = '/opt/ml/input/data/model/model.tar.gz'

    if '.tar.gz' in args.trained_model:
        with tarfile.open(args.trained_model) as tar:
            tar.extractall(args.trained_model[:-len('.tar.gz')])
        args.trained_model = args.trained_model[:-len('.tar.gz')]

    if os.path.exists(os.path.join(args.trained_model, 'base_net_best.pth')):
        model_path = os.path.join(args.trained_model, 'base_net_best.pth')
    else:
        model_path = os.path.join(args.trained_model, 'base_net.pth')
    base_net.load_state_dict(torch.load(model_path, map_location='cpu' if device == torch.device('cpu') else None))

    # Parameter-less modules
    ner_loss = BaselineNerLoss(base_net, args.decoder == "softmax").to(device=device)
    ner_predictor = NERPredictor(base_net, args.decoder == "softmax", args.max_token_length).to(device=device)

    # Save a copy of the initialization weights
    init_parameters = save(base_net)

    # Train on all tasks and evaluate
    metrics = []
    for task in range(args.num_samples):
        with torch.no_grad():
            # Sample task
            dataset = data_test.sample_task(args.batch_size, args.n, args.k)
            ner_predictor.set_mapping(dataset.tgt_slots)
            # Optimizer creation
            optimizer = create_optimizer(base_net, args)
            # Metrics placeholder
            metrics_on_task = []
            # Evaluate without finetuning
            micro_f1 = evaluate(dataset, ner_predictor, device)
            metrics_on_task.append((micro_f1, 0))
        # Step function closure
        step_func = lambda: bert_ner_step(dataset, ner_loss, optimizer, device)
        for epoch in range(args.max_epochs):
            ave_loss = train(step_func, args.num_updates)
            micro_f1 = evaluate(dataset, ner_predictor, device)
            metrics_on_task.append((micro_f1, ave_loss))
        load(base_net, init_parameters)
        metrics.append(metrics_on_task)
        if args.log_validation:
            logger.info("Num tasks : {:d}".format(task + 1))
            msg = print_metrics_per_epochs(metrics)
            logger.info("\n" + msg)

    mean_metrics, std_metrics, best_num_epoch = log_fine_tuning_results(metrics, logger)
    logger.info("Number of epochs used : {}".format(args.max_epochs))
    logger.info(u"Micro F1 : {:2.4f}\u00B1{:2.4f}".format(mean_metrics[0][args.max_epochs],
                                                          std_metrics[0][args.max_epochs]))
    logger.info(u"Loss : {:f}\u00B1{:f}".format(mean_metrics[1][args.max_epochs], std_metrics[1][args.max_epochs]))

    logger.info("micro_f1={}; loss={}; best_num_epoch={};".format(
        mean_metrics[0][args.max_epochs], mean_metrics[1][args.max_epochs], args.max_epochs))
    save_hyperparam(cli_string, argsdict, mean_metrics, std_metrics, args.max_epochs, args.output_folder)


if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='Finetune a BERT Model for NER for a given number of epochs',
                                     formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    parser.add_argument("--test", type=str, help='Path to the test data',
                        default=first_file_in_dir(os.environ.get('SM_CHANNEL_TEST')))
    add_decoder_parameters_group(parser)
    add_bert_encoder_parameters_group(parser)
    add_fsl_parameters_group(parser)
    add_optimizer_parameters_group(parser)
    add_computation_parameters_group(parser)
    add_logging_parameters_group(parser)
    add_read_parameters_group(parser)
    add_eval_parameters_group(parser)
    add_write_parameters_group(parser)
    main(parser.parse_args())
