from typing import Dict, List, Tuple, TypeVar, Hashable, Generic, Optional

from dataclasses import dataclass

from coli.hrgguru.hrg import CFGRule
from coli.hrgguru.hyper_graph import PredEdge, HyperEdge, IsTerminal, HyperGraph
from coli.hrgguru.unlexicalized_rules import get_lemma_and_pos, recover_edge_label
from coli.span.const_tree import ConstTree, Lexicon

SpanType = Tuple[int, int]
NodeInfoType = Tuple[str, Optional[str], Optional[SpanType]]  # name, label, span
T = TypeVar("T", bound=Hashable)


@dataclass
class UFSet(Generic[T]):
    roots: Dict[T, T]

    @classmethod
    def from_leafs(cls, symbols=()):
        return cls({i: i for i in symbols})

    @classmethod
    def from_dict(cls, inputs):
        return cls(dict(inputs))

    def __getitem__(self, symbol):
        parent_node = self.roots[symbol]
        if parent_node == symbol:
            return symbol
        self.roots[symbol] = answer = self[parent_node]
        return answer

    def __add__(self, other: "UFSet"):
        parents = {}
        parents.update(self.roots)
        parents.update(other.roots)
        return UFSet(parents)

    def __iadd__(self, other: "UFSet"):
        self.roots.update(other.roots)
        return self

    def copy(self):
        return UFSet(dict(self.roots))

    def union(self, a, b):
        self.roots[self[b]] = self[a]


def transform_edge_label(label, tree_node, lexicon_to_lemma):
    if lexicon_to_lemma is None:
        return label

    if not isinstance(tree_node.children[0], Lexicon):
        return label

    if "{NEWLEMMA}" not in label:
        return label

    word = tree_node.children[0].string
    return recover_edge_label(label, word, lexicon_to_lemma)


@dataclass
class SubGraphFeatureTracker(object):
    ufset: UFSet
    external_nodes: List[str]
    own_node_features: Dict[str, Tuple[Tuple[str, str], SpanType]]
    boundary_node_features: Dict[str, Tuple[Tuple[str, str], SpanType]]
    own_known_edge_features: List[Tuple[NodeInfoType, str, NodeInfoType]]
    unknown_edge_features: List[Tuple[NodeInfoType, str, NodeInfoType]]
    left: Optional["SubGraphFeatureTracker"]
    right: Optional["SubGraphFeatureTracker"]

    @classmethod
    def from_graph(cls, cfg_node: ConstTree, sync_rule: CFGRule, prefix="",
                   left: Optional["SubGraphFeatureTracker"] = None,
                   right: Optional["SubGraphFeatureTracker"] = None,
                   lexicon_to_lemma=None
                   ):
        children = []

        # union node names
        ufset = UFSet.from_leafs([prefix + node.name for node in sync_rule.hrg.rhs.nodes])
        if left is not None:
            children.append(left)
            ufset += left.ufset
            left_edge = sync_rule.rhs[0][1]
            for a, b in zip(left.external_nodes,
                            (prefix + node.name for node in left_edge.nodes)
                            ):
                ufset.union(a, b)

        if right is not None:
            children.append(right)
            ufset += right.ufset
            right_edge = sync_rule.rhs[1][1]
            for a, b in zip(right.external_nodes,
                            (prefix + node.name for node in right_edge.nodes)
                            ):
                ufset.union(a, b)

        external_nodes = [ufset[prefix + node.name] for node in sync_rule.hrg.lhs.nodes]
        own_node_features = {}
        useful_node_features = {}
        own_known_edge_features = []
        unknown_edge_features = []

        # nodes in this rule
        for edge in sync_rule.hrg.rhs.edges:
            if len(edge.nodes) == 1 and edge.is_terminal:
                node_name = ufset[prefix + edge.nodes[0].name]
                value = ((edge.label,
                          transform_edge_label(edge.label, cfg_node, lexicon_to_lemma)
                          if lexicon_to_lemma is not None else None),
                         cfg_node.extra["DelphinSpan"])
                own_node_features[node_name] = value
                useful_node_features[node_name] = value

        # nodes in children
        for child_idx, child in enumerate(children):
            # root node may change
            for k, v in child.boundary_node_features.items():
                useful_node_features[ufset[k]] = v

        node_features = {}
        for node in sync_rule.hrg.rhs.nodes:
            node = ufset[prefix + node.name]
            value = node_features.get(node)
            if value is not None:
                useful_node_features[node] = value

        # edges in this rule
        for edge in sync_rule.hrg.rhs.edges:
            if len(edge.nodes) == 1 and edge.is_terminal:
                continue
            elif len(edge.nodes) == 2 and edge.is_terminal:
                left_name = ufset[prefix + edge.nodes[0].name]
                right_name = ufset[prefix + edge.nodes[1].name]
                left_node = useful_node_features.get(left_name)
                right_node = useful_node_features.get(right_name)

                left_info = (left_name, left_node[0], left_node[1]) if left_node else (left_name, None, None)
                right_info = (right_name, right_node[0], right_node[1]) if right_node else (right_name, None, None)
                if left_node and right_node:
                    own_known_edge_features.append((left_info, edge.label, right_info))
                else:
                    unknown_edge_features.append((left_info, edge.label, right_info))
            elif not edge.is_terminal:
                continue
            else:
                raise Exception("Unexpected edge {}".format(edge))

        # edge in children
        for child_idx, child in enumerate(children):
            for (left_name, left_label, left_span), label, (right_name, right_label, right_span) in \
                    child.unknown_edge_features:
                left_name = ufset[left_name]
                right_name = ufset[right_name]
                if left_label is None:
                    left_node = useful_node_features.get(left_name)
                    if left_node is not None:
                        left_label, left_span = left_node
                if right_label is None:
                    right_node = useful_node_features.get(right_name)
                    if right_node is not None:
                        right_label, right_span = right_node
                all_info = (left_name, left_label, left_span), label, (right_name, right_label, right_span)
                if left_label and right_label:
                    own_known_edge_features.append(all_info)
                else:
                    unknown_edge_features.append(all_info)

        return cls(UFSet.from_dict({i: ufset[i] for i in external_nodes}),
                   external_nodes,
                   own_node_features,
                   {i: useful_node_features[i] for i in external_nodes},
                   own_known_edge_features,
                   unknown_edge_features,
                   left, right
                   )

    def accumulate_features(self):
        ufset = self.ufset.copy()
        node_features = dict(self.own_node_features)
        edge_features = list((a_info[0], label, b_info[0])
                             for a_info, label, b_info in self.own_known_edge_features)
        children = []
        if self.left is not None:
            children.append(self.left)
        if self.right is not None:
            children.append(self.right)

        for symbol in node_features.keys():
            if symbol not in ufset.roots.keys():
                ufset.roots[symbol] = symbol

        for a_name, label, b_name in edge_features:
            if a_name not in ufset.roots.keys():
                ufset.roots[a_name] = a_name
            if b_name not in ufset.roots.keys():
                ufset.roots[b_name] = b_name

        stack = []
        for child in children:
            child_ufset, child_node_features, child_edge_features = child.accumulate_features()
            stack.append((child, child_node_features, child_edge_features))
            for symbol, parent in child_ufset.roots.items():
                if symbol not in ufset.roots.keys():
                    ufset.roots[symbol] = parent

        for child, child_node_features, child_edge_features in stack:
            for node_name, value in child_node_features.items():
                node_features[ufset[node_name]] = value
            for a_name, label, b_name in child_edge_features:
                edge_features.append((ufset[a_name], label, ufset[b_name]))

        return ufset, node_features, edge_features

    def make_graph(self):
        ufset, node_features, edge_features = self.accumulate_features()
        nodes, pred_edges = zip(*[PredEdge.as_new(span, real_label, k)
                                  for k, ((label, real_label), span) in node_features.items()])
        name_to_node = {i.name: i for i in nodes}
        edges = set(HyperEdge([name_to_node[ufset[a_name]], name_to_node[ufset[b_name]]],
                              label, IsTerminal.TERMINAL)
                    for a_name, label, b_name in edge_features)
        edges.update(pred_edges)
        return HyperGraph(frozenset(nodes), frozenset(edges))
