# -*- coding: utf-8 -*-

import itertools
import multiprocessing
import os

import pyshrg
from framework.common.dataclass_options import OptionsBase, SingleOptionsParser, argfield

PYSHRG_ABBREVS = {
    'num_contexts': '-j',
    'parser_type': '-P',
    'max_pool_size': '-S',
    'verbose': '-v',
    'filter_method': '-F'
}

PYSHRG_PARSER_ABBREVS = {
    'grammar_dir': 'grammar_dir',
    'grammar_suffix': '-s',
    'graphs_tag': '-t'
}


class PySHRGOptions(OptionsBase, active_time='both'):
    parser_type: str = 'linear'
    num_contexts: int = min(os.cpu_count(), 8)
    max_pool_size: int = 50
    filter_method: str = argfield('none', choices=['disconnected', 'none'])
    verbose: bool = False


class PySHRGParserOptions(OptionsBase):
    grammar_dir: str
    grammar_suffix: str = ''
    graphs_tag: str = 'dev'

    pyshrg: PySHRGOptions


def pyshrg_init(grammar_path, graphs_path=None, options=PySHRGOptions()) -> pyshrg.Manager:
    pyshrg.initialize()
    manager = pyshrg.get_manager()
    manager.allocate(options.num_contexts)
    assert manager.load_grammars(grammar_path, filter=options.filter_method)

    if graphs_path is not None:
        assert manager.load_graphs(graphs_path)

    manager.init_all(options.parser_type, verbose=options.verbose,
                     max_pool_size=options.max_pool_size)

    return manager


def pyshrg_parse_args(argv=None, options_class=None, default_instance=None, abbrevs=None):
    if options_class is None and default_instance is None:
        options_class = PySHRGParserOptions

    if abbrevs is None:
        abbrevs = {}

    for key, value in PYSHRG_PARSER_ABBREVS.items():
        if key not in abbrevs:
            abbrevs[key] = value

    parser = SingleOptionsParser()
    parser.setup(options_class,
                 default_instance=default_instance,
                 abbrevs={'pyshrg': PYSHRG_ABBREVS, **abbrevs})

    options = parser.parse_args(argv)  # type: MainOptions

    grammar_dir = options.grammar_dir

    manager = pyshrg_init(os.path.join(grammar_dir, f'train.mapping.txt{options.grammar_suffix}'),
                          os.path.join(grammar_dir, f'{options.graphs_tag}.graphs.txt'),
                          options.pyshrg)

    return manager, options


def _pyshrg_worker(arguments):
    process2index, task_fn, user_args = arguments

    manager = pyshrg.get_manager()
    if isinstance(process2index, int):
        context_index = process2index
    else:
        worker_id = id(multiprocessing.current_process())
        context_index = process2index.setdefault(worker_id, len(process2index))

    assert context_index < manager.context_size, 'out of range, some process may exit abnormally'
    context = manager.get_context(context_index)  # type: pyshrg.Context

    return task_fn(context, *user_args)


class PySHRGPool:
    def __init__(self, mp=multiprocessing):
        manager = pyshrg.get_manager()

        self._mp = mp
        self._pool = mp.Pool(processes=manager.context_size)

    def imap_unordered(self, task_fn, all_arguments, chunksize=10):
        process2worker_id = self._mp.Manager().dict()
        yield from self._pool.imap_unordered(_pyshrg_worker,
                                             zip(itertools.repeat(process2worker_id),
                                                 itertools.repeat(task_fn),
                                                 all_arguments),
                                             chunksize)

    def apply_async(self, task_fn, arguments, context_index):
        arguments = context_index, task_fn, arguments
        return self._pool.apply_async(_pyshrg_worker, args=(arguments,))

    def terminate(self):
        self._pool.terminate()

    def __enter__(self, *_):
        return self

    def __exit__(self, *_):
        self.terminate()
