import base64
import functools
import hashlib

from collections import defaultdict
from enum import Enum
from operator import itemgetter

from typing import Iterable, Optional, Tuple, FrozenSet, Mapping, Union

import os

import six
from dataclasses import dataclass

from delphin.mrs import Pred, Xmrs
from delphin.mrs.eds import Eds
from delphin.mrs.components import links as mrs_links


@functools.lru_cache(maxsize=65536)
def strip_category(cat, return_tuple=False):
    if cat.endswith("u_unknown"):
        lemma, pos_and_sense = cat.rsplit("/", 1)
        pos_part, sense_part = pos_and_sense.split("_", 1)
        lemma_part = "X"
    else:
        pred_obj = Pred.stringpred(cat)
        lemma_part = "X" if cat.startswith("_") else pred_obj.lemma
        pos_part = str(pred_obj.pos)
        sense_part = str(pred_obj.sense)
    if return_tuple:
        return lemma_part, pos_part, sense_part
    else:
        return lemma_part + "_" + pos_part + "_" + sense_part


@six.python_2_unicode_compatible
class GraphNode(object):
    def __init__(self, id_=None, is_root=False):
        self.name = id_ or base64.b64encode(os.urandom(15)).decode("ascii")
        self.is_root = is_root

    def __str__(self):
        return "GraphNode: {}".format(self.name)

    def __repr__(self):
        return str(self)

    def __hash__(self):
        return hash(self.name)

    def __eq__(self, other):
        return isinstance(other, self.__class__) and self.name == other.name


class IsTerminal(Enum):
    TERMINAL = 10001
    NONTERMINAL_LEFT = 10002
    NONTERMINAL_RIGHT = 10003
    NONTERMINAL_PENDING = 10004

    def __bool__(self):
        if self is self.TERMINAL:
            return True
        return False


@six.python_2_unicode_compatible
class HyperEdge(object):
    def __init__(self,
                 nodes,  # type: Iterable[GraphNode]
                 label,  # type: str
                 is_terminal,  # type: IsTerminal
                 span=None  # type: Optional[Tuple[int, int]]
                 ):
        self.nodes = tuple(nodes)  # as immutable list
        self.label = label  # type: str
        self.is_terminal = is_terminal  # type: IsTerminal
        self.span = span  # type: Optional[Tuple[int, int]]

    def __str__(self):
        return "{}{}: {}".format(self.span if self.span is not None else "", self.label,
                                 " -- ".join(i.name for i in self.nodes))

    def __repr__(self):
        return self.__str__()

    def __hash__(self):
        if not hasattr(self, "_hash_cache"):
            self._hash_cache = hash((self.nodes, self.label, self.is_terminal, self.span))
        return self._hash_cache

    def __getstate__(self):
        # hash may be different in another machine
        if hasattr(self, "_hash_cache"):
            delattr(self, "_hash_cache")
        return self.__dict__

    def __eq__(self, other):
        return isinstance(other, HyperEdge) and self.nodes == other.nodes and \
               self.label == other.label and self.is_terminal == other.is_terminal and \
               self.span == other.span


class PredEdge(HyperEdge):
    def __init__(self,
                 pred_node,  # type: GraphNode
                 span,  # type: Tuple[int, int]
                 label,  # type: str
                 carg=None
                 ):
        super(PredEdge, self).__init__([pred_node], label, IsTerminal.TERMINAL, span)
        self.carg = carg

    @classmethod
    def as_new(cls,
               span,  # type: Tuple[int, int]
               label,  # type: str
               name,  # type: str
               carg=None
               ):
        pred_node = GraphNode(name)
        return pred_node, cls(pred_node, span, label, carg)

    @classmethod
    def from_eds_node(cls,
                      eds_node,  # type:
                      lemma_to_x=False
                      ):
        name = str(eds_node.pred)
        if lemma_to_x:
            name = strip_category(name)
        return cls.as_new(eds_node.lnk.data, name,
                          str(eds_node.nodeid), eds_node.carg)


@dataclass(eq=True)
class HyperGraph(object):
    nodes: FrozenSet[GraphNode]
    edges: FrozenSet[HyperEdge]

    def __hash__(self):
        if not hasattr(self, "_hash_cache"):
            self._hash_cache = hash(self.nodes) ^ hash(self.edges)
        return self._hash_cache

    def __getstate__(self):
        if hasattr(self, "_hash_cache"):
            delattr(self, "_hash_cache")
        return self.__dict__

    shapes = defaultdict(lambda: "doublecircle",
                         {3: "invtriangle", 4: "diamond", 5: "star",
                          6: "hexagon", 7: "polygon", 8: "octagon"})

    @classmethod
    def from_eds(cls,
                 e: Eds,
                 lemma_to_x: bool = False,
                 strip_options: set = frozenset()
                 ) -> "HyperGraph":
        nodes = []
        nodes_by_pred_label = {}
        edges = []
        for node in e.nodes():
            if "strip_d" in strip_options and str(node.pred).endswith("_d"):
                nodes_by_pred_label[node.nodeid] = "__STRIPPED__"
                continue
            graph_node, edge = PredEdge.from_eds_node(node, lemma_to_x)
            graph_node.is_root = (node.nodeid == e.top)
            nodes_by_pred_label[node.nodeid] = graph_node
            nodes.append(graph_node)
            edges.append(edge)

        for node in e.nodes():
            for label, target in e.edges(node.nodeid).items():
                start_node = nodes_by_pred_label[node.nodeid]
                end_node = nodes_by_pred_label[target]
                if start_node == "__STRIPPED__" or end_node == "__STRIPPED__":
                    continue
                if "strip-hndl" in strip_options and label.endswith("-HNDL"):
                    continue
                edges.append(HyperEdge([start_node,
                                        end_node], label=label,
                                       is_terminal=IsTerminal.TERMINAL))

        return cls(frozenset(nodes), frozenset(edges))

    @classmethod
    def from_mrs(cls,
                 m: Xmrs,
                 lemma_to_x: bool = False,
                 strip_options: set = frozenset()
                 ):
        """:rtype: HyperGraph"""
        nodes = []
        name_to_number = {}
        nodes_by_pred_label = {}
        edges = []
        for node in m.eps():
            if "strip_d" in strip_options and str(node.pred).endswith("_d"):
                nodes_by_pred_label[node.nodeid] = "__STRIPPED__"
                continue
            graph_node, edge = PredEdge.from_eds_node(node, lemma_to_x)
            graph_node.is_root = (node.label == m.top)
            nodes_by_pred_label[node.nodeid] = graph_node
            name_to_number[node.label] = node.nodeid
            nodes.append(graph_node)
            edges.append(edge)

        for start, end, rargname, post in mrs_links(m):
            if start == 0:
                continue
            start_node = nodes_by_pred_label[start]
            end_node = nodes_by_pred_label[end]
            if start_node == "__STRIPPED__" or end_node == "__STRIPPED__":
                continue
            if "strip-hndl" in strip_options and rargname.endswith("-HNDL"):
                continue
            edges.append(HyperEdge([start_node, end_node],
                                   label=rargname + "/" + post,
                                   is_terminal=IsTerminal.TERMINAL))

        return cls(frozenset(nodes), frozenset(edges))

    def draw(self, output, file_format="pdf",
             attr_map: Mapping[Union[HyperEdge, GraphNode], dict] = None,
             node_name_map: Mapping[GraphNode, str] = None,
             show_span: bool = True) -> Optional[str]:
        """
        :param attr_map: specify attributes for edges or nodes
        :param node_name_map: render node name
        :param show_span: add span to edge label
        :return dot source code if file_format == "source", otherwise None
        """
        if attr_map is None:
            attr_map = {}
        if node_name_map is None:
            node_name_map = {}

        from graphviz import Digraph
        dot = Digraph()

        # add nodes
        for node in self.nodes:
            attr = attr_map.get(node) or {}
            attr.update({"width": "0.075", "height": "0.075", "fixedsize": "true"})
            label = node_name_map.get(node) or ""
            if label != "":
                attr.update({"width": "0.35", "height": "0.35"})
            dot.node(node.name, label=label, _attributes=attr)

        # add edges
        for edge in self.edges:
            attr = attr_map.get(edge) or {}
            attr.update({"arrowsize": "0.5"})

            # edge label
            if edge.span is not None:
                label = "{}({},{})".format(edge.label, edge.span[0], edge.span[1]) if show_span else edge.label
            else:
                label = edge.label

            if edge.is_terminal == IsTerminal.NONTERMINAL_LEFT:
                label += "(LEFT)"

            if edge.is_terminal == IsTerminal.NONTERMINAL_RIGHT:
                label += "(RIGHT)"

            carg = getattr(edge, "carg", None)
            if carg:
                label += f"[{carg}]"

            if len(edge.nodes) == 1:
                # pred edge: add a invisible node as fake end
                fake_end = edge.nodes[0].name + label + "_end" + base64.b64encode(os.urandom(3)).decode()
                dot.node(fake_end, label="",
                         _attributes={"width": "0.005", "height": "0.005", "fixedsize": "true",
                                      "color": "white"})
                dot.edge(edge.nodes[0].name, fake_end, label=label,
                         _attributes=attr)
            elif len(edge.nodes) == 2:
                # normal edge
                dot.edge(edge.nodes[0].name, edge.nodes[1].name, label,
                         _attributes=attr)
            else:
                # hyper edge:
                center_node = "{}_{}_hyperedge_center_{}".format(
                    edge.label, edge.span,
                    str(edge.is_terminal).replace(" ", ""))
                attr.pop("arrowsize")
                attr["shape"] = self.shapes[len(edge.nodes)]
                dot.node(center_node, label=label, _attributes=attr)
                for idx, end_point in enumerate(edge.nodes):
                    dot.edge(center_node, end_point.name,
                             label=str(idx),
                             )

        if file_format == "source":
            return dot.source
        else:
            dot.format = file_format
            dot.render(output, cleanup=True)

    def to_hgraph(self):
        from common.hgraph.hgraph import Hgraph
        hgraph = Hgraph()
        hgraph.my_hyper_graph = self

        for node in self.nodes:  # type: GraphNode
            label = ""
            ext_id = None
            ident = "_" + node.name

            # Insert a node into the AMR
            ignoreme = hgraph[ident]  # Initialize dictionary for this node
            hgraph.node_to_concepts[ident] = label
            if ext_id is not None:
                if ident in hgraph.external_nodes and hgraph.external_nodes[ident] != ext_id:
                    raise Exception("Incompatible external node IDs for node %s." % ident)
                hgraph.external_nodes[ident] = ext_id
                hgraph.rev_external_nodes[ext_id] = ident
            if node.is_root:
                hgraph.roots.append(ident)

        for edge in self.edges:  # type: HyperEdge
            hyperchild = tuple("_" + node.name for node in edge.nodes[1:])
            ident = "_" + edge.nodes[0].name
            new_edge = edge.label
            hgraph._add_triple(ident, new_edge, hyperchild)

        return hgraph

    def to_standardized_node_names(self, return_mapping=False):
        edge_by_node = defaultdict(list)
        for edge in self.edges:
            for idx, node in enumerate(edge.nodes):
                edge_by_node[node].append((edge, idx))

        default_hash = hashlib.md5(b"13").digest()
        node_hashes = {node: default_hash for node in self.nodes}

        def get_edge_hashes(edge,  # type: HyperEdge
                            idx  # type: int
                            ):
            md5_obj = hashlib.md5((edge.label + "#" + str(idx)).encode())
            for adj_node in edge.nodes:
                md5_obj.update(node_hashes[adj_node] + b"#")
            return md5_obj.digest()

        def get_sibling_hashes(node  # type: GraphNode
                               ):
            md5_obj = hashlib.md5()
            edge_hashes = sorted(get_edge_hashes(edge, idx)
                                 for edge, idx in edge_by_node[node])
            for h in edge_hashes:
                md5_obj.update(h)
            return md5_obj.digest()

        def recalculate_hashs():
            new_node_hashes = {}
            for node in self.nodes:
                md5_obj = hashlib.md5()
                md5_obj.update(get_sibling_hashes(node))
                new_node_hashes[node] = md5_obj.digest()
            return new_node_hashes

        for cycle in range(10):
            node_hashes = recalculate_hashs()

        nodes_in_order = sorted(node_hashes.items(), key=itemgetter(1))

        node_rename_map = {}
        for node_idx, (node, hash_value) in enumerate(nodes_in_order):
            node_rename_map[node] = GraphNode(str(node_idx))

        new_edges = []
        for edge in self.edges:
            new_edge = HyperEdge((node_rename_map[node] for node in edge.nodes),
                                 edge.label, edge.is_terminal, edge.span)
            if hasattr(edge, "carg"):
                new_edge.carg = edge.carg
            new_edges.append(new_edge)

        ret = self.__class__(frozenset(node_rename_map.values()),
                             frozenset(new_edges))
        if return_mapping:
            return ret, node_rename_map
        else:
            return ret

    def dfs(self):
        not_visited = set(self.nodes)
        siblings = {node: set() for node in self.nodes}
        for edge in self.edges:
            for node in edge.nodes:
                for sibling_node in edge.nodes:
                    if sibling_node != node:
                        siblings[node].add(sibling_node)

        parents = {}
        stack = []

        while len(not_visited) > 0:
            if not stack:
                node = not_visited.pop()
                stack.append((node, iter(siblings[node])))
                parents[node] = None

            node, sibling_iter = stack[-1]

            try:
                sibling = None
                while sibling is None or sibling not in not_visited:
                    sibling = next(sibling_iter)
            except StopIteration:
                stack.pop()
                continue

            not_visited.remove(sibling)
            stack.append((sibling, iter(siblings[sibling])))
            parents[sibling] = node

        return parents

    def is_connected(self):
        return len([i for i in self.dfs().values() if i is None]) == 1

    def to_eds(self):
        node_names = {i: None for i in self.nodes}
        edges = []

    @staticmethod
    def format_node(pred_node: GraphNode, pred_edge: HyperEdge, with_span=True):
        ret = pred_node.name
        if with_span:
            carg = pred_edge.carg or "!!!None!!!"
            carg = carg.replace(" ", "_").replace("@", "###")
            ret += f"@{pred_edge.span[0]},{pred_edge.span[1]}@{carg}"
        return ret

    def to_nodes_and_edges(self, return_spans=False, node_name_with_span=True):
        # draw eds
        node_mapping = {}
        real_edges = []
        nodes = []
        edges = []
        for edge in self.edges:  # type: HyperEdge
            if len(edge.nodes) == 1:
                main_node = edge.nodes[0]  # type: GraphNode
                if node_mapping.get(main_node) is None:
                    node_mapping[main_node] = edge
                else:
                    print("Dumplicate node name {} and {}!".format(
                        node_mapping[main_node],
                        edge.label
                    ))
            elif len(edge.nodes) == 2:
                real_edges.append(edge)
            else:
                print("Invalid hyperedge with node count {}".format(len(edge.nodes)))

        for node, pred_edge in node_mapping.items():
            assert pred_edge.span is not None
            new_name = self.format_node(node, pred_edge, node_name_with_span)
            if not return_spans:
                nodes.append((new_name, pred_edge.label))
            else:
                nodes.append((new_name, pred_edge.label, pred_edge.span))

        for edge in real_edges:
            node_1, node_2 = edge.nodes
            pred_edge_1, pred_edge_2 = pred_edges = [node_mapping.get(i) for i in edge.nodes]
            if any(i is None for i in pred_edges):
                print("No span for edge {}, nodes {}!".format(edge, pred_edges))
                continue
            edges.append((
                self.format_node(node_1, pred_edge_1, node_name_with_span),
                self.format_node(node_2, pred_edge_2, node_name_with_span),
                edge.label))
        return nodes, edges

    @classmethod
    def from_nodes_and_edges(cls, nodes, edges):
        hrg_nodes = set()
        hrg_edges = set()
        node_to_pred_node = {}

        for node_name_and_span, label in nodes:
            fields = node_name_and_span.rsplit("@", 2)
            if len(fields) == 2:
                node_name, span_str = fields
                carg = None
            else:
                node_name, span_str, carg = fields
                if carg == "!!!None!!!":
                    carg = None
            start_str, end_str = span_str.split(",")
            start = int(start_str)
            end = int(end_str)
            pred_node, edge = PredEdge.as_new((start, end), label, node_name, carg)
            hrg_edges.add(edge)
            hrg_nodes.add(pred_node)
            node_to_pred_node[node_name_and_span] = pred_node

        for s, t, l in edges:
            hrg_edges.add(HyperEdge(
                [node_to_pred_node[s], node_to_pred_node[t]], l, IsTerminal.TERMINAL, None))

        return cls(frozenset(hrg_nodes), frozenset(hrg_edges))
