import torch
import torch.nn as nn

from transformers import AutoModel, AutoConfig, AutoTokenizer

class CmdCaliper(nn.Module):
    def __init__(
        self, device: str, load_from_pretrained: bool, path_to_model_weight: str
    ):
        super().__init__()
        self.device = device
        self.initialize_base_model(load_from_pretrained, path_to_model_weight)

    def initialize_base_model(
        self, load_from_pretrained: bool, path_to_model_weight: str
    ):
        if load_from_pretrained:
            self.transformer = AutoModel.from_pretrained(
                path_to_model_weight, use_cache=False
            )
        else:
            config = AutoConfig.from_pretrained(path_to_model_weight)
            self.transformer = AutoModel.from_config(config)

    @staticmethod
    def get_tokenizer(path_to_model_weight: str, **kwargs):
        return AutoTokenizer.from_pretrained(path_to_model_weight, **kwargs)

    def forward(
        self, input_ids, attention_mask, **kwargs
    ):
        y = self.transformer(
            input_ids=input_ids, 
            attention_mask=attention_mask, 
            output_hidden_states=True
        )
        token_embeddings = y[0]
        input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).to(token_embeddings.dtype)

        return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
