import persist_to_disk as ptd

import models
import utils.nlg as unlg


@ptd.persistf(
    skip_kwargs=["device"], expand_dict_kwargs=["groups"], switch_kwarg="cache"
)
def _get_attentions(
    model_name: str, prompt: str, generation: str, groups: dict, device=0
):
    import utils

    device = utils.gpuid_to_device(device)
    model, tokenizer = models.load_model_and_tokenizer(model_name, device=device)

    _prompt = tokenizer.encode(prompt, return_tensors="pt")[0]
    _generation = tokenizer.encode(
        generation, add_special_tokens=False, return_tensors="pt"
    )[0]
    obj = unlg.AttnGenerationBatched(
        [_prompt], [_generation], model, tokenizer, debug=False
    )

    return obj.summ_attn_by_head(0, groups=groups, layer=None, normalize=True)
