from collections import defaultdict, namedtuple
from functools import reduce
from itertools import product
from operator import itemgetter, mul
from queue import PriorityQueue
import numpy as np
from .treetransforms import collapseunary

import re

from .tree import Tree, escape
from .grammar import lcfrsproductions

class SupertagCorpus:
    def __init__(self, trees=[], sents=[]):
        assert len(trees) == len(sents)

        self.sent_corpus = tuple(sents)
        self.supertag_corpus = []
        self.pos_corpus = []
        self.supertags = ()
        self.pos = ()
        self.tree_corpus = tuple(trees)

        if not trees and not sents:
            return

        for tree, sent in zip(self.tree_corpus, self.sent_corpus):
            deriv = lcfrsproductions(tree, list(range(len(sent))), as_tree=True)
            word2pos = tree.pos()
            word2tag = supertags(propterm(fuseterm(deriv)))
            self.pos_corpus.append(tuple(tag for _, tag in sorted(word2pos)))
            self.supertag_corpus.append(tuple(tag for _, tag in sorted(word2tag)))
        self.__integerize__()

    def __integerize__(self):
        self.supertags = tuple(sorted(set(tag for tags in self.supertag_corpus for tag in tags)))
        supertag_to_id = { tag: i for i, tag in enumerate(self.supertags) }
        self.supertag_corpus = [tuple(supertag_to_id[tag] for tag in tags) for tags in self.supertag_corpus]
        pos_to_id = {}
        self.pos_corpus = [tuple(pos_to_id.setdefault(ptag, len(pos_to_id)) for ptag in ptags) for ptags in self.pos_corpus]
        self.pos = tuple(pos_to_id.keys())

    def __getstate__(self):
        return {
            "sent_corpus": self.sent_corpus,
            "tree_corpus": self.tree_corpus,
            "supertag_corpus": tuple(tuple(self.supertags[tag] for tag in tags) for tags in self.supertag_corpus),
            "pos_corpus": tuple(tuple(self.pos[pos] for pos in poss) for poss in self.pos_corpus),
        }

    def __setstate__(self, state):
        self.__dict__.update(state)
        self.__integerize__()

    def subcorpus(self, indices: slice):
        obj = SupertagCorpus.__new__(SupertagCorpus)
        obj.sent_corpus = self.sent_corpus[indices]
        obj.tree_corpus = self.tree_corpus[indices]
        obj.supertag_corpus = tuple(tuple(self.supertags[tag] for tag in tags) for tags in self.supertag_corpus[indices])
        obj.pos_corpus = tuple(tuple(self.pos[pos] for pos in poss) for poss in self.pos_corpus[indices])
        obj.__integerize__()
        return obj

class Left(namedtuple("Left", "lhs fo")):
    def __eq__(self, other):
        if type(other) is Left:
            return tuple(self) == tuple(other)
        return False

    def __hash__(self):
        return hash(tuple(self))

    def __lt__(self, other):
        if type(other) is Left:
            return tuple(self) < tuple(other)
        return False

    def binary(self):
        yf = ((0,1),) if self.fo == 1 else ((0,), (1,))
        yield ((self.lhs, self.pos(), Right(self.lhs).pos()), yf)

    def pos(self):
        return f"Left{{{self.lhs}}}"

class Right(namedtuple("Right", "lhs")):
    def __eq__(self, other):
        if type(other) is Right:
            return tuple(self) == tuple(other)
        return False

    def __hash__(self):
        return hash(tuple(self))

    def __lt__(self, other):
        if type(other) is Left:
            return True
        if type(other) is Right:
            return tuple(self) < tuple(other)
        return False

    def binary(self):
        return ()

    def pos(self):
        return f"Right{{{self.lhs}}}"

class Nullary(namedtuple("Nullary", "lhs")):
    def __eq__(self, other):
        if type(other) is Nullary:
            return tuple(self) == tuple(other)
        return False

    def __hash__(self):
        return hash(tuple(self))

    def __lt__(self, other):
        if type(other) in [Left, Right]:
            return True
        if type(other) is Nullary:
            return tuple(self) < tuple(other)
        return False

    def binary(self):
        return ()

    def pos(self):
        return self.lhs

def yfstr(yf):
    return ",".join("".join(str(var) for var in c) for c in yf)

def stryf(yf):
    return tuple(tuple(int(var) for var in comp) for comp in yf.split(","))

class Monic(namedtuple("Monic", "lhs rhs1 yf swap")):
    def __eq__(self, other):
        if type(other) is Monic:
            return tuple(self) == tuple(other)
        return False

    def __hash__(self):
        return hash(tuple(self))

    def __lt__(self, other):
        if type(other) in [Left, Right, Nullary]:
            return True
        if type(other) is Monic:
            return tuple(self) < tuple(other)
        return False

    def binary(self):
        yield ((self.lhs, self.rhs1, self.pos()), self.yf)

    def pos(self):
        return f"Monic{{{self.lhs},{self.rhs1},{yfstr(self.yf)},{self.swap}}}"

class Binary(namedtuple("Binary", "lhs rhs1 rhs2 yf")):
    def __eq__(self, other):
        if type(other) is Binary:
            return tuple(self) == tuple(other)
        return False

    def __hash__(self):
        return hash(tuple(self))

    def __lt__(self, other):
        if type(other) in [Left, Right, Nullary, Monic]:
            return True
        if type(other) is Binary:
            return tuple(self) < tuple(other)

    def _split_yf_(self):
        yf1, yf2 = (), ()
        for c in self.yf:
            c2 = ()
            for i in range(len(c)):
                if not (i > 0 and c[i-1] == 2 and c[i] == 1):
                    c2 += (1,) if c[i] in (1,2) else (0,)
                    yf1 += ((2-c[i],),) if c[i] else ()
                else:
                    yf1 = ((0,1),)
            yf2 += (c2,)
        return yf1, yf2

    def binary(self):
        yf1, yf2 = self._split_yf_()
        intermediate_nt = Monic(self.lhs, self.rhs1, yf2, False).pos()
        yield ((intermediate_nt, self.pos(), self.rhs2), yf1)
        yield ((self.lhs, self.rhs1, intermediate_nt), yf2)

    def pos(self):
        return f"Binary{{{self.lhs},{self.rhs1},{self.rhs2},{yfstr(self.yf)}}}"


def read_supertag(st_str):
    NT_RE = "([^,]*(?:<[^<]*>)?(?:_[0-9]+)?(?:_[+-])?)"
    m = re.fullmatch(r"Left\{(.*)\}", st_str)
    if m:
        lhs, = m.groups()
        return Left(lhs, None)
    m = re.fullmatch(r"Right\{(.*)\}", st_str)
    if m:
        lhs, = m.groups()
        return Right(lhs)
    m = re.fullmatch(NT_RE, st_str)
    if m:
        lhs, = m.groups()
        return Nullary(lhs)
    m = re.fullmatch(f"Monic\\{{{NT_RE},{NT_RE},([012,]*),([^,]*)\\}}", st_str)
    if m:
        lhs, rhs, yf, swap = m.groups()
        return Monic(lhs, rhs, stryf(yf), swap == "True")
    m = re.fullmatch(f"Binary\\{{{NT_RE},{NT_RE},{NT_RE},([012,]*)\\}}", st_str)
    if m:
        lhs, rhs1, rhs2, yf = m.groups()
        return Binary(lhs, rhs1, rhs2, stryf(yf))
    raise ValueError(f"not a supertag: {st_str}")


def fuseterm(tree):
    """ Removes rules of the form p → [t] () from the
        derivation, removes the nonterminal in the parent
        rule referring to it and replaces the variable in
        in the parent rule by t.
    """
    if not tree.children:
        ((lhs, _), yf) = tree.label
        assert len(yf) == 1
        return Tree((Nullary(lhs), 0), ())
    return _fuseterm_(tree)[0]

def _fuseterm_(tree):
    (nts, yf) = tree.label
    if len(nts) == 2:
        (term,) = yf
        return (term, True)

    (c1, terminating_1) = _fuseterm_(tree.children[0])
    (c2, terminating_2) = _fuseterm_(tree.children[1])
    if terminating_1 and terminating_2:
        fo = len(yf)
        term1, term2 = c1, c2
        return (Tree((Left(nts[0], fo), term1, term2), ()), False)
    elif terminating_1:
        term, rhs, yf = c1, nts[2], tuple(tuple(1-var for var in c) for c in yf)
        return (Tree((Monic(nts[0], rhs, yf, False), term), (c2,)), False)
    elif terminating_2:
        term, rhs = c2, nts[1]
        return (Tree((Monic(nts[0], rhs, yf, False), term), (c1,)), False)
    else:
        return (Tree((Binary(nts[0], nts[1], nts[2], yf),), (c1, c2)), False)

def propterm(tree):
    return _propterm_(tree)[0]

def _paste_(yf, index, rm, pasteindex):
    yf_ = ()
    rm_ = False
    first = [True, True]

    for c in yf:
        c_ = ()
        for var in c:
            if first[var] and var == index:
                c_ += (pasteindex,)
            if not (first[var] and rm[var]):
                c_ += (var,)
            first[var] = False
        yf_ += (c_,) if c_ else ()
        if not c_:
            rm_ = True

    if len(yf_[0]) == 0:
        rm_ = True
        yf_ = yf[1:]
    return yf_, rm_

def _nt_(lab, rm):
    return lab + ("_-" if rm else "_+")

def _propterm_(tree, require_term=False):
    if type(tree.label[0]) is Binary:
        ((lhs, rhs1, rhs2, yf),) = tree.label
        (c1, prop, rm1) = _propterm_(tree.children[0], require_term)
        (c2, insert, rm2) = _propterm_(tree.children[1], True)
        (yf, rm) = _paste_(yf, 1, (rm1, rm2), 2)
        rhs2 = _nt_(rhs2, rm2)
        if require_term:
            lhs, rhs1 = _nt_(lhs, rm), _nt_(rhs1, rm1)
        return Tree((Binary(lhs, rhs1, rhs2, yf), insert), (c1, c2)), prop, rm
    if type(tree.label[0]) is Left and require_term:
        ((lhs, fo), prop, keep) = tree.label
        lhs = _nt_(lhs, fo == 2)
        rm = fo == 2
        return Tree((Nullary(lhs), keep), ()), prop, rm
    if type(tree.label[0]) is Monic and require_term:
        (c, prop, rm1) = _propterm_(tree.children[0], True)
        ((lhs, rhs, yf, _), term), rm = tree.label, False
        swap_terminals = yf[0][0] == 1
        yf, rm = _paste_(yf, 0 if swap_terminals else None, (rm1, swap_terminals), 1)
        if swap_terminals:
            term, prop = prop, term
        lhs, rhs = _nt_(lhs, rm), _nt_(rhs, rm1)
        return Tree((Monic(lhs, rhs, yf, swap_terminals), term), (c,)), prop, rm
    else:
        assert not require_term
        return Tree(tree.label, (_propterm_(c)[0] for c in tree.children)), None, False


def supertags(tree):
    for node in tree.subtrees():
        if type(node.label[0]) is Left:
            tag, i, i2 = node.label
            yield (i2, Right(tag.lhs))
        else:
            tag, i = node.label
        yield (i, tag)


def unlexicalize(tree, pos, term=None):
    NT_RE = r"(.*?)(?:_([+-]))?"
    assert type(tree.label[0]) != Right
    if type(tree.label[0]) is Left:
        assert term is None
        assert not tree.children
        (lhs, _), term1, term2 = tree.label
        pos1, pos2 = pos[term1], pos[term2]
        return Tree(lhs, (
            Tree(pos1, (term1,)),
            Tree(pos2, (term2,))
        ))
    if type(tree.label[0]) is Nullary:
        assert not tree.children
        if term is None: # initial terminal
            if tree.label[0].lhs == "NOPARSE":
                return Tree("NOPARSE", (
                    Tree(pos[tree.label[1]], (tree.label[1],)),
                ))
            # singleton tree with rule.lhs == pos[tree.label[1]]
            return Tree(pos[tree.label[1]], (tree.label[1],))
        (lhs,), term2 = tree.label
        lhs, _ = re.fullmatch(NT_RE, lhs).groups()
        pos1, pos2 = pos[term], pos[term2]
        return Tree(lhs, (
            Tree(pos1, (term,)),
            Tree(pos2, (term2,))
        ))
    if type(tree.label[0]) is Monic:
        (lhs, rhs1, yf, swap), own_term = tree.label
        assert len(tree.children) == 1
        if term is not None:
            lhs, _ = re.fullmatch(NT_RE, lhs).groups()
            if swap:
                term, own_term = own_term, term
                terminating_child = Tree(pos[own_term], (own_term,))
                recursive_child = unlexicalize(tree.children[0], pos, term)
                return Tree(lhs, (terminating_child, recursive_child))
            else:
                terminating_child = Tree(pos[own_term], (own_term,))
                recursive_child = unlexicalize(tree.children[0], pos, term)
                return Tree(lhs, (recursive_child, terminating_child))
        else:
            assert swap == False
            if yf[0][0] == 0:
                terminating_child = Tree(pos[own_term], (own_term,))
                recursive_child = unlexicalize(tree.children[0], pos, term)
                return Tree(lhs, (recursive_child, terminating_child))
            else:
                terminating_child = Tree(pos[own_term], (own_term,))
                recursive_child = unlexicalize(tree.children[0], pos, term)
                return Tree(lhs, (terminating_child, recursive_child))
    if type(tree.label[0]) is Binary:
        (lhs, rhs1, rhs2, yf), term1 = tree.label
        assert len(tree.children) == 2
        lhs, _ = re.fullmatch(NT_RE, lhs).groups()
        child1 = unlexicalize(tree.children[0], pos, term)
        child2 = unlexicalize(tree.children[1], pos, term1)
        return Tree(lhs, (child1, child2))
    if tree.label[1] is None:
        child1 = unlexicalize(tree.children[0], pos)
        child2 = unlexicalize(tree.children[1], pos)
        return Tree("NOPARSE", (child1, child2))
    raise ValueError(f"invalid derivation: {tree}")


def test_fuseterm():
    t = Tree((("a", "b", "c"), ((0, 1),)), (
            Tree((("b", "Epsilon"), (0,)), ()),
            Tree((("c", "Epsilon"), (1,)), ())))
    assert fuseterm(t) == Tree((Left("a", 1), 0, 1), ())
    t = Tree((("abc", "a,c", "b"), ((0, 1, 0),)), (
            Tree((("a,c", "a", "c"), ((0,), (1,))), (
                Tree((("a", "Epsilon"), (0,)), ()),
                Tree((("c", "Epsilon"), (2,)), ()))),
            Tree((("b", "Epsilon"), (1,)), ())))
    assert fuseterm(t) == Tree((Monic("abc", "a,c", ((0,1,0),), False), 1), (
                            Tree((Left("a,c", 2), 0, 2), ()),))


def test_propterm_paste():
    t = Tree((Binary("abcd", "ab", "cd", ((0,1),)),), (
            Tree((Left("ab", 1), "a", "b"), ()),
            Tree((Left("cd", 1), "c", "d"), ())))
    assert propterm(t) == \
        Tree((Binary("abcd", "ab", "cd_+", ((0, 2, 1),)), "c"), (
            Tree((Left("ab", 1), "a", "b"), ()),
            Tree((Nullary("cd_+"), "d"), ())))


def test_propterm_propleft():
    t = Tree(
        (Binary("NP_2", "NP_1", "PP", ((0,), (1,))),),
        (
            Tree((Left("NP_1", 1), "A", "hearing"), ()),
            Tree(
                (Monic("PP", "NP_1", ((1, 0),), False), "on"),
                (
                    Tree((Left("NP_1", 1), "the", "issue"), ()),
                )
            ),
        )
    )
    p = Tree(
        (Binary("NP_2", "NP_1", "PP_+", ((0,), (2, 1))), "on"),
        (
            Tree((Left("NP_1", 1), "A", "hearing"), ()),
            Tree(
                (Monic("PP_+", "NP_1_+", ((1, 0),), True), "the"),
                (
                    Tree((Nullary("NP_1_+"), "issue"), ()),
                )
            )
        )
    )
    assert propterm(t) == p


def test_propterm_propright():
    t = Tree(
        (Binary("NP_2", "NP_1", "PP", ((0,), (1,))),),
        (
            Tree((Left("NP_1", 1), "A", "hearing"), ()),
            Tree(
                (Monic("PP", "NP_1", ((0, 1),), False), "on"),
                (
                    Tree((Left("NP_1", 1), "the", "issue"), ()),
                )
            )
        )
    )
    p = Tree(
        (Binary("NP_2", "NP_1", "PP_+", ((0,), (2, 1))), "the"),
        (
            Tree((Left("NP_1", 1), "A", "hearing"), ()),
            Tree(
                (Monic("PP_+", "NP_1_+", ((0, 1),), False), "on"),
                (
                    Tree((Nullary("NP_1_+"), "issue"), ()),
                )
            ),
        )
    )
    assert propterm(t) == p


def test_propterm_pasteprop():
    t = Tree(
        (Binary("Root", "Fake", "NP_2", ((0, 1, 0, 1),)),),
        (
            Tree((Left("Fake", 2), "fake1", "fake2"), ()),
            Tree(
                (Binary("NP_2", "NP_1", "PP", ((0,), (1,))),),
                (
                    Tree((Left("NP_1", 1), "A", "hearing"), ()),
                    Tree(
                        (Monic("PP", "NP_1", ((1, 0),), False), "on"),
                        (
                            Tree((Left("NP_1", 1), "the", "issue"), ()),
                        )
                    ),
                )
            ),
        )
    )
    p = Tree(
        (Binary("Root", "Fake", "NP_2_+", ((0, 2, 1, 0, 1),)), "A"),
        (
            Tree((Left("Fake", 2), "fake1", "fake2"), ()),
            Tree(
                (Binary("NP_2_+", "NP_1_+", "PP_+", ((0,), (2, 1))), "on"),
                (
                    Tree((Nullary("NP_1_+"), "hearing"), ()),
                    Tree(
                        (Monic("PP_+", "NP_1_+", ((1, 0),), True), "the"),
                        (
                            Tree((Nullary("NP_1_+"), "issue"), ()),
                        )
                    ),
                )
            ),
        )
    )
    assert propterm(t) == p


def test_supertags():
    t = Tree((Binary("abcd", "ab", "cd", ((0, 2, 1),)), 2), (
            Tree((Left("ab", 1), 0, 1), ()),
            Tree((Nullary("cd"), 3), ())))

    assert set(supertags(t)) == \
        set((
            (0, Left("ab", 1)),
            (1, Right("ab")),
            (2, Binary("abcd", "ab", "cd", ((0, 2, 1),))),
            (3, Nullary("cd")),
        ))

    assert set(r for _, tag in supertags(t) for r in tag.binary()) == \
        set((
            (("abcd", "ab", "Monic{abcd,ab,01,False}"), ((0, 1),)),
            (("Monic{abcd,ab,01,False}", "Binary{abcd,ab,cd,021}", "cd"), ((0, 1),)),
            (("ab", "Left{ab}", "Right{ab}"), ((0, 1),))
        ))

    assert [tag.pos() for _, tag in sorted(supertags(t))] == \
        [ ("Left{ab}"), ("Right{ab}"),
          ("Binary{abcd,ab,cd,021}"),
          ("cd") ]

    t = Tree((Binary("ROOT", "A", "B_-", ((0, 2, 0, 1),)), 1), (
            Tree((Left("A", 2), 0, 2), ()),
            Tree((Nullary("B_-"), 3), ())))

    assert tuple(tag for _, tag in sorted(supertags(t))) == \
        (
            Left("A", 2),
            Binary("ROOT", "A", "B_-", ((0, 2, 0, 1),)),
            Right("A"),
            Nullary("B_-"),
        )

    assert set(r for _, tag in supertags(t) for r in tag.binary()) == \
        set((
            (("ROOT", "A", "Monic{ROOT,A,0101,False}"), ((0, 1, 0, 1),)),
            (("Monic{ROOT,A,0101,False}", "Binary{ROOT,A,B_-,0201}", "B_-"), ((0,), (1,))),
            (("A", "Left{A}", "Right{A}"), ((0,), (1,),))
        ))

    assert [tag.pos() for _, tag in sorted(supertags(t))] == \
        [ "Left{A}",
          "Binary{ROOT,A,B_-,0201}",
          "Right{A}",
          "B_-" ]


def test_fuseterm_propterm_on_corpus():
    from os import environ
    from .treebank import READERS
    treebankfile = environ.get("TREEBANKFILE")
    if not treebankfile:
        return
    corpus = READERS['export'](treebankfile, 'utf8')
    trees = list(corpus.trees().values())
    sents = list(corpus.sents().values())
    for tree, sent in zip(trees, sents):
        ast = lcfrsproductions(tree, list(range(len(sent))), as_tree=True)
        # FIXME:
        # get_pos was removed in 1443af8eaf565d0f282389f08daa309e362a8c48
        pos = get_pos(ast)
        lexed = propterm(fuseterm(ast))
        unlexed = unlexicalize(lexed)
        assert tree == unlexed
