import operator
from collections import defaultdict
from functools import reduce

from coli.hrgguru.const_tree import Lexicon

cate_arg_dict = \
    {
        "eds": {"measure": (0, "ARG2"), "times": (0, "ARG3"), "plus": (0, "ARG2"),
                "loc_nonsp": (0, "ARG2"), "nominalization": (0, "ARG1"), "parg_d": (0, "ARG1"),
                "superl": (0, "ARG1"), "comp": (0, "ARG1"), "comp_equal": (0, "ARG1"),
                "comp_less": (0, "ARG1"), "comp_so": (0, "ARG1"), "time": (1, "ARG1"), "place": (1, "ARG1"),
                "udef_q": (0, "BV"), "proper_q": (0, "BV"), "number_q": (0, "BV"),
                "def_implicit_q": (0, "BV"), "def_explicit_q": (0, "BV"),
                "neg": (0, "ARG1"), "subord": (0, "ARG2"),
                "elliptical_n": (1, "ARG2"), "num_seq": (0, "R-INDEX"), "ellipsis": (1, "ARG2"),
                "idiom_q_i": (0, "BV"), "place_n": (1, "ARG1"), "time_n": (1, "ARG1"),
                "_pre-_a_ante": (0, "ARG1"), "_un-_a_rvrs": (0, "ARG1"), "pronoun_q": (0, "BV")
                },
        "dmrs": {"parg_d": (0, "ARG1/EQ"),
                 "comp": (0, "ARG1/EQ"), "superl": (0, "ARG1/EQ"), "comp_less": (0, "ARG1/EQ"),
                 "comp_so": (0, "ARG1/EQ"), "comp_equal": (0, "ARG1/EQ"),
                 "neg": (0, "ARG1/H"), "subord": (0, "ARG2/H"),
                 "nominalization": (0, "ARG1/HEQ"), "plus": (0, "ARG2/HEQ"), "times": (0, "ARG3/HEQ"),
                 "elliptical_n": (1, "ARG2/NEQ"), "time_n": (1, "ARG1/EQ"), "place_n": (1, "ARG1/EQ"),
                 "_pre-_a_ante": (0, "ARG1/EQ"), "_un-_a_rvrs": (0, "ARG1/EQ"),
                 "num_seq": (0, "R-INDEX/NEQ"),
                 },  # will fill later
        "lfrg": {},  # will fill later
    }

for k, v in cate_arg_dict["eds"].items():
    if v[1] == "BV":
        if k not in cate_arg_dict["dmrs"]:
            cate_arg_dict["dmrs"][k] = (0, "RSTR/H")
        cate_arg_dict["lfrg"][k] = ("ARG0", ("ARG0", "LBL"))
    if v[0] == 0 and v[1].startswith("ARG"):
        if k not in cate_arg_dict["dmrs"]:
            cate_arg_dict["dmrs"][k] = (0, f"{v[1]}/NEQ")
        cate_arg_dict["lfrg"][k] = (v[1], ("ARG0", "LBL"))

lemmatizer_extra = {"an": "a", "/": ["and", "per"], "auto": "automobile",
                    "hoped": "hope", "rated": "rate"}

compound_dict = {"compound", "unknown", "appos", "part_of", "implicit_conj", "generic_entity"}
compound_arg_dict = {"eds": {"compound": (0, "ARG2"), "unknown": (0, "ARG"),
                             "appos": (0, "ARG1"), "part_of": (0, "ARG1"),
                             "implicit_conj": (0, "L-INDEX"), "generic_entity": (1, "ARG1"),
                             "focus_d": (0, "ARG1"), "with_p": (0, "ARG2"),
                             "id": (0, "ARG2"), "relative_mod": (0, "ARG2"),
                             "parenthetical": (0, "ARG1"),
                             "eventuality": (0, "ARG1"),
                             "pron": (1, "ARG1")
                             }}


# for k, v in compound_arg_dict["eds"].items():
#     if v[0] == 0 and v[1].startswith("ARG"):
#         compound_arg_dict["dmrs"][k] = (0, f"{v[1]}/NEQ")

carg_compare_dict = {
    "ord": {
        'first': '1', 'second': '2', 'third': '3', 'fourth': '4', 'fifth': '5', 'sixth': '6',
        'seventh': '7', 'eighth': '8', 'ninth': '9', 'tenth': '10', 'eleventh': '11', 'twelfth': '12',
        'thirteenth': '13', 'fourteenth': '14', 'fifteenth': '15', 'sixteenth': '16',
        'seventeenth': '17', 'eighteenth': '18', 'nineteen': '19', 'twentieth': '20',
        'thirtieth': '30', 'fortieth': '40', 'fiftieth': '50', 'sixtieth': '60', 'seventieth': '70',
        'eightieth': '80', 'ninetieth': '90', 'hundredth': '100', 'thousandth': '1000',
        'millionth': '1000000', 'billionth': '1000000000', 'trillionth': '1000000000000'},
    "card": {
        'a': '1', 'one': '1', 'two': '2', 'three': '3', 'four': '4', 'five': '5', 'six': '6', 'seven': '7',
        'eight': '8',
        'nine': '9', 'ten': '10', 'eleven': '11', 'twelve': '12', 'thirteen': '13', 'fourteen': '14',
        'fifteen': '15', 'sixteen': '16', 'seventeen': '17', 'eighteen': '18', 'nineteenth': '19', 'twenty': '20',
        'thirty': '30', 'forty': '40', 'fifty': '50', 'sixty': '60', 'seventy': '70', 'eighty': '80',
        'ninety': '90', 'hundred': '100', 'thousand': '1000', 'million': '1000000', 'billion': '1000000000',
        'trillion': '1000000000000'},
    "mofy": {
        "january": "jan", "february": "feb", "march": "mar", "april": "apr", "may": "may",
        "june": "jun", "july": "jul", "august": "aug", "september": "sep", "october": "oct",
        "november": "nov", "december": "dec"
    },
    "named_n": {"_IMF": "imf", "u.s.": "us"}
}

lexical_edges_extra = {"much-many_a": ["more"]}

is_not_lexical_edge = {"_pre-_a_ante", "_un-_a_rvrs", "_re-_a_again", "_mis-_a_error",
                       "_counter-_a_anti", "_counter-_a_anti", "_co-_a_with"}
interval_marker = ["to", "–"]
special_lexical_edge = {"interval": interval_marker, "interval_p_start": interval_marker,
                        "interval_p_end": interval_marker}


def lowest_common_ancestor(tree_nodes, nodes):
    for node in tree_nodes:
        all_children = list(node.generate_rules())
        if all(i in all_children for i in nodes):
            return node
    raise Exception(f"No common ancestor for {nodes}")


class EDSAttachmentSolver(object):
    def __init__(self, arg_dict):
        self.arg_dict = arg_dict

    def solve(self, pred_edge, structual_edges=None):
        direction, which_label = self.arg_dict[pred_edge.label]
        if isinstance(which_label, str):
            which_label = [which_label]

        keys = [(pred_edge.nodes[0], direction, label)
                for label in which_label]
        return keys


class LFRGAttachmentSolver(object):
    def __init__(self, arg_dict):
        self.arg_dict = arg_dict

    def solve(self, pred_edge, structual_edges):
        attach_out_label, main_in_labels = self.arg_dict[pred_edge.label]
        to_var_edge = [edge for edge in structual_edges
                       if edge.nodes[0] == pred_edge.nodes[0] and edge.label == attach_out_label]
        if len(to_var_edge) != 1:
            raise Exception(f"{attach_out_label} count of {pred_edge} != 1")
        var_node = to_var_edge[0].nodes[1]
        qeq_target_nodes = [edge.nodes[1] for edge in structual_edges
                            if edge.nodes[0] == var_node and edge.label == "QEQ"]
        if len(qeq_target_nodes) == 1:
            var_node = qeq_target_nodes[0]
        keys = [(var_node, 1, i) for i in main_in_labels]
        return keys


attachment_solvers = {"eds": EDSAttachmentSolver,
                      "dmrs": EDSAttachmentSolver, "lfrg": LFRGAttachmentSolver}


class NodeDistributor(object):
    def __init__(self, hg, cfg_root, graph_type,
                 fully_lexicalized=False,
                 log_func=print):
        self.hg = hg
        self.cfg_root = cfg_root
        self.log_func = log_func
        self.results = defaultdict(set)
        self.edge_to_node = {}
        self.attachment_waiting_list = {}
        self.internal_waiting_list = []
        self.graph_type = graph_type
        self.fully_lexicalized = fully_lexicalized
        self.arg_dict = cate_arg_dict[graph_type]

        if fully_lexicalized:
            self.arg_dict.update(compound_arg_dict[graph_type])
            self.compound_dict = {}
        else:
            self.compound_dict = compound_dict

        self.attachment_solver = attachment_solvers[graph_type](self.arg_dict)

    def log_mapping(self, tree_node, pred_edge, reason=""):
        surface_string = " ".join(i.string for i in tree_node.generate_words())
        edge_string = pred_edge.label
        self.log_func(f"Map {edge_string} to {surface_string} ({reason})")

    def solve(self):
        # sometimes the span is incorrect
        span_rewrite = {}
        span_to_pred_edges = defaultdict(list)
        span_to_tree_nodes = defaultdict(list)
        structural_edges = set()
        is_lexicon = {}

        for node in self.cfg_root.generate_preterminals():
            if node.span[1] - node.span[0] == 0:
                continue
            rewrote_span = span_rewrite.get(node.span) or node.span
            for another_span in span_to_tree_nodes:
                if not (node.span[0] >= another_span[1] or node.span[1] < another_span[0]):
                    rewrote_span = another_span
                    span_rewrite[node.span] = another_span
                    if isinstance(node.children[0], Lexicon):
                        is_lexicon[another_span] = True
                    break
            if isinstance(node.children[0], Lexicon):
                is_lexicon[rewrote_span] = True
            span_to_tree_nodes[rewrote_span].append(node)

        for node in self.cfg_root.generate_rules():
            # skip preterminals
            if isinstance(node.children[0], Lexicon):
                continue
            if node.span in span_rewrite:
                rewrote_span = span_rewrite[node.span]
            else:
                children_spans = set(span_rewrite.get(i.span) or i.span
                                     for i in node.children
                                     if i.span[1] - i.span[0] != 0)
                if len(children_spans) == 1:
                    rewrote_span = list(children_spans)[0]
                    span_rewrite[node.span] = rewrote_span
                else:
                    rewrote_span = node.span
            span_to_tree_nodes[rewrote_span].append(node)

        for edge in self.hg.edges:
            assert edge.is_terminal
            if edge.span is not None:
                rewrote_span = span_rewrite.get(edge.span) or edge.span
                span_to_pred_edges[rewrote_span].append(edge)
            else:
                structural_edges.add(edge)

        redundant_spans = set(span_to_pred_edges.keys()) - set(span_to_tree_nodes.keys())
        if self.fully_lexicalized:
            for span in redundant_spans:
                for edge in span_to_pred_edges[span]:
                    keys = self.attachment_solver.solve(edge, structural_edges)
                    for key in keys:
                        self.attachment_waiting_list[key] = edge
        else:
            assert not redundant_spans, f"Redundant nodes in graph: {redundant_spans}"

        many_to_many = []

        for span, tree_nodes in span_to_tree_nodes.items():
            pred_edges = span_to_pred_edges[span]
            if not pred_edges:
                continue

            terminal_tree_nodes = []
            internal_tree_nodes = []

            for node in tree_nodes:
                if isinstance(node.children[0], Lexicon):
                    terminal_tree_nodes.append(node)
                else:
                    internal_tree_nodes.append(node)

            if len(terminal_tree_nodes) == 1:
                tree_node = terminal_tree_nodes[0]
                self.results[tree_node].update(pred_edges)
                for edge in pred_edges:
                    self.edge_to_node[edge.nodes[0]] = tree_node
            elif len(terminal_tree_nodes) > 1:
                many_to_many.append((span, terminal_tree_nodes))
            elif len(internal_tree_nodes) >= 1 and self.fully_lexicalized:
                for edge in pred_edges:
                    keys = self.attachment_solver.solve(edge, structural_edges)
                    for key in keys:
                        self.attachment_waiting_list[key] = edge
            elif len(internal_tree_nodes) == 1:
                tree_node = internal_tree_nodes[0]
                self.results[tree_node].update(pred_edges)
                for edge in pred_edges:
                    self.edge_to_node[edge.nodes[0]] = tree_node
            elif len(terminal_tree_nodes) == 0 and len(internal_tree_nodes) > 1:
                self.internal_waiting_list.append(pred_edges)
            else:
                raise Exception(f"Invalid tree nodes: {span} {terminal_tree_nodes} {internal_tree_nodes} {pred_edges}")

        all_nodes = [i for i in self.cfg_root.generate_rules()
                     if not isinstance(i.children[0], Lexicon)]

        if many_to_many:
            for span, terminal_tree_nodes in many_to_many:
                pred_edges = span_to_pred_edges[span]
                self.solve_leaf(terminal_tree_nodes, pred_edges, structural_edges)

        if self.fully_lexicalized:
            self.solve_attachment(structural_edges)
        else:
            if self.internal_waiting_list:
                self.solve_internal(self.internal_waiting_list, all_nodes, structural_edges)

        flat_results = reduce(operator.or_, self.results.values(), set())
        for key, value in dict(self.attachment_waiting_list).items():
            if value in flat_results:
                self.attachment_waiting_list.pop(key)

        if self.attachment_waiting_list:
            raise Exception(f"Don't know how to solve these edges {list(self.attachment_waiting_list.values())}")

        return self.results

    def solve_leaf(self, leaf_nodes, pred_edges, structural_edges):
        from coli.hrgguru.unlexicalized_rules import get_lemma_and_pos, wordnet_lemma

        rest_lexical_edges = {}
        top_waiting = set()

        for edge in pred_edges:
            if edge.label in special_lexical_edge:
                rest_lexical_edges[edge] = (None, None, None)
            elif edge.carg is not None:
                rest_lexical_edges[edge] = (edge.carg.rstrip("-").lower(), None, edge.label)
            elif edge.label.startswith("_") and edge.label not in is_not_lexical_edge:
                lemma, postag = get_lemma_and_pos(edge.label)
                rest_lexical_edges[edge] = (lemma.rstrip("-").lower(), postag, None)
            elif edge.label in self.compound_dict:
                top_waiting.add(edge)
            elif edge.label in lexical_edges_extra:
                rest_lexical_edges[edge] = (lexical_edges_extra[edge.label], None, None)
            elif edge.label in self.arg_dict:
                keys = self.attachment_solver.solve(edge, structural_edges)
                for key in keys:
                    self.attachment_waiting_list[key] = edge
            else:
                raise Exception(f"Don't know how to solve edge {edge}")

        self.internal_waiting_list.append(top_waiting)

        # to which tree node this pred node is assigned
        rest_leaf_nodes = set(leaf_nodes)

        def assign_edge_to_node(node, pred_edge, reason=""):
            self.results[node].add(pred_edge)
            self.log_mapping(node, pred_edge)
            rest_lexical_edges.pop(pred_edge)
            self.edge_to_node[pred_edge.nodes[0]] = node

        for node in leaf_nodes:
            for pred_edge, (lemma, postag, carg_type) in list(rest_lexical_edges.items()):
                node_string = node.children[0].string
                eq = False
                special_marker = special_lexical_edge.get(pred_edge.label)
                if special_marker and node_string in special_marker:
                    eq = True

                if not eq:
                    if postag in ("n", "v", "a"):
                        for pos in (postag, "n", "v", "a"):
                            node_lemma = wordnet_lemma.lemmatize(node_string, pos)
                            node_lemma = node_lemma.rstrip("-").lower()
                            eq = (node_lemma == lemma)
                            if not eq:
                                graph_lemma_test = wordnet_lemma.lemmatize(lemma, postag)
                                eq = (node_lemma == graph_lemma_test)
                            if eq:
                                break
                    else:
                        node_lemma = node_string.rstrip("-").lower()
                        eq = (node_lemma == lemma)

                # noinspection PyUnboundLocalVariable
                if not eq and node_lemma in lemmatizer_extra:
                    extra_mapping = lemmatizer_extra[node_lemma]
                    if isinstance(extra_mapping, list):
                        eq = any(i == lemma for i in extra_mapping)
                    else:
                        eq = (extra_mapping == lemma)
                if not eq and carg_type is not None:
                    compare_dict = carg_compare_dict.get(carg_type)
                    if compare_dict:
                        eq = (compare_dict.get(node_lemma) == lemma)
                if eq:
                    try:
                        rest_leaf_nodes.remove(node)
                    except KeyError:
                        pass
                    assign_edge_to_node(node, pred_edge)
                    if pred_edge.label not in special_lexical_edge:
                        # one lexical edge only
                        break

        # last try
        if len(rest_leaf_nodes) == 1 and len(rest_lexical_edges) == 1:
            rest_leaf_node = list(rest_leaf_nodes)[0]
            rest_lexical_edge = list(rest_lexical_edges)[0]
            self.results[rest_leaf_node].add(rest_lexical_edge)
            self.log_mapping(rest_leaf_node, rest_lexical_edge, "rest")
            rest_lexical_edges.pop(rest_lexical_edge)
            self.edge_to_node[rest_lexical_edge.nodes[0]] = rest_leaf_node

        if rest_lexical_edges:
            raise Exception(f"Don't know how to solve these edges {list(rest_lexical_edges.keys())}")

        self.solve_attachment(structural_edges)

    def solve_attachment(self, structural_edges):
        can_process = True
        while can_process:
            can_process = False
            # attach to leaf
            if self.attachment_waiting_list:
                for edge in structural_edges:
                    as_source = self.attachment_waiting_list.get((edge.nodes[0], 0, edge.label))
                    as_target = self.attachment_waiting_list.get((edge.nodes[1], 1, edge.label))
                    if as_source:
                        if edge.nodes[1] in self.edge_to_node:
                            tree_node = self.edge_to_node[edge.nodes[1]]
                            self.log_mapping(tree_node, as_source)
                            self.results[tree_node].add(as_source)
                            self.attachment_waiting_list.pop((edge.nodes[0], 0, edge.label))
                            self.edge_to_node[edge.nodes[0]] = tree_node
                            can_process = True
                    elif as_target:
                        if edge.nodes[0] in self.edge_to_node:
                            tree_node = self.edge_to_node[edge.nodes[0]]
                            self.log_mapping(tree_node, as_target)
                            self.results[tree_node].add(as_target)
                            self.attachment_waiting_list.pop((edge.nodes[1], 1, edge.label))
                            self.edge_to_node[edge.nodes[1]] = tree_node
                            can_process = True

    def solve_internal(self, pred_edges_list, tree_nodes, structural_edges):
        for pred_edges in pred_edges_list:
            top_waiting = set()
            for edge in pred_edges:
                if edge.label in self.compound_dict:
                    # solve "compound"
                    top_waiting.add(edge)
                elif edge.label in self.arg_dict:
                    # solve attachment
                    keys = self.attachment_solver.solve(edge, structural_edges)
                    for key in keys:
                        self.attachment_waiting_list[key] = edge
                else:
                    raise Exception(f"Don't know how to solve edge {edge}")

            can_process = True
            while top_waiting and can_process:
                can_process = False
                for top_edge in set(top_waiting):
                    target_tree_nodes = []
                    for edge in structural_edges:
                        if edge.nodes[0] == top_edge.nodes[0]:
                            if self.graph_type != "lfrg":
                                target = self.edge_to_node.get(edge.nodes[1])
                            else:
                                if not edge.label.startswith("ARG0"):
                                    continue
                                var_node = edge.nodes[1]
                                target = None
                                for s_edge in structural_edges:
                                    if s_edge.label == "ARG0" and s_edge.nodes[1] == var_node:
                                        target = self.edge_to_node.get(edge.nodes[0])
                                if target is None:
                                    continue
                            target_tree_nodes.append(target)
                    if all(i is not None for i in target_tree_nodes):
                        coverage_nodes = [i for i in tree_nodes
                                          if i.span[0] <= top_edge.span[0] <= top_edge.span[1] <= i.span[1]]
                        lca = lowest_common_ancestor(coverage_nodes, target_tree_nodes)
                        self.results[lca].add(top_edge)
                        self.edge_to_node[top_edge.nodes[0]] = lca
                        self.log_mapping(lca, top_edge)
                        top_waiting.remove(top_edge)
                        can_process = True

            if top_waiting:
                raise Exception(f"Don't know how to solve these edges {list(top_waiting)}")

        self.solve_attachment(structural_edges)
