cate_arg_dict = {"measure": (0, "ARG2"), "times": (0, "ARG3"), "plus": (0, "ARG2"),
                 "loc_nonsp": (0, "ARG2"), "nominalization": (0, "ARG1"), "parg_d": (0, "ARG1"),
                 "superl": (0, "ARG1"), "comp": (0, "ARG1"), "comp_equal": (0, "ARG1"),
                 "comp_less": (0, "ARG1"), "comp_so": (0, "ARG1"), "time": (1, "ARG1"), "place": (1, "ARG1"),
                 "udef_q": (0, "BV"), "proper_q": (0, "BV"), "number_q": (0, "BV"),
                 "def_implicit_q": (0, "BV"), "def_explicit_q": (0, "BV"),
                 "neg": (0, "ARG1"), "subord": (0, "ARG2"),
                 "elliptical_n": (1, "ARG2"), "num_seq": (0, "R-INDEX"), "ellipsis": (1, "ARG2"),
                 "idiom_q_i": (0, "BV"), "place_n": (1, "ARG1"), "time_n": (1, "ARG1"),
                 "_pre-_a_ante": (0, "ARG1"), "_un-_a_rvrs": (0, "ARG1")
                 }

lemmatizer_extra = {"an": "a", "/": ["and", "per"], "auto": "automobile",
                    "hoped": "hope", "rated": "rate"}

compound_dict = {"compound", "unknown", "appos", "part_of", "implicit_conj", "generic_entity"}
carg_compare_dict = {
    "ord": {
        'first': '1', 'second': '2', 'third': '3', 'fourth': '4', 'fifth': '5', 'sixth': '6',
        'seventh': '7', 'eighth': '8', 'ninth': '9', 'tenth': '10', 'eleventh': '11', 'twelfth': '12',
        'thirteenth': '13', 'fourteenth': '14', 'fifteenth': '15', 'sixteenth': '16',
        'seventeenth': '17', 'eighteenth': '18', 'nineteen': '19', 'twentieth': '20',
        'thirtieth': '30', 'fortieth': '40', 'fiftieth': '50', 'sixtieth': '60', 'seventieth': '70',
        'eightieth': '80', 'ninetieth': '90', 'hundredth': '100', 'thousandth': '1000',
        'millionth': '1000000', 'billionth': '1000000000', 'trillionth': '1000000000000'},
    "card": {
        'a': '1', 'one': '1', 'two': '2', 'three': '3', 'four': '4', 'five': '5', 'six': '6', 'seven': '7',
        'eight': '8',
        'nine': '9', 'ten': '10', 'eleven': '11', 'twelve': '12', 'thirteen': '13', 'fourteen': '14',
        'fifteen': '15', 'sixteen': '16', 'seventeen': '17', 'eighteen': '18', 'nineteenth': '19', 'twenty': '20',
        'thirty': '30', 'forty': '40', 'fifty': '50', 'sixty': '60', 'seventy': '70', 'eighty': '80',
        'ninety': '90', 'hundred': '100', 'thousand': '1000', 'million': '1000000', 'billion': '1000000000',
        'trillion': '1000000000000'},
    "mofy": {
        "january": "jan", "february": "feb", "march": "mar", "april": "apr", "may": "may",
        "june": "jun", "july": "jul", "august": "aug", "september": "sep", "october": "oct",
        "november": "nov", "december": "dec"
    },
    "named_n": {"_IMF": "imf", "u.s.": "us"}
}

lexical_edges_extra = {"much-many_a": ["more"]}

is_not_lexical_edge = {"_pre-_a_ante", "_un-_a_rvrs"}
interval_marker = ["to", "–"]
special_lexical_edge = {"interval": interval_marker, "interval_p_start": interval_marker,
                        "interval_p_end": interval_marker}


def lowest_common_ancestor(tree_nodes, nodes):
    for node in tree_nodes:
        all_children = list(node.generate_rules())
        if all(i in all_children for i in nodes):
            return node
    raise Exception(f"No common ancestor for {nodes}")


def log_mapping(log_func, tree_node, pred_edge, reason=""):
    surface_string = " ".join(i.string for i in tree_node.generate_words())
    edge_string = pred_edge.label
    log_func(f"Map {edge_string} to {surface_string} ({reason})")


def solve_compound(leaf_nodes, tree_nodes, pred_edges, structural_edges, ret, log_func=print):
    from coli.hrgguru.unlexicalized_rules import get_lemma_and_pos, wordnet_lemma

    rest_lexical_edges = {}
    waiting_list = {}
    top_waiting = set()

    for edge in pred_edges:
        if edge.label in special_lexical_edge:
            rest_lexical_edges[edge] = (None, None, None)
        elif edge.carg is not None:
            rest_lexical_edges[edge] = (edge.carg.rstrip("-").lower(), None, edge.label)
        elif edge.label.startswith("_") and edge.label not in is_not_lexical_edge:
            lemma, postag = get_lemma_and_pos(edge.label)
            rest_lexical_edges[edge] = (lemma.rstrip("-").lower(), postag, None)
        elif edge.label in compound_dict:
            top_waiting.add(edge)
        elif edge.label in lexical_edges_extra:
            rest_lexical_edges[edge] = (lexical_edges_extra[edge.label], None, None)
        elif edge.label in cate_arg_dict:
            direction, which_label = cate_arg_dict[edge.label]
            waiting_list[edge.nodes[0], direction, which_label] = edge
        else:
            raise Exception(f"Don't know how to solve edge {edge}")

    # to which tree node this pred node is assigned
    edge_to_node = {}
    rest_leaf_nodes = set(leaf_nodes)

    def assign_edge_to_node(node, pred_edge, reason=""):
        ret[node].add(pred_edge)
        log_mapping(log_func, node, pred_edge)
        rest_lexical_edges.pop(pred_edge)
        edge_to_node[pred_edge.nodes[0]] = node

    for node in leaf_nodes:
        for pred_edge, (lemma, postag, carg_type) in list(rest_lexical_edges.items()):
            node_string = node.children[0].string
            eq = False
            special_marker = special_lexical_edge.get(pred_edge.label)
            if special_marker and node_string in special_marker:
                eq = True

            if not eq:
                if postag in ("n", "v", "a"):
                    for pos in (postag, "n", "v", "a"):
                        node_lemma = wordnet_lemma.lemmatize(node_string, pos)
                        node_lemma = node_lemma.rstrip("-").lower()
                        eq = (node_lemma == lemma)
                        if not eq:
                            graph_lemma_test = wordnet_lemma.lemmatize(lemma, postag)
                            eq = (node_lemma == graph_lemma_test)
                        if eq:
                            break
                else:
                    node_lemma = node_string.rstrip("-").lower()
                    eq = (node_lemma == lemma)

            # noinspection PyUnboundLocalVariable
            if not eq and node_lemma in lemmatizer_extra:
                extra_mapping = lemmatizer_extra[node_lemma]
                if isinstance(extra_mapping, list):
                    eq = any(i == lemma for i in extra_mapping)
                else:
                    eq = (extra_mapping == lemma)
            if not eq and carg_type is not None:
                compare_dict = carg_compare_dict.get(carg_type)
                if compare_dict:
                    eq = (compare_dict.get(node_lemma) == lemma)
            if eq:
                try:
                    rest_leaf_nodes.remove(node)
                except KeyError:
                    pass
                assign_edge_to_node(node, pred_edge)

    # last try
    if len(rest_leaf_nodes) == 1 and len(rest_lexical_edges) == 1:
        rest_leaf_node = list(rest_leaf_nodes)[0]
        rest_lexical_edge = list(rest_lexical_edges)[0]
        ret[rest_leaf_node].add(rest_lexical_edge)
        log_mapping(log_func, rest_leaf_node, rest_lexical_edge, "rest")
        rest_lexical_edges.pop(rest_lexical_edge)
        edge_to_node[rest_lexical_edge.nodes[0]] = rest_leaf_node

    if rest_lexical_edges:
        raise Exception(f"Don't know how to solve these edges {list(rest_lexical_edges.keys())}")

    can_process = True
    while can_process:
        can_process = False
        # attach to leaf
        if waiting_list:
            for edge in structural_edges:
                as_source = waiting_list.get((edge.nodes[0], 0, edge.label))
                as_target = waiting_list.get((edge.nodes[1], 1, edge.label))
                if as_source:
                    if edge.nodes[1] in edge_to_node:
                        ret[edge_to_node[edge.nodes[1]]].add(as_source)
                        waiting_list.pop((edge.nodes[0], 0, edge.label))
                        edge_to_node[edge.nodes[0]] = edge_to_node[edge.nodes[1]]
                        can_process = True
                elif as_target:
                    if edge.nodes[0] in edge_to_node:
                        ret[edge_to_node[edge.nodes[0]]].add(as_target)
                        waiting_list.pop((edge.nodes[1], 1, edge.label))
                        edge_to_node[edge.nodes[1]] = edge_to_node[edge.nodes[0]]
                        can_process = True

        if top_waiting:
            for top_edge in set(top_waiting):
                target_tree_nodes = []
                for edge in structural_edges:
                    if edge.nodes[0] == top_edge.nodes[0]:
                        target = edge_to_node.get(edge.nodes[1])
                        target_tree_nodes.append(target)
                if all(i is not None for i in target_tree_nodes):
                    lca = lowest_common_ancestor(tree_nodes, target_tree_nodes)
                    ret[lca].add(top_edge)
                    log_mapping(log_func, lca, top_edge)
                    top_waiting.remove(top_edge)
                    can_process = True

    if top_waiting:
        raise Exception(f"Don't know how to solve these edges {list(top_waiting)}")

    if waiting_list:
        raise Exception(f"Don't know how to solve these edges {list(waiting_list.values())}")


def solve_internal(internal_nodes_and_pred_edges, tree_nodes, structural_edges, ret, log_func):
    edge_to_node = {}
    for tree_node, edges in ret.items():
        for edge in edges:
            edge_to_node[edge.nodes[0]] = tree_node

    for internal_nodes, pred_edges in internal_nodes_and_pred_edges:
        waiting_list = {}
        top_waiting = set()
        for edge in pred_edges:
            if edge.label in compound_dict:
                top_waiting.add(edge)
            elif edge.label in cate_arg_dict:
                direction, which_label = cate_arg_dict[edge.label]
                waiting_list[edge.nodes[0], direction, which_label] = edge
            else:
                raise Exception(f"Don't know how to solve edge {edge}")

        can_process = True
        while top_waiting and can_process:
            can_process = False
            for top_edge in set(top_waiting):
                target_tree_nodes = []
                for edge in structural_edges:
                    if edge.nodes[0] == top_edge.nodes[0]:
                        target = edge_to_node.get(edge.nodes[1])
                        target_tree_nodes.append(target)
                if all(i is not None for i in target_tree_nodes):
                    lca = lowest_common_ancestor(tree_nodes, target_tree_nodes)
                    ret[lca].add(top_edge)
                    log_mapping(log_func, lca, top_edge)
                    top_waiting.remove(top_edge)
                    can_process = True

        if top_waiting:
            raise Exception(f"Don't know how to solve these edges {list(top_waiting)}")

        can_process = True
        # attach to leaf
        while waiting_list and can_process:
            can_process = False
            for edge in structural_edges:
                as_source = waiting_list.get((edge.nodes[0], 0, edge.label))
                as_target = waiting_list.get((edge.nodes[1], 1, edge.label))
                if as_source:
                    if edge.nodes[1] in edge_to_node:
                        ret[edge_to_node[edge.nodes[1]]].add(as_source)
                        waiting_list.pop((edge.nodes[0], 0, edge.label))
                        edge_to_node[edge.nodes[0]] = edge_to_node[edge.nodes[1]]
                        can_process = True
                elif as_target:
                    if edge.nodes[0] in edge_to_node:
                        ret[edge_to_node[edge.nodes[0]]].add(as_target)
                        waiting_list.pop((edge.nodes[1], 1, edge.label))
                        edge_to_node[edge.nodes[1]] = edge_to_node[edge.nodes[0]]
                        can_process = True

        if waiting_list:
            raise Exception(f"Don't know how to solve these edges {list(waiting_list.values())}")
