import torch
import torch.nn as nn

from fairseq.models.transformer import TransformerEncoderLayer
from fairseq.modules import LayerDropModuleList, LayerNorm

class Adapter(nn.Module):
    def __init__(self, args, layer_num, padding_idx, embed_dim):
        super().__init__()
        self.encoder_layerdrop = args.encoder_layerdrop
        if self.encoder_layerdrop > 0.0:
            self.layers = LayerDropModuleList(p=self.encoder_layerdrop)
        else:
            self.layers = nn.ModuleList([])
        self.layers.extend([
            self.build_layer(args) for _ in range(layer_num)
        ])

        self.padding_idx = padding_idx

        if args.encoder_normalize_before:
            self.layer_norm = LayerNorm(embed_dim)
        else:
            self.layer_norm = None

    def build_layer(self, args):
        return TransformerEncoderLayer(args)

    def forward(self, src_tokens, x, encoder_padding_mask=None):
        if encoder_padding_mask is None:
            encoder_padding_mask = src_tokens.eq(self.padding_idx)
        for (i, layer) in enumerate(self.layers):
            x = layer(x, encoder_padding_mask)

        if self.layer_norm is not None:
            x = self.layer_norm(x)
        return x
