from typing import Tuple

import torch
from torch import nn
from transformers.modeling_bert import gelu_new as gelu_bert

from page.const import NEG_INF


class Gelu(nn.Module):
    def forward(self, tensor: torch.Tensor) -> torch.Tensor:
        return gelu_bert(tensor)


class AveragePooling(nn.Module):
    def __init__(self, dim: int = -1, keepdim: bool = False):
        super().__init__()
        self.dim = dim
        self.keepdim = keepdim

    def forward(self, tensor: torch.Tensor) -> torch.Tensor:
        return tensor.mean(dim=self.dim, keepdim=self.keepdim)

    def extra_repr(self):
        return 'dim={dim}, keepdim={keepdim}'.format(**self.__dict__)


class Squeeze(nn.Module):
    def __init__(self, dim: int = -1):
        super().__init__()
        self.dim = dim

    def forward(self, tensor: torch.Tensor) -> torch.Tensor:
        return tensor.squeeze(dim=self.dim)

    def extra_repr(self):
        return 'dim={dim}'.format(**self.__dict__)


class LogSoftmax(nn.LogSoftmax):
    def forward(self, tensor: torch.Tensor) -> torch.Tensor:
        # Find maximum values
        max_t = tensor.max(dim=self.dim, keepdim=True).values
        # Reset maximum as zero if it is a finite value.
        tensor = (tensor - max_t.masked_fill(~torch.isfinite(max_t), 0.0))

        # If a row's elements are all infinity, set the row as zeros to avoid NaN.
        all_inf_mask = torch.isinf(tensor).all(dim=self.dim, keepdim=True)
        if all_inf_mask.any().item():
            tensor = tensor.masked_fill(all_inf_mask, 0.0)

        # Forward nn.LogSoftmax.
        return super().forward(tensor)


class InfinityAwareLinear(nn.Linear):
    def forward(self, tensor: torch.Tensor) -> torch.Tensor:
        is_inf = ~torch.isfinite(tensor)
        tensor_masked = tensor.masked_fill(is_inf, 0.0)

        output = super().forward(tensor_masked)
        return output.masked_fill(is_inf.any(dim=-1, keepdim=True), NEG_INF)


class ApplyNormalize(nn.Module):
    def __init__(self, layer: nn.Module, eps: float = 1e-12, dim: int = -1):
        super().__init__()
        self.layer = layer
        self.eps = eps
        self.dim = dim

    def forward(self, tensor: torch.Tensor) -> torch.Tensor:
        output = self.layer(tensor)
        return output / output.norm(dim=self.dim, keepdim=True).clamp_min(self.eps)


class HeadHiddenCombineLayer(nn.Module):
    def __init__(self, num_heads: int, hidden_dim: int, output_dim: int, bias: bool = True, keepdim: bool = False):
        super().__init__()
        self.num_heads = num_heads
        self.hidden_dim = hidden_dim
        self.keepdim = keepdim

        self.head_linear = InfinityAwareLinear(num_heads, output_dim, bias=bias)
        self.hidden_linear = InfinityAwareLinear(hidden_dim, output_dim, bias=bias)

    def forward(self, pair: Tuple[torch.Tensor, torch.Tensor]) -> torch.Tensor:
        # Output has shape [B, T, H], and head weights has shape [B, T, X, HEAD]
        output, head_weights = pair

        output = self.hidden_linear(output).unsqueeze(2)  # [B, T, O] -> [B, T, 1, O]
        head_weights = self.head_linear(head_weights)  # [B, T, X, O]

        output = output + head_weights
        if output.shape[-1] == 1 and not self.keepdim:
            output = output.squeeze(-1)

        return output


__all__ = ['Gelu', 'AveragePooling', 'Squeeze', 'LogSoftmax', 'InfinityAwareLinear', 'HeadHiddenCombineLayer',
           'ApplyNormalize']
