from pytest import raises

from ..guide import VanillaGuide, StrictGuide, LeastGuide, ModifierGuide, ShortestGuide
from ...constituent_tree import ConstituentTree
from ....tree import ImmutableTree, Tree, HEAD

def test_inorder_guide():
    tree = ImmutableTree("(S (NP (PPER 0) (NP (ADJA 3) (NN 4))) (S|<> (VVFIN 1) (S|<> (ADV 2) ($. 5))))")
    tree = ConstituentTree(tree)
    guide = VanillaGuide(tree)

    assert guide.tree == \
        ImmutableTree(1, [
            ImmutableTree(0, [
                ImmutableTree(3, [])
            ]),
            ImmutableTree(2, [
                ImmutableTree(5, [])
            ])
        ])
    assert guide(()) == 1
    assert guide((0,0)) == 3
    assert guide.untransported == 4

    tree = ConstituentTree(ImmutableTree("(ROOT+NE 0)"))
    guide = VanillaGuide(tree)
    assert guide.tree == None
    assert guide.untransported == 0

def test_inorder_guide_mod():
    tree = ImmutableTree("(S (NP (PPER 0) (NP (ADJA 3) (NN 4))) (S|<> (VVFIN 1) (S|<> (ADV 2) ($. 5))))")
    tree = ConstituentTree(tree)
    guide = StrictGuide(tree)
    assert guide.tree == \
        ImmutableTree(1, [
            ImmutableTree(3, [
                ImmutableTree(4, [])
            ]),
            ImmutableTree(2, [
                ImmutableTree(5, [])
            ])
        ])
    assert guide(()) == 1
    assert guide((0,0)) == 4
    assert guide.untransported == 0

def test_least_guide():
    tree = ImmutableTree("(S (NP (PPER 0) (NP (ADJA 3) (NN 4))) (S|<> (VVFIN 1) (S|<> (ADV 2) ($. 5))))")
    tree = ConstituentTree(tree)
    guide = LeastGuide(tree)

    assert guide.tree == \
        ImmutableTree(4, [
            ImmutableTree(0, [
                ImmutableTree(3, [])
            ]),
            ImmutableTree(1, [
                ImmutableTree(2, [])
            ])
        ])

def test_shortest_guide():
    tree = ImmutableTree("(S (NP (PPER 0) (NP (ADJA 3) (NN 4))) (S|<> (VVFIN 1) (S|<> (ADV 2) ($. 5))))")
    tree = ConstituentTree(tree)
    guide = ShortestGuide(tree)

    assert guide.tree == \
        ImmutableTree(0, [
            ImmutableTree(3, [
                ImmutableTree(4, [])
            ]),
            ImmutableTree(1, [
                ImmutableTree(2, [])
            ])
        ])

def test_modifier_guide():
    tree = Tree("(S (NP (PPER 0) (NP (ADJA 3) (NN 4))) (S|<> (S|<> (VVFIN 1) (ADV 2)) ($. 5)))")
    ctree = ConstituentTree(tree)
    with raises(AssertionError) as _:
        ModifierGuide(ctree)

    tree[1].type = HEAD
    tree[(0,1)].type = HEAD
    tree[(0,1,1)].type = HEAD
    tree[(1,0)].type = HEAD
    tree[(1,0,0)].type = HEAD
    ctree = ConstituentTree(tree)
    guide = ModifierGuide(ctree)

    assert guide.tree == \
        ImmutableTree(4, [
            ImmutableTree(0, [
                ImmutableTree(3, [])]),
            ImmutableTree(5, [
                ImmutableTree(2, [])
            ])
        ])