import gzip
from collections import UserDict

from delphin.mrs import simplemrs

from coli.hrgguru.hyper_graph import GraphNode, HyperEdge, HyperGraph, IsTerminal


class SmartDefaultDict(UserDict):
    def __init__(self, default_factory, seq=()):
        self.default_factory = default_factory
        super(SmartDefaultDict, self).__init__(seq)

    def __missing__(self, key):
        value = self[key] = self.default_factory(key)
        return value


def extract_lfrg(mrs_obj, include_qeq=False, hyperedge_style=True):
    node_dict = SmartDefaultDict(lambda x: GraphNode(x))
    edges = set()
    for ep in mrs_obj.eps():
        if hyperedge_style:
            node_1 = node_dict[ep.label]
            node_rest = [node_dict[node_name]
                         for arg_name, node_name
                         in sorted(ep.args.items())]
            external_nodes = [node_1] + node_rest
            edge = HyperEdge(external_nodes, ep.pred.string,
                             IsTerminal.TERMINAL, tuple(ep.lnk)[1])
            edges.add(edge)
        else:
            lnk = tuple(ep.lnk)[1]
            center_node = node_dict["{}-{}-{}-center-node".format(ep.label, ep.pred.string, lnk)]
            pred_edge = HyperEdge([center_node], ep.pred.string, IsTerminal.TERMINAL, tuple(ep.lnk)[1])
            pred_edge.carg = ep.args.get("CARG")
            edges.add(pred_edge)
            label_edge = HyperEdge([center_node, node_dict[ep.label]], "LBL", IsTerminal.TERMINAL, None)
            edges.add(label_edge)
            for arg_name, node_name in ep.args.items():
                if arg_name == "CARG":
                    continue
                edge = HyperEdge([center_node, node_dict[node_name]], arg_name,
                                 IsTerminal.TERMINAL, None)
                edges.add(edge)

    if include_qeq:
        for hcon in mrs_obj.hcons():
            edges.add(HyperEdge(
                [node_dict[hcon.hi], node_dict[hcon.lo]],
                "QEQ", IsTerminal.TERMINAL, None))

    hg = HyperGraph(frozenset(node_dict.values()),
                    frozenset(edges))

    return hg


def lfrg_to_mrs(hg: HyperGraph):
    hg = hg.to_standardized_node_names()
    pred_edges = [i for i in hg.edges if len(i.nodes) == 1]
    variables = {i.nodes[1] for i in hg.edges if len(i.nodes) == 2
                 and i.label == "ARG0"}
    ret = "   [ RELS: <\n"

    def node_name(node):
        if node in variables:
            return "x" + node.name
        return "h" + node.name

    def pred_edge_to_rel(pred_edge: HyperEdge):
        main_node = pred_edge.nodes[0]
        ret = "          [ {}<{}:{}>\n".format(pred_edge.label, pred_edge.span[0], pred_edge.span[1])
        attr_map = []
        for edge in hg.edges:
            if len(edge.nodes) == 2 and edge.nodes[0] == main_node:
                attr_map.append((edge.label, edge.nodes[1]))
        attr_map.sort(key=lambda x: x[0] == "LBL", reverse=True)
        for label, value in attr_map:
            ret += "            {}: {}\n".format(label, node_name(value))
        ret = ret[:-1] + " ]"
        return ret

    ret += "\n".join(pred_edge_to_rel(i) for i in pred_edges)
    ret += " >\n    HCONS: < "

    for edge in hg.edges:
        if edge.label == "QEQ":
            assert len(edge.nodes) == 2
            ret += "{} QEQ {} ".format(node_name(edge.nodes[0]),
                                       node_name(edge.nodes[1]))
    ret += "> ]"
    return ret


if __name__ == '__main__':
    with gzip.open("/home/chenyufei/Development/large-data/deepbank1.1/export/wsj00a/20001001.gz", "rb") as f:
        fields = f.read().decode("utf-8").strip().split("\n\n")
        mrs_literal = fields[-3]
    mrs_obj = simplemrs.loads_one(mrs_literal)
    hg = extract_lfrg(mrs_obj, True, False)
    hg.draw("/tmp/abcdrrr", file_format="pdf",
            node_name_map={i: i.name if not i.name.endswith("-center-node") else ""
                           for i in hg.nodes})
