from tqdm import tqdm
import math
from PIL import Image
import torch
from transformers import (
    CLIPProcessor,
    CLIPModel,
)


def batchify(data, batch_size=16):
    one_batch = []
    for example in data:
        one_batch.append(example)
        if len(one_batch) == batch_size:
            yield one_batch
            one_batch = []
    if one_batch:
        yield one_batch


def calculate_clipscore(
    model: CLIPModel, processor: CLIPProcessor, texts, images_path, batch_size=64
):
    all_text_embeds = []
    all_image_embeds = []
    for text, image_path in tqdm(
        zip(batchify(texts, batch_size), batchify(images_path, batch_size)),
        total=math.ceil(len(texts) / batch_size),
    ):
        assert len(text) == len(image_path)
        batch_inputs = processor(
            text=text,
            images=[Image.open(image) for image in image_path],
            return_tensors="pt",
            max_length=processor.tokenizer.model_max_length,
            padding="max_length",
            truncation=True,
        )
        text_embeds = model.get_text_features(input_ids=batch_inputs["input_ids"])
        image_embeds = model.get_image_features(
            pixel_values=batch_inputs["pixel_values"]
        )
        all_text_embeds.append(text_embeds)
        all_image_embeds.append(image_embeds)

    all_text_embeds = torch.concat(all_text_embeds)
    all_image_embeds = torch.concat(all_image_embeds)
    all_text_embeds = all_text_embeds / all_text_embeds.norm(dim=-1, keepdim=True)
    all_image_embeds = all_image_embeds / all_image_embeds.norm(dim=-1, keepdim=True)
    clip_score = (all_image_embeds * all_text_embeds).sum(-1) * model.logit_scale.exp()
    return clip_score.mean().item()
