import json
import torch
import torch.nn.functional as F
import numpy as np
from PIL import Image
from tqdm import tqdm
from transformers import CLIPProcessor, CLIPModel


model = CLIPModel.from_pretrained('../models/clip-vit-base-patch16')
processor = CLIPProcessor.from_pretrained('../models/clip-vit-base-patch16')

def calculate_vleu(path_func, prompts, num, start=0, temperature=0.01):
    with torch.no_grad():
        text_embs = []
        for prompt in tqdm(prompts[start:start+num]):
            # print(prompt)
            inputs = processor([prompt], return_tensors='pt', truncation=True)
            outputs = model.get_text_features(**inputs)
            outputs /= outputs.norm(dim=-1, keepdim=True)
            text_embs.append(outputs)

        img_embs = []
        for i in tqdm(range(start, start+num)):
            image = Image.open(path_func(i))
            inputs = processor(images=image, return_tensors='pt')
            outputs = model.get_image_features(**inputs)
            outputs /= outputs.norm(dim=-1, keepdim=True)
            img_embs.append(outputs)

        prob_matrix = []
        for i in range(len(img_embs)):
            cosine_sim = []
            for j in range(len(text_embs)):
                cosine_sim.append(img_embs[i] @ text_embs[j].T)
            prob = F.softmax(torch.tensor(cosine_sim) / temperature, dim=0)
            prob_matrix.append(prob)

        prob_matrix = torch.stack(prob_matrix)

        # marginal distribution for text embeddings
        text_emb_marginal_distribution = prob_matrix.sum(axis=0) / prob_matrix.shape[0]

        # KL divergence for each image
        image_kl_divergences = []
        for i in range(prob_matrix.shape[0]):
            kl_divergence = (prob_matrix[i, :] * torch.log(prob_matrix[i, :] / text_emb_marginal_distribution)).sum().item()
            image_kl_divergences.append(kl_divergence)

        vleu_score = np.exp(sum(image_kl_divergences) / prob_matrix.shape[0])
        return vleu_score

def sd15_path(i):
    return f'val_teddy_bear/0/{i}.jpg'

def sd20_path(i):
    return f'img_gen_2.0/0_{i}/0_0.jpg'

def sd21_path(i):
    return f'img_gen_2.1/0_{i}/0_0.jpg'

def sdxl_path(i):
    return f'img_gen_xl/0_{i}/0_0.jpg'

def dalle2_path(i):
    return f'gpt3.5_dalle2/{i}.jpg'

def dalle3_path(i):
    return f'gpt3.5_dalle3/{i}.jpg'


if __name__ == '__main__':
    with open('../teddy_bear.json', 'r', encoding='utf-8') as f:
        prompts = json.load(f)[75:]
    
    for num in [25]:
        for start in [0]:
            for path_func in [sd15_path]:
                score=calculate_vleu(path_func,prompts,num,start,0.01)
                print(score)
