import torch
import torch.nn as nn
from transformers import AutoConfig, AutoModel


class Mean_Pooling_Model(nn.Module):
    def __init__(self, path, dropout, num_labels):
        super().__init__()

        config = AutoConfig.from_pretrained(path)
        self.model = AutoModel.from_pretrained(path, config=config)
        self.linear = nn.Linear(config.hidden_size, num_labels)

    def forward(self, input_ids, attention_mask):

        outputs = self.model(input_ids, attention_mask)
        last_hidden_state = outputs[0]
        input_mask_expanded = (
            attention_mask.unsqueeze(-1).expand(last_hidden_state.size()).float()
        )
        sum_embeddings = torch.sum(last_hidden_state * input_mask_expanded, 1)
        sum_mask = input_mask_expanded.sum(1)
        sum_mask = torch.clamp(sum_mask, min=1e-9)
        mean_embeddings = sum_embeddings / sum_mask
        logits = self.linear(mean_embeddings)

        return logits
