#!/usr/bin/env python3 -u

import logging
import sys
import yaml
import gc

import optuna
from optuna.pruners import (MedianPruner, NopPruner, PercentilePruner,
                            SuccessiveHalvingPruner, HyperbandPruner)


from fairseq import options
from .train import main as train_main

logging.basicConfig(
    format='%(asctime)s | %(levelname)s | %(name)s | %(message)s',
    datefmt='%Y-%m-%d %H:%M:%S',
    level=logging.INFO,
    stream=sys.stdout,
)
logger = logging.getLogger('fairseq_cli.hyperopt')


def array_type(params):
    def is_of_type(value, t):
        try:
            t(value)
            return True
        except ValueError:
            return False

    non_float = any(not is_of_type(p, float) for p in params)
    non_int = any(not is_of_type(p, int) for p in params)
    non_bool = len(set(params + ['True', 'False', 'true', 'false'])) != 4

    if non_float and non_int:
        if non_bool:
            return str
        else:
            return lambda x: x in ('True', 'true')
    elif non_int:
        return float
    return int


def params_autotype(form, params):
    if form == 'categorical':
        fn = array_type(params)
    else:
        fn = int if form == 'int' else float

    params = [fn(v) for v in params]

    if form == 'categorical':
        return (params, )
    return params


def suggest(spec, trial, args):
    name, form, *params = spec
    name = name.replace('-', '_')
    params = params_autotype(form, params)
    old_val = getattr(args, name)

    assert form in ('categorical', 'discrete_uniform',
                    'int', 'loguniform', 'uniform')

    fn = getattr(trial, f'suggest_{form}')
    val = fn(name, *params)

    if isinstance(old_val, list):
        val = [val]
    setattr(args, name, val)


def final_score(valid_losses, args):
    selected = valid_losses[-args.avg_last:]
    return sum(selected) / len(selected)


def objective(trial, args):
    for spec in args.param:
        suggest(spec, trial, args)

    train_main(args, trial=trial)

    intermediate_values = trial.storage.get_trial(trial.number).intermediate_values
    valid_scores = [intermediate_values[step]
                    for step in sorted(intermediate_values.keys())]

    return final_score(valid_scores, args)


def finish_study(study, args):
    logger.info('Finished optimization. Best value: %s.', study.best_value)
    logger.info('Params:')

    for k, v in study.best_params.items():
        logger.info(' - %s: %s', k, v)

    if args.output:
        logger.info('Saving determined params in %s', args.output)
        with open(args.output, 'w') as outfile:
            yaml.dump({**study.user_attrs, **study.best_params}, outfile)


def supress_other_loggers():
    logger.info('Supressing loggers unrelated to hyperopt...')
    for logger_name in logging.root.manager.loggerDict:
        if logger_name not in ('fairseq_cli.hyperopt', 'fairseq_cli') \
           and not 'optuna' in logger_name:
            logging.getLogger(logger_name).setLevel(logging.WARNING)


def pruner(args):
    if args.pruner == 'median':
        return MedianPruner(n_startup_trials=args.pruner_startup_trials,
                            n_warmup_steps=args.pruner_warmup_steps,
                            interval_steps=args.pruner_interval_steps)
    elif args.pruner == 'percentile':
        return PercentilePruner(args.pruner_percentile,
                                n_startup_trials=args.pruner_startup_trials,
                                n_warmup_steps=args.pruner_warmup_steps,
                                interval_steps=args.pruner_interval_steps)
    elif args.pruner == 'successive_halving':
        return SuccessiveHalving(reduction_factor=args.pruner_reduction_factor)
    elif args.pruner == 'hyperband':
        return HyperbandPruner(reduction_factor=args.pruner_reduction_factor)
    return NopPruner()

def add_user_attributes(study, args):
    args_blacklist = ['param', 'output', 'pruner', 'trials', 'avg_last', 'study_name',
                      'quiet', 'study_storage', 'pruner_startup_trials',
                      'pruner_warmup_steps', 'pruner_interval_steps',
                      'pruner_percentile', 'pruner_reduction_factor']
    for arg_name, arg_value in vars(args).items():
        if arg_name not in args_blacklist:
            study.set_user_attr(arg_name, arg_value)

def cli_main():
    parser = options.get_hyperopt_parser()
    args, _ = options.parse_args_and_arch(parser, parse_known=True)
    
    # Due to unexpected behavior of fairseq when training was stopped
    # in the middle of epoch. Note we are also passing gc_after_trial=False
    # to the study.optimize()
    # TODO: investigate this in the future
    gc.disable()

    if not args.param:
        logger.error('You must specify parameters to optimize.')
        return  

    if args.distributed_init_method is not None:
        logger.error('Hyperoptimization does not support distributed training at the moment.')
        return

    # Propagate logs to the root logger
    optuna.logging.enable_propagation()
    optuna.logging.disable_default_handler()

    if args.quiet:
        supress_other_loggers()

    study = optuna.create_study(study_name=args.study_name, storage=args.study_storage,
                                load_if_exists=True, pruner=pruner(args),
                                direction='maximize' if args.maximize_best_checkpoint_metric else 'minimize')

    add_user_attributes(study, args)

    try:
        study.optimize(lambda t: objective(t, args), n_trials=args.trials, gc_after_trial=False)
    except KeyboardInterrupt:
        logger.info('Stopping hyperopt manually (aborted by ctrl+c)...')

    finish_study(study, args)


if __name__ == '__main__':
    cli_main()
