from collections import namedtuple
from typing import List, Dict, Set

import torch

from page.const import PAD_ID


def mask_forward(sz: int, diagonal: int = 1) -> torch.Tensor:
    """
    Generate a mask that ignores future words. Each (i, j)-entry will be True if j >= i + diagonal

    :param int sz: Length of the sequence.
    :param int diagonal: Amount of shift for diagonal entries.
    :rtype: torch.Tensor
    :return: Mask tensor with shape [sz, sz].
    """
    return torch.ones(sz, sz, dtype=torch.bool, requires_grad=False).triu(diagonal=diagonal).contiguous()


def mask_self(sz: int) -> torch.Tensor:
    """
    Generate a mask that ignores the same word itself. Each (i, j)-entry will be True if i == j.

    :param int sz: Length of the sequence.
    :rtype: torch.Tensor
    :return: Mask tensor with shape [sz, sz].
    """
    return torch.eye(sz, dtype=torch.bool)


#: Named tuple that represents a tree.
Tree = namedtuple('Tree', ('node', 'children', 'is_top_level'))


def mask_tree(equation: torch.Tensor, token_infomap: Dict[int, dict]) -> torch.Tensor:
    """
    Generate a mask that accepts the siblings/ancestors in equation forest.
    Each (i, j)-entry will be False if:
    - i == j
    - Node i is a sibling of node j
    - Node i is ancestor/child of node j.

    :param torch.Tensor equation: Equation to build masks. Shape [B, T], where B = batch size, T = length of sequence.
    :param Dict[int,int] operator: Dictionary mapping from operator token id to its arity.
    :param Set[int] toplevels: Set of token ids that represent equalities or special tokens
    :rtype: torch.Tensor
    :return: Mask tensor with shape [B, T, T].
    """
    # Prepare length information [B, T].
    batch_sz, seq_len = equation.shape

    # 'True' means we'll not using it. Shape [B, T, T]
    mask = torch.ones(batch_sz, seq_len, seq_len, dtype=torch.bool, device=equation.device)

    for i in range(batch_sz):
        eqn_i = equation[i]
        forest = []

        # Each token depends on previously generated top-level tokens
        # Build equation forest.
        for t, tok in enumerate(eqn_i.tolist()):
            if tok == PAD_ID:
                break

            # Dependent to itself
            mask[i, t, t] = False

            # Dependent to top-level tokens that are previously built.
            mask[i, t, [tree.node for tree in forest]] = False

            if tok in token_infomap:
                operands = []
                info = token_infomap[tok]

                # Gather operands as much as possible.
                for _ in range(info['arity']):
                    if len(forest) == 0 or forest[-1].is_top_level:
                        # Cannot be used for a child.
                        break
                    operands.append(forest.pop())

                # Save tree in the stack.
                forest.append(Tree(t, operands, is_top_level=info['toplv']))
            else:
                forest.append(Tree(t, [], is_top_level=False))

    # Return [B, T, T].
    return mask


__all__ = ['mask_forward', 'mask_self', 'mask_tree']
