# -*- coding: utf-8 -*-

import functools
import gzip
import os
import pickle
from collections import Counter, OrderedDict
from multiprocessing import cpu_count
from typing import List, Optional

from framework.common.dataclass_options import OptionsBase, argfield
from framework.common.logger import LOGGER, open_wrapper

from .const_tree import ConstTree
from .const_tree_preprocess import (FIX_HYPHEN_OPTIONS,
                                    MODIFY_TREE_OPTIONS,
                                    merge_cfg_tree,
                                    modify_const_tree)
from .dataset_utils import ReaderBase
from .graph_io import READERS, TEXT_WRITRES
from .graph_transformers import TRANSFORMS
from .shrg import DETECT_FUNCTIONS, REMOVE_NULL_SEMANTIC_OPTIONS, extract_shrg_rule
from .shrg_anonymize import anonymize_rule
from .utils.container import IndexedCounter


class IgnoreException(Exception):
    pass


class ExtractionOptions(OptionsBase):
    graph_type: str = argfield(default='eds', choices=READERS.keys())
    detect_function: str = argfield(default='construction', choices=DETECT_FUNCTIONS.keys())

    modify_tree: str = argfield(default='no-suffix', choices=MODIFY_TREE_OPTIONS.keys())
    modify_label: List[str] = argfield(default_factory=lambda: ['rm-_d'],
                                       choices=['rm-_d', 'rm-HNDL', 'rm-q'],
                                       nargs='+')

    graph_transformers: List[str] = argfield(default_factory=list, choices=TRANSFORMS.keys())
    remove_disconnected: bool = False
    fully_lexicalized: bool = False
    label_type: str = argfield('cfg', choices=['hpsg', 'cfg'])

    remove_null_semantic: Optional[List[str]] = \
        argfield(default_factory=lambda: [],
                 choices=REMOVE_NULL_SEMANTIC_OPTIONS,
                 nargs='*')

    ep_permutation: List[str] = argfield(default_factory=lambda: ['stick+span'],
                                         choices=['stick+span', 'combine/start',
                                                  'combine/complete', 'combine/all'])

    fix_hyphen: str = argfield(default='none', choices=FIX_HYPHEN_OPTIONS.keys())


class MainOptions(OptionsBase):
    grammar_name: str

    prefix: str
    tree_path: str
    data_path: str

    extraction: ExtractionOptions
    num_workers: int = -1


@functools.lru_cache()
def get_tree_bank(tree_path, bank):
    tree_strings = {}
    with open(os.path.join(tree_path, bank)) as fin:
        while True:
            sentence_id = fin.readline().strip()
            if not sentence_id:
                break
            assert sentence_id.startswith('#')
            tree_strings[sentence_id[1:]] = fin.readline().strip()
    return tree_strings


def save_rules(output_prefix, rules_counter, options, extra_suffix=''):
    _open = open_wrapper(lambda x: output_prefix + x + extra_suffix)

    hrg2cfg_mapping = OrderedDict()
    head_counter = Counter()
    for index, (shrg_rule, counter_item) in enumerate(rules_counter):
        hrg = shrg_rule.hrg
        if hrg is None:
            hrg = shrg_rule.cfg.lhs
            head_counter[hrg] += counter_item.count
        else:
            head_counter[hrg.lhs.unique_label] += counter_item.count
        hrg2cfg_mapping.setdefault(hrg, []).append((index, shrg_rule.cfg, counter_item.count))

    LOGGER.info('All hrg rules: %s', len(hrg2cfg_mapping))
    LOGGER.info('All rules: %s', len(rules_counter))

    with _open('.mapping.txt', 'w') as mapping_out:
        TEXT_WRITRES.invoke('mapping', mapping_out, hrg2cfg_mapping, head_counter,
                            write_detail=False)
    del hrg2cfg_mapping
    del head_counter

    with _open('.rules.detail.txt', 'w') as rule_detail_out:
        TEXT_WRITRES.invoke('shrg', rule_detail_out, rules_counter, write_detail=True)

    with _open('.counter.p', 'wb') as out:
        pickle.dump((rules_counter, options), out)


class DeepBankExtractor(ReaderBase):
    Options = ExtractionOptions

    def __init__(self, options, data_path, tree_path, split_patterns, logger=LOGGER):
        super().__init__(options, data_path, split_patterns, logger=logger, tree_path=tree_path)

    def on_error(self, filename, error):
        self.logger.error('%s %s', filename, error)

    @classmethod
    def _read_tree_and_graph(cls, options, tree_string, graph_data):
        fields = graph_data.strip().split('\n\n')

        const_tree = ConstTree.from_java_code_and_deepbank_1_1(tree_string, graph_data)[0]
        if options.label_type == 'cfg':
            merge_cfg_tree(fields, const_tree)

        const_tree = modify_const_tree(const_tree, options.modify_tree, options.fix_hyphen)

        hyper_graph, eds_graph = READERS.invoke(options.graph_type, fields, options)

        for transformer in options.graph_transformers:
            hyper_graph = TRANSFORMS.invoke(transformer, hyper_graph)

        lexicalize_options = options.remove_null_semantic
        ignore_punct = lexicalize_options and 'ignore_punct' in lexicalize_options

        eds_graph.lemma_sequence = \
            ' '.join(x.string for x in const_tree.generate_lexicons(ignore_punct))

        return const_tree, (hyper_graph, eds_graph)

    @classmethod
    def _worker(cls, args):
        files, options, training, extra_args = args

        bank = os.path.basename(os.path.dirname(files[0]))
        tree_strings = get_tree_bank(extra_args['tree_path'], bank)

        outputs = []
        for filename in files:
            sentence_id = os.path.basename(filename).rstrip('.gz')
            tree_string = tree_strings[sentence_id]
            with gzip.open(filename, 'rb') as fp:
                try:
                    output = cls._read_tree_and_graph(options, tree_string, fp.read().decode())
                    output = cls.build_graph(options, output, sentence_id, training)
                    if output is None:
                        continue
                    outputs.append((True, filename, (bank, sentence_id, output)))
                except IgnoreException as err:
                    LOGGER.debug('%s %s', err, sentence_id)
                except Exception as err:
                    LOGGER.exception('%s', sentence_id)
                    outputs.append((False, filename, err))

        return outputs

    @classmethod
    def _extract_rule(cls, options, hyper_graph, const_tree, sentence_id, **kwargs):
        return extract_shrg_rule(hyper_graph, const_tree,
                                 detect_function=options.detect_function,
                                 fully_lexicalized=options.fully_lexicalized,
                                 remove_null_semantic=options.remove_null_semantic,
                                 ep_permutation_methods=options.ep_permutation,
                                 graph_type=options.graph_type,
                                 sentence_id=sentence_id,
                                 **kwargs)

    @classmethod
    def build_graph(cls, options, output, sentence_id, training):
        const_tree, (hyper_graph, eds_graph) = output

        if options.remove_disconnected and not hyper_graph.is_connected():
            LOGGER.debug('%s disconnected', sentence_id)
            return

        shrg_rules = nodes_info = edges_info = None
        if training:
            shrg_rules, (_, edge_blame_dict), boundary_node_dict = \
                cls._extract_rule(options, hyper_graph, const_tree, sentence_id)

            nodes_info = [None] * len(shrg_rules)
            edges_info = [set() for _ in range(len(shrg_rules))]
            for edge, step in edge_blame_dict.items():
                assert edge.is_terminal, f'{edge} is not a terminal ???'
                edges_info[step].add(edge.to_tuple())
            for step, nodes in boundary_node_dict.items():
                nodes_info[step] = tuple(_.name for _ in nodes)

        return shrg_rules, (edges_info, nodes_info), const_tree, eds_graph

    @classmethod
    def extract(cls, options, tree_string, graph_data, sentence_id, anonymous=False):
        const_tree, (hyper_graph, eds_graph) = \
            cls._read_tree_and_graph(options, tree_string, graph_data)

        try:
            output = cls._extract_rule(options, hyper_graph, const_tree, sentence_id,
                                       return_derivation_infos=True)
        except Exception as err:
            return err, const_tree, hyper_graph

        if anonymous:
            shrg_rules, *rest = output
            for index, shrg_rule in enumerate(shrg_rules):
                new_rule = anonymize_rule(shrg_rule)
                if new_rule is not None:
                    shrg_rules[index] = new_rule
        return output, const_tree, hyper_graph

    def save(self, all_results, output_prefix, training):
        options = self.options
        _open = open_wrapper(lambda x: output_prefix + x)

        rules_counter = IndexedCounter(5)
        derivations = {}

        tree_writer = TEXT_WRITRES.normalize('tree')
        graph_writer = TEXT_WRITRES.normalize(options.graph_type)

        modify_label = options.modify_label

        total_count = len(all_results)
        with _open('.graphs.txt', 'w') as graph_out, _open('.trees.txt', 'w') as tree_out:
            graph_out.write(str(total_count) + '\n')
            tree_out.write(str(total_count) + '\n')
            for (bank, sentence_id,
                 (shrg_rules, (edges_info, nodes_info), const_tree, eds_graph)) in all_results:
                if training:
                    shrg_rules = [
                        rules_counter.add(shrg_rule, (sentence_id, step))
                        for step, shrg_rule in enumerate(shrg_rules)
                    ]

                    derivation = []
                    assert len(shrg_rules) == len(edges_info), 'Strange ???'
                    for node_index, (rule_index, node, *info) in \
                        enumerate(zip(shrg_rules,
                                      const_tree.traverse_postorder(),
                                      edges_info, nodes_info)):
                        children = filter(lambda x: isinstance(x, ConstTree),
                                          getattr(node, 'children', []))
                        derivation.append((rule_index, *info,
                                           *(child.index for child in children)))
                else:
                    derivation = shrg_rules
                derivations[sentence_id] = derivation

                sentence_id = bank + os.path.sep + sentence_id
                graph_writer(graph_out, sentence_id, eds_graph, modify_label)
                tree_writer(tree_out, sentence_id, const_tree, shrg_rules)

        LOGGER.info('All graphs: %s', len(derivations))

        if not training:
            return

        with _open('.derivations.p', 'wb') as out:
            pickle.dump(derivations, out)
        del derivations

        save_rules(output_prefix, rules_counter, options)


def anonymize_rules(output_prefix):
    _open = open_wrapper(lambda x: output_prefix + x)
    rules, params = pickle.load(_open('.counter.p', 'rb'))
    num_rules = len(rules)
    LOGGER.info('Loaded %d rules (%s)', num_rules, output_prefix)

    anonymize_mapping = {}
    anonymous_rules = IndexedCounter(5)
    num_new_rules = 0
    for i, (rule, counter_item) in enumerate(rules):
        new_rule = anonymize_rule(rule)
        if new_rule is not None:
            num_new_rules += 1
            new_index = rules.add(new_rule, counter_item)
            anonymize_mapping.setdefault(new_index, set()).add(i)
        # accumulate
        anonymous_rules.add(new_rule or rule, counter_item)

    LOGGER.info('Generate %d new rules (merged from %d)', len(rules) - num_rules, num_new_rules)
    save_rules(output_prefix, anonymous_rules, params, extra_suffix='.anonymous')
    save_rules(output_prefix, rules, params, extra_suffix='.merged')
    pickle.dump(anonymize_mapping, _open('.merged.relations.p', 'wb'))


def extract_shrg_from_dataset(options: MainOptions, split_patterns, train_splits={'train'}):
    extractor = DeepBankExtractor(options.extraction,
                                  options.data_path, options.tree_path, split_patterns)

    options.prefix = options.prefix.format(grammar=options.grammar_name,
                                           graph_type=options.extraction.graph_type,
                                           suffix=options.extraction.detect_function)

    os.makedirs(os.path.dirname(options.prefix), exist_ok=True)

    options.extraction.to_file(options.prefix + 'config')

    num_workers = options.num_workers
    if num_workers == -1:
        num_workers = min(cpu_count(), 8)

    for split, _ in split_patterns:
        output_prefix = options.prefix + split
        LOGGER.info('use %d processes', num_workers)
        LOGGER.info('output_prefix: %s', output_prefix)
        LOGGER.info('begin to extract rules ...')

        training = (split in train_splits)
        results = extractor.get_split(split, num_workers=num_workers, training=training)

        LOGGER.info('begin to write results ...')

        extractor.save(sorted(results.values(), key=lambda _: _[1]),
                       options.prefix + split,
                       training=training)
        if training:
            anonymize_rules(output_prefix)
