import base64
import requests
import json
import re
import os
import random
import tqdm
import argparse
import jsonlines

api_key = "xxxxxxx"

# Function to encode the image
def encode_image(image_path):
  with open(image_path, "rb") as image_file:
    return base64.b64encode(image_file.read()).decode('utf-8')







def filter_question(image_id, questions, standards):
    image_path = f"/home/user/llavafinetune/images/{image_id}.jpg"
    #search line which image_id is the same as input image_id in sample_output_modified

    base64_image = encode_image(image_path)
    headers = {
    "Content-Type": "application/json",
    "Authorization": f"Bearer {api_key}"
    }

    payload = {
    "model": "gpt-4-turbo-2024-04-09",
    "messages": [
        {
        "role": "user",
        "content": [
            {
            "type": "text",
            "text": "Given the image and the corresponding questions listed below, please select a question that is most likely to lead to additional clarification questions, a response of 'I cannot answer this question from the image,' or challenges to false assumptions. The selected question should elicit uncertainty or a need for more information.\
                     Here are the standards you should refer to when you are making decisions: \n" + standards + 
                     "Here are the questions you should consider: \n" + questions + "Don't Give me additional explanation, just give me the question type and the question itself. \
                     The format should be\
                     'Type: xxx \
                      Question: xxx'"
            # "text": prompt
            },
                
            {
            "type": "image_url",
            "image_url": {
                "url": f"data:image/jpeg;base64,{base64_image}"
            }
            }
        ]
        }
    ],
    "max_tokens": 300
    }

    attempt_count = 0
    while attempt_count < 3:
        try:
            response = requests.post("https://api.openai.com/v1/chat/completions", headers=headers, json=payload)
            output = response.json()
            questions = output['choices'][0]['message']['content']
            return questions
        except:
            attempt_count += 1
            print('Attempt', attempt_count, 'failed.')
    print(f"Failed to filter question for image {image_id}.")
    return None






dir = '/home/user/llavafinetune/images'
files = os.listdir(dir)
# samples = random.sample(files, 5000)

type_counts = {
    'Subject Ambiguity': 0,
    'Unclear User Background': 0,
    'Subjective Interpretations': 0,
    'Unanswerable Questions': 0,
    'False Premise': 0
}



standard_map = {
    "Subject Ambiguity": "Questions that does not specify which object or which person is being asked when the image has multiple people or objects of the same type are considered subject ambiguous. Questions that can be directly answered without further clarification should be considered bad. \
                          For example, 'Is he wearing a hat?' where there are two men in the image is a good question. However, 'Which person is wearing a hat?' is a bad one since it can be directly answered.",
    "Unclear User Background": "Questions that compare the scenes or persons with you when no information about you is provided.",
    "Subjective Interpretations": "Questions that rely on subjective judgment without clear criteria or specific human preference when no preferences are given. Questions that use subjective words which most humans would not have a consensus on the definition should be considered good.",
    "Unanswerable Questions": "Questions that completely cannot be answered or inferred based on the image alone, even with clarification questions. Questions that may raise uncertainty or ambiguity on whether it can be answered should be consideired bad.",
    "False Premise": "Tricky Questions that give false premises or incorrect assumptions. Questions that are tricky enough to fool the model so that it will fail to point out the false assumptions should be considered good."
}


total_result =[]
src_data = []
count = 0

with jsonlines.open('/home/user/llavafinetune/original_question_output.jsonl') as reader:
    for obj in reader:
        src_data.append(obj)



for obj in tqdm.tqdm(src_data): 
    try:
        image_id = obj['image_id']
        obj.pop('image_id')
        filtered_questions = []
        filtered_standard =[]
        for key, value in obj.items():
            if type_counts[key] < 1960:
                if value == 'N/A':
                    continue
                else:
                    filtered_questions.append(f"{key}: {value}\n")
                    filtered_standard.append(f"{key}: {standard_map[key]}\n")
            else:
                continue
        if len(filtered_questions) == 0:
            continue
        questions = "\n".join(filtered_questions)
        standards = "\n".join(filtered_standard)
        output = filter_question(image_id, questions, standards)
        if output is None:
            continue
        line_one = output.split('\n')[0]
        line_two = output.split('\n')[1]
        type_index = line_one.find('Type:')
        question_index = line_two.find('Question:')
        if type_index != -1:
            question_type = line_one[type_index + len('Type:'):]
            question_type = question_type.strip()
        if question_index != -1:
            question = line_two[question_index + len('Question:'):]
            question = question.strip()

        type_counts[question_type] += 1
        result = {'Image ID': image_id, 'Question Type': question_type, 'Question': question}
        total_result.append(result)
        count += 1
        if count % 100 == 0:
            with open('/home/user/llavafinetune/filtered_output.jsonl', 'w') as f:
                for item in total_result:
                    f.write(json.dumps(item) + "\n")
    except:
        print(f"Failed to filter question for image {image_id}.")
        continue

