import torch.nn as nn

class ConvNet(nn.Module):
    def __init__(self, input_channels, out_dim):
        super(ConvNet, self).__init__()

        self.conv1 = nn.Conv1d(input_channels, 32, kernel_size=24)
        self.relu1 = nn.ReLU()
        self.dropout1 = nn.Dropout(0.1)

        self.conv2 = nn.Conv1d(32, 64, kernel_size=16)
        self.relu2 = nn.ReLU()
        self.dropout2 = nn.Dropout(0.1)

        self.conv3 = nn.Conv1d(64, 96, kernel_size=8)
        self.relu3 = nn.ReLU()
        self.dropout3 = nn.Dropout(0.1)

        self.global_max_pooling = nn.AdaptiveMaxPool1d(1)
        
        self.fc1 = nn.Linear(96, out_dim)

    def forward(self, x):
        x = x.permute(0, 2, 1)
        x = self.conv1(x)
        x = self.relu1(x)
        x = self.dropout1(x)

        x = self.conv2(x)
        x = self.relu2(x)
        x = self.dropout2(x)

        x = self.conv3(x)
        x = self.relu3(x)
        x = self.dropout3(x)
        
        x = self.global_max_pooling(x)
        x = x.squeeze(-1)
        x = self.fc1(x)
        return x
    