# -*- coding: utf-8 -*-

from framework.common.dataclass_options import SingleOptionsParser
from framework.common.logger import LOGGER
from SHRG.shrg_extract import MainOptions, extract_shrg_from_dataset

STANDARD_SPLITS = {
    'train': ['wsj1*', 'wsj0*'],
    'dev': ['wsj20*'],
    'test': ['wsj21*']
}


def main(argv=None):
    parser = SingleOptionsParser()
    parser.add_argument('--force', '-F', default=False, action='store_true')
    parser.add_argument('--tags', '-T', default=['train', 'dev', 'test'], nargs='+')

    options = MainOptions()
    options.prefix = 'output/{grammar}.{graph_type}.{suffix}/'
    options.tree_path = 'data/trees'
    options.data_path = 'data/deepbank_export_1.1'

    parser.setup(
        MainOptions,
        default_instance=options,
        abbrevs={
            'grammar_name': 'grammar-name',
            'prefix': '-p',
            'tree_path': '-J',
            'data_path': '-D',
            'extraction': {
                'graph_type': '-g',
                'detect_function': '-d',
                'modify_tree': '-s',
                'modify_label': '-S',
                'remove_null_semantic': '-l',
                'remove_disconnected': '-r',
                'fully_lexicalized': '--fully-lexicalized',
                'fix_hyphen': '-f',
                'ep_permutation': '-e',
                'label_type': '-L'
            }
        }
    )

    options = parser.parse_args(argv)
    print(options.pretty_format())
    try:
        tags = parser.extra_options.tags
        train_splits = ({'train'} if not parser.extra_options.force else tags)
        split_patterns = [(tag, STANDARD_SPLITS[tag]) for tag in tags]

        extract_shrg_from_dataset(options, split_patterns, train_splits=train_splits)
    except KeyboardInterrupt:
        LOGGER.warn('Stop.')


if __name__ == '__main__':
    main()
