from abc import ABC, abstractmethod
from typing import Tuple, Set, Union

from ..constituent_tree import ConstituentTree
from ...tree import Tree, ImmutableTree


class TransportGuide(ABC):
    def __init__(self, constituent_tree: ConstituentTree):
        self.tree: ImmutableTree = None if constituent_tree.is_leaf else self.guide_tree(constituent_tree)
        transported_leaves = set(s.label for s in self.tree.subtrees()) if not self.tree is None else set()
        leftover = constituent_tree.yd - transported_leaves
        assert len(leftover) == 1
        self.untransported: int = next(iter(leftover))

    def children(self):
        return (GuideView(self, c) for c in self.tree)

    def __call__(self, position: Tuple[int]) -> int:
        return self.tree[position].label

    @property
    def transported(self):
        return self.tree.label

    @abstractmethod
    def guide_tree(self, constituent_tree: ConstituentTree, **kwargs) -> ImmutableTree:
        raise NotImplementedError()


class GuideView(TransportGuide):
    def __init__(self, parent: TransportGuide, subtree: ImmutableTree):
        self.untransported: int = parent.untransported
        self.tree: ImmutableTree = subtree

    def guide_tree(self, constituent_tree: ConstituentTree, **kwargs) -> ImmutableTree:
        raise NotImplementedError()


class GuideFactory:
    def __init__(self, guide_str: str = "Vanilla"):
        if not guide_str.endswith("Guide"):
            guide_str += "Guide"
        self.guide_t = globals()[guide_str]
        assert self.guide_t

    def produce(self, constituent_tree: ConstituentTree) -> TransportGuide:
        return self.guide_t(constituent_tree)


class VanillaGuide(TransportGuide):
    def guide_tree(self, ctree: ConstituentTree, request_terminal: bool = False) -> ImmutableTree:
        assert not ctree.is_leaf
        leaf = ctree[0].label if ctree[0].is_leaf and not request_terminal else min(ctree[1].yd)
        children = []
        for i, child in enumerate(ctree.node_children()):
            t = self.guide_tree(child, request_terminal or i == 1)
            children.append(t)
        return ImmutableTree(leaf, children)


class StrictGuide(TransportGuide):
    def guide_tree(self, ctree: ConstituentTree) -> ImmutableTree:
        assert not ctree.is_leaf
        leaf = min(ctree[1].yd)
        children = []
        for child in ctree.node_children():
            t = self.guide_tree(child)
            children.append(t)
        return ImmutableTree(leaf, children)


class LeastGuide(TransportGuide):
    def guide_tree_rec(self, ctree: ConstituentTree) -> Tuple[ImmutableTree, Set[int]]:
        children, used_leafs = [], set()
        for child, ul in (self.guide_tree_rec(c) for c in ctree.node_children()):
            children.append(child)
            used_leafs |= ul
        leaf = next(l for l in ctree.bfs_leaves() if not l in used_leafs)
        used_leafs.add(leaf)
        return ImmutableTree(leaf, children), used_leafs

    def guide_tree(self, ctree: ConstituentTree) -> ImmutableTree:
        return self.guide_tree_rec(ctree)[0]


class ShortestGuide(TransportGuide):
    @classmethod
    def closest_leaf(cls, ctree: ConstituentTree, marked_leaves: Set[int]) -> int:
        nodes = list(c for c in ctree.children() if not (c.yd & marked_leaves))
        while nodes:
            n = nodes.pop(0)
            if n.is_leaf:
                return n.label
            for c in n.children():
                nodes.append(c)

    def guide_tree_rec(self, ctree: ConstituentTree, marked_leaves: Set[int]) -> ImmutableTree:
        closest = self.__class__.closest_leaf(ctree, marked_leaves)
        children = (self.guide_tree_rec(c, marked_leaves|{closest}) for c in ctree.node_children())
        return ImmutableTree(closest, children)

    def guide_tree(self, ctree: ConstituentTree) -> ImmutableTree:
        return self.guide_tree_rec(ctree, set())


class ModifierGuide(TransportGuide):
    def guide_tree(self, ctree: ConstituentTree) -> ImmutableTree:
        return ImmutableTree(
            ctree.mod,
            (self.guide_tree(c) for c in ctree.node_children())
        )

