# cython: language_level=3
import random

from nltk import Nonterminal, Production

import nltk

import cython

def apply_perm(l, perm):
    return [l[x] for x in perm]

def get_inv_perm(perm):
    inv_perm = [None] * len(perm)
    for i, x in enumerate(perm):
        inv_perm[x] = i
    return inv_perm

def weighted_choice(vec):
    """
    Return an index of vec with probability proportional to the entries
    """
    Z = sum(vec)
    rnd = random.random() * Z
    s = 0
    for i in range(len(vec)):
        s += vec[i]
        if s >= rnd:
            return i

def possibly_weighted_choice(rules: list[Production]) -> Production:
    if len(rules) == 0:
        raise ValueError("Cannot sample from empty list of rules.")
    probs = []
    for r in rules:
        if not hasattr(rules, "prob"):
            probs = None
            break
        else:
            probs.append(r.prob())
    if probs is not None:
        return rules[weighted_choice(probs)]
    return random.choice(rules)

def sample_tree(grammar, max_len, non_terminal_prob: float):
    def _sample_tree(nt, mlen):
        # Select a production that won't go over the maximum depth constraint.
        cands = [prod for prod in grammar.lhs_to_prod[nt]
                 # consider prod, in the worst-case, how much length does this add at least?
                 # we want to stay below the limit of mlen
                 if sum(grammar.min_lengths.get(symbol, 0) if isinstance(symbol, Nonterminal) else
                        1 for symbol in prod.rhs()) <= mlen]
        non_terminal_cands = [c for c in cands if any(isinstance(r, Nonterminal) for r in c.rhs())]
        if non_terminal_cands and random.random() < non_terminal_prob:
            prod = possibly_weighted_choice(non_terminal_cands)
        else:
            prod = possibly_weighted_choice(cands)
        children = []

        #Create children in random order; this is needed because of length constraint
        permutation = list(range(len(prod.rhs())))
        random.shuffle(permutation)
        rhs = apply_perm(prod.rhs(), permutation)
        num_leaves = 0
        child_leaves = []
        num_tokens_still_needed = sum(grammar.min_lengths[symbol] if isinstance(symbol, Nonterminal) else 1
                                      for symbol in rhs)
        for i, symbol in enumerate(rhs):
            if isinstance(symbol, Nonterminal):
                # leaves_needed_for_rest = sum(self.min_lengths[symbol] if isinstance(symbol, Nonterminal) else 1
                #                              for symbol in rhs[i+1:])
                num_tokens_still_needed -= grammar.min_lengths[symbol]
                # assert num_tokens_still_needed == leaves_needed_for_rest
                child, sub_leaves = _sample_tree(symbol, mlen - num_tokens_still_needed)
                mlen -= sub_leaves
                children.append(child)
                num_leaves += sub_leaves
            else:
                children.append(symbol)
                num_leaves += 1
                mlen -= 1
                num_tokens_still_needed -= 1
        # Undo the permutation to get everything back in the right order
        children = apply_perm(children, get_inv_perm(permutation))

        return nltk.ImmutableTree((prod.lhs(), prod), children), num_leaves

    tree, num_leaves = _sample_tree(grammar.start(), max_len)
    return tree