from abc import ABC, abstractmethod
from enum import IntFlag
from functools import reduce
from operator import or_
from re import compile
from typing import Set, Iterable, List


class NtConstructor(ABC):
    @abstractmethod
    def __call__(self, constituent: str, yd: Set[int], post_yd: Set[int], childidx: int):
        """
            Construct a nonterminal for a given constituent symbol, its
            yield in the constituent tree, and its yield after removing
            at most one leaf that is transported to a position above the
            constituent.
        """
        pass

    def leaf(self, *args):
        """
            Same as __call__ but exclusively used for untransported leaves.
        """
        return "LEAF"


class NtConstructorFeatures(IntFlag):
    Constituent = 2**0
    MergedChainConstituents = 2**1
    OldFanout = 2**2
    NewFanout = 2**3
    FanoutChange = 2**4
    Transport = 2**5
    ChildIdx = 2**6
    CoarseConstituent = 2**7
    BinarizationSuffix = 2**8

    # shortcuts
    Vanilla = 2**0 + 2**1 + 2**2 + 2**4 + 2**5 + 2**8
    Classic = 2**0 + 2**3 + 2**8
    Coarse = 2**0 + 2**3 + 2**7 + 2**8
    Direction = 2**6 + 2**3 + 2**8
    Star = 2**3 + 2**8

    @classmethod
    def from_strs(cls, features: Iterable[str]):
        return reduce(or_, (cls[f] for f in features), 0)


def fanout(yd: Iterable[int]):
    ordered_yd = sorted(yd)
    last_pos = -2
    fo = 0
    for pos in ordered_yd:
        if pos != last_pos + 1:
            fo += 1
        last_pos = pos
    return fo


r_binarization = compile(r"([^|]+)(?:\|<([^>]*)>)?")
r_binarization_list = compile(r"(\$\,|[^,]+|,)(?:$|,)")
class CompositionalNtConstructor(NtConstructor):
    def __init__(self, features: Iterable[str] = ["Constituent", "BinarizationSuffix"]):
        self.features: NtConstructorFeatures = NtConstructorFeatures.from_strs(features)

    def _constituent(self, c_symbol: str):
        # disassemble constituent into horizontal and vertical markov. nonterminals,
        # un-merge monic constituent nodes
        vert_cons, horiz_cons = r_binarization.match(c_symbol).groups()
        vert_cons = vert_cons.split("+")
        
        if not (self.features & NtConstructorFeatures.MergedChainConstituents):
            vert_cons = vert_cons[:1]
        if not (self.features & NtConstructorFeatures.Constituent):
            vert_cons = ("*" for _ in vert_cons)
        if self.features & NtConstructorFeatures.CoarseConstituent:
            vert_cons = (n[:1] for n in vert_cons)
        
        if not (self.features & NtConstructorFeatures.BinarizationSuffix):
            horiz_cons = None

        if not horiz_cons is None:
            horiz_cons = r_binarization_list.finditer(horiz_cons)
            horiz_cons = (ns.group(1).split("+") for ns in horiz_cons)
            if not (self.features & NtConstructorFeatures.MergedChainConstituents):
                horiz_cons = (ns[:1] for ns in horiz_cons)
            if not (self.features & NtConstructorFeatures.Constituent):
                horiz_cons = (("*" for _ in ns) for ns in horiz_cons)
            if self.features & NtConstructorFeatures.CoarseConstituent:
                horiz_cons = ((n[:1] for n in ns) for ns in horiz_cons)

        # re-assemble
        vert_cons = "+".join(vert_cons)
        if not horiz_cons is None:
            horiz_cons = ",".join("+".join(ns) for ns in horiz_cons)
            return f"{vert_cons}|<{horiz_cons}>"
        return vert_cons


    def __call__(self, constituent: str, yd: Set[int], post_yd: Set[int], childidx: int) -> str:
        nt_components = [self._constituent(constituent)]

        if NtConstructorFeatures.OldFanout & self.features:
            nt_components.append(str(fanout(yd)))
        
        if NtConstructorFeatures.NewFanout & self.features:
            nt_components.append(str(fanout(post_yd)))
        
        if NtConstructorFeatures.FanoutChange & self.features:
            diff = fanout(post_yd)- fanout(yd)
            nt_components.append(str(diff))
        
        if NtConstructorFeatures.Transport & self.features:
            nt_components.append(str(not yd == post_yd))

        if NtConstructorFeatures.ChildIdx & self.features:
            nt_components.append(str(childidx))

        return "/".join(nt_components)

    def leaf(self, parent_constituent: str, yd: Set[int], chilidx: int) -> str:
        return f"L-{self(parent_constituent, yd, yd, chilidx)}"