import torch
from torch import Tensor
from torch.nn import Module, Parameter
import torch.nn.functional as F
from .box_wrapper import SigmoidBoxTensor, BoxTensor, TBoxTensor, DeltaBoxTensor
from .box_wrapper import BoxTensorLearntTemp
from typing import List, Tuple, Dict, Optional, Any, Union, TypeVar, Type, Callable
from allennlp.modules.seq2vec_encoders import pytorch_seq2vec_wrapper
from allennlp.modules.token_embedders import Embedding
import logging
import numpy as np

logger = logging.getLogger(__name__)

TTensor = TypeVar("TTensor", bound="torch.Tensor")

# TBoxTensor = TypeVar("TBoxTensor", bound="BoxTensor")


def _uniform_init_using_minmax(weight, emb_dim, param1, param2, box_type):
    with torch.no_grad():
        temp = torch.zeros_like(weight)
        torch.nn.init.uniform_(temp, param1, param2)
        z, Z = (
            torch.min(temp[..., :emb_dim], temp[..., emb_dim:]),
            torch.max(temp[..., :emb_dim], temp[..., emb_dim:]),
        )
        w, W = box_type.get_wW(z, Z)
        weight[..., :emb_dim] = w
        weight[..., emb_dim:] = W


def _uniform_small(weight, emb_dim, param1, param2, box_type):
    with torch.no_grad():
        temp = torch.zeros_like(weight)
        torch.nn.init.uniform_(temp, 0.0 + 1e-7, 1.0 - 0.1 - 1e-7)
        # z = torch.min(temp[..., :emb_dim], temp[..., emb_dim:])
        z = temp[..., :emb_dim]
        Z = z + 0.1
        w, W = box_type.get_wW(z, Z)
        weight[..., :emb_dim] = w
        weight[..., emb_dim : emb_dim * 2] = W


def _uniform_big(weight, emb_dim, param1, param2, box_type):
    with torch.no_grad():
        temp = torch.zeros_like(weight)
        torch.nn.init.uniform_(temp, 0 + 1e-7, 0.01)
        z = torch.min(temp[..., :emb_dim], temp[..., emb_dim:])
        Z = z + 0.9
        w, W = box_type.get_wW(z, Z)
        weight[..., :emb_dim] = w
        weight[..., emb_dim : emb_dim * 2] = W


class BoxEmbedding(Embedding):
    box_types = {
        "SigmoidBoxTensor": SigmoidBoxTensor,
        "DeltaBoxTensor": DeltaBoxTensor,
        "BoxTensor": BoxTensor,
        "BoxTensorLearntTemp": BoxTensorLearntTemp,
    }

    def init_weights(self):
        # if self.box_type == 'SigmoidBoxTensor':
        # torch.nn.init.uniform_(self.weight, -0.25, 0.25)
        # else:
        #    torch.nn.init.uniform_(self.weight[:, :self.box_embedding_dim],
        #                           -0.5, 0.5)
        #    torch.nn.init.uniform_(self.weight[:, self.box_embedding_dim:],
        #                           -0.1, 0.1)
        if self.old_init:
            if self.box_type != "SigmoidBoxTensor":
                torch.nn.init.uniform_(
                    self.weight[..., : self.box_embedding_dim],
                    -self.init_interval_center,
                    self.init_interval_center,
                )
                torch.nn.init.uniform_(
                    self.weight[..., self.box_embedding_dim :],
                    -self.init_interval_delta,
                    self.init_interval_delta,
                )
            else:
                torch.nn.init.uniform_(
                    self.weight[..., : self.box_embedding_dim],
                    -self.init_interval_center,
                    self.init_interval_center,
                )
                torch.nn.init.uniform_(
                    self.weight[..., self.box_embedding_dim :],
                    -0.4,
                    -0.4 + self.init_interval_delta,
                )
        else:
            _uniform_small(
                self.weight,
                self.box_embedding_dim,
                0.0 + 1e-7,
                1.0 - 1e-7,
                self.box_types[self.box_type],
            )

    def __init__(
        self,
        num_embeddings: int,
        box_embedding_dim: int,
        box_type="SigmoidBoxTensor",
        weight: torch.FloatTensor = None,
        padding_index: int = None,
        trainable: bool = True,
        max_norm: float = None,
        norm_type: float = 2.0,
        scale_grad_by_freq: bool = False,
        sparse: bool = False,
        vocab_namespace: str = None,
        pretrained_file: str = None,
        init_interval_center=0.25,
        init_interval_delta=0.1,
    ) -> None:
        """Similar to allennlp embeddings but returns box
        tensor by splitting the output of usual embeddings
        into z and Z

        Arguments:
            box_embedding_dim: Embedding weight would be box_embedding_dim*2
                               if the temp and the 
        """
        vector_emb_dim = box_embedding_dim * 2
        if box_type == "BoxTensorLearntTemp":
            vector_emb_dim = box_embedding_dim * 4

        super().__init__(
            num_embeddings,
            vector_emb_dim,
            weight=weight,
            padding_index=padding_index,
            trainable=trainable,
            max_norm=max_norm,
            norm_type=norm_type,
            scale_grad_by_freq=scale_grad_by_freq,
            sparse=sparse,
            vocab_namespace=vocab_namespace,
            pretrained_file=pretrained_file,
        )
        self.old_init = False
        self.box_type = box_type
        self.init_interval_delta = init_interval_delta
        self.init_interval_center = init_interval_center
        try:
            self.box = self.box_types[box_type]
        except KeyError as ke:
            raise ValueError("Invalid box type {}".format(box_type)) from ke
        self.box_embedding_dim = box_embedding_dim
        self.init_weights()

    def forward(self, inputs: torch.LongTensor):
        emb = super().forward(inputs)  # shape (**, self.box_embedding_dim*2)
        box_emb = self.box.from_split(emb)
        return box_emb

    def get_volumes(self, temp: Union[float, torch.Tensor]) -> torch.Tensor:
        return self.all_boxes.log_soft_volume(temp=temp)

    @property
    def all_boxes(self) -> TBoxTensor:
        all_index = torch.arange(
            0, self.num_embeddings, dtype=torch.long, device=self.weight.device
        )
        all_ = self.forward(all_index)

        return all_
