from ..extraction import SupertagExtractor
from ..guide import GuideFactory
from ..nt_constructor import CompositionalNtConstructor
from ...supertag import Supertag, LexicalRule
from ...constituent_tree import ConstituentTree
from ....tree import ImmutableTree, Tree


def test_supertag_extraction():
    tree = ImmutableTree("(S (NP (PPER 0) (NP (ADJA 3) (NN 4))) (S|<> (VVFIN 1) (S|<> (ADV 2) ($. 5))))")
    ctree = ConstituentTree(tree)

    extract = SupertagExtractor(GuideFactory(), CompositionalNtConstructor())
    assert extract.extract_from_single_tree_recursive(ctree, extract.guides.produce(ctree), -1)[2] == \
        Tree(((Supertag(LexicalRule("S;NP;S|<>;0*101"), 1, "S", "VVFIN"), True), 1), [
            Tree(((Supertag(LexicalRule("NP;NP;*:0"), None, "NP", "PPER" ), False), 0), [
                Tree(((Supertag(LexicalRule("NP;L-NP;*0"), None, "NP", "ADJA"), False), 3), [
                    Tree(((Supertag(LexicalRule("L-NP;*"), None, None, "NN"), False), 4), [])
                ])
            ]),
            Tree(((Supertag(LexicalRule("S|<>;S|<>;*:0"), 0, None, "ADV" ), False), 2), [
                Tree(((Supertag(LexicalRule("S|<>;*"), None, None, "$."), False), 5), [])
            ])
        ])
    extract_sequence = list(extract(tree))
    assert extract_sequence == [
            (Supertag(LexicalRule("NP;NP;*:0"), None, "NP", "PPER" ), False),
            (Supertag(LexicalRule("S;NP;S|<>;0*101"), 1, "S", "VVFIN"), True),
            (Supertag(LexicalRule("S|<>;S|<>;*:0"), 0, None, "ADV" ), False),
            (Supertag(LexicalRule("NP;L-NP;*0"), None, "NP", "ADJA"), False),
            (Supertag(LexicalRule("L-NP;*"), None, None, "NN"), False),
            (Supertag(LexicalRule("S|<>;*"), None, None, "$."), False)
        ]

    tree = ImmutableTree("(ROOT+NE 0)")
    assert list(extract(tree)) == [(Supertag(LexicalRule("L-ROOT;*"), None, None, "ROOT+NE"), True)]

def test_directional_nt_extraction():
    tree = ImmutableTree("(S (NP (PPER 0) (NP (ADJA 3) (NN 4))) (S|<> (VVFIN 1) (S|<> (ADV 2) ($. 5))))")
    ctree = ConstituentTree(tree)

    extract = SupertagExtractor(GuideFactory(), CompositionalNtConstructor(["ChildIdx"]))
    assert extract.extract_from_single_tree_recursive(ctree, extract.guides.produce(ctree), -1)[2] == \
        Tree(((Supertag(LexicalRule("*/0;*/0;*/1;0*101"), 1, "S", "VVFIN"), True), 1), [
            Tree(((Supertag(LexicalRule("*/0;*/0;*:0"), None, "NP", "PPER" ), False), 0), [
                Tree(((Supertag(LexicalRule("*/0;L-*/0;*0"), None, "NP", "ADJA"), False), 3), [
                    Tree(((Supertag(LexicalRule("L-*/0;*"), None, None, "NN"), False), 4), [])
                ])
            ]),
            Tree(((Supertag(LexicalRule("*/1;*/0;*:0"), 0, None, "ADV" ), False), 2), [
                Tree(((Supertag(LexicalRule("*/0;*"), None, None, "$."), False), 5), [])
            ])
        ])