import torch.nn as nn
from .utils import Activations

class Adapter(nn.Module):
    """Conventional adapter latyer."""
    def __init__(self, config, input_dim, remove_upsampling=False):
        super().__init__()
        self.remove_upsampling = remove_upsampling 
        self.config = config
        self.input_dim = input_dim
        self.down_sample_size = self.input_dim // config.reduction_factor
        self.down_sampler = nn.Linear(self.input_dim, self.down_sample_size)
        if not self.remove_upsampling:
            self.activation = Activations(config.nonlinearity.lower())
            self.up_sampler = nn.Linear(self.down_sample_size, self.input_dim)

    def forward(self, x):
        output = self.down_sampler(x)
        if not self.remove_upsampling:
            output = self.activation(output)
            output = self.up_sampler(output)
        return output
