import torch
from torch.nn import Module, Parameter
import torch.nn.functional as F
from torch import Tensor
from .box_wrapper import (
    SigmoidBoxTensor,
    BoxTensor,
    TBoxTensor,
    DeltaBoxTensor,
    BoxTensorLearntTemp,
)
from .pooling_utils import PerEntityAlpha
from typing import List, Tuple, Dict, Optional, Any, Union, TypeVar, Type, Callable


box_types = {
    "SigmoidBoxTensor": SigmoidBoxTensor,
    "DeltaBoxTensor": DeltaBoxTensor,
    "BoxTensor": BoxTensor,
    "BoxTensorLearntTemp": BoxTensorLearntTemp,
}


class LearntPooling(Module):
    def __init__(self, alpha_layer: Module):
        super(LearntPooling, self).__init__()
        self.alpha_layer = alpha_layer

    def forward(self, ids: Tensor, box: TBoxTensor, dim: int = 1):
        box_type = box.__class__.__name__
        alpha = self.alpha_layer(ids, dim)
        # Calculate weights
        # z_weight  = torch.exp(box.z *  alpha)
        # Z_weight = torch.exp(box.Z * -alpha)
        z_weight = F.softmax(box.z * alpha, dim=dim)
        Z_weight = F.softmax(box.Z * -alpha, dim=dim)

        # Aggregate using sum(x * z_weight)/sum(z_weight)
        z = torch.sum(box.z * z_weight, dim=dim)
        Z = torch.sum(box.Z * Z_weight, dim=dim)

        if torch.isnan(z).any() or torch.isnan(Z).any():
            raise RuntimeError

        if box_type == "BoxTensorLearntTemp":
            int_temp = torch.mean(box.int_temp, dim=dim)
            vol_temp = torch.mean(box.vol_temp, dim=dim)
            return box_types[box_type].from_zZ(z, Z, int_temp, vol_temp)

        return box_types[box_type].from_zZ(z, Z)


class MaskedLearntPooling(LearntPooling):
    def forward(self, ids: Tensor, box: TBoxTensor, mask: torch.Tensor, dim: int = 1):
        box_type = box.__class__.__name__
        alpha = self.alpha_layer(ids, dim)
        # Calculate weights
        # z_weight  = torch.exp(box.z *  alpha)
        # Z_weight = torch.exp(box.Z * -alpha)

        z_weight = F.softmax(box.z * alpha, dim=dim)
        Z_weight = F.softmax(box.Z * -alpha, dim=dim) # Alpha could be different for Z & z.
        z_weight = z_weight * mask.unsqueeze(-1)
        Z_weight = Z_weight * mask.unsqueeze(-1)

        # z = (box.z * z_weight).sum(dim=dim)

        # z = (z ).sum()
        # Aggregate using sum(x * z_weight)/sum(z_weight)
        z = torch.div(torch.sum(box.z * z_weight, dim=dim), torch.sum(z_weight, dim=dim) + 1e-19)
        Z = torch.div(torch.sum(box.Z * Z_weight, dim=dim), torch.sum(Z_weight, dim=dim) + 1e-19)

        if torch.isnan(z).any() or torch.isnan(Z).any():
            raise RuntimeError
        if torch.isinf(z).any() or torch.isinf(Z).any():
            raise RuntimeError

        if box_type == "BoxTensorLearntTemp":
            int_temp = torch.mean(box.int_temp, dim=dim)
            vol_temp = torch.mean(box.vol_temp, dim=dim)
            return box_types[box_type].from_zZ(z, Z, int_temp, vol_temp)

        return box_types[box_type].from_zZ(z, Z)
