import numpy as np
import torch
import torch.nn as nn
from torch.nn.parameter import Parameter
from torch.nn.functional import embedding

from ..init import assign_tensor


class Embedding(nn.Module):
    r"""A simple lookup table that stores embeddings of a fixed dictionary and size.
    This module is often used to store word embeddings and retrieve them using indices.
    The input to the module is a list of indices, and the output is the corresponding
    word embeddings.
    Args:
        num_embeddings (int): size of the dictionary of embeddings
        embedding_dim (int): the size of each embedding vector
        init_embedding (Tensor): If given, the embedding will be initialized with the given tensor.
        freeze (boolean, optional): If ``True``, the tensor does not get updated in the learning process.
        padding_idx (int, optional): If given, pads the output with zeros whenever it encounters the index.
        max_norm (float, optional): If given, will renormalize the embeddings to always have a norm lesser than this
        norm_type (float, optional): The p of the p-norm to compute for the max_norm option
        scale_grad_by_freq (boolean, optional): if given, this will scale gradients by the frequency of
                                                the words in the mini-batch.
        sparse (boolean, optional): if True, gradient w.r.t. weight matrix will be a sparse tensor. See Notes for
                                    more details regarding sparse gradients.
    Attributes:
        weight (Tensor): the learnable weights of the module of shape (num_embeddings, embedding_dim)
    Shape:
        - Input: LongTensor `(N1, N2, ...,Nm, W)`, N = mini-batch, W = number of indices to extract per mini-batch
        - Output: `(N1, N2, ..., Nm, W, embedding_dim)`
    Notes:
        Keep in mind that only a limited number of optimizers support
        sparse gradients: currently it's `optim.SGD` (`cuda` and `cpu`),
        and `optim.Adagrad` (`cpu`)
    """

    def __init__(self, num_embeddings, embedding_dim, init_embedding=None, freeze=False, padding_idx=None,
                 max_norm=None, norm_type=2, scale_grad_by_freq=False, sparse=False):
        super(Embedding, self).__init__()
        self.num_embeddings = num_embeddings
        self.embedding_dim = embedding_dim
        self.padding_idx = padding_idx
        self.max_norm = max_norm
        self.norm_type = norm_type
        self.scale_grad_by_freq = scale_grad_by_freq
        self.weight = Parameter(torch.Tensor(num_embeddings, embedding_dim))
        self.frozen = freeze
        self.sparse = sparse

        self.reset_parameters(init_embedding)

    def reset_parameters(self, init_embedding):
        if init_embedding is None:
            scale = np.sqrt(3.0 / self.embedding_dim)
            self.weight.data.uniform_(-scale, scale)
        else:
            assign_tensor(self.weight, init_embedding)
        if self.padding_idx is not None:
            self.weight.data[self.padding_idx].fill_(0)

        if self.frozen:
            if init_embedding is None:
                raise Warning('Freeze embeddings which are randomly initialized.')
            self.weight.requires_grad = False

    def freeze(self):
        self.weight.requires_grad = False
        self.frozen = True

    def forward(self, input):
        padding_idx = self.padding_idx
        if padding_idx is None:
            padding_idx = -1

        input_size = input.size()
        if input.dim() > 2:
            num_inputs = int(np.prod(input_size[:-1]))
            input = input.view(num_inputs, input_size[-1])

        output_size = input_size + (self.embedding_dim,)
        return embedding(input,self.weight,padding_idx,self.max_norm,
                         self.norm_type,self.scale_grad_by_freq,
                         self.sparse).view(output_size)
        #return self._backend.Embedding.apply(
        #    input, self.weight,
        #    padding_idx, self.max_norm, self.norm_type,
        #    self.scale_grad_by_freq, self.sparse).view(output_size)

    def __repr__(self):
        s = '{name}({num_embeddings}, {embedding_dim}'
        if self.padding_idx is not None:
            s += ', padding_idx={padding_idx}'
        if self.max_norm is not None:
            s += ', max_norm={max_norm}'
        if self.norm_type != 2:
            s += ', norm_type={norm_type}'
        if self.scale_grad_by_freq is not False:
            s += ', scale_grad_by_freq={scale_grad_by_freq}'
        if self.sparse is not False:
            s += ', sparse=True'
        s += ')'
        return s.format(name=self.__class__.__name__, **self.__dict__)
