from typing import Dict, List, Union

import torch
from torch import nn
from torch.nn import functional as F

from page.const import PAD_ID, OUT_GEN, NEG_INF
from .attention import MultiheadAttention, MultiheadAttentionWeights


def apply_module_dict(modules: nn.ModuleDict, encoded: torch.Tensor, **kwargs) -> torch.Tensor:
    """
    Predict next entry using given module and equation.

    :param nn.ModuleDict modules:
        Dictionary of modules to be applied. Modules will be applied with ascending order of keys.
        We expect three types of modules. nn.Linear, nn.LayerNorm and MultiheadAttention.
    :param torch.Tensor encoded:
        Float Tensor that represents encoded vectors.
        Shape [B, T, H], where B = batch size, T = length of equation, and H = hidden dimension.
    :keyword torch.Tensor key_value:
        Float Tensor that represents key and value vectors when computing attention.
        Shape [B, K, H], where K = length of keys
    :keyword torch.Tensor key_ignorance_mask:
        Bool Tensor whose True values at (b, k) make attention layer ignore k-th key on b-th item in the batch.
        Shape [B, K].
    :keyword attention_mask:
        Bool Tensor whose True values at (t, k) make attention layer ignore k-th key when computing t-th query.
        Shape [T, K].
    :rtype: torch.Tensor
    :return:
        Float Tensor that indicates the scores under given information. Shape will be [B, T, ?]
    """
    output = encoded
    keys = sorted(modules.keys())

    for key in keys:
        layer = modules[key]
        if isinstance(layer, (MultiheadAttention, MultiheadAttentionWeights)):
            output = layer(query=output, **kwargs)
        else:
            output = layer(output)

    return output


def apply_across_dim(function, dim=1, shared_keys=None, **tensors) -> Dict[str, torch.Tensor]:
    """
    Apply a function repeatedly for each tensor slice through the given dimension.
    For example, we have tensor [B, X, S] and dim = 1, then we will concatenate the following matrices on dim=1.
    - function([:, 0, :])
    - function([:, 1, :])
    - ...
    - function([:, X-1, :]).

    :param function: Function to apply.
    :param int dim: Dimension through which we'll apply function.
    :param set shared_keys: Set of keys representing tensors to be shared.
    :param torch.Tensor tensors: Keyword arguments of tensors to compute. Dimension should >= `dim`.
    :rtype: torch.Tensor
    :return: The resulted Tensor of application.
    """
    # Separate shared and non-shared tensors
    shared_arguments = {}
    repeat_targets = {}
    for key, tensor in tensors.items():
        if not isinstance(tensor, torch.Tensor) or (shared_keys and key in shared_keys):
            shared_arguments[key] = tensor
        else:
            repeat_targets[key] = tensor

    # Check whether the size of the given dimension is the same across sliced_tensors.
    size = {key: tensor.shape[dim] for key, tensor in repeat_targets.items()}
    assert len(set(size.values())) == 1, 'Tensors does not have same size on dimension %s: We found %s' % (dim, size)

    # Since the sizes are the same, we will represent the size using the first entry.
    size = list(size.values())[0]

    # Dictionary for storing outputs
    output = {}

    for i in range(size):
        # Build kwargs for the function.
        kwargs = {key: tensor.select(dim=dim, index=i).contiguous() for key, tensor in repeat_targets.items()}
        kwargs.update(shared_arguments)

        # Apply function on the slice and restore the dimension for concatenation.
        for key, tensor in function(**kwargs).items():
            if key in shared_keys:
                continue

            if key not in output:
                output[key] = []

            output[key].append(tensor.unsqueeze(dim=dim))

    # Check whether the outputs are have the same size.
    assert all(len(t) == size for t in output.values())

    # Concatenate all outputs, and return.
    return {key: torch.cat(tensor, dim=dim).contiguous() for key, tensor in output.items()}


def shift_target(target: torch.Tensor, fill_value=PAD_ID) -> torch.Tensor:
    """
    Shift matrix to build generation targets.

    :param torch.Tensor target: Target tensor to build generation targets. Shape [B, T]
    :return: Tensor with shape [B, T], where (i, j)-entries are (i, j+1) entry of target tensor.
    """
    with torch.no_grad():
        pad_at_end = torch.full((target.shape[0], 1), fill_value=fill_value, dtype=target.dtype, device=target.device)
        return torch.cat([target[:, 1:], pad_at_end], dim=-1).contiguous()


def var_pointers(target: torch.Tensor) -> torch.Tensor:
    with torch.no_grad():
        padded = torch.full_like(target, fill_value=PAD_ID)
        for i, item in enumerate(target):
            variable_positions = {}
            for j, idx in enumerate(item.tolist()):
                if idx == PAD_ID:
                    continue
                if idx not in variable_positions:
                    variable_positions[idx] = j

                padded[i, j] = variable_positions[idx]

        return padded


def onehot_accumulated_target(target: torch.Tensor, num_classes: int, ignore_index: int = PAD_ID) -> torch.Tensor:
    with torch.no_grad():
        # Build one-hots: [B, T, C]
        target = F.one_hot(target.clamp_min(0), num_classes) \
                     .masked_fill((target == ignore_index).unsqueeze(-1), 0) \
                     .sum(dim=1, keepdim=True) > 0

        # Sum-up across dim 1: [B, 1, C]
        return target.long().to(target.device)


def exp_accumulated_target(target: torch.Tensor, num_classes: int, ignore_index: int = PAD_ID,
                           discount: float = 0.8, window: int = 5) -> torch.Tensor:
    with torch.no_grad():
        # Build one-hots: [B, T, C]
        target = F.one_hot(target.clamp_min(0), num_classes) \
            .masked_fill((target == ignore_index).unsqueeze(-1), 0).float() \
            .to(target.device)

        # Build discount factor matrix, whose (i, j) entry is discount ** clamp(j - i - 1, low=0, high=window)
        discount_indices = torch.arange(target.shape[1]).float()
        discount_indices = discount_indices.view(1, -1) - discount_indices.view(-1, 1) - 1
        discount_matrix = (discount ** discount_indices.clamp(0, window)) \
            .unsqueeze(0).expand(target.shape[0], -1, -1).to(target.device)

        # Compute multiplication: [B, T, T] X [B, T, C] = [B, T, C].
        target = torch.bmm(discount_matrix, target)

        # Normalize row sum to 1.
        target = target / target.sum(dim=-1, keepdim=True)
        return target.masked_fill(~torch.isfinite(target), 0)


def merge_list_of_dict(results: List[Dict[str, torch.Tensor]]) -> Dict[str, torch.Tensor]:
    """
    Merge a list of dictionaries.

    :param Dict[str,torch.Tensor] results: List of resulted dictionaries.
    :rtype: Dict[str, torch.Tensor]
    :return: Dictionary of merged tensors
    """
    if len(results) == 1:
        return results[-1]

    output = {}

    # Gather outputs. Shape [B, Ms, T]
    max_len = 0
    for res in results:
        for key, value in res.items():
            if key not in output:
                output[key] = []

            output[key].append(value)

        max_len = max(max_len, res[OUT_GEN].shape[-1])

    # Concatenate templates. Shape [B, M2, max(T1, T2)]
    for key in output:
        tensors = output[key]
        if len(tensors) > 1:
            for i in range(len(tensors)):
                t = tensors[i]
                batch_sz, t_beam, t_len = t.shape
                sz_gap = max_len - t_len

                if sz_gap > 0:
                    tensors[i] = torch.cat([t, torch.full((batch_sz, t_beam, sz_gap), fill_value=PAD_ID,
                                                          device=t.device, dtype=t.dtype)], dim=-1)

            # Concatenate at the beam dimension.
            tensors = torch.cat(tensors, dim=1).contiguous()
        else:
            tensors = tensors[0].contiguous()

        output[key] = tensors

    return output


def average_pooling_without_pad(encoded: torch.Tensor, dim: int = -1, pad: torch.Tensor = None) -> torch.Tensor:
    """
    Computes average hidden state of given sequence, without padded values.

    :param torch.Tensor encoded:
        Float Tensor which contains hidden state of encoded tokens.
        Shape [*1, D, *2, H], where D is the dimension for averaging and H = hidden dimension
    :param int dim:
        Dimension to apply average operation. -1 by default.
    :param torch.Tensor pad:
        Bool Tensor whose True values indicates padded positions. Shape [*1, *2].
    :rtype: torch.Tensor
    :return:
        Float Tensor of averaged hidden states. Shape [*1, *2, H].
    """
    assert pad is None or encoded.shape[:-1] == pad.shape

    if pad is not None:
        expand_shape = (-1,) * pad.dim() + (encoded.shape[-1],)
        pad_expand = pad.unsqueeze(dim=-1).expand(*expand_shape)
        return encoded.masked_fill(pad_expand, 0.0).sum(dim=dim) / (~pad).sum(dim=dim, keepdim=True)
    else:
        return encoded.mean(dim=dim)


def get_embedding_without_pad(embedding: Union[nn.Embedding, torch.Tensor],
                              tokens: torch.Tensor, ignore_index=PAD_ID) -> torch.Tensor:
    """
    Get embedding vectors of given token tensor with ignored indices are zero-filled.

    :param nn.Embedding embedding: An embedding instance
    :param torch.Tensor tokens: A Long Tensor to build embedding vectors.
    :param int ignore_index: Index to be ignored. `PAD_ID` by default.
    :rtype: torch.Tensor
    :return: Embedding vector of given token tensor.
    """
    tokens = tokens.clone()
    ignore_positions = (tokens == ignore_index)
    if ignore_positions.any():
        tokens.masked_fill_(ignore_positions, 0)

    if isinstance(embedding, nn.Embedding):
        embedding = embedding(tokens)
    else:
        embedding = F.embedding(tokens, embedding)

    if ignore_positions.any():
        embedding.masked_fill_(ignore_positions.unsqueeze(-1), 0.0)

    return embedding.contiguous()


def varpos_to_varindex(tensor: torch.Tensor) -> torch.Tensor:
    padded = torch.full_like(tensor, PAD_ID)

    for b, item in enumerate(tensor.tolist()):
        var_positions = {}
        var_counter = 0

        for i, pos in enumerate(item):
            if pos == PAD_ID:
                continue
            if pos not in var_positions:
                var_positions[pos] = var_counter
                var_counter += 1

            # Enable chaining across variables, i.e. [29]->[3]->[1]
            var_positions[i] = var_positions[pos]
            padded[b, i] = var_positions[pos]

    return padded


def logsumexp(*tensors: torch.Tensor, dim=-1) -> torch.Tensor:
    if len(tensors) > 1:
        max_t = torch.max(*tensors)
        mask = max_t == NEG_INF

        zero_m = max_t.masked_fill_(mask, 0)
        tensors = sum((t - zero_m).exp() for t in tensors).contiguous()

        return tensors.masked_fill_(mask, 1).log() + max_t.masked_fill_(mask, NEG_INF)
    else:
        tensor = tensors[0]
        max_t = tensor.max(dim=dim, keepdim=True).values
        mask = max_t == NEG_INF

        zero_m = max_t.masked_fill_(mask, 0)
        tensors = (tensor - zero_m).exp().sum(dim=dim, keepdim=True)

        return tensors.masked_fill_(mask, 1).log() + max_t.masked_fill_(mask, NEG_INF)


__all__ = ['apply_module_dict', 'apply_across_dim', 'shift_target', 'var_pointers',
           'exp_accumulated_target', 'onehot_accumulated_target',
           'merge_list_of_dict', 'average_pooling_without_pad', 'get_embedding_without_pad',
           'varpos_to_varindex', 'logsumexp']
