import argparse
import datetime
import collections
import pathlib
import subprocess
import sys

import skemb.pipeline_conf


parser = argparse.ArgumentParser()


stage_mapping = collections.OrderedDict()

def learn_sgpatterns(conf, args):
    cmd = (
        'python', '-m', 'skemb.cmd.learn_sgpatterns',
        '--dataset', pathlib.Path('dataset') / 'train',
        '--output-dir', 'sgpatterns',
        '--force',
    )
    run_cmd(cmd)
stage_mapping['sgpatterns'] = learn_sgpatterns


def embedding(conf, args):
    msg('LEARNING AND EXTRACTING EMBEDDINGS FROM TRAIN DATA')
    cmd = (
        'python', '-m', 'skemb.cmd.train_embedding_model',
        '--method', conf.EMBEDDING_METHOD,
        '--params', pathlib.Path('conf') / 'embedding_params.py',
        '--force',
    )
    run_cmd(cmd)
    msg('')
    msg('EXTRACTING EMBEDDINGS FROM TEST DATA')
    cmd = (
        'python', '-m',  'skemb.cmd.extract_embeddings',
        '--force',
    )
    run_cmd(cmd)
stage_mapping['embedding'] = embedding


def classification(conf, args):
    cmd = (
        'python', '-m', 'skemb.cmd.classification',
        '--method', conf.CLASSIFICATION_METHOD,
        '--method-hyperparams', pathlib.Path('conf') / 'classifier_hyperparams.py',
        '--log-file', 'classification-log.txt',
    )
    run_cmd(cmd)
stage_mapping['classification'] = classification


parser.add_argument('--stage', choices=list(stage_mapping) + ['ALL', 'REALLY_ALL'], default='ALL')


def msg(*k, **kw):
    kw['file'] = sys.stderr
    print(*k, **kw)


def run_cmd(cmd, check=True, **kw):
    kw['check'] = check
    cmd = tuple(str(component) for component in cmd)
    return subprocess.run(cmd, **kw)

def run_stage(stage, conf, args):
    msg(f'RUNNING STAGE: {stage}')
    t0 = datetime.datetime.now()
    stage_mapping[stage](conf, args)
    t = datetime.datetime.now()
    msg(f'FINISHED RUNNING STAGE: {stage}    Execution time: {t - t0}')


def run(args):
    conf_path = pathlib.Path('conf') / 'conf.py'
    try:
        conf = skemb.pipeline_conf.from_path(conf_path)
    except FileNotFoundError:
        print(f'Configuration file {conf_path} not found. Are you sure you are at the root of a workspace?', file=sys.stderr)
        exit(1)

    if args.stage in ('ALL', 'REALLY_ALL'):
        for stage in stage_mapping:
            if stage == 'sgpatterns' and args.stage != 'REALLY_ALL':
                continue
            run_stage(stage, conf, args)
    else:
        run_stage(args.stage, conf, args)


if __name__ == '__main__':
    args = parser.parse_args()
    run(args)
