# -*- coding: utf-8 -*-

import os
import pickle
from multiprocessing import Manager

import SHRG.shrg_extract as E
from framework.common.advice import advice_add
from framework.common.logger import LOGGER, open_file
from SHRG.const_tree import ConstTree

SKIP = 0
EXCHANGE = 1


class UnhandledException(Exception):
    pass


def exchange_tree_children(const_tree, C):
    node2parent = {}
    for node in const_tree.traverse_postorder():
        for child in node.children:
            if isinstance(child, ConstTree):
                node2parent[child.index] = node

    #      P                  P             P                  P
    #     / \                / \           / \                / \
    #    C   n4     =>      n1  C         n3   C     =>      C   n2
    #   / \                    / \            / \           / \
    #  n1  n2                 n2  n4         n1  n2        n3  n1
    P = node2parent[C.index]
    assert len(C.children) == 2 and len(P.children) == 2
    n1, n2 = C.children
    n3, n4 = P.children
    assert n3 is C or n4 is C
    if n3 is C:
        P.children = [n1, C]
        C.children = [n2, n4]
    else:
        P.children = [C, n2]
        C.children = [n3, n1]


def check_shrg_rule(shrg_rule, exchange_set):
    if shrg_rule.hrg is None:
        return SKIP

    hg = shrg_rule.hrg.rhs
    # filter _q edges
    components = [component for component in hg.connected_components()
                  if len(component) != 1 or not next(iter(component)).label.endswith('_q')]
    if len(components) < 2:
        return SKIP

    c1, c2 = components
    if len(c1) == len(c2) == 1:
        e1 = c1.pop()
        e2 = c2.pop()
        if not e1.is_terminal and not e2.is_terminal:
            l1 = e1.label.rsplit('@', 1)[-1]
            l2 = e2.label.rsplit('@', 1)[-1]
            if l1 in exchange_set or l2 in exchange_set:
                return EXCHANGE
            return components
        # both edges are two terminal edges or two nonterminal edges
        assert e1.is_terminal and e2.is_terminal, '??? strange rule'
        return components

    return components


def find_interest_steps(shrg_rules, exchange_set):
    exchange_steps = []
    unhandled_steps = []
    for step, shrg_rule in enumerate(shrg_rules):
        components = check_shrg_rule(shrg_rule, exchange_set)
        if components is SKIP:
            continue
        elif not isinstance(components, int):
            unhandled_steps.append((step))
        else:
            exchange_steps.append(step)
    return exchange_steps, unhandled_steps


class Extractor:
    def __init__(self, exchange_set, cache_path=None, keep_all_sentence=False):
        self.exchange_set = exchange_set
        self.cache_path = cache_path
        self._keep_all_sentence = keep_all_sentence

        self._create_cache = not os.path.exists(cache_path)
        if not self._create_cache:
            sentence_infos = pickle.load(open_file(cache_path, 'rb'))
            self.unhandled_sentence_ids, self.success_sentence_ids = sentence_infos
        else:
            self.manager = Manager()
            self.unhandled_sentence_ids = self.manager.dict()
            self.success_sentence_ids = self.manager.dict()

    def save_stats(self):
        if self._create_cache and not self._keep_all_sentence:
            LOGGER.info('unhandled sentences: %d', len(self.unhandled_sentence_ids))
            LOGGER.info('success sentences: %d', len(self.success_sentence_ids))
            pickle.dump((set(self.unhandled_sentence_ids), dict(self.success_sentence_ids)),
                        open_file(self.cache_path, 'wb'))

    def _transfom_steps(self, steps):
        return [f'{self.sentence_id}/{step}' for step in steps]

    def __call__(self, fn, hyper_graph, const_tree, *args, **kwargs):
        if self._keep_all_sentence:
            try:
                return self.forward(fn, hyper_graph, const_tree, *args, **kwargs)
            except UnhandledException:
                LOGGER.exception('can not handle this sentence, but give it another try')
                return fn(hyper_graph, const_tree, *args, **kwargs)
        else:
            return self.forward(fn, hyper_graph, const_tree, *args, **kwargs)

    def forward(self, fn, hyper_graph, const_tree, *args, **kwargs):
        self.sentence_id = sentence_id = kwargs.get('sentence_id')

        if not self._create_cache:
            new_const_tree = self.success_sentence_ids.get(sentence_id)
            if new_const_tree is not None:
                const_tree = new_const_tree
            elif sentence_id in self.unhandled_sentence_ids:
                raise E.IgnoreException(f'skip unhandled sentence {sentence_id}')

            result = fn(hyper_graph, const_tree, *args, **kwargs)
            exchange_steps, unhandled_steps = find_interest_steps(result[0], {})
            assert not unhandled_steps, f'??? {unhandled_steps}'
            return result

        result = fn(hyper_graph, const_tree, *args, **kwargs)
        exchange_steps, unhandled_steps = find_interest_steps(result[0], self.exchange_set)

        if not exchange_steps:
            if unhandled_steps:
                self.unhandled_sentence_ids[sentence_id] = True
                unhandled_steps = self._transfom_steps(unhandled_steps)
                # do not use this sentence
                raise UnhandledException(
                    f'{unhandled_steps}; U/S = '
                    f'{len(self.unhandled_sentence_ids)}/{len(self.success_sentence_ids)}')
            return result

        tree_nodes = {node.index: node for node in const_tree.traverse_postorder()}
        for step in exchange_steps:
            exchange_tree_children(const_tree, tree_nodes[step])

        # use modified tree
        new_result = fn(hyper_graph, const_tree, *args, **kwargs)
        exchange_steps, unhandled_steps = find_interest_steps(new_result[0], self.exchange_set)

        if exchange_steps or unhandled_steps:
            mapping = {node.index: original_index for original_index, node in tree_nodes.items()}
            self.unhandled_sentence_ids[sentence_id] = True

            unhandled_steps = self._transfom_steps(mapping[step] for step in unhandled_steps)
            exchange_steps = self._transfom_steps(mapping[step] for step in exchange_steps)
            # do not use this sentence
            raise UnhandledException(
                f'not working: {exchange_steps} or unhandled: {unhandled_steps}; U/S = '
                f'{len(self.unhandled_sentence_ids)}/{len(self.success_sentence_ids)}')

        # all disconnected rules are removed
        self.success_sentence_ids[sentence_id] = const_tree
        return new_result


def main(argv=None, keep_all_sentence=False):
    def around_extract_shrg_from_dataset(fn, options, *args, **kwargs):
        if options.extraction.label_type == 'cfg':
            exchange_set = ('V', 'P')
        else:
            exchange_set = ('v', 'p')

        cache_path = f'output/{options.grammar_name}.sentence_infos.p'
        extractor = Extractor(exchange_set, cache_path, keep_all_sentence)
        advice_add(E, 'extract_shrg_rule', extractor)

        try:
            fn(options, *args, **kwargs)
        finally:
            extractor.save_stats()

    import extract_shrg_rules as M
    advice_add(M, 'extract_shrg_from_dataset', around_extract_shrg_from_dataset)

    M.main(argv)


if __name__ == '__main__':
    main()
