from transformers import CLIPProcessor, CLIPModel
from PIL import Image
import argparse
import random
import torch
import json


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--clip_model', type=str)
    parser.add_argument('--token_json', type=str)
    parser.add_argument('--image_path', type=str)
    parser.add_argument('--n_prompt', type=int)
    parser.add_argument('--no_unique', default=False, action='store_true')
    args = parser.parse_args()

    model = CLIPModel.from_pretrained(args.clip_model)
    processor = CLIPProcessor.from_pretrained(args.clip_model)

    with open(args.token_json,'r',encoding='utf-8') as f:
        token_list = json.load(f)

    if args.no_unique:
        tokens = [x for s in token_list for x in s]
    else:
        tokens = list(set(x for s in token_list for x in s))

    image = Image.open(args.image_path)
    inputs = processor(images=[image], return_tensors="pt")

    for i in range(args.n_prompt):
        length = len(random.choice(token_list))
        new_seq = [49406]+random.choices(tokens, k=length-2)+[49407]
        inputs['input_ids'] = torch.LongTensor([new_seq])
        inputs['attention_mask'] = torch.LongTensor([[1] * length])
        outputs = model(**inputs)
        text_embeds = outputs.text_embeds
        image_embeds = outputs.image_embeds
        score = torch.cosine_similarity(text_embeds, image_embeds).sum().item()
        print(score)
