from datasets.utils import logging
from test_af import *
import argparse
from logging import CRITICAL
from logging import DEBUG
from logging import ERROR
from logging import INFO
from logging import NOTSET
from logging import WARNING

from utils import load_yml_configs

_PRUNING_STRATEGY = {
    "Origin": PRUNING_STRATEGY.NONE,
    "Layer": PRUNING_STRATEGY.LAYER,
    "Weight":PRUNING_STRATEGY.WEIGHT,
    "Neurons":PRUNING_STRATEGY.NUERON,
}


LOG_LEVEL={
    "NOTSET":NOTSET,
    "DEBUG":DEBUG,
    "INFO": INFO,
    "WARNING": WARNING,
    "ERROR": ERROR,
    "CRITICAL": CRITICAL
}

def train(config_path, log_level):
    configs = load_yml_configs(config_path)
    print(configs["train"])
    task = configs["train"]["task"]
    seed = configs["train"]["seeds"]
    model_config_ori = configs['model']

    pruning_strategy = _PRUNING_STRATEGY.get(model_config_ori['pruning_strategy'], None)
    assert pruning_strategy, "Pruning strategy must be Layer/ Weight/ Neurons/ Origin"

    af_adapters = model_config_ori['af_adapters']
    save_model_path = model_config_ori['save_model_path']
    skip_layers = model_config_ori.pop('skip_layers', None)

    set_seed(seed)
    
    is_LTH = pruning_strategy != PRUNING_STRATEGY.NONE
    lr_config = load_yml_configs('lth.yml' if is_LTH else 'af.yml')
    model_config = {
    'seed':seed,
    'lr':float(lr_config['lr'][task]),
    'epoch':configs['train']['epochs'],
    'fp16':False,
    'num_gpu':1,
    }

    run = wandb.init(project="Test-flops", 
    tags=["8 Adapter", "AF-"+pruning_strategy, task],
    group="AdapterDrop" if skip_layers else "LTH" if is_LTH else "Origin",
    reinit=True
    )
    print("**** Start AF-"+pruning_strategy+" Run",run.name, run.id,"****")
    with run:
        test(task, 
            af_adapters = af_adapters,
            model_config = model_config,
            save_model_path = save_model_path,
            model_checkpoint = model_config_ori.pop('model_checkpoint', 'bert-base-uncased'),
            pruning_strategy = pruning_strategy,
            wandb_run = run,
            cal_imp = model_config_ori.pop('cal_imp', True),
            imp_training = model_config_ori.pop('imp_training', True),
            cal_flops = model_config_ori.pop('cal_flops', True),
            do_train = model_config_ori.pop('do_train', False),
            save_eval = model_config_ori.pop('save_eval', True),
            gen_submit = model_config_ori.pop('gen_submit', True),
            skip_layers = skip_layers,
            do_eval = model_config_ori.pop('do_eval', True)
            )

def submit(model_dir, log_level):

    pass

def evaluate(model_dir, log_level):
    
    pass

if __name__ == '__main__':

    parser = argparse.ArgumentParser()

    parser.add_argument("task",
                        default=None, type=str, choices=["train", "submit", "evaluate"],
                        help='Task in GLUE'
    )
    parser.add_argument("-c", "--config",
                        default=None, type=str, required=False,
                        help="path for train configs"
    )
    parser.add_argument("-m", "--model_dir",
                        default=None, type=str, required=False,
                        choices=list(LOG_LEVEL.keys()),
                        help="path for model to saved"
    )
    parser.add_argument("-l", "--log-level",
                        default="INFO", type=str, required=False,
                        choices=list(LOG_LEVEL.keys()),
                        help="log level"
    )
    args = parser.parse_args()

    if args.task == "train":
        assert args.config is not None, parser.error(f"must specify config path for {args.task}")
        train(args.config, args.log_level)

    elif args.task == "submit":
        assert args.model_dir is not None, parser.error(f"must specify model dir for {args.task}")
        submit(args.model_dir, args.log_level)

    else:
        assert args.model_dir is not None, parser.error(f"must specify model dir for {args.task}")
        evaluate(args.model_dir, args.log_level)

