import torch
import torch.nn as nn
from torch.autograd import Function

def z_norm(inputs, batch_dim=0):
    mean = inputs.mean(batch_dim, keepdim=True)
    var = inputs.var(batch_dim, unbiased=False, keepdim=True)
    return (inputs - mean) / torch.sqrt(var + 1e-9)

def getMeanAndVar(inputs, batch_dim=0):
    """
    return the mean and variance of given inputs
    Return:
        mean, variance: Both Tensor
    """
    mean = inputs.mean(batch_dim, keepdim=True)
    var = inputs.var(batch_dim, unbiased=False, keepdim=True)
    return mean, var

def normalizedFn(inputs, mean, var):
    """
    normalize inputs with given mean and variance
    Input:
        inputs, mean, var: Both Tensor
    Return:
        result: Tensor
    """
    return (inputs - mean) / torch.sqrt(var + 1e-9)

def reconstructFn(inputs, mean, var):
    """
    reconstruct inputs with given mean and variance
    Input:
        inputs, mean, var: Both Tensor
    Return:
        result: Tensor
    """
    # return (inputs - mean) / torch.sqrt(var + 1e-9)
    return inputs * torch.sqrt(var + 1e-9) + mean


def lir(token_embed, sent_embed, c, return_proj_token_hidden=True):
    """
    Input:
        token_embed: [B, L, H]
        sent_embed: [B, H]
        c: [H, r]
    Output:
        token_embed: [B, L, H]
        subtract: [B, H]
    """
    c = c.to(sent_embed)
    proj = sent_embed.mm(c) # [B, r]
    subtract = proj.mm(c.transpose(0, 1))
    if return_proj_token_hidden:
        return token_embed - subtract.unsqueeze(1), subtract
    else:
        return subtract

class LinearClassifer(nn.Module):
    def __init__(self, input_dim, output_dim) -> None:
        super().__init__()
        self.model = nn.Linear(input_dim, output_dim, bias=False)
    
    def forward(self, x):
        return self.model(x)


class GradReverse(Function):
    @staticmethod
    def forward(ctx, x, lambd=1, **kwargs: None):
        ctx.lambd = lambd
        return x.view_as(x)

    @staticmethod
    def backward(ctx, grad_output):
        return grad_output * -ctx.lambd, None
