from queue import PriorityQueue

from .lexcorpus import SupertagCorpus, Nullary, Monic, Binary, Left, Right, unlexicalize
from .plcfrs import parse_supertags
from .kbest import lazykbest
from .tree import Tree
from .containers import Grammar
from math import log

from typing import List, Iterable
from numpy import array, newaxis

def fanout(tag):
    if type(tag) is Nullary:
        return 1
    if type(tag) is Left:
        return tag.fo
    if type(tag) in (Monic, Binary):
        return len(tag.yf)

def rhss(tag):
    if type(tag) in (Monic, Binary):
        fo = sum(1 for vs in tag.yf for v in vs if v == 0)
        yield tag.rhs1, fo
    if type(tag) is Binary:
        fo = sum(1 for vs in tag.yf for v in vs if v == 1)
        yield tag.rhs2, fo

def validate_prop_annot(tag):
    import re
    rmterm = re.fullmatch(r".*?(?:_([+-]))", tag.lhs)
    if type(tag) is Nullary:
        return True
    if type(tag) in [Left, Right]:
        return rmterm is None
    if type(tag) is Monic:
        addterm = re.fullmatch(r".*?(?:_([+-]))", tag.rhs1)
        return (rmterm is None) == (addterm is None)
    if type(tag) is Binary:
        addterm1 = re.fullmatch(r".*?(?:_([+-]))", tag.rhs1)
        addterm2 = re.fullmatch(r".*?(?:_([+-]))", tag.rhs2)
        return not addterm2 is None and (rmterm is None) == (addterm1 is None)

def filter_supertags(taglist):
    from collections import Counter, defaultdict
    tags = set(taglist)
    # find agreement for fanouts per nonterminal
    fanouts = defaultdict(Counter)
    for tag in tags:
        if fo := fanout(tag):
            fanouts[tag.lhs][fo] += 1
    fanouts = { lhs: fos.most_common(1)[0][0] for lhs, fos in fanouts.items() }
    for tag in tags:
        if (type(tag) is Right or fanout(tag) == fanouts[tag.lhs]) and \
                all(fanouts.get(rhs) == fo for rhs, fo in rhss(tag)) and \
                validate_prop_annot(tag):
            yield tag

class SupertagGrammar:
    def __init__(self, corpus: SupertagCorpus, fallback_prob: float = 0.0):
        self.tags = corpus.supertags
        self.pos = corpus.pos
        self.__fallback_prob__ = fallback_prob
        self.roots = set(t.label for t in corpus.tree_corpus)
        self.sync_grammar()

    def subgrammar(self, tags):
        """ Creates a grammar from a sequence of supertags. The tags do not
            need to occur in `self`, the occurring nonterminal's fanouts also
            don't need to match to those in `self`, except the ones of `Left`
            tags, which assume the fanout as defined in `self`.
        """
        obj = SupertagGrammar.__new__(SupertagGrammar)
        obj.tags = tuple(filter_supertags(
            t if not type(t) is Left else Left(t.lhs, self.fanouts[t.lhs])
            for t in tags
        ))
        obj.pos = self.pos
        obj.roots = self.roots
        obj.__fallback_prob__ = self.fallback_prob
        obj.sync_grammar()
        return obj

    @property
    def fallback_prob(self):
        return self.__fallback_prob__

    @fallback_prob.setter
    def fallback_prob(self, value: float):
        self.plcfrs.set_fallback_prob(value)
        self.__fallback_prob__ = value

    def sync_grammar(self):
        self.fanouts = {tag.lhs: fanout(tag) for tag in self.tags}
        rules = list(set((r, 1) for tag in self.tags for r in tag.binary()))
        rules += [(((tag.pos(), "Epsilon"), (" ",)), 1) for tag in self.tags]
        rules += [((("**ROOT**", start_nt), ((0,),)), 1) for start_nt in self.roots]
        fallback_nts = set(
            tag.lhs for tag in self.tags
            if not (tag.lhs.endswith("_-") or tag.lhs.endswith("_+")))
        self.plcfrs = Grammar(rules, start="**ROOT**", fallback_prob=self.fallback_prob, fallback_nts=fallback_nts, normalize=False)
        str_nt_to_idx = { nt: idx for idx, nt in self.plcfrs.labels() if not nt.startswith("NOPARSE") }
        self.tag_to_nt = tuple(str_nt_to_idx[tag.pos()] for tag in self.tags)
        self.pos_to_nt = { tag.pos(): str_nt_to_idx[tag.pos()] for tag in self.tags }
        self.unsplit = Unsplitter(self.tags, self.plcfrs.fallback_indicators())

    def deintegerize_and_parse(self, sent, pos, tags, weights, k=1, beta=1):
        tags = (((tag, w) for (tag, w) in zip(*tags_weights)) for tags_weights in zip(tags, weights))
        pos = tuple(self.pos[nt] for nt in pos)
        return self.parse(pos, tags, k, beta)

    def todict(self):
        return self.__getstate__()

    @classmethod
    def fromdict(cls, dict):
        obj = cls.__new__(cls)
        obj.__setstate__(dict)
        return obj

    def __getstate__(self):
        return {
            "tags": self.tags,
            "pos": self.pos,
            "roots": self.roots,
            "__fallback_prob__": self.__fallback_prob__
        }

    def __setstate__(self, state):
        self.__dict__.update(state)
        self.sync_grammar()

    def parse(self, pos: List[str], tags: Iterable[Iterable[int]], k: int = 1, beta: float = 1.0, posmode: bool = False, ktags: int = 0, estimates: array = None):
        """ :param pos: a list of postags; must be subscriptable.
            :param tags: a 2-step iterable giving `ktags` supertags for each
                word.
            :param k: determines the number of derivation trees to enumerate.
            :param beta: if given, uses a beam width in the parsing process.
            :param posmode: if True, assumes string representations of
                supertags in `tags`.
            :param ktags: the number of supertags per word. If given, the
                number of items in the parsing process is estimated as
                2⋅`ktags`⋅n, where n is the length of the parsed sentence.
            :param estimates: the best supertag score for each sentence
                position. If given, computes estimates while parsing which
                assume the best supertag score for each position.
        """

        if posmode:
            tags = (((self.pos_to_nt[pos], w) for (pos, w) in poss) for poss in tags)
        else:
            tags = (((self.tag_to_nt[tag], w) for (tag, w) in word_tags) for word_tags in tags)

        if not estimates is None:
            estimates = 'Constants', estimates[:, newaxis, newaxis, newaxis].astype("double")
        chart, _ = parse_supertags(len(pos), tags, self.plcfrs, beam_beta=-log(beta),
            itemsestimate=(2*ktags*len(pos) if ktags else None), exhaustive=(estimates is None and k > 1), estimates=estimates)

        for str_t, _ in lazykbest(chart, k):
            lex_deriv = self.unsplit(Tree(str_t)[0]) # skip artificial root node
            parsetree = unlexicalize(lex_deriv, pos)
            yield parsetree

    def __str__(self):
        return str(self.plcfrs)


class Unsplitter:
    def __init__(self, tags, fallback_lhs):
        self.pos_to_tag = { tag.pos(): tag for tag in tags }
        self.skip_lhs = fallback_lhs

    def _unsplit_(self, deriv: Tree):
        # pos rule
        if len(deriv.children) == 1 and not type(deriv.children[0]) is Tree:
            tag = self.pos_to_tag.get(deriv.label, None)
            terminal = deriv.children[0]
            if type(tag) is Nullary:
                return Tree((tag, terminal), children=())
            else:
                return Tree((Nullary("NOPARSE"), terminal), children=())

        # skip unary wildcard rules
        if len(deriv.children) == 1:
            assert deriv.label in self.skip_lhs or deriv.label == "ROOT"
            return self._unsplit_(deriv.children[0])

        # binary wildcard rule
        if deriv.label in self.skip_lhs:
            return Tree(((), None), tuple(self._unsplit_(d) for d in deriv))

        # Double-lex. rule
        indicator_tag = self.pos_to_tag.get(deriv[1].label)
        if type(indicator_tag) is Right:
            t1, t2 = deriv.leaves()
            fo = 1 if t1+1 == t2 else 2
            return Tree((Left(indicator_tag.lhs, fo), t1, t2), ())

        # Binary, type(indicator) is Monic and maybe even a supertag
        # so we need to check this case befor Monic
        if type(deriv[(1,0)]) is Tree:
            # find the Binary tag's pos sybol in right successor
            indicator_tag = self.pos_to_tag[deriv[(1,0)].label]
            assert type(indicator_tag) is Binary
            terminal, = deriv[1][0].leaves()
            left = self._unsplit_(deriv[0])
            right = self._unsplit_(deriv[(1,1)])
            return Tree((indicator_tag, terminal), (left, right))

        # Monic
        if type(indicator_tag) is Monic:
            terminal = deriv[(1, 0)]
            return Tree((indicator_tag, terminal), (self._unsplit_(deriv[0]),))

        raise NotImplementedError()

    def __call__(self, deriv: Tree):
        """ Undoes the binarization of lexical rules. The returned `Tree`
            contains nodes with labels of the form
            * (lhs, t1) for single-lexical terminating rules,
            * DoubleLex(lhs, args, lengths, -1, t1, t2, 0) for double-
              lexical terminating rules,
            * SingleLex(lhs, rhs1, args, lengths, t1) for single-lexical monic
              rules, and
            * PasteUnlex(lhs, rhs1, rhs2, args, lengths, t1, 0, 0) for single-
              lexical binary rules,
            where t1, t2 are terminals in the parsed word `sentence`.
        """
        return self._unsplit_(deriv)

def test_grammar():
    from .tree import Tree
    from .grammar import lcfrsproductions
    to_ast = lambda tree, sent: lcfrsproductions(tree, list(range(len(sent))), as_tree=True)
    t = Tree("(ROOT (a,c (a 0) (c 2)) (b 1))")
    sent = "a b c".split()
    corpus = SupertagCorpus([t], [sent])
    grammar = SupertagGrammar(corpus)
    derivs = list(grammar.parse(sent, [[(tag, 0)] for tag in corpus.supertag_corpus[0]]))
    assert [to_ast(d, sent) for d in derivs] == [
        Tree( (('ROOT', 'a,c', 'b'), ((0, 1, 0),)), (
            Tree( (('a,c', 'a', 'c'), ((0,), (1,))), (
                Tree( (('a', 'Epsilon'), (0,)), ()),
                Tree( (('c', 'Epsilon'), (2,)), ()))),
            Tree( (('b', 'Epsilon'), (1,)), ())))]

    t = Tree("(ROOT (A (A1 0) (A2 2)) (B (B1 1) (B2 3)))")
    sent = "a c b d".split()
    corpus = SupertagCorpus([t], [sent])
    grammar = SupertagGrammar(corpus)
    derivs = list(grammar.parse(["A1", "B1", "A2", "B2"], [[(tag, 0)] for tag in corpus.supertag_corpus[0]]))
    assert [to_ast(d, sent) for d in derivs] == [
        Tree((("ROOT", "A", "B"), ((0,1,0,1),)), (
            Tree((("A", "A1", "A2"), ((0,), (1,))), (
                Tree((("A1", "Epsilon"), (0,)), ()),
                Tree((("A2", "Epsilon"), (2,)), ())
            )),
            Tree((("B", "B1", "B2"), ((0,), (1,))), (
                Tree((("B1", "Epsilon"), (1,)), ()),
                Tree((("B2", "Epsilon"), (3,)), ())
            ))))]

def test_parsing_on_corpus():
    from os import environ
    from .treebank import READERS
    from .treetransforms import addfanoutmarkers
    treebankfile = environ.get("TREEBANKFILE")
    if not treebankfile:
        return
    corpus = READERS['export'](treebankfile, 'utf8')
    trees = list(corpus.trees().values())
    sents = list(corpus.sents().values())
    for cur in range(len(trees)):
        # FIXME discodop doesn't parse these
        if len(sents[cur]) == 1 or cur in (5239, 7560,):
            #                              n.p.  b.a.
            continue
        tree = addfanoutmarkers(trees[cur])
        k = 1
        derivs = ()
        corpus = SupertagCorpus([tree], [sents[cur]])
        grammar = SupertagGrammar(corpus)
        pos_sent = tuple(corpus.pos[tag] for tag in corpus.pos_corpus[0])
        supertag_sent = tuple([(tag, 0)] for tag in corpus.supertag_corpus[0])
        while not trees[cur] in derivs:
            derivs = tuple(d for d in grammar.parse(pos_sent, supertag_sent, k))
            k += 1
            # don't take too long, 5 is enough for Tiger's first 10k sents
            if k >= 5:
                break
        assert trees[cur] in derivs

def test_parsing_on_combined_corpus():
    from collections import defaultdict
    from os import environ
    import pytest
    from .treebank import READERS
    from .treetransforms import addfanoutmarkers
    treebankfile = environ.get("TREEBANKFILE")
    if not treebankfile:
        pytest.skip("environment variable TREEBANKFILE not set")
    corpus = READERS['export'](treebankfile, 'utf8')
    trees = list(addfanoutmarkers(t) for t in corpus.trees().values())
    sents = list(corpus.sents().values())
    corpus = SupertagCorpus(trees, sents)
    grammar = SupertagGrammar(corpus)
    for cur in range(len(trees)):
        # FIXME discodop doesn't parse these
        if len(sents[cur]) == 1 or cur in (5239, 7560, 17631):
            #                              n.p.  b.a.  n.p.
            continue
        k = 3
        derivs = ()
        pos_sent = tuple(corpus.pos[tag] for tag in corpus.pos_corpus[cur])
        supertag_sent = tuple([(tag, 0)] for tag in corpus.supertag_corpus[cur])
        while not trees[cur] in derivs:
            derivs = grammar.parse(pos_sent, supertag_sent, k)
            k += 1
            # don't take too long, 5 is enough for Tiger's first 20k sents
            assert k < 5

def test_softmatch():
    from .tree import Tree, ParentedTree
    t = Tree("(ROOT (a,c (a 0) (c 2)) (b 1))")
    d = Tree( (('ROOT', 'a,c', 'b'), ((0, 1, 0),)), (
            Tree( (('a,c', 'a', 'c'), ((0,), (1,))), (
                Tree( (('a', 'Epsilon'), (0,)), ()),
                Tree( (('c', 'Epsilon'), (2,)), ()))),
            Tree( (('b', 'Epsilon'), (1,)), ())))
    sent = "a b c".split()
    corpus = SupertagCorpus([t], [sent])

    grammar = SupertagGrammar(corpus, fallback_prob=1e-2)
    derivs = grammar.parse(sent, [[(n, 0)] for n in corpus.supertag_corpus[0]], k=10)
    assert set(str(d) for d in derivs) == {
        "(ROOT (a,c (a 0) (c 2)) (b 1))",
        "(NOPARSE (a,c (a 0) (c 2)) (NOPARSE (b 1)))",
        "(NOPARSE (NOPARSE (NOPARSE (a 0)) (NOPARSE (b 1))) (NOPARSE (c 2)))"
    }

    t = Tree("(ROOT (A (A1 0) (A2 2)) (B (B1 1) (B2 3)))")
    d = Tree((("ROOT", "A", "B"), ((0,1,0,1),)), (
            Tree((("A", "A1", "A2"), ((0,), (1,))), (
                Tree((("A1", "Epsilon"), (0,)), ()),
                Tree((("A2", "Epsilon"), (2,)), ())
            )),
            Tree((("B", "B1", "B2"), ((0,), (1,))), (
                Tree((("B1", "Epsilon"), (1,)), ()),
                Tree((("B2", "Epsilon"), (3,)), ())
            ))))

    sent = "a c b d".split()
    corpus = SupertagCorpus([t], [sent])

    grammar = SupertagGrammar(corpus, fallback_prob=1e-2)
    derivs = list(str(t) for t in grammar.parse(["A1", "B1", "A2", "B2"], [[(tag, 0)] for tag in corpus.supertag_corpus[0]], k=100))
    assert "(ROOT (A (A1 0) (A2 2)) (B (B1 1) (B2 3)))" in derivs
    assert "(NOPARSE (NOPARSE (A (A1 0) (A2 2)) (NOPARSE (B1 1))) (NOPARSE (B2 3)))" in derivs
    assert "(NOPARSE (NOPARSE (NOPARSE (NOPARSE (A1 0)) (NOPARSE (B1 1))) (NOPARSE (A2 2))) (NOPARSE (B2 3)))" in derivs
    # TODO: fails, because there is another fallback derivation on top of
    # (ROOT (A (A1 0) (A2 2)) (B (B1 1) (B2 3))), that should be avoided
    assert len(derivs) == 3

def test_subgrammar():
    from .grammar import lcfrsproductions
    to_ast = lambda tree, sent: lcfrsproductions(tree, list(range(len(sent))), as_tree=True)
    t = Tree("(ROOT (A (A1 0) (A2 2)) (B (B1 1) (B2 3)))")
    sent = "a c b d".split()
    pos = "A1 B1 A2 B2".split()
    d = Tree((("ROOT", "A", "B"), ((0,1,0,1),)), (
            Tree((("A", "A1", "A2"), ((0,), (1,))), (
                Tree((("A1", "Epsilon"), (0,)), ()),
                Tree((("A2", "Epsilon"), (2,)), ())
            )),
            Tree((("B", "B1", "B2"), ((0,), (1,))), (
                Tree((("B1", "Epsilon"), (1,)), ()),
                Tree((("B2", "Epsilon"), (3,)), ())
            ))))

    corpus = SupertagCorpus([t], [sent])
    grammar = SupertagGrammar(corpus)

    tags = (Left("A", None), Binary("ROOT", "A", "B_-", ((0,2,0,1),)), Right("A"), Nullary("B_-"))
    subgrammar: SupertagGrammar = grammar.subgrammar(tags)
    parses = subgrammar.parse(pos, (((t.pos(), 0.0),) for t in tags), posmode=True)
    assert [to_ast(d, sent) for d in parses] == [d]

    sent = "a c b d e f".split()
    pos = "A1 B1 C1 A2 B2 C2".split()
    d = Tree((("ROOT", "A", "ROOT|<>"), ((0,1,0,1),)), (
            Tree((("A", "A1", "A2"), ((0,), (1,))), (
                Tree((("A1", "Epsilon"), (0,)), ()),
                Tree((("A2", "Epsilon"), (3,)), ())
            )),
            Tree((("ROOT|<>", "B", "C"), ((0,1), (0,1,))), (
                Tree((("B", "B1", "B2"), ((0,), (1,))), (
                    Tree((("B1", "Epsilon"), (1,)), ()),
                    Tree((("B2", "Epsilon"), (4,)), ())
                )),
                Tree((("C", "C1", "C2"), ((0,), (1,))), (
                    Tree((("C1", "Epsilon"), (2,)), ()),
                    Tree((("C2", "Epsilon"), (5,)), ())
                ))
            ))))
    tags = (
        Left("A", None),
        Binary("ROOT", "A", "ROOT|<>_+", ((0,2,1,0,1),)),
        Binary("ROOT|<>_+", "B_-", "C_-", ((2,),(0,1))),
        Right("A"),
        Nullary("B_-"),
        Nullary("C_-")
    )
    subgrammar: SupertagGrammar = grammar.subgrammar(tags)
    parses = subgrammar.parse(pos, (((t.pos(), 0.0),) for t in tags), posmode=True)
    assert [to_ast(d, sent) for d in parses] == [d]
