from PIL import Image
from transformers import AutoProcessor, LlavaForConditionalGeneration
import torch
import transformers

PRETRAINED_PATH = "llava-hf/llava-1.5-7b-hf"
PROMPT = "<image> What is author's intention of this image? Answer: "
DEVICE = 'cuda'
MAX_ANSWER_TOKENS = 200
DTYPE = torch.float16
FP16 = True
ANSWER_OFFSET = "ANSWER:"

def load_llava():
  processor = AutoProcessor.from_pretrained(PRETRAINED_PATH)
  model = LlavaForConditionalGeneration.from_pretrained(PRETRAINED_PATH, load_in_8bit=FP16, torch_dtype=DTYPE)
  # pipeline = transformers.pipeline(
  #    "visual-question-answering",
  #    model=PRETRAINED_PATH,
  #    model_kwargs={"torch_dtype":torch.float16},
  #    device_map="auto"
  # )
  return processor, model

def llava(processor, model, image_path, prompt=PROMPT, answer_offset=ANSWER_OFFSET):
    image = Image.open(image_path)
    inputs = processor(text=prompt, images=image, return_tensors="pt").to(device=DEVICE, dtype=DTYPE)

    generated_ids = model.generate(**inputs, max_new_tokens=MAX_ANSWER_TOKENS, do_sample=False)

    generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
    # outputs = pipleline(
    #    question=prompt,
    #    image=image_path,
    #    max_new_tokens=256,
    #    max_length=256,
    #    generate_kwargs={"max_new_tokes":256}
    #    do_sample=True,
    #    temperature=0.9,
    #    top_p=0.9
    # )

    generated_text = generated_text.split(answer_offset)[1]
    return generated_text
    # return outputs[0]["generated_text"]
