from collections import defaultdict
from operator import itemgetter

from coli.hrgguru.hyper_graph import HyperGraph

bilexical_table = {('parenthetical', ('ARG1', 'ARG2')), ('unspec_manner', ('ARG1', 'ARG2')),
                   ('loc_nonsp', ('ARG1', 'ARG2')), ('interval_p_start', ('ARG1', 'ARG2')),
                   ('focus_d', ('ARG1', 'ARG2')), ('plus', ('ARG2', 'ARG3')), ('interval_p_end', ('ARG1', 'ARG2')),
                   ('poss', ('ARG1', 'ARG2')), ('parg_d', ('ARG1', 'ARG2')), ('addressee', ('ARG1', 'ARG2')),
                   ('with_p', ('ARG1', 'ARG2')), ('refl_mod', ('ARG1', 'ARG2')), ('id', ('ARG1', 'ARG2')),
                   ('temp_loc_x', ('ARG1', 'ARG2')), ('measure', ('ARG1', 'ARG2')), ('of_p', ('ARG1', 'ARG2')),
                   ('times', ('ARG2', 'ARG3')), ('compound', ('ARG1', 'ARG2')), ('appos', ('ARG1', 'ARG2'))}


def bilexical_transform(hg: HyperGraph):
    nodes, edges = hg.to_nodes_and_edges()
    in_edges = defaultdict(set)
    out_edges = defaultdict(set)
    for s, t, l in edges:
        in_edges[t].add((s, l))
        out_edges[s].add((t, l))

    nodes_selected = []
    edges_extra = set()
    for node, name in nodes:
        if (not name.startswith("_")) and len(out_edges[node]) == 2 and len(in_edges[node]) == 0:
            targets, labels = zip(*sorted(out_edges[node], key=itemgetter(1)))
            if (name, labels) in bilexical_table:
                edges_extra.add((targets[0], targets[1], f'{name}#{"-".join(labels)}'))
            else:
                nodes_selected.append((node, name))
        else:
            nodes_selected.append((node, name))

    all_edges = list(edges_extra)
    for node, name in nodes_selected:
        for target, label in out_edges[node]:
            all_edges.append((node, target, label))

    return HyperGraph.from_nodes_and_edges(nodes_selected, all_edges)


def remove_lonely_nodes(hg: HyperGraph):
    nodes, edges = hg.to_nodes_and_edges()
    in_edges = defaultdict(set)
    out_edges = defaultdict(set)
    for s, t, l in edges:
        in_edges[t].add((s, l))
        out_edges[s].add((t, l))

    nodes_selected = []

    for node, name in nodes:
        if len(in_edges[node]) == 0 and len(out_edges[node]) == 0:
            continue
        else:
            nodes_selected.append((node, name))

    all_edges = []
    for node, name in nodes_selected:
        for target, label in out_edges[node]:
            all_edges.append((node, target, label))

    return HyperGraph.from_nodes_and_edges(nodes_selected, all_edges)
