from dataclasses import dataclass
from typing import List, Union, Optional

from ..tree import Tree, ImmutableTree, HEAD

def copy_without_pos(tree: Union[Tree, int], pos_container: List[str]) -> ImmutableTree:
    if len(tree) == 1 and type(tree[0]) is int:
        pos_container[tree[0]] = tree.label
        return tree[0]
    return ImmutableTree(tree.label, (copy_without_pos(c, pos_container) for c in tree))

def head_indices(tree: Tree) -> ImmutableTree:
    if len(tree) == 1 and type(tree[0]) is int:
        return None
    try:
        head_index = next(i for i, c in enumerate(tree) if c.type == HEAD)
    except StopIteration:
        head_index = None
    children = (
        None if len(child) == 1 and type(child[0]) is int else head_indices(child)
        for child in tree)
    return ImmutableTree(head_index, children)

@dataclass(init=False)
class ConstituentTree:
    constituency: Union[ImmutableTree, int]
    dependency: Union[ImmutableTree, None]
    pos: List[str]
    words: Optional[List[str]]

    def __init__(self, discodop_tree: ImmutableTree, sent: Optional[List[str]] = None):
        self.words = sent
        self.pos = [None]*len(discodop_tree.leaves())
        self.constituency = copy_without_pos(discodop_tree, self.pos)
        self.dependency = head_indices(discodop_tree)

    def leaf_children(self):
        if type(self.constituency) is int:
            return ()
        return ((i, c) for i, c in enumerate(self.constituency) if type(c) is int)

    def node_children(self):
        return (child for child in self.children() if not child.is_leaf)

    def children(self):
        if type(self.constituency) is int:
            return ()
        return (ConstituentTreeView(self, c, d) for c, d in zip(self.constituency, self.dependency))

    def __getitem__(self, index):
        assert not self.is_leaf
        return ConstituentTreeView(self, self.constituency[index], self.dependency[index])

    @property
    def yd(self):
        if type(self.constituency) is int:
            return { self.constituency }
        return set(self.constituency.leaves())

    @property
    def is_leaf(self):
        return type(self.constituency) is int

    @property
    def label(self):
        return self.constituency if self.is_leaf else self.constituency.label

    @property
    def head(self):
        if self.is_leaf:
            return self.label
        assert not self.dependency.label is None, "Constituent tree was not initialized with dependencies"
        return self[self.dependency.label].head

    @property
    def mod(self):
        assert len(self.constituency) == 2, "Constituent tree is not binary" 
        assert not self.dependency.label is None, \
            "Constituent tree was not initialized with dependencies or is not head-outward binarized"
        return self[1-self.dependency.label].head

    def bfs_leaves(self):
        queue = [self.constituency]
        while queue:
            node = queue.pop(0)
            if type(node) is int:
                yield node
            else:
                for subtree in node:
                    queue.append(subtree)

    def __str__(self):
        if self.is_leaf:
            return f"({self.pos[self.label]} {self.label})"
        child_strs = [str(c) for c in self.children()]
        if not self.dependency.label is None:
            child_strs[self.dependency.label] = "^" + child_strs[self.dependency.label]
        return f"({self.label} {' '.join(child_strs)})"

class ConstituentTreeView(ConstituentTree):
    def __init__(self, context: ConstituentTree, subtree: ImmutableTree, subtree_dependency: ImmutableTree):
        self.words = context.words
        self.pos = context.pos
        self.constituency = subtree
        self.dependency = subtree_dependency