import dataclasses
import itertools
import math
from typing import Callable, Optional, Union
import nltk
from nltk import Production, CFG, Nonterminal, EarleyChartParser

import random
import tqdm

import numpy as np

import pyximport
pyximport.install()
import meta_adapters.STEP.sip_sample_from_grammar as sip_sample_from_grammar

def weighted_choice(vec):
    """
    Return an index of vec with probability proportional to the entries
    """
    Z = sum(vec)
    rnd = random.random() * Z
    s = 0
    for i in range(len(vec)):
        s += vec[i]
        if s >= rnd:
            return i

def possibly_weighted_choice(rules: list[Production]) -> Production:
    if len(rules) == 0:
        raise ValueError("Cannot sample from empty list of rules.")
    probs = []
    for r in rules:
        if not hasattr(rules, "prob"):
            probs = None
            break
        else:
            probs.append(r.prob())
    if probs is not None:
        return rules[weighted_choice(probs)]
    return random.choice(rules)

class ProductionWithFunction(Production):

    def __init__(self, lhs, rhs, map_terminal: str, fname: str = None):
        super().__init__(lhs, rhs)
        self.map_terminal = map_terminal
        terminal = [x for x in self.rhs() if not isinstance(x, Nonterminal)]
        assert len(terminal) <= 1
        self.terminal = terminal[0] if len(terminal) == 1 else None
        if self.terminal is None:
            assert self.map_terminal is None
        # assert isinstance(self.map_terminal, str)
        # assert isinstance(self.terminal, str)
        self.fname = fname

    @property
    def arity(self):
        return sum(int(isinstance(x, Nonterminal)) for x in self.rhs())

    def __str__(self):
        s = super().__str__()
        eps = "''"
        if self.terminal is None:
            return s + f" ({self.fname})"
        return s + f" ({self.fname}, {self.terminal} -> {self.map_terminal if self.map_terminal else eps})"

@dataclasses.dataclass
class ProductionRule:
    lhs: int
    fname: str
    fint: int # fname as int
    map_term: str
    rhs: list[Union[str, int]]


class CFGWithSampling(CFG):
    def __init__(self, nts, start, productions, prune: bool = False):
        self.min_depths = compute_min_nt_depths(nts, productions)
        if prune:
           nts, productions = CFGWithSampling._prune_min_depths(nts, productions, self.min_depths)
        if start not in nts:
            raise ValueError("Start symbol not part of (pruned) grammar.")

        super().__init__(start, productions)
        self.nts = nts
        self.min_lengths = compute_min_nt_lengths(nts, productions)

        self.lhs_to_prod = {nt: [] for nt in self.nts}
        for prod in self.productions():
            self.lhs_to_prod[prod.lhs()].append(prod)

        # For every production, compute how many tokens it will add at least to the string by choosing it
        self.min_lengths_prod = {prod: sum(self.min_lengths.get(symbol, 0) if isinstance(symbol, Nonterminal)
                                       else 1 for symbol in prod.rhs())
                                 for prod in self.productions()}

    def sample_tree_with_depth_constraint(self, max_depth):
        def _sample_tree(nt, mdepth):
            # Select a production that won't go over the maximum depth constraint.
            cands = [prod for prod in self.lhs_to_prod[nt]
                                  # consider prod, in the worst-case, how much depth would this add at least?
                                  # we want to stay below the limit of mdepth
                                  if max(self.min_depths.get(symbol, 0) for symbol in prod.rhs()) < mdepth]
            prod = possibly_weighted_choice(cands)
            children = []
            for symbol in prod.rhs():
                if isinstance(symbol, Nonterminal):
                    children.append(_sample_tree(symbol, mdepth-1))
                else:
                    children.append(symbol)
            return nltk.ImmutableTree((prod.lhs(), prod), children)

        return _sample_tree(self.start(), max_depth)

    def minimum_str_length(self):
        return self.min_lengths[self.start()]

    def sample_tree(self, max_len, non_terminal_prob: float):
        return sip_sample_from_grammar.sample_tree(self, max_len, non_terminal_prob)
        # def _sample_tree(nt, mlen):
        #     # Select a production that won't go over the maximum depth constraint.
        #     cands = [prod for prod in self.lhs_to_prod[nt]
        #              # consider prod, in the worst-case, how much length does this add at least?
        #              # we want to stay below the limit of mlen
        #              if sum(self.min_lengths.get(symbol, 0) if isinstance(symbol, Nonterminal) else
        #                     1 for symbol in prod.rhs()) <= mlen]
        #     non_terminal_cands = [c for c in cands if any(isinstance(r, Nonterminal) for r in c.rhs())]
        #     if non_terminal_cands and random.random() < non_terminal_prob:
        #         prod = possibly_weighted_choice(non_terminal_cands)
        #     else:
        #         prod = possibly_weighted_choice(cands)
        #     children = []
        #
        #     #Create children in random order; this is needed because of length constraint
        #     permutation = list(range(len(prod.rhs())))
        #     random.shuffle(permutation)
        #     rhs = apply_perm(prod.rhs(), permutation)
        #     child_leaves = []
        #     num_tokens_still_needed = sum(self.min_lengths[symbol] if isinstance(symbol, Nonterminal) else 1
        #                                          for symbol in rhs)
        #     for i, symbol in enumerate(rhs):
        #         if isinstance(symbol, Nonterminal):
        #             # leaves_needed_for_rest = sum(self.min_lengths[symbol] if isinstance(symbol, Nonterminal) else 1
        #             #                              for symbol in rhs[i+1:])
        #             num_tokens_still_needed -= self.min_lengths[symbol]
        #             # assert num_tokens_still_needed == leaves_needed_for_rest
        #             child, sub_leaves = _sample_tree(symbol, mlen - num_tokens_still_needed)
        #             mlen -= len(sub_leaves)
        #             children.append(child)
        #             child_leaves.append(sub_leaves)
        #         else:
        #             children.append(symbol)
        #             child_leaves.append([symbol])
        #             mlen -= 1
        #             num_tokens_still_needed -= 1
        #     # Undo the permutation to get everything back in the right order
        #     children = apply_perm(children, get_inv_perm(permutation))
        #     leaves = []
        #     for child_leaves in apply_perm(child_leaves, get_inv_perm(permutation)):
        #         leaves.extend(child_leaves)
        #
        #     return nltk.ImmutableTree((prod.lhs(), prod), children), leaves
        #
        # tree, leaves = _sample_tree(self.start(), max_len)
        # return tree

    @staticmethod
    def _prune_min_depths(nts, productions, min_depths):
        """
        Remove all productions that involve non-terminals that cannot expand to a finite-length string
        :return:
        """
        bad_nts = {nt for nt in min_depths if min_depths[nt] is None} # min_depth is infinite
        # sorting to ensure canonical order
        nts = sorted(set(nts) - bad_nts)

        prod = []
        for rule in productions:
            if rule.lhs() in bad_nts or any(r in bad_nts for r in rule.rhs()):
                pass
            else:
                prod.append(rule)

        return nts, prod

    @staticmethod
    def can_derive_graph(nts: list[Nonterminal], productions: list[Production]) -> dict[Nonterminal, list[Nonterminal]]:
        """
        Nodes: Non-terminals. There is an edge from A -> B iff there is a rule
        A -> ... B ...
        :param nts:
        :param productions:
        :return:
        """
        edges = {nt: [] for nt in nts}
        for prod in productions:
            for symbol in prod.rhs():
                if isinstance(symbol, Nonterminal):
                    edges[prod.lhs()].append(symbol)
        return edges

    def recursable_non_terminals(self) -> set[Nonterminal]:
        return self._recursable_non_terminals(self.nts, self.productions())

    @classmethod
    def _recursable_non_terminals(cls, nts, productions) -> set[Nonterminal]:
        """
        Return non-terminals for which there are recursive rules. This can be a rule of the form
        A -> ... A ... but indirect recursion counts as well, e.g.
        ====
        A -> ... B ...
        B -> ... A ...
        ===
        :return:
        """
        # Repeatedly remove non-terminals that don't expand to other non-terminals
        # until no further change
        changed = True
        edges = CFGWithSampling.can_derive_graph(nts, productions)
        reversed_edges = {nt: [] for nt in nts}
        for nt in edges:
            for nt2 in edges[nt]:
                reversed_edges[nt2].append(nt)
        while changed:
            changed = False
            for nt in nts:
                if nt in edges and len(edges[nt]) == 0:
                    # identify nodes nt2 such that nt2 has nt as a child
                    # & remove nt from them
                    for nt2 in reversed_edges[nt]:
                        try:
                            edges[nt2].remove(nt)
                            changed = True
                        except:
                            pass
        return {nt for nt in edges if len(edges[nt]) > 0}



class InterpretableCFGWithSampling(CFGWithSampling):

    functions = dict() # define this in a sub-class
    def __init__(self, nts, start, productions):
        super().__init__(nts, start, productions)
        int2fun = [fname for v in self.functions.values() for fname in v.keys()]
        # reserve id 0 for padding
        self.fun2int = {fname: i + 1 for i, fname in enumerate([None] + int2fun)}


    def eval_tree(self, t: nltk.Tree) -> list[str]:
        if isinstance(t.label(), str):
            assert len(t) == 0
            if t.label() == "":
                return []
            return [t.label()]
        nt, prod = t.label()
        vals = []
        for child, child_symbol in zip(t, prod.rhs()):
            if isinstance(child_symbol, Nonterminal):
                vals.append(self.eval_tree(child))
            elif isinstance(child_symbol, str):
                assert child == prod.terminal
                if prod.map_terminal == "":
                    vals.append([])
                else:
                    vals.append([prod.map_terminal])

        if len(vals) == 1:
            return vals[0]

        assert prod.fname is not None, "Production with arity > 1 must have an fname"
        fun = self.functions[len(vals)][prod.fname]
        return fun(*vals)

    def eval_tree_as_length(self, t: nltk.Tree) -> int:
        if isinstance(t.label(), str):
            assert len(t) == 0
            if t.label() == "":
                return 0
            return 1
        nt, prod = t.label()
        vals = []
        for child, child_symbol in zip(t, prod.rhs()):
            if isinstance(child_symbol, Nonterminal):
                vals.append(self.eval_tree_as_length(child))
            elif isinstance(child_symbol, str):
                assert child == prod.terminal
                if prod.map_terminal == "":
                    vals.append(0)
                else:
                    vals.append(1)

        if len(vals) == 1:
            return vals[0]

        assert prod.fname is not None, "Production with arity > 1 must have an fname"
        fun = self.functions[len(vals)][prod.fname]
        return fun(*vals)


class ExtendedSGrammarWithEval(InterpretableCFGWithSampling):

    # arity => function name => function
    # Note: each function has to simultaenously work for
    # list[str], list[str] -> list[str
    # as well as
    # int, int -> int
    # where the int version makes a statement about the length of the output as function of the lengths of the inputs.
    functions = {
        2: {"concat": lambda x, y: x + y,
            "rev":    lambda x, y: y + x,
            "wrap-l": lambda x, y: x + y + x,
            "wrap-r": lambda x, y: y + x + y,
        },
        3: {"1-2-3": lambda x, y, z: x + y + z,
            "1-3-2": lambda x, y, z: x + z + y,
            "2-3-1": lambda x, y, z: y + z + x,
            "3-2-1": lambda x, y, z: z + y + x,
            "2-1-3": lambda x, y, z: y + x + z,
            "3-1-2": lambda x, y, z: z + x + y,
            }
    }


    def __init__(self, nts, is_left: bool, productions: list[ProductionWithFunction], start: Nonterminal):

        super().__init__(nts, start, productions)

        self.functions = dict(self.functions)
        # self.functions[0] = {k: lambda: v for k,v in productions}
        # Validate that grammar is unambiguous
        str_count: dict[tuple[Nonterminal, str], ProductionWithFunction] = dict()
        for prod in self.productions():
            if is_left:
                if not isinstance(prod.rhs()[0], str):
                    raise ValueError(f"{prod} does not start with a terminal")
                if (prod.lhs(), prod.rhs()[0]) in str_count:
                    raise ValueError(f"Potentially ambiguous grammar. There must be no two productions that"
                                     f" start with the same terminal. Found {prod} and"
                                     f" {str_count[(prod.lhs(), prod.rhs()[0])]}")

                str_count[(prod.lhs(), prod.rhs()[0])] = prod
            else:
                if not isinstance(prod.rhs()[-1], str):
                    raise ValueError(f"{prod} does not end with a terminal")

                if (prod.lhs(), prod.rhs()[-1]) in str_count:
                    raise ValueError(f"Potentially ambiguous grammar. There must be no two productions that"
                                     f" end with the same terminal. Found {prod} and"
                                     f" {str_count[(prod.lhs(), prod.rhs()[-1])]}")

                str_count[(prod.lhs(), prod.rhs()[-1])] = prod



    @staticmethod
    def generate_rules(lhs, is_left, vocab, nts, p_drop, p_id):
        rules = []
        max_children = max(ExtendedSGrammarWithEval.functions.keys())
        for v in vocab:
            if random.random() < p_drop:
                continue
            else:
                num_nts = random.randint(0, max_children-1) # TODO: might want to de-prioritise arities higher than 0?
                num_children = num_nts + 1

                rhs = [random.choice(nts) for _ in range(num_nts)]

                if random.random() < p_id:
                    map_term = v
                else:
                    map_term = random.choice(vocab + [""])  # "" corresponds to an epsilon output in an FST.

                if num_children > 1:
                    f = random.choice(sorted(ExtendedSGrammarWithEval.functions[num_children].keys()))  # sorting to make this reproduceable
                else:
                    f = None
                rules.append(ProductionWithFunction(lhs, [v] + rhs if is_left else rhs + [v], map_term, f))
        return rules

    @staticmethod
    def generate_grammar(nts: int, is_left: bool, vocab: list[str], p_drop: float, p_id:float, attempts: int = 30):
        nts = [Nonterminal("N" + str(i)) for i in range(nts)]

        # Generate rules
        for _ in range(attempts):
            rules = []
            for nt in nts:
                rules.extend(ExtendedSGrammarWithEval.generate_rules(nt, is_left, vocab, nts, p_drop, p_id))

            # Prune out non-terminals which cannot derive strings of finite length

            min_depths = compute_min_nt_depths(nts, rules)
            nts, rules = ExtendedSGrammarWithEval._prune_min_depths(nts, rules, min_depths)

            if len(nts) > 0:
                break
        else:
            raise ValueError("Didn't manage to generate non-empty grammar.")


        # Find a start symbol that does more than generate a single string, if possible
        cand = set()
        for r in rules:
            if r.arity > 0:
                cand.add(r.lhs())
        cand = sorted(cand, key=lambda x: x.symbol())
        if len(cand) == 0:
            cand = nts

        start_symbol = random.choice(cand)

        # We now remove rules that are never used, create a graph where A -> B if there is a rule with A -> ... B ...
        # then depth first traverse it starting from start symbol
        edges = ExtendedSGrammarWithEval.can_derive_graph(nts, rules)
        reachable_nts = reachable_from(start_symbol, edges)

        # Remove any productions using a non-reachable NT
        non_reachable = set(nts) - reachable_nts
        rules = [r for r in rules if r.lhs() not in non_reachable]  # it suffices to check the LHS

        #TODO: somewhere check that generated grammar allows recursion.

        return ExtendedSGrammarWithEval(sorted(reachable_nts), is_left, rules, start_symbol)


    def encode_as_rules(self) -> list[ProductionRule]:
        """
        Map non-terminals to integer ids where the start symbol gets id 1.
        :return:
        """
        l = list(self.nts)
        l.remove(self.start())
        l.insert(0, self.start())
        nt2i = {nt: i for i, nt in enumerate(l, start=1)}

        out = []
        for prod in self.productions():
            out.append(ProductionRule(nt2i[prod.lhs()], prod.fname,
                                      self.fun2int[prod.fname],
                                      prod.map_terminal, [nt2i[s] if isinstance(s, Nonterminal) else s for s in prod.rhs()]))

        out.sort(key=lambda x: x.lhs)
        return out




def apply_perm(l, perm):
    return [l[x] for x in perm]

def get_inv_perm(perm):
    inv_perm = [None] * len(perm)
    for i, x in enumerate(perm):
        inv_perm[x] = i
    return inv_perm


def reachable_from(initial: Nonterminal, edges: dict[Nonterminal, list[Nonterminal]]) -> set[Nonterminal]:
    agenda = [initial]
    reached = set()
    while agenda:
        curr = agenda.pop()
        reached.add(curr)
        if curr in edges:
            for next in edges[curr]:
                if next not in reached:
                    agenda.append(next)
    return reached


def compute_min_nt_depths(non_terminals: list[Nonterminal], rules: list[ProductionWithFunction]) -> dict[Nonterminal, Union[int, None]]:
    """
    For every non-terminal, compute how deep the shallowest tree is that can be derived from it. In the smallest case this will be 2, e.g.
    (A a), is a tree with a leaf (a) + 1 and a symbol on top of it (+1), so a depth of 2.

    Set output to None if there is no finite string that can be derived.
    This is slower than necessary as it doesn't track where the changes were made.
    :param non_terminals:
    :param rules:
    :return:
    """
    ret = {nt: math.inf for nt in non_terminals}

    non_terminal_rules = []
    for r in rules:
        arity = sum(isinstance(c, Nonterminal) for c in r.rhs())
        if arity == 0:
            ret[r.lhs()] = 2
        else:
            non_terminal_rules.append(r)
    changed = True
    while changed:
        changed = False
        for rule in non_terminal_rules:
            min_prod = 1 + max([ret[symbol] for symbol in rule.rhs() if isinstance(symbol, Nonterminal)], default=0)
            if min_prod < ret[rule.lhs()]:
                changed = True
                ret[rule.lhs()] = min_prod

    for k in ret:
        if ret[k] == math.inf:
            ret[k] = None

    return ret

def compute_min_nt_lengths(non_terminals: list[Nonterminal], rules: list[ProductionWithFunction]) -> dict[Nonterminal, Union[int, None]]:
    """
    For every non-terminal, compute how long the shortest string is that can be derived.

    Set output to None if no (finite) string can be derived at all (infinite recursion or no rules).
    :param non_terminals:
    :param rules:
    :return:
    """
    ret = {nt: math.inf for nt in non_terminals}

    non_terminal_rules = []
    for r in rules:
        arity = sum(isinstance(c, Nonterminal) for c in r.rhs())
        if arity == 0:
            ret[r.lhs()] = min(ret.get(r.lhs(), math.inf),
                               sum(int(isinstance(s, str)) for s in r.rhs()))
        else:
            non_terminal_rules.append(r)
    changed = True
    while changed:
        changed = False
        for rule in non_terminal_rules:
            own_terminals = sum(int(isinstance(s, str)) for s in rule.rhs())
            min_prod = own_terminals + sum([ret[symbol] for symbol in rule.rhs() if isinstance(symbol, Nonterminal)])
            if min_prod < ret[rule.lhs()]:
                changed = True
                ret[rule.lhs()] = min_prod

    # for k in ret:
    #     if ret[k] == math.inf:
    #         ret[k] = None

    return ret

def check_exact_one_tree(trials, tree_per_grammar):
    for _ in tqdm.tqdm(range(trials)):
        # g = ExtendedSGrammarWithEval.generate_grammar(2, 4, list("qwertyuiopasdfghjkl"), p_drop=0.01)
        g = ExtendedSGrammarWithEval.generate_grammar(5, True, list("qwertyuiopasdfghjkl"), p_drop=0.01)
        # g = ExtendedSGrammarWithEval.generate_grammar(3, 3, list("qwertyuiopasdfghjklzxcvbnm"), p_drop=0.05)
        earley = EarleyChartParser(g)
        for _ in range(tree_per_grammar):
            # s = g.sample_tree(6).leaves()
            s = g.sample_tree(11).leaves()
            # s = g.sample_tree(5).leaves()
            chart = earley.chart_parse(s)
            parses = list(chart.parses(g.start()))
            if len(parses) > 1:
                parses[0].draw()
                print("----")
                parses[1].draw()
            assert len(parses) == 1

def write_to_file(fname:str, data):
    with open(fname, "w") as f:
        for i,o in data:
            f.write(i)
            f.write("\t")
            f.write(o)
            f.write("\n")


def eval_trees(grammar: ExtendedSGrammarWithEval, trees: list[nltk.Tree], joiner: str) -> list[tuple[str, str]]:
    r = []
    for t in trees:
        i = joiner.join(t.leaves())
        o = joiner.join(grammar.eval_tree(t))
        r.append((i,o))
    return r


def eval_and_write(grammar: ExtendedSGrammarWithEval, trees: list[nltk.Tree], fname: str, joiner: str):
    data = eval_trees(grammar, trees, joiner)
    write_to_file(fname, data)


def tree_to_span_matrix(t: nltk.Tree, g: InterpretableCFGWithSampling, padding = (0, 0)) -> np.array:
    """
    Takes a tree with information about which function to apply where
    and return
    :param t:
    :param g: the grammar from which the tree was sampled
    :param padding: the padding with special tokens from the tokenizer
    :return:
    """
    n = len(t.leaves()) + sum(padding)
    matrix = np.zeros((n, n), dtype=np.int64)
    pos = padding[0]
    def rec(subt):
        nonlocal pos
        if isinstance(subt, str):
            pos += 1
            return pos-1, pos-1
        else:
            children = [rec(c) for c in subt]
            begin, _ = children[0]
            _, end = children[-1]
            nt, prod = subt.label()
            matrix[begin, end] = g.fun2int[prod.fname]
            return begin, end

    rec(t)
    return matrix


# => ["1-3-2", "i", 2, 'f', 1, 1] # embed first into functions, any other string as a token, numbers as NTs.
if __name__ == "__main__":
    # random.seed(123471)
    # random.seed(1234712)
    # random.seed(126667)
    random.seed(755421)
    #TODO: this is too difficult to get a minimum length,
    # just a different probability for terminal symbol (... setting it to 0...?)
    g = ExtendedSGrammarWithEval.generate_grammar(3, False,
                                                  ["a", "b", "c", "d", "e", "f", "g", "h", "i"],
                                                  p_drop=0.6, p_id=0.4)
    print(g)
    print("Recursable", g.recursable_non_terminals())
    print(compute_min_nt_lengths(g.nts, g.productions()))
    # print(g.encode_as_rules())

    # check_exact_one_tree(100, 10)

    # print(g.left_nts + g.right_nts)
    # print(g.start())
    # print(g.productions())

    random.seed(234656388)
    for _ in range(10):
        t, leaves = g.sample_tree(8, 0.7)

        print(t)
        print(leaves)
        print(t.leaves())
        assert leaves == t.leaves()
        print("---")
    # t.draw()
    # print(len(t.leaves()))
    # print(g.eval_tree(t))
    # print("====")
    # t.draw()

    # earley = EarleyChartParser(g)
    # s = t.leaves()
    # print(s)
    # chart = earley.chart_parse(s)
    # parses = list(chart.parses(g.start()))
    # print(f"Found {len(parses)} trees:")
    # for parse_tree in parses:
    #     print(parse_tree)




    # Test minimal depth
    A = Nonterminal("A")
    B = Nonterminal("B")
    C = Nonterminal("C")
    rules = [ProductionWithFunction(A, ["a"], ""), ProductionWithFunction(B, [A, C, "x"], "",  "concat"),
             ProductionWithFunction(C, [C, B, "y"], "", "concat")]

    ret = compute_min_nt_depths([A, B, C], rules)
    assert ret == {A: 2, B: None, C: None}












