import os
import re
from collections import namedtuple
from itertools import zip_longest
from random import Random

from coli.hrgguru.const_tree import Lexicon, ConstTree

DONT_STRIP = 0
STRIP_ALL_LABELS = 1
STRIP_TO_UNLABEL = 2
FUZZY_TREE = 3
STRIP_INTERNAL_LABELS = 5
STRIP_TO_HEADS = 6


def span_overlap(a, b):
    return a != b and a[0] >= b[0] and a[1] <= b[1]


def merge_punct_hyphen(tree, hg):
    span = tree.span
    words = list(tree.generate_words())
    preterminals = list(tree.generate_preterminals())
    if all(i.span == span or i.span[1] - i.span[0] == 0
           for i in words) and sum(1 for i in words if i.span == span) >= 2:
        lexicon = Lexicon("".join(i.string for i in words))
        lexicon.span = span
        tree.child = [lexicon]
    elif len(preterminals) == 4 and words[1].string == "-" and \
            preterminals[3].tag.startswith("punct") and \
            words[0].span[0] == words[2].span[0] and \
            words[0].span[1] == words[2].span[1] - 1:
        span_1 = (span[0], (span[1] - span[0]) // 2)
        span_2 = ((span[1] - span[0]) // 2, span[1])
        left_tree = ConstTree(tree.tag)
        left_lexicon = Lexicon("".join(i.string for i in words[:3]))
        left_lexicon.span = span_1
        left_tree.span = span_1
        left_tree.child = [left_lexicon]
        right_tree = ConstTree(preterminals[3].tag)
        right_lexicon = Lexicon(words[3].string)
        right_tree.child = [right_lexicon]
        right_tree.span = span_2
        right_lexicon.span = span_2
        tree.child = [left_tree, right_tree]
        for edge in hg.edges:
            if edge.span == words[0].span or edge.span == words[2].span:
                edge.span = span
    elif isinstance(tree.child[0], ConstTree):
        for i in tree.child:
            assert isinstance(i, ConstTree)
            merge_punct_hyphen(i, hg)


def fix_span_error(tree, hg):
    words = list(tree.generate_words())
    preterminals = list(tree.generate_preterminals())
    if len(preterminals) == 4 and words[1].string == "-" and \
            preterminals[3].tag.startswith("punct") and \
            words[0].span[0] == words[2].span[0] and \
            words[0].span[1] == words[2].span[1] - 1:
        for rule in tree.generate_rules():
            rule.span = words[2].span
        for word in tree.generate_rules():
            word.span = words[2].span
        for edge in hg.edges:
            if edge.span == words[0].span:
                edge.span = words[2].span
    elif isinstance(tree.child[0], ConstTree):
        for i in tree.child:
            assert isinstance(i, ConstTree)
            fix_span_error(i, hg)


def split_punct_hyphen(tree, hg):
    # TODO: finish split_punct_hyphen
    raise NotImplementedError
    words = list(tree.generate_words())
    preterminals = list(tree.generate_preterminals())
    if len(words) == 3 and words[1].string == "-" and words[1].span[1] - words[1].span[0] == 0:
        # TODO: attach nodes
        new_left_span_size = len(words[0].string)
        new_left_span = (words[0].span[0], words[0].span[0] + new_left_span_size)
        new_right_span = (new_left_span[1], words[2].span[1])
        tree.children[0].span = new_left_span
        tree.children[1].span = new_right_span
        preterminals[0].span = new_left_span
        preterminals[1].span = (new_left_span[1], new_left_span[1])
        preterminals[2].span = new_right_span
        words[0].span = new_left_span
        words[1].span = (new_left_span[1], new_left_span[1])
        words[2].span = new_right_span
    elif len(preterminals) == 4 and words[1].string == "-" and \
            preterminals[3].tag.startswith("punct") and \
            words[0].span[0] == words[2].span[0] and \
            words[0].span[1] == words[2].span[1] - 1:
        new_left_span_size = len(words[0].string)
        new_left_span = (words[0].span[0], words[0].span[0] + new_left_span_size)
        new_right_span = (new_left_span[1], words[3].span[1])
        tree.children[0].span = new_left_span
        tree.children[1].span = new_right_span
        preterminals[0].span = new_left_span
        preterminals[1].span = (new_left_span[1], new_left_span[1])
        preterminals[2].span = new_right_span
        preterminals[3].span = (new_right_span[1], new_right_span[1])
        words[0].span = new_left_span
        words[1].span = (new_left_span[1], new_left_span[1])
        words[2].span = new_right_span
        words[3].span = (new_right_span[1], new_right_span[1])

        for edge in hg.edges:
            if edge.span == words[0].span:
                edge.span = new_left_span
            if edge.span == words[2].span and edge.label != "compound":
                edge.span = new_right_span

    elif isinstance(tree.child[0], ConstTree):
        for i in tree.child:
            assert isinstance(i, ConstTree)
            split_punct_hyphen(i, hg)


punct_hyphen_fixer = {"merge": merge_punct_hyphen, "split": split_punct_hyphen,
                      "fix-error": fix_span_error,
                      "none": lambda *args, **kwargs: None}


def strip_label(tree):
    if isinstance(tree, ConstTree):
        tree.tag = tree.tag.split("_")[0]
        for i in tree.child:
            strip_label(i)


def strip_label_internal(tree):
    if isinstance(tree, ConstTree) and isinstance(tree.child[0], ConstTree):
        tree.tag = tree.tag.split("_")[0]
        for i in tree.child:
            strip_label_internal(i)


def strip_unary(node):
    while len(node.child) == 1 and \
            isinstance(node.child[0], ConstTree) and node.tag == node.child[0].tag:
        node.child = node.child[0].child
    for sub_tree in node.child:
        if isinstance(sub_tree, ConstTree):
            strip_unary(sub_tree)


def strip_to_unlabel(node):
    while len(node.child) == 1 and isinstance(node.child[0], ConstTree):
        node.child = node.child[0].child
    node.tag = "X"
    for sub_tree in node.child:
        if isinstance(sub_tree, ConstTree):
            strip_to_unlabel(sub_tree)


class ToHeadStripper(object):
    Rule = namedtuple("Rule", ["child_count", "head"])

    def __init__(self, rule_file=os.path.dirname(__file__) + "/rules.hds"):
        self.rule_file = rule_file

    def load_rules(self):
        results = {}
        with open(self.rule_file) as f:
            for line in f:
                name, child_count, head = line.strip().split()
                results[name] = self.Rule(int(child_count), int(head))
        return results

    @property
    def rules(self):
        if not hasattr(self, "_rules"):
            self._rules = self.load_rules()
        return self._rules

    # sometimes there are special labels such as "house_n2"
    label_allow = {"n", "generic", "punct", "aj", "v", "p", "av", "cm", "pp", "d", "c", "pt", "x"}
    special_tag_mapping = {"a": "aj", "adv": "av"}
    normal_tag_mapping = {"genericname": "generic", "hasnt": "v", "hadnt": "v", "havent": "v",
                          "have": "v", "have-poss": "v",
                          "all": "n", "sharply": "av"}

    def modify(self, tree):
        if isinstance(tree.child[0], Lexicon):
            orig_tag = tree.tag
            assert len(tree.child) == 1
            if tree.tag == "but_np_not_conj":
                tree.tag = "c"
                print("rewrite but_np_not_conj to c")
            elif re.match(r"^.*\w{1,2}\d$", tree.tag):
                new_tag = tree.tag.rsplit("_", 1)[1].rstrip("0123456789")
                if new_tag in self.special_tag_mapping:
                    new_tag = self.special_tag_mapping[new_tag]
                print(f"rewrite {tree.tag} to {new_tag}")
                tree.tag = new_tag
            else:
                tree.tag = tree.tag.split("_")[0]
                if tree.tag in self.normal_tag_mapping:
                    tree.tag = self.normal_tag_mapping[tree.tag]

            if tree.tag not in self.label_allow:
                print(f"Invalid label {orig_tag} -> {tree.tag} for {tree.child[0]}")
            return tree
        else:
            sub_tree_heads = []
            for sub_tree in tree.child:
                s_head = self.modify(sub_tree)
                sub_tree_heads.append(s_head)
            if "root" in tree.tag.lower() or len(tree.child) == 1:
                child_count = 1
                head_idx = 0
            elif len(tree.child) == 2 and "punct" in tree.child[0].tag:
                child_count = 2
                head_idx = 1
            elif len(tree.child) == 2 and "punct" in tree.child[1].tag:
                child_count = 2
                head_idx = 0
            else:
                child_count, head_idx = self.rules[tree.tag]
            assert len(tree.child) == child_count
            head = sub_tree_heads[head_idx]
            tree.tag = head.tag
            return head


strip_to_heads = ToHeadStripper().modify


def fuzzy_cfg(cfg, names):
    random_obj = Random(45)
    spans = {i[0] for i in names}
    words = list(cfg.generate_words())

    def wrap_word(span):
        ret = ConstTree("X")
        ret.word_span = span
        ret.child.append(words[span[0]])
        return ret

    def make_sub_tree(span):
        ret = ConstTree("X")
        ret.word_span = span
        if span[1] - span[0] == 1:
            return wrap_word(span)
        else:
            return ret

    sub_trees = [make_sub_tree(i) for i in spans]
    sub_trees.sort(key=lambda x: x.word_span[1] - x.word_span[0], reverse=True)

    top_trees = []
    while len(sub_trees) > 1:
        this_tree = sub_trees[-1]
        parent_tree = None
        for other_tree in sub_trees[:-1]:
            if span_overlap(this_tree.word_span, other_tree.word_span):
                if parent_tree is None or span_overlap(other_tree.word_span, parent_tree.word_span):
                    parent_tree = other_tree
        if parent_tree is None:
            top_trees.append(this_tree)
        else:
            parent_tree.child.append(this_tree)
        sub_trees.pop()

    if len(sub_trees) == 0:
        root = sub_trees[0]
        if root.word_span[1] - root.word_span[0] != len(words):
            new_root = ConstTree("X")
            new_root.child.append(root)
            root = new_root
    else:
        root = ConstTree("X")
        root.word_span = (0, len(words))
        root.child = sub_trees

    def sort_and_fill_blank(node):
        if not node.child:
            node.child = [wrap_word((i, i + 1)) for i in range(*node.word_span)]
        elif isinstance(node.child[0], ConstTree):
            node.child.sort(key=lambda x: x.word_span)
            new_child_list = []
            for i in range(node.word_span[0], node.child[0].word_span[0]):
                new_child_list.append(wrap_word((i, i + 1)))
            for child_node, next_child_node in zip_longest(node.child, node.child[1:]):
                new_child_list.append(child_node)
                end = next_child_node.word_span[0] if next_child_node is not None else node.word_span[1]
                for i in range(child_node.word_span[1], end):
                    new_child_list.append(wrap_word((i, i + 1)))
            origin_children = node.child
            node.child = new_child_list
            for child in origin_children:
                sort_and_fill_blank(child)

    sort_and_fill_blank(root)

    def random_merge(node):
        children = node.child
        for child_node in children:
            if isinstance(child_node, ConstTree):
                random_merge(child_node)
            else:
                assert len(children) == 1
        while len(children) > 2:
            idx = random_obj.randint(0, len(children) - 2)
            tree_a = children[idx]
            tree_b = children[idx + 1]
            new_tree = ConstTree("X")
            new_tree.word_span = (tree_a.word_span[0], tree_b.word_span[1])
            new_tree.child = [tree_a, tree_b]
            children[idx] = new_tree
            children.pop(idx + 1)

    random_merge(root)
    root.populate_spans_internal()
    return root
