import torch
from torch import nn as nn, Tensor
from torch.autograd.function import InplaceFunction
from torch.jit import ScriptModule, script_method
from torch.nn import Dropout


class FeatureDropoutFunction(InplaceFunction):
    @classmethod
    def forward(cls, ctx, input, p=0.5, train=False, inplace=False):
        batch_size, max_sent_length, feature_count = input.shape
        if p < 0 or p > 1:
            raise ValueError("dropout probability has to be between 0 and 1, "
                             "but got {}".format(p))

        ctx.p = p
        ctx.train = train
        ctx.inplace = inplace

        if ctx.inplace:
            ctx.mark_dirty(input)
            output = input
        else:
            output = input.clone()

        if ctx.p > 0 and ctx.train:
            ctx.noise = input.new().resize_(batch_size, 1, feature_count)
            if ctx.p == 1:
                ctx.noise.fill_(0)
            else:
                ctx.noise.bernoulli_(1 - ctx.p).div_(1 - ctx.p)
            ctx.noise = ctx.noise.repeat(1, max_sent_length, 1)
            output.mul_(ctx.noise)

        return output

    @staticmethod
    def backward(ctx, grad_output):
        if ctx.p > 0 and ctx.train:
            return grad_output.mul(ctx.noise), None, None, None, None
        else:
            return grad_output, None, None, None, None


class FeatureDropout(nn.Module):
    """
    Feature-level dropout: takes an input of size len x num_features and drops
    each feature with probabibility p. A feature is dropped across the full
    portion of the input that corresponds to a single batch element.
    """

    def __init__(self, p=0.5, inplace=False):
        super().__init__()
        if p < 0 or p > 1:
            raise ValueError("dropout probability has to be between 0 and 1, "
                             "but got {}".format(p))
        self.p = p
        self.inplace = inplace

    def forward(self, input):
        return FeatureDropoutFunction.apply(input, self.p, self.training, self.inplace)


class FeatureDropout2(ScriptModule):
    """
    Feature-level dropout: takes an input of size len x num_features and drops
    each feature with probabibility p. A feature is dropped across the full
    portion of the input that corresponds to a single batch element.
    """

    def __init__(self, p=0.5, inplace=False):
        super().__init__()
        if p < 0 or p > 1:
            raise ValueError("dropout probability has to be between 0 and 1, "
                             "but got {}".format(p))
        self.p = p
        self.inplace = inplace
        self.dropout = Dropout(p)

    @script_method
    def forward(self, input: Tensor):
        noise = self.dropout(torch.ones(input.shape[0], input.shape[2],
                                        device=input.device)
                             ).unsqueeze(-2).expand(-1, input.shape[1], -1)
        return input * noise


class LayerNormalization(ScriptModule):
    __constants__ = ["eps", "affine"]

    def __init__(self, d_hid, eps=1e-3, affine=True):
        super(LayerNormalization, self).__init__()

        self.eps = eps
        self.affine = affine
        if self.affine:
            self.a_2 = nn.Parameter(torch.ones(d_hid), requires_grad=True)
            self.b_2 = nn.Parameter(torch.zeros(d_hid), requires_grad=True)

    @script_method
    def forward(self, z):
        if z.size(-1) == 1:
            ln_out = z
        else:
            mu = torch.mean(z, keepdim=True, dim=-1)
            sigma = torch.std(z, keepdim=True, dim=-1)
            ln_out = (z - mu.expand_as(z)) / (sigma.expand_as(z) + self.eps)
            if self.affine:
                ln_out = ln_out * self.a_2.expand_as(ln_out) + self.b_2.expand_as(ln_out)

            # NOTE(nikita): the t2t code does the following instead, with eps=1e-6
            # However, I currently have no reason to believe that this difference in
            # implementation matters.
            # mu = torch.mean(z, keepdim=True, dim=-1)
            # variance = torch.mean((z - mu.expand_as(z))**2, keepdim=True, dim=-1)
            # ln_out = (z - mu.expand_as(z)) * torch.rsqrt(variance + self.eps).expand_as(z)
            # ln_out = ln_out * self.a_2.expand_as(ln_out) + self.b_2.expand_as(ln_out)

        return ln_out
