import itertools
from typing import Tuple, Dict, Generator, Union

from dataclasses import dataclass

from coli.basic_tools.common_utils import add_slots
from coli.hrgguru.hrg import CFGRule
from coli.hrgguru.hyper_graph import HyperGraph, GraphNode, HyperEdge
from coli.hrgguru.unlexicalized_rules import recover_edge_label
from coli.span.const_tree import ConstTree


@add_slots
@dataclass()
class SubGraph(object):
    """
    Subgraph is a hypergraph corresponds to a span in the string,
    Subgraph doesn't contain any non-terminal edges.
    """
    graph: HyperGraph
    external_nodes: Tuple[GraphNode]

    @staticmethod
    def transform_edge(mapping, edge, word=None, lexicon_to_lemma=None):
        """ transform the edge in the rule into edge in concrete graph."""
        label = edge.label
        if "{NEWLEMMA}" in edge.label:
            label = recover_edge_label(label, word, lexicon_to_lemma)
        return HyperEdge((mapping[i] for i in edge.nodes),
                         label,
                         edge.is_terminal,
                         None)

    @staticmethod
    def transform_edge_2(mapping, edge):
        """ transform the edge in the rule into edge in concrete graph."""
        return HyperEdge(((mapping.get(i) or i) for i in edge.nodes),
                         edge.label,
                         edge.is_terminal,
                         edge.span)

    @classmethod
    def merge(cls, cfg_node: ConstTree,
              sync_rule: CFGRule,
              left_sub_graph: "SubGraph",
              right_sub_graph: "SubGraph",
              return_blame_map: bool = False
              ) -> Union["SubGraph", Tuple["SubGraph", dict]]:
        blame_map = {}

        # create concrete node and unify with external nodes of subgraphs
        nodes_mapping: Dict[GraphNode, GraphNode] = {
            i: GraphNode() for i in sync_rule.hrg.rhs.nodes}
        external_nodes_map_left = {}
        external_nodes_map_right = {}
        left_name, left_edge = sync_rule.rhs[0]
        if left_edge is not None:
            assert len(left_sub_graph.external_nodes) == len(left_edge.nodes)
            external_nodes_map_left.update(
                {abstract_node: concrete_node
                 for abstract_node, concrete_node
                 in zip(left_edge.nodes, left_sub_graph.external_nodes)})

        right_name, right_edge = sync_rule.rhs[1]
        if right_edge is not None:
            assert len(right_sub_graph.external_nodes) == len(right_edge.nodes)
            external_nodes_map_right.update(
                {abstract_node: concrete_node
                 for abstract_node, concrete_node
                 in zip(right_edge.nodes, right_sub_graph.external_nodes)})
        nodes_mapping.update(external_nodes_map_left)
        nodes_mapping.update(external_nodes_map_right)
        common_mapping = {}
        if left_edge is not None and right_edge is not None:
            for abstract_node in external_nodes_map_left.keys() & external_nodes_map_right.keys():
                common_mapping[external_nodes_map_left[abstract_node]] = external_nodes_map_right[abstract_node]

        # build new graph
        edges = frozenset(cls.transform_edge(nodes_mapping, edge)
                          for edge in sync_rule.hrg.rhs.edges
                          if edge != left_edge and edge != right_edge)

        non_terminals = [i for i in edges if not i.is_terminal]
        if non_terminals:
            raise Exception("Non-terminals {} found by rule {} in node {}".format(
                non_terminals, sync_rule, cfg_node))
        for new_edge in edges:
            if len(new_edge.nodes) == 1 and new_edge.span is None:
                new_edge.span = cfg_node.extra["DelphinSpan"]
        nodes = frozenset(nodes_mapping.values())

        if return_blame_map:
            for i in nodes:
                blame_map[i] = "this"
            for i in edges:
                blame_map[i] = "this"

        if right_edge is not None:
            edges |= right_sub_graph.graph.edges
            nodes |= right_sub_graph.graph.nodes

            if return_blame_map:
                for i in right_sub_graph.graph.nodes:
                    blame_map[i] = "right"
                for i in right_sub_graph.graph.edges:
                    blame_map[i] = "right"

        if left_edge is not None:
            new_edges = frozenset(cls.transform_edge_2(common_mapping, edge) for edge in left_sub_graph.graph.edges)
            edges |= new_edges
            new_nodes = (left_sub_graph.graph.nodes - common_mapping.keys())
            nodes |= new_nodes

            if return_blame_map:
                for i in new_nodes:
                    blame_map[i] = "left"
                for i in new_edges:
                    blame_map[i] = "left"

        external_nodes = tuple(nodes_mapping[node]
                               for node in sync_rule.hrg.lhs.nodes)
        sub_graph = HyperGraph(nodes, edges)
        ret = SubGraph(sub_graph, external_nodes)
        if return_blame_map:
            return ret, blame_map
        else:
            return ret

    @classmethod
    def create_leaf_graph(cls, cfg_node, sync_rule, word=None, lexicon_to_lemma=None):
        """ :rtype: SubGraph """
        nodes_mapping: Dict[GraphNode, GraphNode] = {
            i: GraphNode() for i in sync_rule.hrg.rhs.nodes}
        edges = frozenset(cls.transform_edge(nodes_mapping, edge, word, lexicon_to_lemma)
                          for edge in sync_rule.hrg.rhs.edges)
        for new_edge in edges:
            if len(new_edge.nodes) == 1:
                new_edge.span = cfg_node.extra["DelphinSpan"]
        sub_graph = HyperGraph(frozenset(nodes_mapping.values()),
                               edges)
        external_nodes = tuple(nodes_mapping[node]
                               for node in sync_rule.hrg.lhs.nodes)
        return SubGraph(sub_graph, external_nodes)

    def permutations(self) -> Generator["SubGraph", None, None]:
        if len(self.external_nodes) == 1:
            yield self
        else:
            for external_nodes in itertools.permutations(self.external_nodes):
                yield SubGraph(self.graph, external_nodes)
