from typing import Tuple, Iterable, Union, Set

from ...tree import Tree, ImmutableTree
from ...treetransforms import binarize, collapseunary
from ..constituent_tree import ConstituentTree
from ..supertag import Supertag, LexicalRule
from .guide import TransportGuide, GuideFactory
from .nt_constructor import NtConstructor


class SupertagExtractor:
    def __init__(self, guides: GuideFactory, nt: NtConstructor, **binarization_parameters):
        self.binarize = lambda tree: binarize(tree, **binarization_parameters)
        self.guides = guides
        self.nt = nt

    def __call__(self, tree: ImmutableTree) -> Iterable[Supertag]:
        constituent_tree = ConstituentTree(self.binarize(collapseunary(Tree.convert(tree), collapsepos=True, collapseroot=True)))
        if constituent_tree.is_leaf:
            return ((Supertag(LexicalRule(self.nt.leaf("ROOT", {0}, 0), (), ((-1,),)), None, None, constituent_tree.pos[0]), True),)
        guide: TransportGuide = self.guides.produce(constituent_tree)
        _, _, tag_derivation = self.extract_from_single_tree_recursive(constituent_tree, guide, -1)
        index_tag = sorted((i, s) for (s,i) in (t.label for t in tag_derivation.subtrees()))
        return (s for _, s in index_tag)

    def extract_from_single_tree_recursive(self, constituent_tree, guide, index) -> Tuple[str, Set[int], Tree]:
        child_derivs, child_nts, child_yds = [], [], []
        transport = None
        leaf = guide.transported
        for child, child_guide in zip(constituent_tree.node_children(), guide.children()):
            nt, yd, drv = self.extract_from_single_tree_recursive(child, child_guide, len(child_derivs))
            if leaf in child.yd:
                transport = len(child_derivs)
            child_yds.append(yd)
            child_nts.append(nt)
            child_derivs.append(drv)
        for _, child in constituent_tree.leaf_children():
            if child == guide.untransported:
                yd = { child }
                lhs = self.nt.leaf(constituent_tree.label, constituent_tree.yd, len(child_derivs))
                rule = LexicalRule(lhs, (), ((-1,),))
                drv = Tree(
                    ((Supertag(rule, None, None, constituent_tree.pos[child]), False), child), [])
                child_derivs.append(drv)
                child_nts.append(lhs)
                child_yds.append(yd)
        yd = { leaf }
        for child_yd in child_yds:
            yd |= child_yd
        lhs = self.nt(constituent_tree.label, constituent_tree.yd, yd, max(index,0))
        rule = LexicalRule(lhs, child_nts, tuple(lexical_composition(yd, child_yds, leaf)))
        conlabel = constituent_tree.label
        if "|<" in conlabel:
            conlabel = None
        supertag = Supertag(rule, transport, conlabel, constituent_tree.pos[leaf]), index==-1
        return lhs, yd, Tree((supertag, leaf), child_derivs)


def lexical_composition(rootpos: Iterable[int], childpos: Iterable[Iterable[int]], lexical_position: int) -> Iterable[Tuple[int]]:
    postoarg = {}
    for arg, argpositions in enumerate(childpos):
        for pos in argpositions:
            postoarg[pos] = arg
    
    sortedpos = sorted(rootpos)
    span = []
    last_arg = None
    for prev, next in zip([-2] + sortedpos, sortedpos):
        if prev + 1 != next and span:
            yield tuple(span)
            span = []
            last_arg = None
        next = postoarg[next] if next != lexical_position else -1
        if next != last_arg:
            span.append(next)
        last_arg = next
    yield tuple(span)