from collections import defaultdict
from nltk.stem import WordNetLemmatizer

from coli.hrgguru.compound_split import is_not_lexical_edge
from coli.hrgguru.hrg import HRGRule, CFGRule

wordnet_lemma = WordNetLemmatizer()

from coli.hrgguru.const_tree import Lexicon
from coli.hrgguru.hyper_graph import HyperEdge, HyperGraph


def get_lemma_and_pos(edge_label, get_start_end=False):
    label = edge_label
    lemma_start = label.find("_") + 1
    lemma_end = label.find("_", lemma_start)
    lemma_end_slash = label.rfind("/", lemma_start, lemma_end)
    if lemma_end_slash != -1:
        lemma_end = lemma_end_slash
    old_lemma = label[lemma_start:lemma_end]

    tag_end = label.find("_", lemma_end + 1)
    if tag_end != -1:
        pos = label[lemma_end + 1:tag_end]
    else:
        pos = None

    if get_start_end:
        return old_lemma, pos, lemma_start, lemma_end
    else:
        return old_lemma, pos


def transform_edge_label(label, lexicon, check=True):
    if not label.startswith("_"):
        return label
    old_lemma, pos, lemma_start, lemma_end = get_lemma_and_pos(label, True)
    if pos in ("n", "v", "a"):
        pred_lemma = wordnet_lemma.lemmatize(lexicon, pos)
        if check and old_lemma != pred_lemma:
            raise ArithmeticError("Unmatched lemma {} {}".format(old_lemma, pred_lemma))
    else:
        if check and lexicon != old_lemma:
            raise ArithmeticError("{} {} {}".format(lexicon, old_lemma, label))
    return label[:lemma_start] + "{NEWLEMMA}" + label[lemma_end:]


def transform_edge(edge, lexicon):
    if not edge.label.startswith("_"):
        return edge
    new_edge = HyperEdge(edge.nodes,
                         transform_edge_label(edge.label, lexicon),
                         edge.is_terminal, edge.span)
    return new_edge


def counter_factory():
    return defaultdict(int)


ners = {"named", "named_n", "ord", "card", "yofc", "dofw", "dofm", "mofy", "pron",
        "year_range", "numbered_hour", "fraction", "numbered_hour", "meas_np", "holiday",
        }


def anonymize_rule(sync_rule: CFGRule):
    assert isinstance(sync_rule.rhs[0][0], Lexicon) and sync_rule.hrg is not None
    assert len(sync_rule.rhs) == 1
    word = sync_rule.rhs[0][0].string
    origin_lhs = sync_rule.hrg.lhs
    sub_graph = sync_rule.hrg.rhs
    lexical_edges = [edge for edge in sub_graph.edges
                     if (edge.label.startswith("_") or edge.label in ners) and
                     edge.label not in is_not_lexical_edge]
    if len(lexical_edges) > 1:
        edge_names = set(edge.label for edge in lexical_edges)
        if "pron" in edge_names:
            lexical_edge = [i for i in lexical_edges if i.label != "pron"][0]
        elif edge_names == {"_both_q", "card"}:
            return sync_rule, None
        else:
            raise ArithmeticError(f"Too many lexical edges: {lexical_edges}")
    elif len(lexical_edges) == 0:
        return sync_rule, None
    else:
        lexical_edge = lexical_edges[0]

    # transform lexical edge into anonymous form
    label = lexical_edge.label
    lemma_mapping = None
    if label in ners:
        new_edge = lexical_edge
    else:
        old_lemma, pos, lemma_start, lemma_end = get_lemma_and_pos(label, True)
        transformed_label = label[:lemma_start] + "{NEWLEMMA}" + label[lemma_end:]
        if pos in ("n", "v", "a"):
            pred_lemma = wordnet_lemma.lemmatize(word, pos)
            lemma_mapping = ((pred_lemma, transformed_label), old_lemma)
        else:
            if word != old_lemma:
                lemma_mapping = ((word, transformed_label), old_lemma)
        new_edge = HyperEdge(lexical_edge.nodes, transformed_label,
                             lexical_edge.is_terminal, lexical_edge.span)

    new_subgraph = HyperGraph(sub_graph.nodes,
                              frozenset(new_edge
                                        if edge is lexical_edge else edge
                                        for edge in sub_graph.edges)
                              )
    standard_new_subgraph, node_map = new_subgraph.to_standardized_node_names(True)
    new_lhs = HyperEdge([node_map[i] for i in origin_lhs.nodes],
                        origin_lhs.label, origin_lhs.is_terminal,
                        origin_lhs.span)
    new_edge_standard = HyperEdge([node_map[i] for i in new_edge.nodes], new_edge.label,
                                  new_edge.is_terminal, new_edge.span)
    new_rule = HRGRule(new_lhs, standard_new_subgraph, comment=sync_rule.hrg.comment)
    return CFGRule(sync_rule.lhs, ((Lexicon("{NEWLEMMA}"), new_edge_standard),), new_rule), lemma_mapping


def anonymize_derivation(derivation):
    lemma_mappings = []
    new_derivation = []
    for sync_rule in derivation:
        if isinstance(sync_rule.rhs[0][0], Lexicon):
            if sync_rule.hrg is not None:
                lemma_mapping = None
                try:
                    sync_rule, lemma_mapping = anonymize_rule(sync_rule)
                except ArithmeticError as e:
                    # just ignore it
                    print(e)
                new_derivation.append(sync_rule)
                if lemma_mapping:
                    lemma_mappings.append(lemma_mapping)
            else:
                new_rule = CFGRule(sync_rule.lhs,
                                   ((Lexicon("{NEWLEMMA}"), None),),
                                   None)
                new_derivation.append(new_rule)
        else:
            new_derivation.append(sync_rule)
    return new_derivation, lemma_mappings


def recover_edge_label(label, word, lexicon_to_lemma):
    _, pos, lemma_start, lemma_end = get_lemma_and_pos(label, True)
    lemma = None
    if pos in ("n", "v", "a"):
        pred_lemma = wordnet_lemma.lemmatize(word, pos)
        lemmas = lexicon_to_lemma.get((pred_lemma, label))
    else:
        lemmas = lexicon_to_lemma.get((word, label))
    if lemmas is not None:
        lemma = sorted(lemmas.items(), key=lambda x: x[1], reverse=True)[0][0]
    if lemma is None:
        lemma = word
    new_label = label.replace("{NEWLEMMA}", lemma)
    return new_label
