import torch
import clip
from PIL import Image


device = "cuda" if torch.cuda.is_available() else "cpu"
device="cpu"

def load_clip():
  model, preprocess = clip.load("ViT-B/32", device=device)
  return model, preprocess

def clip_encode_docs(model, texts):
  text_features = []
  with torch.no_grad():
    for text in texts:
      text = clip.tokenize(text, truncate=True).to(device)
      text_feature = model.encode_text(text)
      text_feature /= text_feature.norm(dim=-1, keepdim=True)
      text_features.append(text_feature)

  return text_features

def clip_hard_neg(model, preprocess, img, text_features):
  image = Image.open(img)
  similarities = []
  with torch.no_grad():
    if image.mode != 'RGB':
      image = image.convert('RGB')
    image = preprocess(image).unsqueeze(0).to(device)
    image_feature = model.encode_image(image)
    image_feature /= image_feature.norm(dim=-1, keepdim=True)
    for text_feature in text_features:
      cossim = torch.einsum('bc,bc->b', image_feature, text_feature)
      similarities.append(cossim)
    similarities = torch.FloatTensor(similarities)
  (values, indices) = similarities.topk(600)
  mask = (values < 0.24)
  indices = indices[mask]
  # if len(indices) == 0:
  #   print(values[-1])


  return indices