from dataclasses import dataclass, field, fields
from typing import Tuple, Iterable, Optional, Dict, List, Set

from numpy import array, newaxis

from ..treetransforms import unbinarize
from ..containers import Grammar
from ..plcfrs import parse_supertags
from ..kbest import lazykbest, onlybest
from ..tree import Tree
from .extraction.extraction import Supertag


class DerivationConverter:
    def __init__(self, tags: Iterable[Supertag]):
        self.pos_to_tag = { tag.str_tag(indicator=True): tag for tag in tags }

    def _convert_(self,
                deriv: Tree,
                prop: Optional[Tree],
                opt: Optional[List[str]],
                oc: Optional[List[str]],
                ot: Optional[List[int]]
            ) -> Tuple[List[Tree], int]:
        # pos node
        if len(deriv) == 1 and type(deriv[0]) is Tree and len(deriv[0]) == 1 and type(deriv[(0,0)]) is int:
            tag = self.pos_to_tag[deriv[0].label]
            postag = tag.pos if opt is None else opt[deriv[(0,0)]]
            constituent = tag.constituent if oc is None else oc[deriv[(0,0)]]
            c1, i1 = Tree(postag, [deriv[(0,0)]]), deriv[(0,0)]
            children = [c1]
            if not prop is None:
                if prop[0] < i1:
                    i1 = prop[0]
                    children = [prop, c1]
                else:
                    children.append(prop)
            if constituent is None:
                return children, i1
            else:
                return [Tree(constituent, children)], i1

        # pos is a wildcard rule
        if len(deriv) == 1 and type(deriv[0]) is int:
            assert prop is None
            terminal_position = deriv[0]
            c = Tree("NOPARSE" if opt is None else opt[terminal_position], [terminal_position])
            return [c], terminal_position

        # skip unary wildcard rules
        if len(deriv) == 1:
            assert prop is None
            return self._convert_(deriv[0], None, opt, oc, ot)

        tag = self.pos_to_tag.get(deriv[1].label)

        # binary wildcard rule
        if tag is None:
            assert prop is None
            children, spanmins = zip(*(self._convert_(d, None, opt, oc, ot) for d in deriv))
            return [Tree("NOPARSE", (c for cs in children for c in cs))], min(spanmins)

        # monic rule
        if len(tag.rule.rhs) == 1:
            terminal = Tree(tag.pos if opt is None else opt[deriv[(1, 0)]], [deriv[(1, 0)]])
            constituent = tag.constituent if oc is None else oc[deriv[(1, 0)]]
            if (ot and ot[deriv[(1,0)]] == 0) or tag.transport == 0:
                terminal, prop = prop, terminal
            children, minidx = self._convert_(deriv[0], prop, opt, oc, ot)
            if not terminal is None:
                if terminal[0] < minidx:
                    children = [terminal] + children
                    minidx = terminal[0]
                else:
                    children.append(terminal)
            if constituent is None:
                return children, minidx
            return [Tree(constituent, children)], minidx
        
        # binary rule
        if len(tag.rule.rhs) == 2:
            terminal = Tree(tag.pos if opt is None else opt[deriv[(1, 1, 0)]], [deriv[(1, 1, 0)]])
            constituent = tag.constituent if oc is None else oc[deriv[(1, 1, 0)]]
            if (ot and ot[deriv[(1,1,0)]] == 0) or tag.transport == 0:
                terminal, prop = prop, terminal
            c1, i1 = self._convert_(deriv[0], prop, opt, oc, ot)
            c2, i2 = self._convert_(deriv[(1, 0)], terminal, opt, oc, ot)
            if i1 > i2:
                i1, i2 = i2, i1
                c1, c2 = c2, c1
            if constituent is None:
                return c1+c2, i1
            return [Tree(constituent, c1+c2)], i1

        raise NotImplementedError()

    def __call__(self,
                deriv: Tree,
                override_pos_tags: Optional[List[str]] = None,
                override_constituents: Optional[List[str]] = None,
                override_transports: Optional[List[int]] = None
            ) -> Tree:
        """ Undoes the binarization of lexical rules.
        """
        return self._convert_(deriv, None, override_pos_tags, override_constituents, override_transports)[0][0]


@dataclass(eq=False, order=False)
class SupertagGrammar:
    tags: Tuple[Supertag]
    roots: Tuple[str]
    _fallback_prob: float = 0.0
    internal_grammar: Grammar = field(init=False)
    convert: DerivationConverter = field(init=False)
    idx_tag_to_nt: Tuple[int] = field(init=False)
    str_tag_to_nt: Dict[str, int] = field(init=False)

    def __post_init__(self):
        rules = [(r, 1) for tag in self.tags for r in tag.discodop_rules()]
        rules += [((("**ROOT**", start_nt), ((0,),)), 1) for start_nt in self.roots]
        fallback_nts = set(tag.rule.lhs for tag in self.tags if not tag.rule.lhs in self.roots)

        self.internal_grammar = Grammar(rules, start="**ROOT**", fallback_prob=self.fallback_prob, fallback_nts=fallback_nts, normalize=False)
        str_nt_to_idx = { nt: idx for idx, nt in self.internal_grammar.labels() if not nt.startswith("NOPARSE") }
        self.idx_tag_to_nt = tuple(str_nt_to_idx[tag.str_tag()] for tag in self.tags)
        self.str_tag_to_nt = {
            tag.str_tag(): str_nt_to_idx[tag.str_tag()]
            for tag in self.tags
        }
        self.convert = DerivationConverter(self.tags)

    @property
    def fallback_prob(self):
        return self._fallback_prob

    @fallback_prob.setter
    def fallback_prob(self, value: float):
        self.internal_grammar.set_fallback_prob(value)
        self._fallback_prob = value

    @property
    def str_tags(self) -> Iterable[str]:
        return self.str_tag_to_nt.keys()

    def parse(self,
            tags: Iterable[Iterable[Tuple[int, float]]],
            pos: Optional[List[str]] = None,
            constituent: Optional[List[str]] = None,
            transport: Optional[List[int]] = None,
            length: Optional[int] = None,
            ktags: Optional[int] = None,
            estimates: Optional[array] = None,
            str_tag_mode: bool = False,
            ktrees: int = 1) -> Iterable[Tree]:
        if length is None:
            length = len(tags)
        if ktags is None:
            ktags = len(tags[0])

        if not estimates is None:
            estimates = 'Constants', estimates[:, newaxis, newaxis, newaxis].astype("double")

        tag_to_nt = self.str_tag_to_nt if str_tag_mode else self.idx_tag_to_nt
        tags = (((tag_to_nt[tag], w) for (tag, w) in word_tags) for word_tags in tags)

        chart, _ = parse_supertags(length, tags, self.internal_grammar, beam_beta=0.0,
            itemsestimate=(2*ktags*length if ktags else None), exhaustive=ktrees>1, estimates=estimates)

        for tree, _ in lazykbest(chart, ktrees):
            # skip artificial root node and undo lex. procedure,
            # unbinarize un-merges unary nodes
            yield unbinarize(self.convert(Tree(tree)[0], pos, constituent, transport))

    def __getstate__(self):
        return { f.name: self.__getattribute__(f.name) for f in fields(self) if f.init }

    def __setstate__(self, state):
        for attr, v in state.items():
            self.__setattr__(attr, v)
        self.__post_init__()
