from collections import defaultdict
from typing import Any, List

from nltk import WordNetLemmatizer

from coli.hrgguru.hrg import CFGRule, HRGRule
from coli.hrgguru.hyper_graph import GraphNode, HyperEdge, HyperGraph
from coli.hrgguru.unlexicalized_rules import transform_edge

wordnet_lemma = WordNetLemmatizer()


def get_oracle(cfg_rule: CFGRule):
    rule = cfg_rule.hrg
    if len(cfg_rule.rhs) == 2:
        rhs1 = cfg_rule.rhs[0][1]
        rhs2 = cfg_rule.rhs[1][1]
    else:
        rhs1 = rhs2 = None

    if rule is None:
        return [("END",)]
    external_nodes = rule.lhs.nodes
    internal_nodes = rule.rhs.nodes - set(external_nodes)
    created = set()
    result = []
    visited = set()

    def get_outgoing_edges(node):
        return sorted((edge for edge in rule.rhs.edges
                       if edge.nodes[0] == node),
                      key=lambda edge: len(edge.nodes))

    def rewrite_edge_label(edge: HyperEdge):
        if edge == rhs1:
            return "__LEFT__"
        elif edge == rhs2:
            return "__RIGHT__"
        # lexical
        elif len(cfg_rule.rhs) == 1:
            try:
                return transform_edge(edge, cfg_rule.rhs[0][0].string).label
            except ArithmeticError:
                return edge.label
        else:
            return edge.label

    def create_node(node: GraphNode):
        if node not in created:
            is_external = node in external_nodes
            result.append((
                "create_node",
                "external" if is_external else "internal",
                node.name
            ))
            created.add(node)

    def visit(node: GraphNode):
        if node in visited:
            return
        else:
            visited.add(node)

        edges = get_outgoing_edges(node)
        for edge in edges:
            assert node == edge.nodes[0]
            for target in edge.nodes[1:]:
                create_node(target)
            result.append(
                ("create_edge", len(edge.nodes),
                 rewrite_edge_label(edge),
                 [i.name for i in edge.nodes]))
            for target in edge.nodes[1:]:
                visit(target)

    for node in external_nodes:
        create_node(node)
        visit(node)

    for node in internal_nodes:
        create_node(node)
        visit(node)

    result.append(("END",))
    return result


def get_oracle_2(cfg_rule: CFGRule):
    rule = cfg_rule.hrg
    if rule is None:
        return [("END",)]

    external_nodes = rule.lhs.nodes
    internal_nodes = rule.rhs.nodes - set(external_nodes)
    result: List[Any] = []
    visited = set()
    side_nodes = []
    node_names = {}

    if len(cfg_rule.rhs) == 2:
        rhs1 = cfg_rule.rhs[0][1]
        rhs2 = cfg_rule.rhs[1][1]
        if rhs1 is not None:
            for idx, node in enumerate(rhs1.nodes):
                node_names[node] = ("left", idx)
        if rhs2 is not None:
            for idx, node in enumerate(rhs2.nodes):
                node_names[node] = ("right", idx)
        if rhs1 is not None and rhs2 is not None:
            common_nodes = set(rhs1.nodes).intersection(rhs2.nodes)
            assert not common_nodes
    else:
        rhs1 = rhs2 = None

    def get_outgoing_edges(node):
        return sorted((edge for edge in rule.rhs.edges
                       if edge.nodes[0] == node),
                      key=lambda edge: len(edge.nodes))

    def rewrite_edge_label(edge: HyperEdge):
        # lexical
        if len(cfg_rule.rhs) == 1:
            try:
                return transform_edge(edge, cfg_rule.rhs[0][0].string).label
            except ArithmeticError:
                return edge.label
        else:
            return edge.label

    def create_node(node: GraphNode):
        if node not in node_names:
            node_name = ("side", len(side_nodes))
            result.append((
                "create_node",
                node_name,
            ))
            node_names[node] = node_name
            side_nodes.append(node_name)
        else:
            node_name = node_names[node]
        return node_name

    def visit(node: GraphNode):
        if node in visited:
            return
        else:
            visited.add(node)

        edges = get_outgoing_edges(node)
        for edge in edges:
            assert node == edge.nodes[0]
            if edge == rhs1 or edge == rhs2:
                continue
            if len(edge.nodes) <= 2:
                for target in edge.nodes[1:]:
                    create_node(target)
                result.append(
                    ("create_edge", len(edge.nodes),
                     rewrite_edge_label(edge),
                     [node_names[i] for i in edge.nodes],
                     ))
                for target in edge.nodes[1:]:
                    visit(target)

    for node in external_nodes:
        create_node(node)
        visit(node)

    for node in internal_nodes:
        create_node(node)
        visit(node)

    for node in external_nodes:
        result.append(("expose", node_names[node]))

    result.append(("END",))
    return result


def compile_oracle(cfg_lhs,
                   cfg_rhs,
                   oracle,
                   answer=None
                   ):
    lhs_label, lhs_degree = cfg_lhs
    nodes = {}
    external_nodes = []
    edges = set()
    if len(cfg_rhs) == 2:
        lexicon = "__LexiconNotExist__"
        for idx, ((rhs_label, rhs_node_count), l_or_r) in \
                enumerate(zip(cfg_rhs, ("left", "right"))):
            this_nodes = []
            for i in range(rhs_node_count):
                node = GraphNode()
                nodes[(l_or_r, i)] = node
                this_nodes.append(node)
            if this_nodes:
                edges.add(HyperEdge(this_nodes, rhs_label, False,
                                    (idx, idx + 1)))
    else:
        lexicon = cfg_rhs[0][0].string

    for action in oracle:
        action_type = action[0]
        if action_type == "create_node":
            node_symbol = action[1]
            nodes[node_symbol] = GraphNode()
        elif action_type == "create_edge":
            edge_type = action[1]
            if edge_type == 1:
                edge_label = action[2]
                if "{NEWLEMMA}" in edge_label:
                    lemma_start = edge_label.find("_") + 1
                    lemma_end = edge_label.find("_", lemma_start)
                    tag_end = edge_label.find("_", lemma_end + 1)
                    pos = edge_label[lemma_end + 1:tag_end]
                    if tag_end != -1 and pos in ("n", "v", "a"):
                        pred_lemma = wordnet_lemma.lemmatize(lexicon, pos)
                    else:
                        pred_lemma = lexicon
                    edge_label = edge_label.replace("{NEWLEMMA}", pred_lemma)
                node_symbol = action[3][0]
                node_obj = nodes[node_symbol]
                edges.add(HyperEdge([node_obj], edge_label, True))
            elif edge_type == 2:
                edge_label = action[2]
                node_1_symbol = action[3][0]
                node_2_symbol = action[3][1]
                edges.add(HyperEdge([nodes[node_1_symbol], nodes[node_2_symbol]],
                                    edge_label, True))
            else:
                raise Exception("Invalid edge type {}".format(edge_type))
        elif action_type == "expose":
            node_symbol = action[1]
            external_nodes.append(nodes[node_symbol])
        elif action_type == "END":
            break
        else:
            raise Exception("Invalid action type {}".format(action_type))

    hrg_lhs = HyperEdge(external_nodes, lhs_label, False)
    hrg_rhs = HyperGraph(frozenset(nodes.values()), frozenset(edges))
    rule = HRGRule(hrg_lhs, hrg_rhs).to_standardized_node_names(
        left_and_right_span=((0, 1), (1, 2)))

    if answer is not None:
        assert rule == answer

    return rule


if __name__ == '__main__':
    from serve_rules import Grammar

    grammar = Grammar()
    oracle_length_distribution = defaultdict(lambda: 0)
    for idx, (rule, (count, example)) in enumerate(grammar.rules):
        print(idx)
        oracle = get_oracle_2(rule)
        oracle_length_distribution[len(oracle)] += 1
        print(oracle)
        cfg_lhs = (rule.lhs, len(rule.hrg.lhs.nodes) if rule.hrg is not None else 0)
        if len(rule.rhs) == 2:
            cfg_rhs = [(label, 0 if edge is None else len(edge.nodes))
                       for label, edge in rule.rhs]
        else:
            assert len(rule.rhs) == 1
            cfg_rhs = rule.rhs
        compile_oracle(cfg_lhs, cfg_rhs, oracle, rule.hrg)
    print(sorted(oracle_length_distribution.items()))
