from __future__ import print_function
import hashlib
from collections import defaultdict, UserList, OrderedDict
from operator import itemgetter

from itertools import permutations
from string import ascii_uppercase
from typing import List, Tuple, Union, Optional, Sequence, Dict, Set, FrozenSet, Any

from dataclasses import dataclass, field

from coli.basic_tools.common_utils import deprecated
from coli.hrgguru.compound_split import NodeDistributor
from coli.hrgguru.const_tree import ConstTree, Lexicon
from coli.hrgguru.hyper_graph import HyperEdge, GraphNode, HyperGraph, IsTerminal


@dataclass(eq=True, frozen=True)
class HRGRule(object):
    lhs: HyperEdge
    rhs: HyperGraph
    comment: Dict[str, str] = field(default_factory=dict, compare=False, hash=False)

    def __str__(self):
        return "{} -> \n{}\n".format(self.lhs, "\n".join(str(i) for i in self.rhs.edges))

    def __repr__(self):
        return self.__str__()

    def to_standardized_node_names(self, return_mapping=False, left_and_right_span=None):
        external_nodes = set(self.lhs.nodes)
        internal_nodes = self.rhs.nodes - external_nodes
        for edge in self.rhs.edges:
            for node in edge.nodes:
                assert node in self.rhs.nodes
        node_mapping, new_rule = self.extract(
            self.rhs.edges, internal_nodes,
            external_nodes, self.lhs.label, left_and_right_span)
        if return_mapping:
            return node_mapping, new_rule
        else:
            return new_rule

    @classmethod
    def extract(cls,
                edges: Union[Set[HyperEdge], FrozenSet[HyperEdge]],
                internal_nodes: Union[Set[GraphNode], FrozenSet[GraphNode]],
                external_nodes: Union[Set[GraphNode], FrozenSet[GraphNode]],
                label: str,
                ep_permutation_methods: Union[Set[str], FrozenSet[str]],
                *,
                left_and_right_span: Optional[Tuple[Tuple[int, int], Tuple[int, int]]] = None,
                extra_labels=None,
                combine_dict: Dict[GraphNode, Sequence[int]] = None
                ):
        node_rename_map = {}
        nodes = internal_nodes.union(external_nodes)
        edge_by_node = defaultdict(list)  # node -> (edge, index of this node in this edge)
        for edge in edges:
            for idx, node in enumerate(edge.nodes):
                edge_by_node[node].append((edge, idx))

        default_hash = hashlib.md5(b"13").digest()
        node_hashes = {node: default_hash for node in nodes}  # node -> hash

        def get_edge_hashes(
                node_hashes,  # type: Dict[GraphNode, bytes]
                edge,  # type: HyperEdge
                idx  # type: int
        ):
            """hash(edge) = hash(edge.label#edge.is_terminal#external_count#hash(node_1)#hash(node_2)#...)"""
            md5_obj = hashlib.md5((edge.label + "#" + edge.is_terminal.name + "#" +
                                   str(idx)).encode())
            for adj_node in edge.nodes:
                md5_obj.update(node_hashes[adj_node] + b"#")
            return md5_obj.digest()

        def get_sibling_hashes(
                node_hashes,  # type: Dict[GraphNode, bytes]
                node  # type: GraphNode
        ):
            """hash(node)=hash(set of hash(sibling_edges)#node_name if exist)"""
            md5_obj = hashlib.md5()
            edge_hashes = sorted(get_edge_hashes(node_hashes, edge, idx)
                                 for edge, idx in edge_by_node[node])
            for h in edge_hashes:
                md5_obj.update(h)
            if node_rename_map.get(node) is not None:
                md5_obj.update(("#" + node_rename_map[node].name).encode("utf-8"))
            return md5_obj.digest()

        def recalculate_hashes():
            new_node_hashes = {}
            # recalculate hashes
            for node in nodes:
                md5_obj = hashlib.md5()
                md5_obj.update(get_sibling_hashes(node_hashes, node))
                md5_obj.update(b'\x01' if node in external_nodes else b'\x00')
                new_node_hashes[node] = md5_obj.digest()
            return new_node_hashes

        for cycle in range(len(nodes) + 10):
            node_hashes = recalculate_hashes()
        node_hashes_original = dict(node_hashes)

        while len(node_rename_map) < len(nodes):
            nodes_in_order = sorted(node_hashes.items(), key=itemgetter(1))
            has_symmetric = False
            for idx, (node, hash_value) in enumerate(nodes_in_order):
                if idx != len(nodes_in_order) - 1 and nodes_in_order[idx + 1][1] == hash_value:
                    # detect symmetric
                    has_symmetric = True
                    assert node not in node_rename_map
                    idx = len(node_rename_map)
                    node_rename_map[node] = GraphNode(str(idx))
                    for cycle in range(len(nodes) + 10):
                        node_hashes = recalculate_hashes()
                    break
            if not has_symmetric:
                for node, hash_value in nodes_in_order:
                    if node not in node_rename_map:
                        idx = len(node_rename_map)
                        node_rename_map[node] = GraphNode(str(idx))
                break

        # get rhs
        new_edges = []
        for edge in edges:
            new_edges.append(
                HyperEdge((node_rename_map[node] for node in edge.nodes),
                          edge.label, edge.is_terminal))
        rhs = HyperGraph(frozenset(node_rename_map.values()),
                         frozenset(new_edges))

        # determine external nodes permutation
        def get_external_nodes_permutation() -> Tuple[List[GraphNode], Dict[str, str]]:
            if "stick" in ep_permutation_methods:
                pending = []
                for permutation in permutations(external_nodes):
                    if any(edge.nodes == permutation for edge in edges):
                        pending.append(permutation)

                if len(pending) == 1:
                    permutation = pending[0]
                    comment = {"EP permutation": "Stick hyperedge to one edge"}
                    return [node_rename_map[i] for i in permutation], comment

            if "spans" in ep_permutation_methods:
                if len(external_nodes) == 2:
                    if left_and_right_span is not None:
                        left_span, right_span = left_and_right_span
                        left_node = [edge.nodes[0] for edge in edges
                                     if len(edge.nodes) == 1 and edge.span == left_span]
                        right_node = [edge.nodes[0] for edge in edges
                                      if len(edge.nodes) == 1 and edge.span == right_span]
                        if left_node and right_node and \
                                node_hashes_original[left_node[0]] != node_hashes_original[right_node[0]] \
                                and {left_node[0], right_node[0]} == external_nodes:
                            comment = {"EP permutation": "judge #EP2 edge direction by spans of left and right node"}
                            return [node_rename_map[left_node[0]], node_rename_map[right_node[0]]], comment

            def key_func(node):
                keys = []
                # combine time
                if "first_combine_time" in ep_permutation_methods:
                    keys.append(combine_dict[node][0])
                elif "combine_time" in ep_permutation_methods:
                    keys.append(combine_dict[node])
                elif "finish_time" in ep_permutation_methods:
                    keys.append(combine_dict[node][::-1])

                # extra_label
                if "extra_label" in ep_permutation_methods:
                    keys.append(extra_labels[node])
                return tuple(keys)

            external_nodes_list = list(external_nodes)
            keys = [key_func(i) for i in external_nodes_list]
            indexed_keys = [(idx, key) for idx, key in enumerate(keys)]
            indexed_keys.sort(key=lambda x: x[1])
            sorted_renamed_external_nodes = [node_rename_map[external_nodes_list[i]]
                                             for i, _ in indexed_keys]

            methods_string = " ".join(sorted(ep_permutation_methods))
            if len(set(keys)) == len(keys):
                comment = {"EP permutation": methods_string}
            else:
                comment = {"EP permutation": "partial " + methods_string}
            if len(external_nodes) == 1:
                comment = {}
            else:
                comment["EP permutation"] += " *** keys: {}".format(
                    ", ".join(["{} : {}".format(external_nodes_list[idx], key)
                               for idx, key in indexed_keys]))

            return sorted_renamed_external_nodes, comment

        # get lhs
        ep_permutation, comment = get_external_nodes_permutation()
        lhs = HyperEdge(ep_permutation,
                        label=label,
                        is_terminal=IsTerminal.NONTERMINAL_PENDING,
                        )
        return node_rename_map, cls(lhs, rhs, comment)

    def apply(self,
              hg,  # type: HyperGraph
              edge  # type: HyperEdge
              ):
        assert edge in hg
        assert edge.label == self.lhs.label
        assert len(edge.nodes) == len(self.lhs.nodes)

    def draw_in_graph(self):
        raise NotImplementedError

    def rhs_to_hgraph(self):
        from common.cfg import NonterminalLabel
        from common.hgraph.hgraph import Hgraph
        nt_id_count = 0
        hgraph = Hgraph()

        for node in self.rhs.nodes:  # type: GraphNode
            label = ""
            try:
                ext_id = self.lhs.nodes.index(node)
            except ValueError:
                ext_id = None
            ident = "_" + node.name

            # Insert a node into the AMR
            ignoreme = hgraph[ident]  # Initialize dictionary for this node
            hgraph.node_to_concepts[ident] = label
            if ext_id is not None:
                if ident in hgraph.external_nodes and hgraph.external_nodes[ident] != ext_id:
                    raise Exception("Incompatible external node IDs for node %s." % ident)
                hgraph.external_nodes[ident] = ext_id
                hgraph.rev_external_nodes[ext_id] = ident
            if ext_id == 0:
                hgraph.roots.append(ident)

        for edge in self.rhs.edges:  # type: HyperEdge
            hyperchild = tuple("_" + node.name for node in edge.nodes[1:])
            ident = "_" + edge.nodes[0].name
            if "_" not in edge.label and not edge.label.startswith("ARG") \
                    and not edge.label.startswith("BV"):
                # this is a nonterminal Edge
                new_edge = NonterminalLabel(edge.label)
                if not new_edge.index:
                    new_edge.index = "_%i" % nt_id_count
                    nt_id_count = nt_id_count + 1
            else:
                new_edge = edge.label

            hgraph._add_triple(ident, new_edge, hyperchild)

        return hgraph

    def to_grammar(self, rule_id):
        from parser.vo_rule import VoRule
        return VoRule(rule_id, self.lhs.label, 0.0,
                      self.rhs_to_hgraph(), None,
                      nodelabels=False, logprob=False)

    def get_node_name_map(self):
        node_name_map = {node: str(idx) for idx, node in enumerate(self.lhs.nodes)}
        for node in self.rhs.nodes:
            if node not in self.lhs.nodes:
                node_name_map[node] = ascii_uppercase[int(node.name)]
        return node_name_map

    def draw(self, save_path, draw_format="png"):
        attrs = {}
        for node in self.lhs.nodes:
            attrs[node] = {"color": "red", "shape": "square"}
        return self.rhs.draw(save_path, draw_format, attrs, self.get_node_name_map())


class HRGDerivation(UserList):
    @staticmethod
    def detect_small(hg, rule, direct_edges):
        if not direct_edges:
            return None
        related_nodes = set(i for edge in direct_edges for i in edge.nodes)
        related_edges = direct_edges.union(
            set(i for i in hg.edges
                if i.span is None and all(j in related_nodes for j in i.nodes)))
        internal_nodes = set(node for node in related_nodes
                             if all(edge in related_edges for edge in hg.edges
                                    if node in edge.nodes))
        external_nodes = related_nodes - internal_nodes

        if not external_nodes:
            # if not external nodes, random select one
            node = internal_nodes.pop()
            external_nodes.add(node)

        return related_edges, internal_nodes, external_nodes

    @deprecated
    @staticmethod
    def detect_lexicalized(hg, rule, direct_edges):
        is_lexical = isinstance(rule.child[0], Lexicon)
        if not direct_edges:
            return None
        related_nodes = set(i for edge in direct_edges for i in edge.nodes)
        related_edges = direct_edges.union(
            set(i for i in hg.edges
                if i.span is None and all(j in related_nodes for j in i.nodes)))

        def get_outgoing_edges(node):
            # if some external node only have internal edges and outgoing edges,
            # it can be converted into internal node
            ret = []
            for edge in hg.edges:
                if edge.span is None and node == edge.nodes[0]:
                    ret.append(edge)
            return ret

        # edges that start with related_nodes
        if is_lexical:
            outgoing_edges = set(edge for node in related_edges
                                 for edge in get_outgoing_edges(node))
            outgoing_nodes = set(i.nodes[1] for i in outgoing_edges)
            all_edges = related_edges.union(outgoing_edges)
            all_nodes = related_nodes.union(outgoing_nodes)
        else:
            all_edges = related_edges
            all_nodes = related_nodes

        internal_nodes = set(node for node in all_nodes
                             if all(edge in all_edges for edge in hg.edges
                                    if node in edge.nodes))
        external_nodes = all_nodes - internal_nodes

        if not external_nodes:
            # if not external nodes, random select one
            node = internal_nodes.pop()
            external_nodes.add(node)

        return all_edges, internal_nodes, external_nodes

    @classmethod
    def detect_lexicalized_2(cls, hg, rule, direct_edges):
        related_nodes = set(i for edge in direct_edges for i in edge.nodes)

        if not direct_edges:
            return None

        def get_outgoing_edges(node):
            # if some external node only have internal edges and outgoing edges,
            # it can be converted into internal node
            ret = []
            for edge in hg.edges:
                if edge.is_terminal and len(edge.nodes) == 2 and node == edge.nodes[0]:
                    ret.append(edge)
            return ret

        pred_nodes = set(i.nodes[0] for i in direct_edges if i.is_terminal)
        outgoing_edges = set(edge for node in pred_nodes
                             for edge in get_outgoing_edges(node))
        outgoing_nodes = set(i.nodes[1] for i in outgoing_edges)
        all_edges = direct_edges.union(outgoing_edges)
        all_nodes = related_nodes.union(outgoing_nodes)

        internal_nodes = set(node for node in all_nodes
                             if all(edge in all_edges for edge in hg.edges
                                    if node in edge.nodes))
        external_nodes = all_nodes - internal_nodes

        if not external_nodes:
            # if not external nodes, random select one
            node = internal_nodes.pop()
            external_nodes.add(node)

        return all_edges, internal_nodes, external_nodes

    @classmethod
    def detect_lfrg(cls, hg, rule, direct_edges):
        related_nodes = set(i for edge in direct_edges for i in edge.nodes)

        if not direct_edges:
            return None

        def get_outgoing_edges(node):
            # if some external node only have internal edges and outgoing edges,
            # it can be converted into internal node
            ret = []
            for edge in hg.edges:
                if edge.is_terminal and len(edge.nodes) == 2 and node == edge.nodes[0]:
                    ret.append(edge)
            return ret

        pred_nodes = set(i.nodes[0] for i in direct_edges if i.is_terminal)
        outgoing_edges = set(edge for node in pred_nodes
                             for edge in get_outgoing_edges(node))
        outgoing_nodes = set(i.nodes[1] for i in outgoing_edges)
        all_nodes = related_nodes.union(outgoing_nodes)
        outgoing_qeq = set(edge for edge in hg.edges
                           if edge.label == "QEQ" and edge.nodes[0] in all_nodes)
        qeq_endpoints = set(edge.nodes[1] for edge in outgoing_qeq)
        all_nodes |= qeq_endpoints
        all_edges = direct_edges | outgoing_edges | outgoing_qeq

        internal_nodes = set(node for node in all_nodes
                             if all(edge in all_edges for edge in hg.edges
                                    if node in edge.nodes))
        external_nodes = all_nodes - internal_nodes

        if not external_nodes:
            # if not external nodes, random select one
            node = internal_nodes.pop()
            external_nodes.add(node)

        return all_edges, internal_nodes, external_nodes

    @staticmethod
    def detect_large(hg,  # type: HyperGraph
                     rule, direct_edges):
        if not direct_edges:
            return None
        # node that connected with direct_edges
        related_nodes = set(i for edge in direct_edges for i in edge.nodes)
        # edge that connects internally
        internal_edges = set(i for i in hg.edges
                             if i.span is None and all(j in related_nodes for j in i.nodes))

        all_edges_0 = direct_edges | internal_edges
        internal_nodes_0 = set(node for node in related_nodes
                               if all(edge in all_edges_0 for edge in hg.edges
                                      if node in edge.nodes))
        external_nodes_0 = related_nodes - internal_nodes_0

        def can_be_internal(node):
            # if some external node only have internal edges and outgoing edges,
            # it can be converted into internal node
            ret = []
            for edge in hg.edges:
                if node not in edge.nodes:
                    continue
                if edge in all_edges_0:
                    continue
                if edge.span is None and node == edge.nodes[0]:
                    ret.append(edge)
                    continue
                return []
            return ret

        # edges that start with related_nodes
        outgoing_edges = set(edge for node in external_nodes_0
                             for edge in can_be_internal(node))
        outgoing_nodes = set(i.nodes[1] for i in outgoing_edges)

        # all edges
        all_edges = direct_edges | internal_edges | outgoing_edges
        all_nodes = related_nodes | outgoing_nodes

        internal_nodes = set(node for node in all_nodes
                             if all(edge in all_edges for edge in hg.edges
                                    if node in edge.nodes))
        external_nodes = all_nodes - internal_nodes

        if not external_nodes:
            # if not external nodes, random select one
            node = internal_nodes.pop()
            external_nodes.add(node)

        return all_edges, internal_nodes, external_nodes

    @classmethod
    def convert_cfg_node(cls, node):
        if isinstance(node, Lexicon):
            return node
        ret = ConstTree(node.tag)
        for i in node.child:
            if isinstance(i, Lexicon) or i.has_semantics:
                ret.child.append(i)
            else:
                ret.child.extend(i.generate_words())
        ret.span = node.span
        return ret

    @classmethod
    def draw(cls, hg, path,
             all_edges=(),
             internal_nodes=(),
             external_nodes=(),
             last_new_edge=None,
             draw_format="png"
             ):
        attrs = {}
        for edge in all_edges:
            attrs[edge] = {"color": "red"}
        for node in internal_nodes:
            attrs[node] = {"color": "red"}
        for node in external_nodes:
            attrs[node] = {"color": "red"}

        # draw old pic
        if last_new_edge:
            attrs[last_new_edge] = {"color": "blue"} \
                if last_new_edge not in all_edges else {"color": "violet"}

        return hg.draw(path, draw_format, attrs,
                       node_name_map={node: str(i) for i, node in enumerate(external_nodes)})


@dataclass
class HRGRulePrecursor(object):
    all_edges: Any
    internal_nodes: Any
    external_nodes: Any


@dataclass
class DerivationPrecursor(object):
    cfg: ConstTree
    hrg: HRGRulePrecursor
    hg_pending: HyperGraph
    new_edge: HyperEdge
    left: "DerivationPrecursor"
    right: "DerivationPrecursor"
    left_child_edge: HyperEdge
    right_child_edge: HyperEdge
    new_edge_sorted: HyperEdge = None


@dataclass(eq=True, frozen=True)
class CFGRule(object):
    # CFG LHS
    lhs: str
    # internal node: rhs = Sequence[(cfg_rhs_1, corresponding hrg edge1), ...]
    # leaf node: rhs = Sequence[(lexicon, correspoinding lexical edge or None)]
    rhs: Sequence[Tuple[Union[str, Lexicon], Optional[HyperEdge]]]
    # hrg: corresponding hrg rule
    hrg: Optional[HRGRule]

    def draw_source(self):
        if self.hrg is not None:
            viz = self.hrg.draw("", draw_format="source")
        else:
            viz = """digraph g {
                a [label="No semantics!"];
            }"""
        return viz

    def __hash__(self, recalculate=False):
        hash_cache = self.__dict__.get("_hash_cache")
        if hash_cache is None or recalculate:
            hash_cache = hash((self.lhs, self.rhs, self.hrg))
            self.__dict__["_hash_cache"] = hash_cache
        return hash_cache

    def __getstate__(self):
        if "_hash_cache" in self.__dict__:
            self.__dict__.pop("_hash_cache")
        return self.__dict__

    @classmethod
    def extract(cls,
                hg,  # type: HyperGraph
                cfg_root,  # type: ConstTree
                ep_permutation_methods,
                *,
                extra_labeler_class=None,
                draw=False,
                sent_id=None,
                detect_func=None,
                lexicalize_null_semantic=False,
                root_gather_all=False,
                graph_type="eds",
                fully_lexicalized=False,
                log_func=print
                ):

        cfg_root.add_postorder_idx()
        node_distribution = NodeDistributor(
            hg, cfg_root, graph_type, fully_lexicalized, log_func).solve()

        if detect_func is None:
            detect_func = HRGDerivation.detect_large

        original_hg = hg
        # Node x in step N is corresponding to which node in original graph
        original_node_map: Dict[int, Dict[GraphNode, GraphNode]] = {}
        pics: List[str] = []

        node_to_parent_node: Dict[ConstTree, ConstTree] = {cfg_root: None}
        for idx, rule in enumerate(cfg_root.generate_rules()):
            for i in rule.children:
                if isinstance(i, ConstTree):
                    node_to_parent_node[i] = rule

        # record in which step this terminal edge is selected
        edge_blame_dict: Dict[HyperEdge, int] = {}
        # record in which step this node becomes internal node
        node_blame_dict: Dict[GraphNode, int] = {}

        cache_type = Tuple[GraphNode, List[int], ConstTree, HyperEdge]
        # parent_node: (node, [result], corresponding_rule, current_edge)
        combine_cache: Dict[ConstTree, List[cache_type]] = defaultdict(list)
        combine_dict: Dict[ConstTree, Dict[GraphNode, Sequence[int]]] = defaultdict(lambda: {})

        derivation_precursors: List[Optional[DerivationPrecursor]] = []

        rules = list(cfg_root.generate_rules())  # root last

        # stage 1: condense graph recursively and collect information
        for step, rule in enumerate(rules):
            new_span = (rule.child[0].span[0], rule.child[-1].span[1])
            rule.span = new_span

            collected_pred_edges = set(node_distribution[rule])
            for idx, child_rule in enumerate(rule.children):
                if isinstance(child_rule, ConstTree):
                    child_result = derivation_precursors[child_rule.postorder_idx]
                    if child_result is not None:
                        if idx == 0:
                            child_result.new_edge.is_terminal = IsTerminal.NONTERMINAL_LEFT
                        else:
                            child_result.new_edge.is_terminal = IsTerminal.NONTERMINAL_RIGHT
                        collected_pred_edges.add(child_result.new_edge)
            result = detect_func(hg, rule, collected_pred_edges)

            # null semantic node
            if result is None:
                rule.has_semantics = False
                derivation_precursors.append(None)
                continue

            rule.has_semantics = True
            all_edges, internal_nodes, external_nodes = result

            # select all edges in root node
            if root_gather_all and rule is cfg_root:
                all_edges = hg.edges
                internal_nodes = hg.nodes - external_nodes

            # NOTE: direction of this edge may be incorrect because EP permutation is not solved
            new_edge = HyperEdge(external_nodes, rule.tag,
                                 IsTerminal.NONTERMINAL_PENDING, new_span)

            new_nodes = hg.nodes - internal_nodes
            new_edges = (hg.edges - all_edges) | {new_edge}

            hg_new = HyperGraph(new_nodes, new_edges)

            for i in all_edges:
                if i.is_terminal:
                    assert i not in edge_blame_dict
                    edge_blame_dict[i] = step

            for i in internal_nodes:
                assert i not in node_blame_dict
                node_blame_dict[i] = step

            # record combine time
            parent_node = node_to_parent_node[rule]
            if parent_node is not None:
                # solve external node combination in parent node
                for external_node in external_nodes:
                    combine_cache[parent_node].append((external_node, [], rule, new_edge))

                # solve external node combination of child nodes
                for last_external_node, last_combinations, original_rule, original_edge in combine_cache[rule]:
                    # when node becoming internal node, stop propagate and write
                    if last_external_node in internal_nodes:
                        combine_dict[original_rule][last_external_node] = tuple(last_combinations) + \
                                                                          (rule.postorder_idx,)
                    else:
                        connecting_edges = set(i for i in all_edges if last_external_node in i.nodes)
                        # when node only connect with current_edge, do not count and propagate to parent node
                        # when propagate, rewrite current_edge
                        if connecting_edges == {original_edge}:
                            combine_cache[parent_node].append(
                                (last_external_node, last_combinations, original_rule, new_edge))
                        else:
                            # count and propagate
                            combine_cache[parent_node].append(
                                (last_external_node, last_combinations + [rule.postorder_idx], original_rule,
                                 new_edge))
            else:
                # is root node
                # for self
                for external_node in external_nodes:
                    combine_dict[rule][external_node] = (rule.postorder_idx,)
                # for child nodes
                for last_external_node, last_combinations, original_rule, original_edge in combine_cache[rule]:
                    combine_dict[original_rule][last_external_node] = tuple(last_combinations) + \
                                                                      (rule.postorder_idx,)

            hg = hg_new

            # finding corresponding edges for subtrees
            if isinstance(rule.child[0], Lexicon):
                # leaf node
                assert len(rule.child) == 1
                left_precursor = right_precursor = None
                left_child_edge = right_child_edge = None
            else:
                # internal node
                assert all(isinstance(i, ConstTree) for i in rule.children)
                left_idx = rule.children[0].postorder_idx
                left_precursor = derivation_precursors[left_idx] \
                    if rule.children[0].has_semantics else None
                left_child_edge = left_precursor.new_edge \
                    if rule.children[0].has_semantics else None
                right_idx = rule.children[1].postorder_idx
                right_precursor = derivation_precursors[right_idx] \
                    if rule.children[1].has_semantics else None
                right_child_edge = right_precursor.new_edge \
                    if rule.children[1].has_semantics else None

            # follow up steps is delayed into stage 2
            derivation_precursors.append(DerivationPrecursor(
                rule,
                HRGRulePrecursor(all_edges, internal_nodes, external_nodes),
                hg_new, new_edge,
                left_precursor, right_precursor,
                left_child_edge, right_child_edge
            ))

        # stage 2: decide EP permutation and construct sync rule
        derivations = []
        hg = original_hg
        last_new_edge = None
        extra_labeler = extra_labeler_class.from_derivation_precursors(derivation_precursors)

        for step, (rule, precursor) in enumerate(zip(rules, derivation_precursors)):
            # solve null semantics
            if precursor is None:
                if lexicalize_null_semantic:
                    cfg_rhs = tuple((j, None) for j in rule.generate_words())  # type: Tuple[Tuple[Lexicon, None]]
                else:
                    cfg_rhs = tuple((i if isinstance(i, Lexicon) else i.tag, None)
                                    for i in rule.child)
                ret_rule = CFGRule(rule.tag, cfg_rhs, None)
                this_pic = ret_rule.draw_source() if draw else None
                derivations.append(ret_rule)
                pics.append(this_pic)
                continue

            if len(rule.child) == 2:
                left_and_right_span = [rule.child[0].span, rule.child[1].span]
            else:
                left_and_right_span = None

            def rewrite_nonterminal_left_right(edge):
                if edge is precursor.left_child_edge:
                    ret = HyperEdge(
                        precursor.left.new_edge_sorted.nodes, edge.label,
                        IsTerminal.NONTERMINAL_LEFT, edge.span)
                elif edge is precursor.right_child_edge:
                    ret = HyperEdge(
                        precursor.right.new_edge_sorted.nodes, edge.label,
                        IsTerminal.NONTERMINAL_RIGHT, edge.span)
                else:
                    ret = edge
                return ret

            all_edges = set(rewrite_nonterminal_left_right(i)
                            for i in precursor.hrg.all_edges)

            node_rename_map, hrg_rule = HRGRule.extract(
                all_edges,
                precursor.hrg.internal_nodes,
                precursor.hrg.external_nodes, rule.tag,
                ep_permutation_methods=ep_permutation_methods,
                left_and_right_span=left_and_right_span,
                extra_labels=extra_labeler.extra_tags[step],
                combine_dict=combine_dict[rule]
            )

            reverse_node_rename_map = original_node_map[step] = {v: k for k, v in node_rename_map.items()}
            sorted_eps = [reverse_node_rename_map[i]
                          for i in hrg_rule.lhs.nodes]
            new_edge_sorted = HyperEdge(sorted_eps, rule.tag,
                                        IsTerminal.NONTERMINAL_PENDING, rule.span)
            precursor.new_edge_sorted = new_edge_sorted

            # draw the hg *before* replacement
            if draw:
                pics.append(HRGDerivation.draw(
                    HyperGraph(hg.nodes,
                               frozenset(rewrite_nonterminal_left_right(i)
                                         for i in hg.edges)),
                    None, all_edges,
                    precursor.hrg.internal_nodes,
                    new_edge_sorted.nodes,
                    last_new_edge,
                    draw_format="source"))

            last_new_edge = new_edge_sorted
            hg = precursor.hg_pending

            # create cfg label-hrg edge mapping
            if isinstance(rule.children[0], Lexicon):
                # leaf node
                assert len(rule.children) == 1
                cfg_rhs = ((rule.children[0], None),)
            else:
                # internal node
                assert all(isinstance(i, ConstTree) for i in rule.children)
                cfg_rhs: List = []
                for i, target_edge in zip(
                        rule.child,
                        [precursor.left_child_edge, precursor.right_child_edge]):
                    if not i.has_semantics:
                        if lexicalize_null_semantic:
                            cfg_rhs.extend((j, None) for j in i.generate_words())
                        else:
                            cfg_rhs.append((i.tag, None))
                    else:
                        assert target_edge in precursor.hrg.all_edges, "Non-consistant CFG-HRG Mapping"
                        target_edge_ordered = rewrite_nonterminal_left_right(target_edge)
                        target_edge_r = HyperEdge((node_rename_map[node] for node in target_edge_ordered.nodes),
                                                  target_edge_ordered.label, target_edge_ordered.is_terminal)
                        assert target_edge_r in hrg_rule.rhs.edges
                        cfg_rhs.append((i.tag, target_edge_r))

            derivations.append(CFGRule(rule.tag, tuple(cfg_rhs), hrg_rule))

        if draw:
            pics.append(HRGDerivation.draw(
                HyperGraph(hg.nodes,
                           frozenset(rewrite_nonterminal_left_right(i) for i in hg.edges)),
                None, last_new_edge=last_new_edge,
                draw_format="source"))

        return {"derivations": derivations,
                "node_distribution": node_distribution,
                "blame_dicts": (node_blame_dict, edge_blame_dict),
                "original_node_map": original_node_map,
                "extra_labeler": extra_labeler,
                "pics": pics
                }
