import torch
from torch import Tensor
from torch.nn import Module, Parameter
from typing import List, Tuple, Dict, Optional, Any, Union, TypeVar, Type, Callable

# The alpha layers
# In 

class PerEntityAlpha(Module):
    def __init__(
        self,
        num_embeddings: Module,
        shared_dim: int,
        box_dim: int,
        initial_value: Union[Tensor, float] = 0.0,
        **kwargs
    ):
        super(PerEntityAlpha, self).__init__()
        self.embeddings = Parameter(torch.randn((num_embeddings, shared_dim)))
        self.linear = torch.nn.Linear(shared_dim, box_dim)

    def normalize(self, inp: Tensor, _min: float = -10.0, _max: float = 10.0):
        return (_max - _min) * torch.sigmoid(inp) + _min

    def forward(self, idxs, dim):
        unnormalised_alpha = self.linear(self.embeddings[idxs])
        return self.normalize(unnormalised_alpha)

class PerDimAlpha(Module):
    def __init__(
        self,
        initial_value: Union[Tensor, float] = 0.0,
        **kwargs):
        pass


class GlobalLearnedAlpha(Module):
    def __init__(self, initial_value: Union[Tensor, float], **kwargs):
        super(GlobalLearnedAlpha, self).__init__()
        self.value = Paraneter(initial_value)

    def forward(self, idxs, dim):
        return self.value[None]


class ConstantAlpha:
    def __init__(self, initial_value: Union[Tensor, float], **kwargs):
        self.value = torch.Tensor(initial_value)

    def forward(self, idxs, dim):
        return self.value[None]
