import torch

from sentence_transformers import SentenceTransformer
from utils.config import device

# Load model
model = SentenceTransformer('AnnaWegmann/Style-Embedding', device=device)

# Define get_*_embedding function
def get_style_embedding(texts):
    embeddings = model.encode(texts, device=device, convert_to_tensor=True, normalize_embeddings=True)
    average_embedding = torch.nn.functional.normalize(torch.mean(embeddings, axis=0), dim=0)
    return average_embedding.detach().cpu().numpy()