import argparse

from config.config import get_cfg_defaults
from runner import SecondStageRunner

import warnings

# from runner.test_runner import TestRunner

warnings.filterwarnings("ignore")


def parse_args():
    parser = argparse.ArgumentParser()

    parser.add_argument('--config', type=str, default=None)
    parser.add_argument('--start-from', type=int, default=-1)
    parser.add_argument('--eval-epoch', type=int, default=-1)
    parser.add_argument('--test-file', type=str, default=None)
    return parser.parse_args()

import os
BASEDIR = os.path.dirname(os.path.realpath(__file__))
print('BASEDIR: ', BASEDIR)
if __name__ == "__main__":
    """
    """
    args = parse_args()
    cfg = get_cfg_defaults()
    if args.config is not None:
        cfg.merge_from_file(os.path.join(BASEDIR, 'config', args.config + ".yaml"))
    cfg.freeze()
    runner = SecondStageRunner(cfg)
    if args.eval_epoch != -1:
        runner.load_model(cfg.saved_path + "/models-{}.pt".format(args.eval_epoch))
        print('load model: ', cfg.saved_path + "/models-{}.pt".format(args.eval_epoch))
        # runner.eval(args.eval_epoch, "eval")
        if args.test_file:
            runner.test(args.eval_epoch, args.test_file)
    else:
        if args.start_from >= 0:
            runner.load_model(cfg.saved_path + "/models-{}.pt".format(args.start_from))
        runner.train()
