# -*- coding: utf-8 -*-

import torch
import torch.nn.functional as F

from ..jit import ScriptModule, script_method


class FeatureDropout(ScriptModule):
    """
    Feature-level dropout: takes an input of size len x num_features and drops
    each feature with probabibility p. A feature is dropped across the full
    portion of the input that corresponds to a single batch element.
    """

    __constants__ = ['p', 'inplace']

    def __init__(self, p=0.5, inplace=False):
        super().__init__()

        if p < 0 or p > 1:
            raise ValueError(f'dropout probability has to be between 0 and 1, but got {p}')
        self.p = p
        self.inplace = inplace

    def extra_repr(self):
        return f'p={self.p}'

    @script_method
    def forward(self, input):
        noise = torch.ones(input.shape[0], input.shape[2], device=input.device)
        noise = F.dropout(noise, self.p, self.training, self.inplace)
        noise = noise.unsqueeze(-2).expand(-1, input.shape[1], -1)
        return input * noise

    @classmethod
    def reload(cls, p, inplace):
        return cls(p, inplace)

    def __reduce__(self):
        return self.__class__.reload, (self.p, self.inplace)
