# -*- coding: utf-8 -*-

import os
import pickle

import pyshrg
from framework.common.logger import LOGGER, open_file, open_wrapper
from framework.common.utils import ProgressReporter
from pyshrg_utils.derivation import recover_subgraph_derivation
from pyshrg_utils.parser import PySHRGParserOptions, PySHRGPool, pyshrg_parse_args


class MainOptions(PySHRGParserOptions):
    output_dir: str = 'derivation'
    num_instances_per_file: int = 5000
    num_negative_samples: int = 100


def _export_worker(context, graph_index, derivation, num_negative_samples):
    manager = pyshrg.get_manager()
    graph = manager.get_graph(graph_index)
    code = context.parse(graph_index)
    partitions, derivation = recover_subgraph_derivation(derivation, graph)
    if code != pyshrg.ParserError.kNone:
        data = _get_filtered_rule(derivation, manager)
    else:
        data = context.export_derivation(derivation, partitions, num_negative_samples)
        if not data:
            data = _get_filtered_rule(derivation, manager)
    return graph_index, code, data


def _get_filtered_rule(derivation, manager):
    for item in derivation:
        if item is None:
            continue
        rule = manager.get_shrg(item[0])
        assert rule
        if rule.is_empty:
            return item[0]
    return None


def check_exported_derivation(manager, data):
    for _ in data:
        if _ is None:
            continue
        shrg_index, partitions = _
        hrg_index = partitions[0][0]
        assert manager.get_shrg(shrg_index) is manager.get_hrg(hrg_index)


def main(argv=None):
    abbrevs = {'output_dir': '-d',
               'num_negative_samples': '-N',
               'num_instances_per_file': '-n'}
    manager, options = pyshrg_parse_args(argv, MainOptions, abbrevs=abbrevs)

    grammar_dir = options.grammar_dir
    output_dir = os.path.join(options.grammar_dir, options.output_dir)

    derivations = pickle.load(open_file(os.path.join(grammar_dir, 'train.derivations.p'), 'rb'))

    os.makedirs(output_dir, exist_ok=True)
    _open = open_wrapper(lambda x: os.path.join(output_dir, '{:06d}'.format(x)))

    num_instances_per_file = options.num_instances_per_file
    stats = [0, 0, 0]

    all_arguments = []
    for i, graph in enumerate(manager.iter_graphs()):
        derivation = derivations[graph.sentence_id.split('/')[1]]
        all_arguments.append((i, derivation, options.num_negative_samples))

    progress = ProgressReporter(stop=len(all_arguments),
                                message_fn=lambda _: f'R/F/?: {stats[0]}/{stats[1]}/{stats[2]}',
                                print_time=True)
    exported_data = []
    chunk_index = 0
    with PySHRGPool() as pool:
        for graph_index, code, data in progress(pool.imap_unordered(_export_worker, all_arguments)):
            graph = manager.get_graph(graph_index)
            if code != pyshrg.ParserError.kNone:
                LOGGER.error('%s %s%s', graph.sentence_id, code,
                             f' (filtered: {data})' if data else '')
                stats[0] += 1
                continue

            if isinstance(data, int):
                stats[1] += 1
                LOGGER.error('%s failed (filtered: %s)', graph.sentence_id, data)
                continue
            if data is None:
                stats[2] += 1
                LOGGER.error('%s ???', graph.sentence_id)
                continue

            check_exported_derivation(manager, data)

            if len(exported_data) == num_instances_per_file:
                pickle.dump(exported_data, _open(chunk_index, 'wb'))
                chunk_index += 1
                exported_data = []
            exported_data.append((graph_index, data))

        pickle.dump(exported_data, _open(chunk_index, 'wb'))


if __name__ == '__main__':
    main()
