from ..supertag import Supertag, LexicalRule
from ..grammar import SupertagGrammar
from ...tree import Tree
from pickle import dumps, loads

def test_grammar():
    tags = [
            Supertag(LexicalRule("NP_2;NP;*:0"), None, "NP", "PPER" ),
            Supertag(LexicalRule("S;NP_2;S|<>_2;0*101"), 1, "S", "VVFIN"),
            Supertag(LexicalRule("S|<>_2;S|<>;*:0"), 0, "S|<>", "ADV" ),
            Supertag(LexicalRule("NP;LEAF;*0"), None, "NP", "ADJA"),
            Supertag(LexicalRule("LEAF;*"), None, None, "NN"),
            Supertag(LexicalRule("S|<>;*"), None, "S|<>", "$.")
        ]
    grammar = SupertagGrammar(tuple(tags), ("S",))
    derivs = list(grammar.parse([[(tag, 0)] for tag in range(len(tags))]))
    assert derivs == [Tree("(S (NP (PPER 0) (NP (ADJA 3) (NN 4))) (VVFIN 1) (ADV 2) ($. 5))")]

    grammar = loads(dumps(grammar))
    derivs = list(grammar.parse([[(tag, 0)] for tag in range(len(tags))]))
    assert derivs == [Tree("(S (NP (PPER 0) (NP (ADJA 3) (NN 4))) (VVFIN 1) (ADV 2) ($. 5))")]

    override_c = ["np", "s", "s|<>", "np", None, "s|<>"]
    override_pos = ["pper", "vvfin", "adv", "adja", "nn", "s."]
    derivs = list(grammar.parse([[(tag, 0)] for tag in range(len(tags))], override_pos, override_c))
    assert derivs == [Tree("(s (np (pper 0) (np (adja 3) (nn 4))) (vvfin 1) (adv 2) (s. 5))")]


def test_softmatch1():
    tags = [
            Supertag(LexicalRule("a,c;LEAF;*:0"), None, "a,c", None),
            Supertag(LexicalRule("ROOT;a,c;0*0"), None, "ROOT", None),
            Supertag(LexicalRule("LEAF;*"), None, None, None),
        ]
    grammar = SupertagGrammar(tuple(tags), ("ROOT",), _fallback_prob=1e-2)

    derivs = grammar.parse([[(n, 0)] for n in range(3)], pos=["a", "b", "c"], ktrees=10)
    assert set(str(d) for d in derivs) == {
        "(ROOT (a,c (a 0) (c 2)) (b 1))",
        "(NOPARSE (a,c (a 0) (c 2)) (b 1))",
        "(NOPARSE (NOPARSE (a 0) (b 1)) (c 2))"
    }


def test_softmatch2():
    tags = [
            Supertag(LexicalRule("A;LEAF;*:0"), None, "A", None),
            Supertag(LexicalRule("ROOT;A;B;0*01"), 1, "ROOT", None),
            Supertag(LexicalRule("LEAF;*"), None, None, None),
            Supertag(LexicalRule("B;*"), None, "B", None),
        ]
    grammar = SupertagGrammar(tuple(tags), ("ROOT",), _fallback_prob=1e-2)

    derivs = list(str(t) for t in grammar.parse([[(tag, 0)] for tag in range(4)], pos=["A1", "B1", "A2", "B2"], ktrees=100))
    assert "(ROOT (A (A1 0) (A2 2)) (B (B1 1) (B2 3)))" in derivs
    assert "(NOPARSE (NOPARSE (A (A1 0) (A2 2)) (B1 1)) (B2 3))" in derivs
    assert "(NOPARSE (NOPARSE (NOPARSE (A1 0) (B1 1)) (A2 2)) (B2 3))" in derivs
    # TODO: fails, because there is another fallback derivation on top of
    # (ROOT (A (A1 0) (A2 2)) (B (B1 1) (B2 3))), that should be avoided
    print(derivs)
    assert len(derivs) == 3