"""Implements adapter controller, a module that keeps multiple
layers of adapters and controls which adapter layer to use."""
import torch.nn as nn 
from .adapter_modeling import Adapter


class AdapterController(nn.Module):
    """Implements Adapter controller module which controls the logits of
    putting adapter layers within  the transformer's layers.
    config: adapter configuraiton.
    input_dim: input dimension of the hidden representation feed into adapters.
    is_attention: set to true when adapter is applied to the query/key values."""
    def __init__(self, config, input_dim, is_attention=False, remove_upsampling=False):
        super().__init__()
        self.config = config
        self.input_dim = input_dim
        self.remove_upsampling = remove_upsampling
        self.add_layer_norm_before_adapter = config.add_layer_norm_before_adapter\
            if not is_attention else config.add_layer_norm_before_adapter_attn
        self.add_layer_norm_after_adapter = config.add_layer_norm_after_adapter\
            if not is_attention else config.add_layer_norm_after_adapter_attn 
        self.adapter = self.construct_adapters()
        if self.add_layer_norm_before_adapter:
            self.pre_layer_norm = nn.LayerNorm(input_dim)
        if self.add_layer_norm_after_adapter:
            self.post_layer_norm = nn.LayerNorm(input_dim)

    def construct_adapters(self):
        """Construct the Adapter layers."""
        return Adapter(self.config, input_dim=self.input_dim, remove_upsampling=self.remove_upsampling)

    def forward(self, inputs):
        z = self.pre_layer_norm(inputs) if self.add_layer_norm_before_adapter else inputs
        outputs = self.adapter(z) 
        if self.add_layer_norm_after_adapter:
            outputs = self.post_layer_norm(outputs)
        if not self.remove_upsampling:
            outputs = outputs + inputs
        return outputs 
