import json
import os

import numpy as np

from few_shot_ner.eval import aggregate_metrics


def create_directories(*directories):
    for d in directories:
        if not os.path.exists(d):
            os.makedirs(d)


def first_file_in_dir(dir_path):
    if dir_path is None:
        return None
    return os.path.join(dir_path, os.listdir(dir_path)[0])


def log_fine_tuning_results(metrics, logger):
    mean_metrics, std_metrics = aggregate_metrics(metrics)
    best_num_epoch = np.argmax(mean_metrics[0])
    logger.info("Best number of epochs : {}".format(best_num_epoch))
    logger.info(u"Micro F1: {:2.4f}\u00B1{:2.4f}".format(mean_metrics[0][best_num_epoch], std_metrics[0][best_num_epoch]))
    logger.info(u"Loss: {:f}\u00B1{:f}".format(mean_metrics[1][best_num_epoch], std_metrics[1][best_num_epoch]))
    return mean_metrics, std_metrics, best_num_epoch


def print_metrics_per_epochs(_metrics):
    _pad = lambda n, s: " " * (n - len(s)) + s
    _mean_metrics, _std_metrics = aggregate_metrics(_metrics)
    _metric_names = ["Micro F1: ", "Loss    : "]
    _msg = []
    _msg.append("# Epochs: " + "".join([_pad(15, str(e)) for e in range(_mean_metrics.shape[1])]))
    for m in range(_mean_metrics.shape[0]):
        _msg.append(_metric_names[m] + "".join(
            [_pad(15, u"{:2.4f}\u00B1{:2.2f}".format(mean, std)) for mean, std in zip(
                _mean_metrics[m].tolist(),
                _std_metrics[m].tolist()
            )]))
    return "\n".join(_msg)


def save_hyperparam(cli_string, argsdict, mean_metrics, std_metrics, num_epoch, output_folder):
    hyperparam = {'cli_string': cli_string,
                  'args': argsdict,
                  'results': {
                      'micro_f1': {'mean': mean_metrics[0][num_epoch],
                                   'std': std_metrics[0][num_epoch]},
                      'loss': {'mean': mean_metrics[1][num_epoch],
                               'std': std_metrics[1][num_epoch]},
                      'best_num_epochs_finetuning': int(num_epoch)
                  }}
    json.dump(hyperparam, open(os.path.join(output_folder, 'hyperparam.json'), 'w'))
