import json
import os
import random
import sys
import time
import anthropic
import google.generativeai as genai
from google.generativeai import GenerationConfig
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
import copy

import numpy as np
import openai
import pandas as pd
import torch
import transformers
from evaluate import load
from matplotlib import pyplot as plt
from sklearn.metrics import (
    accuracy_score,
    confusion_matrix,
    f1_score,
    precision_score,
    recall_score,
)
import tqdm

random.seed(42)
transformers.set_seed(42)
sample_number = 3
optimization_step = 3
best_prompt_choice = 4
generated_prompt_number = 4
refine_iter = 2
# optimization_step = 2
# best_prompt_choice = 2
# generated_prompt_number = 2
transformers.logging.set_verbosity_error()

task_name = sys.argv[1]
choice_type = sys.argv[2]
learning_name = sys.argv[3]
model_name = sys.argv[4]
add_explanation = True if sys.argv[5] == "true" else False
optimize_prompt = True if sys.argv[6] == "true" else False
self_refinement = True if sys.argv[7] == "true" else False


# task_name = "code-and-description"
# choice_type = "list"
# learning_name = "zero-shot"
# model_name = "gpt-4o-2024-05-13"
# add_explanation = True
# optimize_prompt = False
# self_refinement = True

# task_name = sys.argv[1]
# task_name = "plain"
# task_name = "code-only"
# task_name = "description-only"
# task_name = "code-and-description"

# choice_type = sys.argv[2]
# choice_type = "list"
# choice_type = "binary"

# learning_name = sys.argv[3]
# learning_name = "fixed-few-shot"
# learning_name = "zero-shot"

# model_name = sys.argv[4]
# model_name = "gpt-3.5-turbo-0613"
# model_name = "gpt-3.5-turbo-0125"
# model_name = "gpt-4-0613"
# model_name = "gpt-4o-2024-05-13"


# add_explanation = True
# add_explanation = False

# optimize_prompt = False
# optimize_prompt = True

#


# model_name = model_name.split("/")[-1]

openai_api_key = os.getenv("OPENAI_API_KEY")
claudi_api_key = os.getenv("CLAUDI_API_KEY")
gemini_api_key = os.getenv("GEMINI_API_KEY")
huggingface_access_token = os.getenv("HUGGINGFACE_ACCESS_TOKEN")

result_folder = "./result/ih/{}/{}/{}/{}".format(
    task_name,
    choice_type,
    learning_name,
    model_name.split("/")[-1],
)
result_folder = (
    os.path.join(result_folder, "cot")
    if add_explanation
    else os.path.join(result_folder, "no_cot")
)
result_folder = (
    os.path.join(result_folder, "optimize")
    if optimize_prompt
    else os.path.join(result_folder, "no_optimize")
)
result_folder = (
    os.path.join(result_folder, "self_refinement")
    if self_refinement
    else os.path.join(result_folder, "no_refinement")
)
print(result_folder)

if not os.path.exists(result_folder):
    os.makedirs(result_folder)
data = dict()
result_file = os.path.join(result_folder, "prediction.jsonl")

if os.path.exists(result_file):
    with open(result_file, "r") as f:
        for line in f.readlines():
            item = json.loads(line.strip())
            data[item["id"]] = item
prompt_file = os.path.join(result_folder, "prompt_change.jsonl")

if "gpt" in model_name:
    client = openai.OpenAI(api_key=openai_api_key)
elif "claude" in model_name:
    client = anthropic.Anthropic(api_key=claudi_api_key)
elif "gemini" in model_name:
    genai.configure(api_key=gemini_api_key)
    generation_config = GenerationConfig(temperature=0)
    model = genai.GenerativeModel("gemini-1.0-pro", generation_config=generation_config)
elif "Llama" in model_name or "llama" in model_name:
    huggingface_access_token = huggingface_access_token
    tokenizer = AutoTokenizer.from_pretrained(
        model_name,
        token=huggingface_access_token,
        cache_dir="/home/xiaobo/data/models",
    )
    nf4_config = BitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_quant_type="nf4",
        bnb_4bit_use_double_quant=True,
        bnb_4bit_compute_dtype=torch.bfloat16,
    )
    model = AutoModelForCausalLM.from_pretrained(
        model_name,
        device_map="auto",
        torch_dtype=torch.float16,
        token=huggingface_access_token,
        cache_dir="/home/xiaobo/data/models",
        quantization_config=nf4_config,
    )

code_book_file = "./data/ih_annotated/ih_codebook_2_20_24.csv"
code_book = pd.read_csv(code_book_file).values.tolist()
code_book = {i: [row[0], row[1], row[2].strip()] for i, row in enumerate(code_book)}
code_book_based_on_code = {
    row[0]: [row[0], row[1], row[2].strip()] for i, row in code_book.items()
}
positive_code_index = [x for x in code_book.keys() if code_book[x][1] == "Yes"]
negative_code_index = [x for x in code_book.keys() if code_book[x][1] == "No"]
positive_description_with_code = [
    "{}: {}".format(code_book[x][0], code_book[x][2]) for x in positive_code_index
]
negative_description_with_code = [
    "{}: {}".format(code_book[x][0], code_book[x][2]) for x in negative_code_index
]
positive_description = [code_book[x][2] for x in positive_code_index]
negative_description = [code_book[x][2] for x in negative_code_index]
positive_code = [code_book[x][0] for x in positive_code_index]
negative_code = [code_book[x][0] for x in negative_code_index]


text_file = "./data/ih_annotated/annotation_both.csv"
text_data = pd.read_csv(text_file)
thread_id_index = text_data.columns.get_loc("thread_id")
thread_title_index = text_data.columns.get_loc("thread_title")
submission_text_index = text_data.columns.get_loc("submission_text")
comment_1_index = text_data.columns.get_loc("comment_1")
comment_2_index = text_data.columns.get_loc("comment_2")
focal_post_index = text_data.columns.get_loc("focal_post")
melody_labels_index = text_data.columns.get_loc("melody_labels (comma separated)")
discussed_labels_index = text_data.columns.get_loc("final_labels (comma separated)")
neil_labels_index = text_data.columns.get_loc("neil_labels (comma separated)")
text_data = text_data.values.tolist()


def main():
    few_shot_dict = generate_few_shot_samples()
    binary_prompt, code_prompt_dict = generate_system_prompt()

    if optimize_prompt:
        with open(prompt_file, mode="w") as f:
            pass
        with tqdm.tqdm(total=len(code_prompt_dict), desc="prompt optimization") as pbar:
            binary_prompt = get_optimized_prompt(few_shot_dict, "Binary", binary_prompt)
            pbar.update(1)
            for code_name in code_prompt_dict.keys():
                code_prompt_dict[code_name] = get_optimized_prompt(
                    few_shot_dict, code_name, code_prompt_dict[code_name]
                )
                pbar.update(1)

    if self_refinement:
        feedback_binary_system_prompt, feedback_code_prompt_dict = (
            generate_feedback_system_prompt()
        )
        refine_system_prompt, refine_code_prompt_dict = generate_refine_system_prompt()

    for item in tqdm.tqdm(text_data, desc="label data", dynamic_ncols=True):
        if item[thread_id_index] in data:
            continue
        if skip_item(item):
            continue

        melody_code_list = (
            item[melody_labels_index].replace("\r", "").replace("\n", "").split(",")
            if isinstance(item[melody_labels_index], str)
            else []
        )
        neil_code_list = (
            item[neil_labels_index].replace("\r", "").replace("\n", "").split(",")
            if isinstance(item[neil_labels_index], str)
            else []
        )
        melody_check = False
        neil_check = False
        for code_name in melody_code_list:
            if code_name in positive_code or code_name in negative_code:
                melody_check = True
                break
        for code_name in neil_code_list:
            if code_name in positive_code or code_name in negative_code:
                neil_check = True
                break
        if not melody_check or not neil_check:
            continue
        binary_generation = None
        messages = [{"role": "system", "content": binary_prompt}]
        text = generate_binary_text(item)
        if learning_name == "fixed-few-shot":
            samples = insert_fiexd_few_shot(
                few_shot_dict, item, "Binary", sample_number
            )
            for sample in samples:
                messages.append(sample[0])
                messages.append(sample[1])
        messages.append({"role": "user", "content": text})
        binary_generation = get_llm_answer(messages)

        if binary_generation is None:
            continue
        if "Yes" in binary_generation:
            binary_result = "Yes"
        else:
            binary_result = "No"

        if self_refinement:
            binary_generation, binary_result = generate_refine_res(
                feedback_binary_system_prompt,
                refine_system_prompt,
                item,
                binary_generation,
                binary_result,
                "binary",
            )

        code_result = list()
        code_explain = dict()
        for code_name, code_prompt in tqdm.tqdm(
            code_prompt_dict.items(), desc="label code", dynamic_ncols=True, leave=False
        ):
            code_generation = None
            messages = [{"role": "system", "content": code_prompt}]
            text = generate_code_text(item)
            if learning_name == "fixed-few-shot":
                samples = insert_fiexd_few_shot(
                    few_shot_dict, item, code_name, sample_number
                )
                for sample in samples:
                    messages.append(sample[0])
                    messages.append(sample[1])
            messages.append({"role": "user", "content": text})
            code_generation = get_llm_answer(messages)
            if code_generation is None:
                continue
            code_generation = code_generation.strip()
            code_res = "No"
            if choice_type == "binary":
                if "Yes" in code_generation.split("\n\n")[
                    -1
                ] or code_generation.startswith("Yes"):
                    code_res = "Yes"
                if self_refinement:
                    code_generation, code_res = generate_refine_res(
                        feedback_code_prompt_dict[code_name],
                        refine_code_prompt_dict[code_name],
                        item,
                        code_generation,
                        code_res,
                        "code",
                    )
                if code_res == "Yes":
                    code_result.append(code_name)
                code_explain[code_name] = code_generation
            else:
                for code_index in list(code_book.keys()):
                    if (
                        code_book[code_index][0] in code_generation
                        or "Label {}".format(code_index) in code_generation
                    ):
                        code_result.append(code_book[code_index][0])
                if self_refinement:
                    code_generation, code_res = generate_refine_res(
                        feedback_code_prompt_dict[code_name],
                        refine_code_prompt_dict[code_name],
                        item,
                        code_generation,
                        "\n".join(code_result),
                        "code",
                    )
                    code_result = code_res.split("\n")
            code_explain[code_name] = code_generation
        code_result = "\n".join(code_result)
        with open(result_file, "a") as f:
            f.write(
                json.dumps(
                    {
                        "id": item[thread_id_index],
                        "ih_binary_LLM": binary_result,
                        "ih_code_LLM": code_result.split("\n") if code_result else [],
                        "ih_code_melody": melody_code_list,
                        "ih_code_neil": neil_code_list,
                        "ih_code_discussed": (
                            item[discussed_labels_index]
                            .replace("\r", "")
                            .replace("\n", "")
                            .split(",")
                            if isinstance(item[discussed_labels_index], str)
                            else []
                        ),
                        "ih_binary_explain": binary_generation,
                        "ih_code_explain": code_explain,
                    },
                    ensure_ascii=False,
                )
                + "\n"
            )


def generate_refine_res(
    feedback_system_prompt,
    refine_system_prompt,
    item,
    original_generation,
    original_result,
    task_type,
):
    refine_generation = original_generation
    refine_result = original_result
    for _ in tqdm.trange(refine_iter, desc="refine", dynamic_ncols=True, leave=False):
        feedback = None
        feedback_messages = [{"role": "system", "content": feedback_system_prompt}]
        feedback_text = (
            generate_feedback_binary_text(item, refine_result, refine_generation)
            if task_type == "binary"
            else generate_feedback_code_text(item, refine_result, refine_generation)
        )
        feedback_messages.append({"role": "user", "content": feedback_text})
        feedback = get_llm_answer(feedback_messages)
        if feedback is None:
            continue
        refine_messages = [{"role": "system", "content": refine_system_prompt}]
        refine_text = (
            generate_refine_binary_text(item, feedback)
            if task_type == "binary"
            else generate_refine_code_text(item, feedback)
        )
        refine_messages.append({"role": "user", "content": refine_text})
        refine_generation = get_llm_answer(refine_messages)
        if choice_type == "binary" or task_type == "binary":
            if "Yes" in refine_generation:
                refine_result = "Yes"
            else:
                refine_result = "No"
        else:
            refine_result = []
            for code_index in list(code_book.keys()):
                if (
                    code_book[code_index][0] in refine_generation
                    or "Label {}".format(code_index) in refine_generation
                ):
                    refine_result.append(code_book[code_index][0])
            refine_result = "\n".join(refine_result)
    return refine_generation, refine_result


def get_llm_answer(feedback_messages):
    res = None
    for delay_secs in (2**x for x in range(0, 5)):
        try:
            if "gpt" in model_name:
                res = get_GPT_answer(feedback_messages)
            elif "claude" in model_name:
                res = get_calude_answer(feedback_messages)
            elif "gemini" in model_name:
                res = get_gemini_answer(feedback_messages)
            elif "Llama" in model_name:
                res = get_LLAMA2_answer(feedback_messages)
            elif "llama" in model_name:
                res = get_LLAMA_answer(feedback_messages)
            break
        except (openai.OpenAIError, anthropic.APIError, ValueError) as e:
            randomness_collision_avoidance = random.randint(0, 1000) / 1000
            sleep_time = delay_secs + randomness_collision_avoidance
            time.sleep(sleep_time)
            continue
    return res


def get_LLAMA2_answer(messages):
    input_ids = tokenizer.apply_chat_template(
        messages, tokenize=True, add_generation_prompt=True, return_tensors="pt"
    ).to("cuda")
    outputs = model.generate(
        input_ids,
        use_cache=True,
        max_new_tokens=400,
        return_dict_in_generate=True,
    )
    sequence = outputs.sequences
    sequence[0][1 : input_ids.shape[1]] = tokenizer.eos_token_id
    predicted_result = tokenizer.decode(sequence[0], skip_special_tokens=True)
    return predicted_result


def get_LLAMA_answer(messages):
    messages = "\n\n".join([x["content"] for x in messages])
    messages = messages + "The answer is:"
    input_ids = tokenizer(messages, return_tensors="pt").input_ids.to("cuda")
    outputs = model.generate(
        input_ids,
        use_cache=True,
        max_new_tokens=800,
        return_dict_in_generate=True,
    )
    sequence = outputs.sequences
    sequence[0][1 : input_ids.shape[1]] = tokenizer.eos_token_id
    predicted_result = tokenizer.decode(sequence[0], skip_special_tokens=True)
    return predicted_result


def get_GPT_answer(messages):
    response = client.chat.completions.create(
        model=model_name, messages=messages, temperature=0
    )
    return response.choices[0].message.content


def get_calude_answer(messages):
    if messages[0]["role"] == "system":
        system_message = messages[0]["content"]
    else:
        system_message = ""
    processed_messages = messages[1:]
    response = client.messages.create(
        system=system_message,
        model=model_name,
        max_tokens=1024,
        messages=processed_messages,
        temperature=0,
    )
    return response.content[0].text


def get_gemini_answer(messages):
    # safety_settings = {"HATE_SPEECH": "block_none","UNSPECIFIED":"block_none","DEROGATORY":"block_none","TOXICITY":"block_none","VIOLENCE":"block_none","SEXUAL":"block_none","MEDICAL":"block_none","DANGEROUS":"block_none","HARASSMENT":"block_none","SEXUALLY_EXPLICIT":"block_none","DANGEROUS_CONTENT":"block_none"}
    safety_settings = {
        "HATE_SPEECH": "block_none",
        "HARASSMENT": "block_none",
        "SEXUAL": "block_none",
        "DANGEROUS": "block_none",
    }
    processed_messages = []
    if messages[0]["role"] == "system":
        system_message = messages[0]["content"]
    for message in messages[1:]:
        processed_message = dict()
        if message["role"] == "user":
            processed_message["role"] = "user"
            processed_message["parts"] = [system_message + "\n\n" + message["content"]]
            processed_messages.append(processed_message)
        else:
            processed_message["role"] = "model"
            processed_message["parts"] = [message["content"]]
            processed_messages.append(processed_message)
    response = model.generate_content(
        processed_messages, safety_settings=safety_settings
    )
    if response.prompt_feedback.block_reason == 2:
        return None
    res = response.text
    return res


def insert_fiexd_few_shot(few_shot_dict, item, code_name, sample_number):
    positive_count = 0
    negative_count = 0
    samples = list()
    for sample in few_shot_dict[code_name]["Yes"]:
        if item[thread_id_index] == sample["id"]:
            sample_number = min(sample_number, len(few_shot_dict[code_name]["Yes"]) - 1)
            break
        else:
            sample_number = min(sample_number, len(few_shot_dict[code_name]["Yes"]))

    for sample in few_shot_dict[code_name]["Yes"]:
        if sample["id"] != item[thread_id_index]:
            samples.append(sample["content"])
            positive_count += 1
        if positive_count >= sample_number:
            break
    for sample in few_shot_dict[code_name]["No"]:
        if sample["id"] != item[thread_id_index]:
            samples.append(sample["content"])
            negative_count += 1
        if negative_count >= sample_number:
            break
    random.shuffle(samples)
    return samples


def generate_system_prompt():
    if task_name == "plain":
        binary_prompt = "You are an expert in the domain of intellectual humility. Your task is to label the given texts from Reddit about religion. Please label the given texts as intellectual humility or not intellectual humility with `No` or `Yes`"
    elif task_name == "code-only":
        binary_prompt = 'You are an expert in the domain of intellectual humility. Your task is to label the given texts from Reddit about religion. Please label the given texts as intellectual humility or not intellectual humility with `No` or `Yes`. If you think the given text can be described as at least of the following \n"{}",\n you should answer `No`. If the given text cannot be described as at least of the following \n"{}",\n you should answer `Yes`'.format(
            "\n".join(negative_code), "\n".join(positive_code)
        )
    elif task_name == "description-only":
        binary_prompt = 'You are an expert in the domain of intellectual humility. Your task is to label the given texts from Reddit about religion. Please label the given texts as intellectual humility or not intellectual humility with `No` or `Yes`. If you think the given text can be described as at least of the following:\n"{}",\nyou should answer `No`.\n If the given text can be described as at least of the following:\n"{}",\nyou should answer `Yes`'.format(
            "\n".join(negative_description),
            "\n".join(positive_description),
        )
    elif task_name == "code-and-description":
        binary_prompt = 'You are an expert in the domain of intellectual humility. Your task is to label the given texts from Reddit about religion. Please label the given texts as intellectual humility or not intellectual humility with `No` or `Yes`. If you think the given text can be described as at least of the following:\n"{}",\nyou should answer `No`.\n If the given text can be described as at least of the following:\n"{}",\nyou should answer `Yes`'.format(
            "\n".join(negative_description_with_code),
            "\n".join(positive_description_with_code),
        )

    code_prompt_dict = {}
    if choice_type == "binary":
        if task_name == "plain":
            for it in list(code_book.values()):
                code_prompt_dict[it[0]] = (
                    """Your task is to label the given text from Reddit about religion. The given text includes the Title, the content of the Submission, the content of the Comment, and the Target Text. If the Target Text can be described as `{}`, answer `Yes`. If it does not fit this description, answer `No`. """.format(
                        it[0], it[0]
                    )
                )
        elif task_name == "code-only":
            for it in list(code_book.values()):
                code_prompt_dict[it[0]] = (
                    """Your task is to label the given text from Reddit about religion. The given text includes the Title, the content of the Submission, the content of the Comment, and the Target Text. If the Target Text can be described as `{}`, answer `Yes`. If it does not fit this description, answer `No`. """.format(
                        it[0], it[0]
                    )
                )
        elif task_name == "description-only":
            for it in list(code_book.values()):
                code_prompt_dict[it[0]] = (
                    """Your task is to label the given text from Reddit about religion. The given text includes the Title, the content of the Submission, the content of the Comment, and the Target Text. If the Target Text can be described as `{}`, answer `Yes`. If it does not fit this description, answer `No`.""".format(
                        it[2], it[2]
                    )
                )
        elif task_name == "code-and-description":
            for it in list(code_book.values()):
                code_prompt_dict[it[0]] = (
                    """Your task is to label the given text from Reddit about religion. The given text includes the Title, the content of the Submission, the content of the Comment, and the Target Text. If the Target Text can be described as `{} : {}`, answer `Yes`. If it does not fit this description, answer `No`.""".format(
                        it[0], it[2], it[0], it[2]
                    )
                )
        if add_explanation:
            for code_index in list(code_book.keys()):
                code = code_book[code_index][0]
                code_prompt_dict[
                    code
                ] += "\n\nYou must explain how you get the answer first then responding Yes or No."
            binary_prompt += "\n\nYou must explain how you get the answer first then responding Yes or No."
        else:
            for code_index in list(code_book.keys()):
                code = code_book[code_index][0]
                code_prompt_dict[
                    code
                ] += "\n\nPlease only answer Yes or No without exlanation."
            binary_prompt += "\n\nPlease only answer Yes or No without exlanation."
    else:
        if task_name == "plain":
            code_prompt = """Your task is to label the given text from Reddit about religion. The given text includes the Title, the content of the Submission, the content of the Comment, and the Target Text. Please label the Target Text with one or more labels from the following list:"{}".\n\nEach sample might be labelled with multiple labels, please separate each label with \\n""".format(
                "\n".join([x for x in list(it[0] for it in list(code_book.values()))])
            )
        elif task_name == "code-only":
            code_prompt = """Your task is to label the given text from Reddit about religion. The given text includes the Title, the content of the Submission, the content of the Comment, and the Target Text. Please label the Target Text  with one or more labels from the following list:"{}".\n\n Each sample might be labelled with multiple labels, please separate each label with \\n""".format(
                "\n".join([x for x in list(it[0] for it in list(code_book.values()))])
            )
        elif task_name == "code-and-description":
            code_prompt = """Your task is to label the given text from Reddit about religion. The given text includes the Title, the content of the Submission, the content of the Comment, and the Target Text. Please label the Target Text  with one or more labels from the following list:"{}".\n\nThe definition of each label is as follow:"{}".\n\nEach sample might be labelled with multiple labels, please separate each label with \\n""".format(
                "\n".join([x for x in list(it[0] for it in list(code_book.values()))]),
                "\n".join(
                    [
                        "{} : {}".format(code_book[k][0], code_book[k][2])
                        for k in list(code_book.keys())
                    ]
                ),
            )
        elif task_name == "description-only":
            code_prompt = """Your task is to label the given text from Reddit about religion. The given text includes the Title, the content of the Submission, the content of the Comment, and the Target Text. Please label the Target Text with one or more labels from the following list:"{}".\n\nThe definition of each label is as follow:"{}".\n\nEach sample might be labelled with multiple labels, please separate each label with \\n""".format(
                "\n".join(["Label {}".format(x) for x in list(code_book.keys())]),
                "\n".join(
                    [
                        "Label {} : {}".format(k, code_book[k][2])
                        for k in list(code_book.keys())
                    ]
                ),
            )
        code_prompt_dict["All"] = code_prompt
        if add_explanation:
            code_prompt_dict[
                "All"
            ] += "\n\nExplain the reasons first and then respond with the labels. You should separate the explanation and the labels with a new line. The label part should only includes the name of the labels."
            binary_prompt += (
                "\n\nExplain the way you think about and then respond with Yes or No."
            )
    return binary_prompt, code_prompt_dict


def generate_few_shot_samples():
    sample_dict = {
        x: {"Yes": [], "No": []} for x in list(it[0] for it in list(code_book.values()))
    }
    sample_dict["Binary"] = {"Yes": [], "No": []}
    for item in text_data:
        if skip_item(item):
            continue
        if sorted(
            [
                it.replace("\r", "").replace("\n", "")
                for it in item[melody_labels_index].split(",")
            ]
        ) != sorted(
            [
                it.replace("\r", "").replace("\n", "")
                for it in item[neil_labels_index].split(",")
            ]
        ):
            continue
        binary_text = generate_binary_text(item)
        positive_count = 0
        negative_count = 0
        skip_mark = True
        for label_single in [
            it.strip().replace("\r", "") for it in item[melody_labels_index].split(",")
        ]:
            if label_single in positive_code or label_single in negative_code:
                skip_mark = False
        if skip_mark:
            continue
        human_label_list = [
            it.strip().replace("\r", "") for it in item[melody_labels_index].split(",")
        ]
        for label in human_label_list:
            if label not in code_book_based_on_code:
                human_label_list.remove(label)
        for label_single in human_label_list:
            if label_single in positive_code:
                positive_count += 1
            elif label_single in negative_code:
                negative_count += 1
        if positive_count - negative_count != 0:
            binary_label = "Yes" if positive_count > negative_count else "No"
            if add_explanation:
                if task_name == "code-only":
                    binary_answer = "The Target Text can be described as:\n{}.\n There are {} positive code, and {} negative code. Therefore the answer is {}".format(
                        ",".join(human_label_list),
                        positive_count,
                        negative_count,
                        binary_label,
                    )
                elif task_name == "description-only":
                    binary_answer = "The Target Text can be described as at:\n{}.\n There are {} positive code, and {} negative code. Therefore the answer is {}".format(
                        " ".join(
                            [
                                code_book_based_on_code[label][2]
                                for label in human_label_list
                            ]
                        ),
                        positive_count,
                        negative_count,
                        binary_label,
                    )
                elif task_name == "code-and-description":
                    binary_answer = "The Target Text can be described as:\n{}.\n There are {} positive code, and {} negative code. Therefore the answer is {}".format(
                        " ".join(
                            [
                                "{}: {}".format(
                                    label, code_book_based_on_code[label][2]
                                )
                                for label in human_label_list
                            ]
                        ),
                        positive_count,
                        negative_count,
                        binary_label,
                    )
            else:
                binary_answer = binary_label
            sample_dict["Binary"][binary_label].append(
                {
                    "id": item[3],
                    "content": [
                        {"role": "user", "content": binary_text},
                        {"role": "assistant", "content": binary_answer},
                    ],
                }
            )
        for label in list(it[0] for it in list(code_book.values())):
            code_text = generate_code_text(item)
            if label in item[melody_labels_index] and label in item[neil_labels_index]:
                code_answer = "Yes"
            else:
                code_answer = "No"
            sample_dict[label][code_answer].append(
                {
                    "id": item[3],
                    "content": [
                        {"role": "user", "content": code_text},
                        {"role": "assistant", "content": code_answer},
                    ],
                }
            )
        for code_name, code_value in sample_dict.items():
            random.shuffle(sample_dict[code_name]["Yes"])
            random.shuffle(sample_dict[code_name]["No"])
    return sample_dict


def skip_item(item):
    return (
        (
            not isinstance(item[melody_labels_index], str)
            and np.isnan(item[melody_labels_index])
        )
        or (
            not isinstance(item[neil_labels_index], str)
            and np.isnan(item[neil_labels_index])
        )
        or (
            item[melody_labels_index] == "Not about religion discourse"
            or item[neil_labels_index] == "Not about religion discourse"
        )
    )


def generate_binary_text(item):
    if item[focal_post_index] == "comment_2":
        text = "Here is a discussion with the title: '{}', and the content is as follows: '{}'. The first comment is: '{}'. The second comment is: '{}'. \n Based on the content do you think Comment: '{}' is intellectual humility or not.".format(
            item[thread_title_index],
            item[submission_text_index],
            item[comment_1_index],
            item[comment_2_index],
            item[
                (
                    comment_1_index
                    if item[focal_post_index] == "comment_1"
                    else comment_2_index
                )
            ],
        )
    else:
        text = "Here is a discussion with the title: '{}', and the content is as follows: '{}'. The comment is: '{}'. \nBased on the content do you think Comment: '{}' is intellectual humility or not.".format(
            item[thread_title_index],
            item[submission_text_index],
            item[comment_1_index],
            item[comment_1_index],
        )
    if not add_explanation:
        text += "\nPlease only answer Yes or No."
    else:
        text += (
            "\nPlease explain how you get the answer first then responding Yes or No."
        )
    return text


def generate_code_text(item):
    if item[focal_post_index] == "comment_2":
        text = """Here is a discussion with the Title: "{}",\n\nand the content of the Submission: "{}".\n\nThe first Comment is: "{}".\n\nThe second comment is: "{}".\n\nPlease label the Target Text: "{}".\n\nPlease analyze the given Target Text with a focus on both the implicit indicators and the explicit content.""".format(
            item[thread_title_index],
            item[submission_text_index],
            item[comment_1_index],
            item[comment_2_index],
            item[
                (
                    comment_1_index
                    if item[focal_post_index] == "comment_1"
                    else comment_2_index
                )
            ],
        )
    else:
        text = """Here is a discussion with the Title: "{}",\n\nand the content of the Submission: "{}".\n\nThe first Comment is: "{}".\n\nPlease label the Target Text: "{}".\n\nPlease analyze the given Target Text with a focus on both the implicit indicators and the explicit content.""".format(
            item[thread_title_index],
            item[submission_text_index],
            item[comment_1_index],
            item[comment_1_index],
        )
    return text


def generate_feedback_system_prompt():
    if task_name == "plain":
        binary_prompt = "You are an AI model providing feedback on a intellectual humility prediction task."
    elif task_name == "code-only":
        binary_prompt = "You are an AI model providing feedback on a intellectual humility prediction task. The following code is used to describe not intellectual humility: \n{}. \n\n The following code is described as intellectual humility: \n{}.".format(
            "\n".join(negative_code), "\n".join(positive_code)
        )
    elif task_name == "description-only":
        binary_prompt = "You are an AI model providing feedback on a intellectual humility prediction task. The following code is used to describe not intellectual humility: \n{}. \n\n The following code is described as intellectual humility: \n{}.".format(
            "\n".join(negative_description), "\n".join(positive_description)
        )
    elif task_name == "code-and-description":
        binary_prompt = "You are an AI model providing feedback on a intellectual humility prediction task. The following code is used to describe not intellectual humility: \n{}. \n\n The following code is described as intellectual humility: \n{}.".format(
            "\n".join(negative_description_with_code),
            "\n".join(positive_description_with_code),
        )

    code_prompt_dict = {}
    if choice_type == "binary":
        if task_name == "plain":
            for it in list(code_book.values()):
                code_prompt_dict[it[0]] = (
                    """You are an AI model providing feedback on a prediction about whether the Given Text can be described as {}.""".format(
                        it[0]
                    )
                )
        elif task_name == "code-only":
            for it in list(code_book.values()):
                code_prompt_dict[it[0]] = (
                    """You are an AI model providing feedback on a prediction about whether the Given text can be described as {}. """.format(
                        it[0]
                    )
                )
        elif task_name == "description-only":
            for it in list(code_book.values()):
                code_prompt_dict[it[0]] = (
                    """You are an AI model providing feedback on a prediction about whether the Given Text can be described as {}.""".format(
                        it[2]
                    )
                )
        elif task_name == "code-and-description":
            for it in list(code_book.values()):
                code_prompt_dict[it[0]] = (
                    """You are an AI model providing feedback on a prediction about whether the Given text can be described as {}:{}.""".format(
                        it[0], it[2]
                    )
                )
    else:
        if task_name == "plain":
            code_prompt = """You are an AI model providing feedback on a prediction about whether the Given Text can be described as {}.""".format(
                "\n".join([x for x in list(it[0] for it in list(code_book.values()))])
            )
        elif task_name == "code-only":
            code_prompt = """You are an AI model providing feedback on a prediction about whether the Given Text can be described as {}.""".format(
                "\n".join([x for x in list(it[0] for it in list(code_book.values()))])
            )
        elif task_name == "code-and-description":
            code_prompt = """You are an AI model providing feedback on a prediction about whether the Given Text can be described as {}.\n\nThe definition of each label is as follow:"{}".""".format(
                "\n".join([x for x in list(it[0] for it in list(code_book.values()))]),
                "\n".join(
                    [
                        "{} : {}".format(code_book[k][0], code_book[k][2])
                        for k in list(code_book.keys())
                    ]
                ),
            )
        elif task_name == "description-only":
            code_prompt = """You are an AI model providing feedback on a prediction about whether the Given Text can be described as {}.\n\nThe definition of each label is as follow:"{}".""".format(
                "\n".join(["Label {}".format(x) for x in list(code_book.keys())]),
                "\n".join(
                    [
                        "Label {} : {}".format(k, code_book[k][2])
                        for k in list(code_book.keys())
                    ]
                ),
            )
        code_prompt_dict["All"] = code_prompt
    return binary_prompt, code_prompt_dict


def generate_feedback_binary_text(item, res, reason):
    if item[focal_post_index] == "comment_2":
        text = "Here is a discussion with the title: '{}', and the content is as follows: '{}'. The first comment is: '{}'. The second comment is: '{}'. \n The Comment: '{}' is reagarded as {} with the following reasons: {}. Please provide a brief feedback especially challenging the existing result".format(
            item[thread_title_index],
            item[submission_text_index],
            item[comment_1_index],
            item[comment_2_index],
            item[comment_2_index],
            res,
            reason,
        )
    else:
        text = "Here is a discussion with the title: '{}', and the content is as follows: '{}'. The first comment is: '{}'. \n The Comment: '{}' is reagarded as {} with the following reasons: {}. Please provide a brief feedback especially challenging the existing result".format(
            item[thread_title_index],
            item[submission_text_index],
            item[comment_1_index],
            item[comment_1_index],
            res,
            reason,
        )
    return text


def generate_feedback_code_text(item, res, reason):
    if item[focal_post_index] == "comment_2":
        text = """Here is a discussion with the Title: "{}",\n\nand the content of the Submission: "{}".\n\nThe first Comment is: "{}".\n\nThe second comment is: "{}".\n\nThe Comment: '{}' is reagarded as {} with the following reasons: {}. Please provide a brief feedback especially challenging the existing result""".format(
            item[thread_title_index],
            item[submission_text_index],
            item[comment_1_index],
            item[comment_2_index],
            item[comment_2_index],
            res,
            reason,
        )
    else:
        text = """Here is a discussion with the Title: "{}",\n\nand the content of the Submission: "{}".\n\nThe first Comment is: "{}".\n\nThe Comment: '{}' is reagarded as {} with the following reasons: {}. Please provide a brief feedback especially challenging the existing result""".format(
            item[thread_title_index],
            item[submission_text_index],
            item[comment_1_index],
            item[comment_1_index],
            res,
            reason,
        )
    return text


def generate_refine_system_prompt():
    if task_name == "plain":
        binary_prompt = "You are an expert in the domain of intellectual humility. Your task is to label the given texts from Reddit about religion based on the provided feedback. Please label the given texts as intellectual humility or not intellectual humility with `No` or `Yes`"
    elif task_name == "code-only":
        binary_prompt = 'You are an expert in the domain of intellectual humility. Your task is to label the given texts from Reddit about religion based on the provided feedback. Please label the given texts as intellectual humility or not intellectual humility with `No` or `Yes`. If you think the given text can be described as at least of the following \n"{}",\n you should answer `No`. If the given text cannot be described as at least of the following \n"{}",\n you should answer `Yes`'.format(
            "\n".join(negative_code), "\n".join(positive_code)
        )
    elif task_name == "description-only":
        binary_prompt = 'You are an expert in the domain of intellectual humility. Your task is to label the given texts from Reddit about religion based on the provided feedback. Please label the given texts as intellectual humility or not intellectual humility with `No` or `Yes`. If you think the given text can be described as at least of the following:\n"{}",\nyou should answer `No`.\n If the given text can be described as at least of the following:\n"{}",\nyou should answer `Yes`'.format(
            "\n".join(negative_description),
            "\n".join(positive_description),
        )
    elif task_name == "code-and-description":
        binary_prompt = 'You are an expert in the domain of intellectual humility. Your task is to label the given texts from Reddit about religion based on the provided feedback. Please label the given texts as intellectual humility or not intellectual humility with `No` or `Yes`. If you think the given text can be described as at least of the following:\n"{}",\nyou should answer `No`.\n If the given text can be described as at least of the following:\n"{}",\nyou should answer `Yes`'.format(
            "\n".join(negative_description_with_code),
            "\n".join(positive_description_with_code),
        )

    code_prompt_dict = {}
    if choice_type == "binary":
        if task_name == "plain":
            for it in list(code_book.values()):
                code_prompt_dict[it[0]] = (
                    """Your task is to label the given text from Reddit about religion based on the provided feedback. The given text includes the Title, the content of the Submission, the content of the Comment, and the Target Text. If the Target Text can be described as `{}`, answer `Yes`. If it does not fit this description, answer `No`. """.format(
                        it[0], it[0]
                    )
                )
        elif task_name == "code-only":
            for it in list(code_book.values()):
                code_prompt_dict[it[0]] = (
                    """Your task is to label the given text from Reddit about religion based on the provided feedback. The given text includes the Title, the content of the Submission, the content of the Comment, and the Target Text. If the Target Text can be described as `{}`, answer `Yes`. If it does not fit this description, answer `No`. """.format(
                        it[0], it[0]
                    )
                )
        elif task_name == "description-only":
            for it in list(code_book.values()):
                code_prompt_dict[it[0]] = (
                    """Your task is to label the given text from Reddit about religion based on the provided feedback. The given text includes the Title, the content of the Submission, the content of the Comment, and the Target Text. If the Target Text can be described as `{}`, answer `Yes`. If it does not fit this description, answer `No`.""".format(
                        it[2], it[2]
                    )
                )
        elif task_name == "code-and-description":
            for it in list(code_book.values()):
                code_prompt_dict[it[0]] = (
                    """Your task is to label the given text from Reddit about religion based on the provided feedback. The given text includes the Title, the content of the Submission, the content of the Comment, and the Target Text. If the Target Text can be described as `{} : {}`, answer `Yes`. If it does not fit this description, answer `No`.""".format(
                        it[0], it[2], it[0], it[2]
                    )
                )
        if add_explanation:
            for code_index in list(code_book.keys()):
                code = code_book[code_index][0]
                code_prompt_dict[
                    code
                ] += "\n\nYou must explain how you get the answer first then responding Yes or No."
            binary_prompt += "\n\nYou must explain how you get the answer first then responding Yes or No."
        else:
            for code_index in list(code_book.keys()):
                code = code_book[code_index][0]
                code_prompt_dict[
                    code
                ] += "\n\nPlease only answer Yes or No without exlanation."
            binary_prompt += "\n\nPlease only answer Yes or No without exlanation."
    else:
        if task_name == "plain":
            code_prompt = """Your task is to label the given text from Reddit about religion based on the provided feedback. The given text includes the Title, the content of the Submission, the content of the Comment, and the Target Text. Please label the Target Text with one or more labels from the following list:"{}".\n\nEach sample might be labelled with multiple labels, please separate each label with \\n""".format(
                "\n".join([x for x in list(it[0] for it in list(code_book.values()))])
            )
        elif task_name == "code-only":
            code_prompt = """Your task is to label the given text from Reddit about religion based on the provided feedback. The given text includes the Title, the content of the Submission, the content of the Comment, and the Target Text. Please label the Target Text  with one or more labels from the following list:"{}".\n\n Each sample might be labelled with multiple labels, please separate each label with \\n""".format(
                "\n".join([x for x in list(it[0] for it in list(code_book.values()))])
            )
        elif task_name == "code-and-description":
            code_prompt = """Your task is to label the given text from Reddit about religion based on the provided feedback. The given text includes the Title, the content of the Submission, the content of the Comment, and the Target Text. Please label the Target Text  with one or more labels from the following list:"{}".\n\nThe definition of each label is as follow:"{}".\n\nEach sample might be labelled with multiple labels, please separate each label with \\n""".format(
                "\n".join([x for x in list(it[0] for it in list(code_book.values()))]),
                "\n".join(
                    [
                        "{} : {}".format(code_book[k][0], code_book[k][2])
                        for k in list(code_book.keys())
                    ]
                ),
            )
        elif task_name == "description-only":
            code_prompt = """Your task is to label the given text from Reddit about religion based on the provided feedback. The given text includes the Title, the content of the Submission, the content of the Comment, and the Target Text. Please label the Target Text with one or more labels from the following list:"{}".\n\nThe definition of each label is as follow:"{}".\n\nEach sample might be labelled with multiple labels, please separate each label with \\n""".format(
                "\n".join(["Label {}".format(x) for x in list(code_book.keys())]),
                "\n".join(
                    [
                        "Label {} : {}".format(k, code_book[k][2])
                        for k in list(code_book.keys())
                    ]
                ),
            )
        code_prompt_dict["All"] = code_prompt
        if add_explanation:
            code_prompt_dict[
                "All"
            ] += "\n\nExplain the reasons first and then respond with the labels. You should separate the explanation and the labels with a new line. The label part should only includes the name of the labels."
            binary_prompt += (
                "\n\nExplain the way you think about and then respond with Yes or No."
            )
    return binary_prompt, code_prompt_dict


def generate_refine_binary_text(item, feed_back):
    if item[focal_post_index] == "comment_2":
        text = "Here is a discussion with the title: '{}', and the content is as follows: '{}'. The first comment is: '{}'. The second comment is: '{}'.\nBased on the feedback {} and the content do you think Comment: '{}' is intellectual humility or not.".format(
            item[thread_title_index],
            item[submission_text_index],
            item[comment_1_index],
            item[comment_2_index],
            feed_back,
            item[comment_2_index],
        )
    else:
        text = "Here is a discussion with the title: '{}', and the content is as follows: '{}'. The comment is: '{}'. \nBased on the feedback {} and the content do you think Comment: '{}' is intellectual humility or not.".format(
            item[thread_title_index],
            item[submission_text_index],
            item[comment_1_index],
            feed_back,
            item[comment_1_index],
        )
    if not add_explanation:
        text += "\nPlease only answer Yes or No."
    else:
        text += (
            "\nPlease explain how you get the answer first then responding Yes or No."
        )
    return text


def generate_refine_code_text(item, feed_back):
    if item[focal_post_index] == "comment_2":
        text = """Here is a discussion with the Title: "{}",\n\nand the content of the Submission: "{}".\n\nThe first Comment is: "{}".\n\nThe second comment is: "{}".\n\nBased on the feedback {}, please label the Target Text: "{}".\n\nPlease analyze the given Target Text with a focus on both the implicit indicators and the explicit content.""".format(
            item[thread_title_index],
            item[submission_text_index],
            item[comment_1_index],
            item[comment_2_index],
            feed_back,
            item[comment_2_index],
        )
    else:
        text = """Here is a discussion with the Title: "{}",\n\nand the content of the Submission: "{}".\n\nThe first Comment is: "{}".\n\nBased on the feedback {}, please label the Target Text: "{}".\n\nPlease analyze the given Target Text with a focus on both the implicit indicators and the explicit content.""".format(
            item[thread_title_index],
            item[submission_text_index],
            item[comment_1_index],
            feed_back,
            item[comment_1_index],
        )
    return text


def analysis():

    predicted_result = {}
    with open(result_file, "r") as f:
        for line in f:
            line = line.strip()
            if line:
                item = json.loads(line)
                predicted_result[item["id"]] = item

    melody_binary_list = {"human": [], "predicted": []}
    neil_binary_list = {"human": [], "predicted": []}
    both_binary_list = {"human": [], "predicted": []}

    melody_code_res = {
        code_book[x][0]: {"predicted": [], "human": []} for x in list(code_book.keys())
    }
    neil_code_res = {
        code_book[x][0]: {"predicted": [], "human": []} for x in list(code_book.keys())
    }
    both_code_res = {
        code_book[x][0]: {"predicted": [], "human": []} for x in list(code_book.keys())
    }
    discussed_code_res = {
        code_book[x][0]: {"predicted": [], "human": []} for x in list(code_book.keys())
    }
    for id, item in predicted_result.items():
        melody_label_list = item["ih_code_melody"]
        neil_label_list = item["ih_code_neil"]
        discussed_label_list = item["ih_code_discussed"]
        melody_binary = 0
        for melody_label in melody_label_list:
            if melody_label in positive_code:
                melody_binary += 1
            else:
                melody_binary -= 1
        melody_binary = 1 if melody_binary > 0 else 0
        neil_binary = 0
        for neil_label in neil_label_list:
            if neil_label in positive_code:
                neil_binary += 1
            else:
                neil_binary -= 1
        neil_binary = 1 if neil_binary > 0 else 0

        predict_label = 1 if item["ih_binary_LLM"] == "Yes" else 0
        if item["ih_code_melody"] != []:
            melody_binary_list["human"].append(melody_binary)
            melody_binary_list["predicted"].append(predict_label)
        if item["ih_code_neil"] != []:
            neil_binary_list["human"].append(neil_binary)
            neil_binary_list["predicted"].append(predict_label)
        if (
            melody_binary == neil_binary
            and item["ih_code_neil"] != []
            and item["ih_code_melody"] != []
        ):
            both_binary_list["human"].append(melody_binary)
            both_binary_list["predicted"].append(predict_label)

        for code_index in list(code_book.keys()):
            code = code_book[code_index][0]
            if item["ih_code_melody"] != []:
                if code in melody_label_list:
                    melody_code_res[code]["human"].append(1)
                else:
                    melody_code_res[code]["human"].append(0)
                if code in item["ih_code_LLM"]:
                    melody_code_res[code]["predicted"].append(1)
                else:
                    melody_code_res[code]["predicted"].append(0)

            if item["ih_code_neil"] != []:
                if code in neil_label_list:
                    neil_code_res[code]["human"].append(1)
                else:
                    neil_code_res[code]["human"].append(0)
                if code in item["ih_code_LLM"]:
                    neil_code_res[code]["predicted"].append(1)
                else:
                    neil_code_res[code]["predicted"].append(0)

            if item["ih_code_melody"] != [] and item["ih_code_neil"] != []:
                if list(sorted(melody_label_list)) != list(sorted(neil_label_list)):
                    continue
                if code in melody_label_list and code in neil_label_list:
                    both_code_res[code]["human"].append(1)
                else:
                    both_code_res[code]["human"].append(0)
                if code in item["ih_code_LLM"]:
                    both_code_res[code]["predicted"].append(1)
                else:
                    both_code_res[code]["predicted"].append(0)

            if item["ih_code_discussed"] != []:
                if code in discussed_label_list:
                    discussed_code_res[code]["human"].append(1)
                else:
                    discussed_code_res[code]["human"].append(0)
                if code in item["ih_code_LLM"]:
                    discussed_code_res[code]["predicted"].append(1)
                else:
                    discussed_code_res[code]["predicted"].append(0)

    neil_binary_performance, melody_binary_performance, both_binary_performance = (
        calc_binary_performance(melody_binary_list, neil_binary_list, both_binary_list)
    )

    code_performance = calc_code_performance(
        code_book,
        melody_code_res,
        neil_code_res,
        both_code_res,
        discussed_code_res,
    )

    record_binary_performance(
        task_name,
        model_name,
        result_folder,
        neil_binary_performance,
        melody_binary_performance,
        both_binary_performance,
    )

    record_code_performance(
        task_name, model_name, result_folder, code_book, code_performance
    )


def record_binary_performance(
    task_name,
    model_name,
    result_folder,
    neil_binary_performance,
    melody_binary_performance,
    both_binary_performance,
):
    metrics_name_list = list(both_binary_performance.keys())
    binary_record_file = os.path.join(
        result_folder,
        "record_{}.csv".format("binary"),
    )
    with open(binary_record_file, "w") as f:
        f.write("human,{}\n".format(",".join(metrics_name_list)))
        f.write(
            "neil,{}\n".format(
                ",".join([str(neil_binary_performance[x]) for x in metrics_name_list])
            )
        )
        f.write(
            "melody,{}\n".format(
                ",".join([str(melody_binary_performance[x]) for x in metrics_name_list])
            )
        )
        f.write(
            "both,{}\n".format(
                ",".join([str(both_binary_performance[x]) for x in metrics_name_list])
            )
        )


def record_code_performance(
    task_name, model_name, result_folder, code_book, code_performance
):
    metrics_name = list(code_performance[code_book[0][0]]["both"].keys())
    code_record_file = os.path.join(
        result_folder,
        "record_{}.csv".format("code"),
    )
    with open(code_record_file, "w") as f:
        f.write("code,human,{}\n".format(",".join(metrics_name)))
        for code_index in list(code_book.keys()):
            code = code_book[code_index][0]
            for human in list(code_performance[code].keys()):
                f.write(
                    "{},{},{}\n".format(
                        code,
                        human,
                        ",".join(
                            [
                                str(code_performance[code][human][metrics])
                                for metrics in metrics_name
                            ]
                        ),
                    )
                )


def calc_code_performance(
    code_book,
    melody_code_res,
    neil_code_res,
    both_code_res,
    discussed_coed_res,
):
    code_performance = {}
    for code_index in list(code_book.keys()):
        code = code_book[code_index][0]
        neil_tn, neil_fp, neil_fn, neil_tp = get_confusion_matrix(neil_code_res[code])
        melody_tn, melody_fp, melody_fn, melody_tp = get_confusion_matrix(
            melody_code_res[code]
        )
        both_tn, both_fp, both_fn, both_tp = get_confusion_matrix(both_code_res[code])
        discussed_tn, discussed_fp, discussed_fn, discussed_tp = get_confusion_matrix(
            discussed_coed_res[code]
        )
        melody_baseline = get_distribution_baseline(melody_code_res[code])
        neil_baseline = get_distribution_baseline(neil_code_res[code])
        both_baseline = get_distribution_baseline(both_code_res[code])
        discussed_baseline = get_distribution_baseline(discussed_coed_res[code])

        code_performance[code] = {
            "neil": {
                "acc": accuracy_score(
                    neil_code_res[code]["human"],
                    neil_code_res[code]["predicted"],
                ),
                "f1": f1_score(
                    neil_code_res[code]["human"],
                    neil_code_res[code]["predicted"],
                    average="macro",
                    zero_division=0,
                ),
                "Baseline F1": neil_baseline,
                "precision": precision_score(
                    neil_code_res[code]["human"],
                    neil_code_res[code]["predicted"],
                    average="macro",
                    zero_division=0,
                ),
                "recall": recall_score(
                    neil_code_res[code]["human"],
                    neil_code_res[code]["predicted"],
                    average="macro",
                    zero_division=0,
                ),
                "true negative": neil_tn,
                "false negative": neil_fn,
                "true positive": neil_tp,
                "false positive": neil_fp,
            },
            "melody": {
                "acc": accuracy_score(
                    melody_code_res[code]["human"],
                    melody_code_res[code]["predicted"],
                ),
                "f1": f1_score(
                    melody_code_res[code]["human"],
                    melody_code_res[code]["predicted"],
                    average="macro",
                    zero_division=0,
                ),
                "Baseline F1": melody_baseline,
                "precision": precision_score(
                    melody_code_res[code]["human"],
                    melody_code_res[code]["predicted"],
                    average="macro",
                    zero_division=0,
                ),
                "recall": recall_score(
                    melody_code_res[code]["human"],
                    melody_code_res[code]["predicted"],
                    average="macro",
                    zero_division=0,
                ),
                "true negative": melody_tn,
                "false negative": melody_fn,
                "true positive": melody_tp,
                "false positive": melody_fp,
            },
            "both": {
                "acc": accuracy_score(
                    both_code_res[code]["human"],
                    both_code_res[code]["predicted"],
                ),
                "f1": f1_score(
                    both_code_res[code]["human"],
                    both_code_res[code]["predicted"],
                    average="macro",
                    zero_division=0,
                ),
                "Baseline F1": both_baseline,
                "precision": precision_score(
                    both_code_res[code]["human"],
                    both_code_res[code]["predicted"],
                    average="macro",
                    zero_division=0,
                ),
                "recall": recall_score(
                    both_code_res[code]["human"],
                    both_code_res[code]["predicted"],
                    average="macro",
                    zero_division=0,
                ),
                "true negative": both_tn,
                "false negative": both_fn,
                "true positive": both_tp,
                "false positive": both_fp,
            },
            # "discussed": {
            #     "acc": accuracy_score(
            #         discussed_coed_res[code]["human"],
            #         discussed_coed_res[code]["predicted"],
            #     ),
            #     "f1": f1_score(
            #         discussed_coed_res[code]["human"],
            #         discussed_coed_res[code]["predicted"],
            #         average="macro",
            #         zero_division=0,
            #     ),
            #     "Baseline F1": discussed_baseline,
            #     "precision": precision_score(
            #         discussed_coed_res[code]["human"],
            #         discussed_coed_res[code]["predicted"],
            #         average="macro",
            #         zero_division=0,
            #     ),
            #     "recall": recall_score(
            #         discussed_coed_res[code]["human"],
            #         discussed_coed_res[code]["predicted"],
            #         average="macro",
            #         zero_division=0,
            #     ),
            #     "true negative": discussed_tn,
            #     "false negative": discussed_fn,
            #     "true positive": discussed_tp,
            #     "false positive": discussed_fp,
            # },
        }

    return code_performance


def calc_binary_performance(melody_binary_list, neil_binary_list, both_binary_list):
    neil_binary_performance = {
        "acc": accuracy_score(neil_binary_list["human"], neil_binary_list["predicted"]),
        "f1": f1_score(
            neil_binary_list["human"], neil_binary_list["predicted"], average="macro"
        ),
        "precision": precision_score(
            neil_binary_list["human"], neil_binary_list["predicted"], average="macro"
        ),
        "recall": recall_score(
            neil_binary_list["human"], neil_binary_list["predicted"], average="macro"
        ),
        "confusion_matrix": str(
            confusion_matrix(neil_binary_list["human"], neil_binary_list["predicted"])
            .reshape(-1)
            .tolist()
        )[1:-1],
    }
    melody_binary_performance = {
        "acc": accuracy_score(
            melody_binary_list["human"], melody_binary_list["predicted"]
        ),
        "f1": f1_score(
            melody_binary_list["human"],
            melody_binary_list["predicted"],
            average="macro",
            zero_division=0,
        ),
        "precision": precision_score(
            melody_binary_list["human"],
            melody_binary_list["predicted"],
            average="macro",
            zero_division=0,
        ),
        "recall": recall_score(
            melody_binary_list["human"],
            melody_binary_list["predicted"],
            average="macro",
            zero_division=0,
        ),
        "confusion_matrix": str(
            confusion_matrix(
                melody_binary_list["human"], melody_binary_list["predicted"]
            )
            .reshape(-1)
            .tolist()
        )[1:-1],
    }
    both_binary_performance = {
        "acc": accuracy_score(both_binary_list["human"], both_binary_list["predicted"]),
        "f1": f1_score(
            both_binary_list["human"], both_binary_list["predicted"], average="macro"
        ),
        "precision": precision_score(
            both_binary_list["human"], both_binary_list["predicted"], average="macro"
        ),
        "recall": recall_score(
            both_binary_list["human"], both_binary_list["predicted"], average="macro"
        ),
        "confusion_matrix": str(
            confusion_matrix(both_binary_list["human"], both_binary_list["predicted"])
            .reshape(-1)
            .tolist()
        )[1:-1],
    }

    return neil_binary_performance, melody_binary_performance, both_binary_performance


def get_confusion_matrix(res):
    cm = confusion_matrix(res["human"], res["predicted"])
    if cm.size != 1:
        tn, fp, fn, tp = cm.ravel()
    else:
        tn, fp, fn, tp = 0, 0, 0, 0
    return tn, fp, fn, tp


def get_distribution_baseline(res):
    positive_count = sum(res["human"])
    negative_count = len(res["human"]) - positive_count
    distribution_predict = np.random.choice(
        [0, 1],
        len(res["human"]),
        p=[negative_count / len(res["human"]), positive_count / len(res["human"])],
    )
    baseline_f1 = f1_score(
        res["human"], distribution_predict, average="macro", zero_division=0
    )
    return baseline_f1


def get_optimized_prompt(few_shot_dict, code_name, init_prompt):
    curr_prompt = init_prompt
    sampled_data = []
    sampled_data.extend(few_shot_dict[code_name]["Yes"][:5])
    sampled_data.extend(few_shot_dict[code_name]["No"][:5])
    random.shuffle(sampled_data)
    acc = 0
    for i, item in enumerate(sampled_data):
        predicted_label, res = test_item_perfromance(curr_prompt, item, code_name)
        sampled_data[i]["prediction"] = res
        if (
            predicted_label == sampled_data[i]["content"][1]["content"]
            or predicted_label
            in sampled_data[i]["content"][1]["content"].split(" ")[-5:]
        ):
            acc += 1
    acc /= len(sampled_data)
    existing_prompt_list = [
        [
            sampled_data,
            code_name,
            curr_prompt,
            ["The initial accuracy is {:.2f}".format(acc)],
            acc,
        ]
    ]

    with tqdm.tqdm(
        total=(optimization_step - 1) * best_prompt_choice + 1,
        desc="Optimizing {} Prompt".format(code_name),
        leave=False,
    ) as pbar:
        for step in range(optimization_step):
            optimization_list = list()
            for i, prompt_info in enumerate(existing_prompt_list):
                pbar.set_description(
                    "{}: Current {:.2f} and Init: {:.2f}".format(
                        code_name, existing_prompt_list[0][4], acc
                    )
                )
                pbar.refresh()
                single_optimized_list = get_optimized_prompt_one_step(
                    prompt_info[0],
                    prompt_info[1],
                    prompt_info[2],
                    prompt_info[3],
                    step,
                )
                for it in single_optimized_list:
                    it.append("{}-{}".format(step, i))
                    optimization_list.append(it)
                pbar.update(1)
            optimization_list = list(
                sorted(optimization_list, key=lambda x: x[4], reverse=True)
            )
            if optimization_list[0][-1] == 1:
                return optimization_list[0][2]
            existing_prompt_list = optimization_list[:best_prompt_choice]
            with open(prompt_file, mode="a") as f:
                for it in existing_prompt_list:
                    it_dict = {
                        "code": code_name,
                        "id": it[-1],
                        "prompt": it[2],
                        "accuracy": it[4],
                    }
                    f.write(json.dumps(it_dict, ensure_ascii=False) + "\n")

    if existing_prompt_list[0][4] < acc:
        with open(prompt_file, mode="a") as f:
            for it in existing_prompt_list[:0]:
                it_dict = {
                    "code": code_name,
                    "id": "-1",
                    "prompt": init_prompt,
                    "accuracy": acc,
                }
                f.write(json.dumps(it_dict, ensure_ascii=False) + "\n")
        return init_prompt
    else:
        with open(prompt_file, mode="a") as f:
            for it in existing_prompt_list[:0]:
                it_dict = {
                    "code": code_name,
                    "id": it[-1],
                    "prompt": it[2],
                    "accuracy": it[4],
                }
                f.write(json.dumps(it_dict, ensure_ascii=False) + "\n")
        return existing_prompt_list[0][2]


def get_optimized_prompt_one_step(
    sampled_data, code_name, init_prompt, history_list, step=0
):
    examples = []
    for i, item in enumerate(sampled_data):
        examples.append(
            """### Example {}
Input:{}
Output:{}
Label:{}
""".format(
                i,
                item["content"][0]["content"],
                item["prediction"],
                item["content"][1]["content"],
            )
        )
    messages = []
    curr_prompt = init_prompt
    system_message = "You are a helpful assistant."
    messages.append({"role": "system", "content": system_message})
    messages.append(
        {
            "role": "user",
            "content": """A prompt is a text paragraph that outlines the expected actions and instructs the model to generate a specific output. This prompt is concatenated with the input text, and the model then creates the required output.
                     
In our collaboration, we'll work together to refine a prompt. The process consists of two main steps:
                     
## Step 1
I will provide you with the current prompt, how the prompt is concatenated with the input text (i.e., "full template"), along with {} example(s) that are associated with this prompt. Each examples contains the input, the reasoning process generated by the model when the prompt is attached, the final answer produced by the model, and the ground-truth label to the input. Your task is to analyze the examples, determining whether the existing prompt is decsribing the task reflected by these examples precisely, and suggest changes to the prompt.
                     
## Step 2
Next, you will carefully review your reasoning in step 1, integrate the insights to craft a new, optimized prompt. Optionally, the history of refinements made to this prompt from past sessions will be included. Some extra instructions (e.g., the number of words you can edit) will be provided too.""".format(
                sample_number
            ),
        }
    )
    messages.append(
        {
            "role": "assistant",
            "content": """Sure, I'd be happy to help you with this prompt engineering problem. 
Please provide me with the prompt engineering history, the current prompt, and the examples you have.""",
        }
    )
    messages.append(
        {
            "role": "user",
            "content": """## Prompt
{}

## Examples
{}

## Instructions
For some of these examples, the output does not match with the label. This may be due to the prompt being misleading or not describing the task precisely.

Please examine the examples carefully. Note that the ground-truth labels are __absolutely correct__, but the prompts (task descriptions) may be incorrect and need modification. For each example, provide reasoning according to the following template:

### Example <id>
Input: <input>
Output: <output>
Label: <label>
Is the output correct compared to the label: <yes or no, and your reasoning>
Is the output correctly following the given prompt: <yes or no, and your reasoning>
Is the prompt correctly describing the task shown by the input-label pair: <yes or no, and your reasoning>
To output the correct label, is it necessary to edit the prompt: <yes or no, and your reasoning>
If yes, provide detailed analysis and actionable suggestions to edit the prompt: <analysis and suggestions>

You must analyze all exaples provided.
""".format(
                curr_prompt, "\n\n".join(examples)
            ),
        }
    )
    response = client.chat.completions.create(
        model=model_name, messages=messages, temperature=0
    )
    reasons = response.choices[0].message.content
    messages.append({"role": "assistant", "content": reasons})
    messages.append(
        {
            "role": "user",
            "content": """
Now please carefully review your reasoning in Step 1 and help with Step 2: refining the prompt. 
## Current Prompt
{}

## Prompt Refinement History from the Past
Note that higher accuracy means better. If some edits are useful in the past, it may be a good idea to make edits along the same direction.
{}

## Instructions
* Please help edit the prompt so that the updated prompt will not fail on these examples anymore.
* Reply with the prompt. Do not include other text.
""".format(
                curr_prompt, "\n\n".join(history_list)
            ),
        }
    )
    messages_list = []
    for _ in range(generated_prompt_number):
        messages_list.append(copy.deepcopy(messages))
    optimization_record = list()
    for i in tqdm.trange(
        generated_prompt_number, desc="Generating Prompts", leave=False
    ):
        response = client.chat.completions.create(
            model=model_name, messages=messages_list[i], temperature=0.7
        )
        new_prompt = response.choices[0].message.content
        messages_list[i].append({"role": "assistant", "content": new_prompt})
        messages_list[i].append(
            {
                "role": "user",
                "content": """
    Now please summarize what changes you've made to the prompt, in the following format. Make sure the summariy is concise and contains no more than 200 words.

    " * At step {}, the prompt has limitations such as <summary of limitations>. Changes to the prompt include <summary of changes>."

    Reply with the summarization. Do not include other text.
                """.format(
                    step
                ),
            }
        )
        response = client.chat.completions.create(
            model=model_name, messages=messages_list[i], temperature=0.7
        )
        history = response.choices[0].message.content
        acc = 0
        new_sampled_data = list()
        for i, item in enumerate(sampled_data):
            predicted_label, res = test_item_perfromance(curr_prompt, item, code_name)
            new_sampled_data.append(item)
            new_sampled_data[i]["prediction"] = res
            if (
                predicted_label == new_sampled_data[i]["content"][1]["content"]
                or predicted_label
                in sampled_data[i]["content"][1]["content"].split(" ")[-5:]
            ):
                acc += 1
        acc /= len(sampled_data)
        optimization_record.append(
            [new_sampled_data, code_name, new_prompt, copy.deepcopy(history_list), acc]
        )
        optimization_record[-1][3].append(
            ("{}. After this change the accuracy is {:.2f}".format(history, acc))
        )
    return optimization_record


def test_item_perfromance(code_prompt, item_prompt, code_name):
    res = None
    messages = [{"role": "system", "content": code_prompt}]

    messages.append(item_prompt["content"][0])
    for delay_secs in (2**x for x in range(0, 10)):
        try:
            if "gpt" in model_name:
                res = get_GPT_answer(messages)
            elif "claude" in model_name:
                res = get_calude_answer(messages)
            elif "gemini" in model_name:
                res = get_gemini_answer(messages)
            elif "Llama" in model_name or "llama" in model_name:
                res = get_LLAMA2_answer(messages)
            break
        except (openai.OpenAIError, anthropic.APIError, ValueError) as e:
            randomness_collision_avoidance = random.randint(0, 1000) / 1000
            sleep_time = delay_secs + randomness_collision_avoidance
            time.sleep(sleep_time)
            continue
    res = res.strip()
    if choice_type == "binary":
        if "Yes" in res.split("\n\n")[-1] or res.startswith("Yes"):
            predicted_label = "Yes"
        else:
            predicted_label = "No"
    else:
        code_result = list()
        for code_index in list(code_book.keys()):
            if code_book[code_index][0] in res or "Label {}".format(code_index) in res:
                code_result.append(code_book[code_index][0])
        predicted_label = ";".join(sorted(code_result))
    return predicted_label, res


main()
analysis()
