from ..nt_constructor import CompositionalNtConstructor, fanout, NtConstructorFeatures

def test_fanout():
    assert fanout({0}) == 1
    assert fanout({0,1}) == 1
    assert fanout({1,2,4}) == 2

def test_compositional_constructor():
    nts = CompositionalNtConstructor("Vanilla".split())
    assert nts.features == NtConstructorFeatures.Constituent + NtConstructorFeatures.MergedChainConstituents + NtConstructorFeatures.OldFanout + NtConstructorFeatures.FanoutChange + NtConstructorFeatures.Transport + NtConstructorFeatures.BinarizationSuffix
    assert nts("S", {1,2,4}, {1,2}, 0) == "S/2/-1/True"
    assert nts("S", {1,2,4}, {1,4}, 0) == "S/2/0/True"
    assert nts.leaf("S", {1,2,4}, 1) == "L-S/2/0/False"
    assert nts("ROOT+S", {1,2,4}, {1,2}, 0) == "ROOT+S/2/-1/True"
    assert nts("ROOT+S|<>", {1,2,4}, {1,2}, 0) == "ROOT+S|<>/2/-1/True"
    assert nts.leaf("S+NP|<>", {1,2,4}, 0) == "L-S+NP|<>/2/0/False"
    assert nts("ROOT+S|<NN,NE,VP+VFIN>", {1,2,4}, {1,2}, 0) == "ROOT+S|<NN,NE,VP+VFIN>/2/-1/True"
    
    nts = CompositionalNtConstructor("Constituent MergedChainConstituents".split())
    assert nts("ROOT+S", {1,2,4}, {1,2}, 0) == "ROOT+S"
    assert nts("ROOT+S|<NN,NE,VP+VFIN>", {1,2,4}, {1,2}, 0) == "ROOT+S"
    
    nts = CompositionalNtConstructor("Constituent MergedChainConstituents BinarizationSuffix".split())
    assert nts("ROOT+S", {1,2,4}, {1,2}, 0) == "ROOT+S"
    assert nts("ROOT+S|<NN,NE,VP+VFIN>", {1,2,4}, {1,2}, 0) == "ROOT+S|<NN,NE,VP+VFIN>"
    
    nts = CompositionalNtConstructor("Constituent BinarizationSuffix".split())
    assert nts("ROOT+S", {1,2,4}, {1,2}, 0) == "ROOT"
    assert nts("ROOT+S|<NN,NE,VP+VFIN>", {1,2,4}, {1,2}, 0) == "ROOT|<NN,NE,VP>"

    nts = CompositionalNtConstructor("ChildIdx".split())
    assert nts("ROOT+S", {1,2,4}, {1,2}, 0) == "*/0"
    assert nts("ROOT+S", {1,2,4}, {1,2}, 1) == "*/1"

    nts = CompositionalNtConstructor("Coarse".split())
    assert nts("ROOT+S", {1,2,4}, {1,2}, 0) == "R/1"
    assert nts("ROOT+S|<NN,NE,VP+VFIN>", {1,2,4}, {1,2}, 0) == "R|<N,N,V>/1"
    assert nts.leaf("S+NP|<NN,NE,VP+VFIN>", {1,2,4},1) == "L-S|<N,N,V>/2"
    assert nts("ROOT+S|<NN,,,VP+VFIN>", {1,2,4}, {1,2}, 0) == "R|<N,,,V>/1"
    assert nts("ROOT+S|<NN,.,VP+VFIN>", {1,2,4}, {1,2}, 0) == "R|<N,.,V>/1"

    nts = CompositionalNtConstructor("Coarse MergedChainConstituents".split())
    assert nts("ROOT", {1,2,4}, {1,2}, 0) == "R/1"
    assert nts("ROOT+S|<NN,NE,VP+VFIN>", {1,2,4}, {1,2}, 0) == "R+S|<N,N,V+V>/1"
    assert nts.leaf("S+NP|<NN,NE,VP+VFIN>", {1,2,4}, 1) == "L-S+N|<N,N,V+V>/2"

def test_noconstituent():
    nts = CompositionalNtConstructor("Star".split())
    assert nts("ROOT", {1,2,4}, {1,2}, 0) == "*/1"
    assert nts("ROOT+S|<NN,NE,VP+VFIN>", {1,2,4}, {1,2}, 0) == "*|<*,*,*>/1"
    assert nts.leaf("S+NP|<NN,NE,VP+VFIN>", {1,2,4}, 1) == "L-*|<*,*,*>/2"

    nts = CompositionalNtConstructor("Star MergedChainConstituents".split())
    assert nts("ROOT", {1,2,4}, {1,2}, 0) == "*/1"
    assert nts("ROOT+S|<NN,NE,VP+VFIN>", {1,2,4}, {1,2}, 0) == "*+*|<*,*,*+*>/1"
    assert nts.leaf("S+NP|<NN,NE,VP+VFIN>", {1,2,4}, 1) == "L-*+*|<*,*,*+*>/2"
    assert nts("ROOT+S|<NN,,,VP+VFIN>", {1,2,4}, {1,2}, 0) == "*+*|<*,*,*+*>/1"