# -*- coding: utf-8 -*-

import os
import pickle

from framework.common.logger import LOGGER, open_wrapper
from framework.common.utils import ProgressReporter
from framework.data.vocab import VocabularySet
from nn_generator.data_utils import edsgraph_to_dict, pad_batch_graphs, simplify_derivation
from pyshrg_utils.parser import PySHRGParserOptions, pyshrg_parse_args


class MainOptions(PySHRGParserOptions):
    derivation_dir: str
    output_dir: str

    num_samples_per_batch: int = 500
    exclude_edges: bool = False


def export_train_graphs_stage1(manager, derivations, vocabs, options):
    """
    @type manager: pyshrg.Manager
    """
    derivations.sort(key=lambda x: manager.get_graph(x[0]).num_nodes)

    instances = []
    for graph_index, derivation in ProgressReporter(len(derivations), step=1000)(derivations):
        graph = edsgraph_to_dict(manager.get_graph(graph_index), vocabs)
        num_partitions, derivation = simplify_derivation(derivation)
        instances.append((num_partitions, graph, derivation))

    return instances


def export_train_graphs_stage2(manager, instances, vocabs, options):
    num_total_partitions = 0
    batches = []
    current_batch = []

    LOGGER.info("Split batches ...")
    for num_partitions, *instance in ProgressReporter(len(instances), step=1000)(instances):
        graph, derivation = instance

        all_partition_indices, shrg_rule_indices, hrg_rule_indices = derivation[1:4]

        expanded_shrg_rule_indices = []
        for shrg_index, partition_indices in zip(shrg_rule_indices, all_partition_indices):
            gold_index = partition_indices[0]  # first partition is the gold partition
            shrg_rule = manager.get_shrg(shrg_index)
            assert manager.get_hrg(hrg_rule_indices[gold_index]) is shrg_rule

            indices = [shrg_index]
            if shrg_rule.size == 1:
                assert shrg_rule.get(0).shrg_index == shrg_index
            else:
                indices.extend(cfg_rule.shrg_index
                               for cfg_rule in shrg_rule.iter_cfgs()
                               if cfg_rule.shrg_index != shrg_index)
                assert len(indices) == shrg_rule.size
            expanded_shrg_rule_indices.append(indices)

        derivation = list(derivation)
        derivation[2] = expanded_shrg_rule_indices  # set shrg_rule_indices
        instance = graph, derivation

        if num_total_partitions + num_partitions > options.num_samples_per_batch:
            batches.append(current_batch)
            num_total_partitions = 0
            current_batch = []

        current_batch.append(instance)
        num_total_partitions += num_partitions

    if current_batch:
        batches.append(current_batch)

    LOGGER.info("Pad batches ...")
    for index, batch in ProgressReporter(len(batches), step=100)(enumerate(batches)):
        batches[index] = pad_batch_graphs(batch, vocabs, exclude_edges=options.exclude_edges)

    return batches


def parse_args(argv=None):
    abbrevs = {
        'grammar_dir': 'grammar_dir',
        'derivation_dir': '-d',
        'output_dir': '-o'
    }
    default_instance = MainOptions()
    default_instance.derivation_dir = 'derivation'
    default_instance.output_dir = 'padded_data'
    default_instance.graphs_tag = 'train'
    default_instance.pyshrg.num_contexts = 1

    manager, options = pyshrg_parse_args(argv, default_instance=default_instance, abbrevs=abbrevs)

    grammar_dir = options.grammar_dir
    output_dir = os.path.join(grammar_dir, options.output_dir)

    if options.exclude_edges:
        output_dir += '.no_edges'
    options.output_dir = output_dir
    os.makedirs(output_dir, exist_ok=True)

    return manager, options


def clip_vocabs(vocabs):
    LOGGER.info('remove low-frequency labels...')
    vocabs.set(vocabs.get('word').copy_without_low_frequency(2, name='word'))
    vocabs.set(vocabs.get('carg').copy_without_low_frequency(2, name='carg'))


def main(argv=None):
    manager, options = parse_args(argv)

    derivation_dir = os.path.join(options.grammar_dir, options.derivation_dir)
    _open_in = open_wrapper(lambda x: os.path.join(derivation_dir, x))
    _open_out = open_wrapper(lambda x: os.path.join(options.output_dir, x))

    skip_stage1 = os.path.exists(os.path.join(options.output_dir, 'vocabs.txt'))
    if skip_stage1:
        vocabs = VocabularySet.from_file(_open_out('vocabs.txt', 'r'))
    else:
        vocabs = VocabularySet()
        LOGGER.info('Exporting vocabs ... ')

    filepaths = []
    for filepath in os.listdir(derivation_dir):
        output_path = 'train-graphs.' + filepath
        if not skip_stage1:
            pickle.dump(export_train_graphs_stage1(manager,
                                                   pickle.load(_open_in(filepath, 'rb')),
                                                   vocabs, options),
                        _open_out(output_path, 'wb'))
        filepaths.append(output_path)

    if not skip_stage1:
        LOGGER.info('Clip vocabs ... ')
        clip_vocabs(vocabs)
        vocabs.to_file(_open_out('vocabs.txt', 'w'))

    LOGGER.info('\n%s', vocabs)

    for filepath in filepaths:
        pickle.dump(export_train_graphs_stage2(manager,
                                               pickle.load(_open_out(filepath, 'rb')),
                                               vocabs, options),
                    _open_out(filepath, 'wb'))
    LOGGER.info('Done')


if __name__ == '__main__':
    main()
