import torch
from transformers import AutoModel, AutoTokenizer


def text_to_style(*, model, tokenizer, texts, device, max_length=512):
    inputs = tokenizer(
        texts, return_tensors='pt', padding=True, truncation=True, max_length=max_length
    )
    inputs = {k: v.to(device) for k, v in inputs.items()}
    embeds = get_style_embedding(
        model=model,
        input_tokens=inputs['input_ids'],
        attention_mask=inputs['attention_mask'],
    )

    embeds = [x for x in embeds]
    return embeds


def load_style_model():
    tokenizer = AutoTokenizer.from_pretrained('AnnaWegmann/Style-Embedding')
    model = AutoModel.from_pretrained('AnnaWegmann/Style-Embedding')
    return model, tokenizer


def mean_pooling(model_output, attention_mask):
    token_embeddings = model_output[
        0
    ]  # First element of model_output contains all token embeddings
    input_mask_expanded = (
        attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
    )
    return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(
        input_mask_expanded.sum(1), min=1e-9
    )


def get_style_embedding(
    *,
    model,
    inputs_embeds=None,
    input_tokens=None,
    attention_mask=None,
):
    assert inputs_embeds is not None or input_tokens is not None
    if inputs_embeds is not None:
        if attention_mask is None:
            attention_mask = torch.ones(*inputs_embeds.shape[:-1]).to(
                inputs_embeds.device
            )
        attention_mask = attention_mask.to(inputs_embeds.device)

        return mean_pooling(
            model(inputs_embeds=inputs_embeds, attention_mask=attention_mask),
            attention_mask=attention_mask,
        )

    else:
        if attention_mask is None:
            attention_mask = torch.ones(*input_tokens.shape).to(input_tokens.device)
        attention_mask = attention_mask.to(input_tokens.device)

        return mean_pooling(
            model(input_tokens, attention_mask=attention_mask),
            attention_mask=attention_mask,
        )
