from pytest import raises

from ..constituent_tree import ConstituentTree, head_indices
from ...tree import ImmutableTree, Tree, HEAD


def test_constituent_tree():
    tree = ImmutableTree("(S (NP (PPER 0) (NP (ADJA 3) (NN 4))) (VVFIN 1) (ADV 2) ($. 5))"), "pper vvfin adv adja nn .".split()
    tree = ConstituentTree(*tree)
    assert tree.label == "S"
    assert tree[0].label == "NP"
    assert tree[(0,0)].label == 0
    assert tree[(0,0)].is_leaf

    assert tree.yd == set(range(6))
    assert tree[0].yd == { 0, 3, 4 }
    assert tree[(0,0)].yd == { 0 }

    assert tree.constituency == ImmutableTree("(S (NP 0 (NP 3 4)) 1 2 5)")
    assert tree.pos == "PPER VVFIN ADV ADJA NN $.".split()
    assert tree.words == "pper vvfin adv adja nn .".split()
    assert list(tree.bfs_leaves()) == [1, 2, 5, 0, 3, 4]
    assert list(tree[0].bfs_leaves()) == [0, 3, 4]
    assert list(tree[(0, 1)].bfs_leaves()) == [3, 4]


def test_singleton_tree():
    tree = ImmutableTree("(ROOT+$. 0)"), ["."]
    tree = ConstituentTree(*tree)
    assert tree.label == 0
    assert tree.yd == { 0 }
    assert tree.constituency == 0
    assert tree.pos == ["ROOT+$."]
    assert tree.words == ["."]
    assert list(tree.bfs_leaves()) == [0]


def test_head_indices():
    tree = Tree("(S (NP (PPER 0) (NP (ADJA 3) (NN 4))) (VVFIN 1) (ADV 2) ($. 5))")
    tree[1].type = HEAD
    tree[(0,1)].type = HEAD
    tree[(0, 1, 1)].type = HEAD

    assert head_indices(tree) == ImmutableTree(1, [ImmutableTree(1, [None, ImmutableTree(1, [None, None])]), None, None, None])


def test_dependencies():
    tree = Tree("(S (NP (PPER 0) (NP (ADJA 3) (NN 4))) (S|<> (S|<> (VVFIN 1) (ADV 2)) ($. 5)))")
    ctree = ConstituentTree(tree)
    assert ctree.dependency == ImmutableTree(None, [
        ImmutableTree(None, [None, ImmutableTree(None, [None, None])]),
        ImmutableTree(None, [ImmutableTree(None, [None, None]), None])])
    with raises(AssertionError) as _:
        ctree.head
    with raises(AssertionError) as _:
        ctree.mod

    tree[1].type = HEAD
    tree[(0,1)].type = HEAD
    tree[(0,1,1)].type = HEAD
    tree[(1,0)].type = HEAD
    tree[(1,0,0)].type = HEAD
    ctree = ConstituentTree(tree)

    assert ctree.head == 1
    assert ctree.mod == 4
    assert ctree[0].head == 4
    assert ctree[0].mod == 0


def test_dependency_binarization():
    from ...treetransforms import binarize
    tree = Tree("(S (NP (PPER 0) (NP (ADJA 3) (NN 4))) (VVFIN 1) (ADV 2) ($. 5))")
    tree[1].type = HEAD
    tree[(0,1)].type = HEAD
    tree[(0, 1, 1)].type = HEAD
    btree = binarize(tree, headoutward=True, horzmarkov=0, vertmarkov=1)
    assert btree == Tree("(S (NP (PPER 0) (NP (ADJA 3) (NN 4))) (S|<> (S|<> (VVFIN 1) (ADV 2)) ($. 5)))")
    assert btree[(1)].type == HEAD
    assert btree[(0,1)].type == HEAD
    assert btree[(0,1,1)].type == HEAD
    assert btree[(1,0)].type == HEAD
    assert btree[(1,0,0)].type == HEAD


def test_str():
    tree = Tree("(S (NP (PPER 0) (NP (ADJA 3) (NN 4))) (S|<> (S|<> (VVFIN 1) (ADV 2)) ($. 5)))")
    assert str(ConstituentTree(tree)) == str(tree)

    tree[1].type = HEAD
    tree[(0,1)].type = HEAD
    tree[(0,1,1)].type = HEAD
    tree[(1,0)].type = HEAD
    tree[(1,0,0)].type = HEAD
    assert str(ConstituentTree(tree)) == "(S (NP (PPER 0) ^(NP (ADJA 3) ^(NN 4))) ^(S|<> ^(S|<> ^(VVFIN 1) (ADV 2)) ($. 5)))"