# -*- coding: utf-8 -*-

import pprint
import sys

COMMON_ABBREVS = {
    'test_paths': '-i',
    'output_prefix': '-p',
    'hyper_params': {
        'external_embedding': {
            'filename': '--embedding-file'
        }
    }
}

MODELS = {
    'fb': 'nn_generator.feature_based.model.SHRGGenerator'
}

ABBREVS = {}
for abbrev, entry_class in MODELS.items():
    ABBREVS[entry_class] = COMMON_ABBREVS


def get_model_class(argv=None):
    if argv is None:
        argv = sys.argv

    model_class = ''
    if len(argv) >= 2:
        model_class = argv.pop(1).lower()

    model_class = MODELS.get(model_class)
    if model_class is None:
        print('Select model first:')
        pprint.pprint(MODELS)
        sys.exit(1)

    return model_class
