from dataclasses import dataclass, fields, MISSING
from typing import Tuple, Iterable, Any, Optional, Set


@dataclass(init=False, frozen=True)
class LexicalRule:
    lhs: str
    rhs: Tuple[str]
    yf: Tuple[Tuple[int]]

    def __init__(self, lhs: str, rhs: Optional[Tuple[str]] = None, yf: Optional[Tuple[Tuple[int]]] = None):
        assert (rhs is None and yf is None) or (not rhs is None and not yf is None)
        if rhs is None and yf is None:
            lhs, *rhs, yfstr = lhs.split(";")
            yf = tuple(
                tuple(int(v) if v != '*' else -1 for v in c)
                for c in yfstr.split(":"))
        object.__setattr__(self, "lhs", lhs)
        object.__setattr__(self, "rhs", tuple(rhs))
        object.__setattr__(self, "yf", yf)

    def __yfstr__(self, var_sep=" ", component_sep=", "):
        cstrs = []
        vars = [0,0]
        vstr = ["x", "y"]
        for c in self.yf:
            cstr = []
            for v in c:
                if v == -1:
                    cstr.append("*")
                else:
                    cstr.append(f"{vstr[v]}{vars[v]}")
                    vars[v] += 1
            cstrs.append(var_sep.join(cstr))
        return component_sep.join(cstrs)

    def __repr__(self):
        yfstrs = ("".join(str(v) if v >= 0 else "*" for v in c) for c in self.yf)
        return ";".join((self.lhs, *self.rhs, ":".join(yfstrs)))

    def __str__(self):
        return f"{self.lhs} -> [{self.__yfstr__()}] ({', '.join(self.rhs)})"

    def str_tag(self):
        rhs_str = "" if not self.rhs else ("," + ",".join(self.rhs))
        yf_str = ":".join("".join(str(v) if v != -1 else "*" for v in c) for c in self.yf)
        return f"LexicalRule{{{self.lhs}{rhs_str},{yf_str}}}"

    def discodop_rules(self, terminal_tag: str):
        if len(self.rhs) == 0:
            yield ((self.lhs, terminal_tag), ((0,),))
        if len(self.rhs) == 1:
            yield ((self.lhs, self.rhs[0], terminal_tag), tuple(tuple(v*(-1) for v in c) for c in self.yf))
        if len(self.rhs) == 2:
            intermediate = terminal_tag + "-SPLIT"
            bottomyf, topyf = split_composition(self.yf, (1, -1), (0,))
            yield ((self.lhs, self.rhs[0], intermediate), topyf)
            yield ((intermediate, self.rhs[1], terminal_tag), bottomyf)
        yield ((terminal_tag, "Epsilon"), (" ",))


@dataclass(frozen=True)
class Supertag:
    rule: LexicalRule
    transport: int = None
    constituent: str = None
    pos: str = None

    def str_tag(self, indicator: bool = False):
        suffix = "-SPLIT" if len(self.rule.rhs) >= 2 and indicator else ""
        attribs = (self.__getattribute__(a.name) for a in fields(self))
        attribs = (a for a in attribs if not a is None)
        return f"Supertag{{{','.join(str_or_repr(a) for a in attribs)}}}{suffix}"

    def discodop_rules(self):
        str_tag = self.str_tag()
        return self.rule.discodop_rules(str_tag)

    def __iter__(self):
        yield self.rule
        yield self.constituent
        yield self.transport
        yield self.pos

    def core(self, core_attribs: Set[str]):
        positional = (self.__getattribute__(a.name) for a in core_attribs if a.default is MISSING)
        kw = { a.name: self.__getattribute__(a.name) for a in core_attribs if not a.default is MISSING }
        return Supertag(*positional, **kw)

    def split(self, *attribs):
        if not attribs:
            return (self,)
        remainder_attribs = [a for a in fields(self) if not a.name in attribs]
        positional = (self.__getattribute__(a.name) for a in remainder_attribs if a.default is MISSING)
        kw = { a.name: self.__getattribute__(a.name) for a in remainder_attribs if not a.default is MISSING }
        return Supertag(*positional, **kw), *(self.__getattribute__(a) for a in attribs)



def str_or_repr(obj: Any) -> str:
    if type(obj) is str:
        return obj
    if type(obj) is LexicalRule or type(obj) is Supertag:
        return obj.str_tag()
    return repr(obj)


def split_composition(yf, vars1, vars2):
    """ Splits a composition yf with ≥ 2 variables into two: one for a given set
        of variables, and one for the remainder. The result of the first composition
        is then used as the last argument of the second as follows:
            yf(arg1, .., argk) = yf2(argj1, .., argjm, yf1(argi1, .., argil))
        where vars1 = (i1, .., il) and vars2 = (j1, .., jm) are pairwise distinct
        and { i1, .., il, j1, .., jm } = { 1, .., k }.

        For example: ((0, 2, 1, 0), (0, 1, 2)) with vars1 = (0,1) and vars2 = (2,)
            ↦ ((0,), (1, 0), (0, 1)) and ((0, 1, 0), (0, 1))
    """
    vars1, vars2 = ({ v: i for i, v in enumerate(s) } for s in (vars1, vars2))
    yf1, yf2 = (), ()
    for component in yf:
        component1, component2 = (), ()
        for v in component:
            if v in vars1:
                component1 += (vars1[v],)
            elif v in vars2:
                if component1:
                    yf1 += (component1,)
                    component1 = ()
                    component2 += (len(vars2),)
                component2 += (vars2[v],)
        if component1:
            yf1 += (component1,)
            component2 += (len(vars2),)
        yf2 += (component2,)
    return yf1, yf2