# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
"""
a low-rank decomposition based embedding
"""

import torch.nn as nn
import torch.nn.functional as F


class FactorizedEmbedding(nn.Embedding):
    def __init__(self, num_embeddings, rank, embedding_dim,
                 padding_idx=None, max_norm=None, norm_type=2.,
                 scale_grad_by_freq=False, sparse=False,
                 _weight=None):
        super().__init__(
            num_embeddings, rank, padding_idx=padding_idx,
            max_norm=max_norm, norm_type=norm_type,
            scale_grad_by_freq=scale_grad_by_freq,
            sparse=sparse, _weight=_weight
        )

        self.projection = nn.Linear(rank, embedding_dim, bias=False)

    def forward(self, x, mask=None):
        x = F.embedding(
            x, self.weight, self.padding_idx, self.max_norm,
            self.norm_type, self.scale_grad_by_freq, self.sparse
        )
        # enable dynamic pruning
        if mask is not None:
            x = x * mask
        x = self.projection(x)
        return x
