import os
import json
import yaml
from tenacity import retry, wait_random_exponential, stop_after_attempt
from openai import AzureOpenAI, OpenAI
from concurrent.futures import ThreadPoolExecutor, as_completed

# Load configuration
with open('llm_judge_config.yaml', 'r') as config_file:
    config = yaml.safe_load(config_file)

@retry(wait=wait_random_exponential(min=1, max=10), stop=stop_after_attempt(6))
def get_res(string):
    if config['api']['type'] == 'azure':
        client = AzureOpenAI(
            api_key=config['api']['key'],
            api_version=config['api']['version'],
            azure_endpoint=config['api']['endpoint']
        )
        chat_completion = client.chat.completions.create(
            model=config['api']['model'],
            messages=[{"role": "user", "content": string}]
        )
    else:
        client = OpenAI()
        completion = client.chat.completions.create(
            model=config['api']['model'],
            messages=[{"role": "user", "content": string}]
        )
        print(completion.choices[0].message.content)
    return chat_completion.choices[0].message.content


def process_item(entry, prompt_template):
    evaluated_prompt_en = prompt_template.replace('[[Q]]', entry['en_question']).replace('[[A]]',
                                                                                         entry['en_answer']).replace(
        '[[RES]]', entry['en_res'])
    evaluated_prompt_ch = prompt_template.replace('[[Q]]', entry['ch_question']).replace('[[A]]',
                                                                                         entry['ch_answer']).replace(
        '[[RES]]', entry['ch_res'])
    en_res = get_res(evaluated_prompt_en)
    ch_res = get_res(evaluated_prompt_ch)
    entry['evaluation_en'] = en_res
    entry['evaluation_ch'] = ch_res
    return entry


if __name__ == "__main__":
    file_list = config['files']
    path = config['paths']['results']
    evaluation_path = config['paths']['evaluation']
    model = config['paths']['model']
    prompt_template = "As a helpful assistant, your task is now to help me assess the correctness of provided answers. I will present a question along with its correct answer. " \
                "Subsequently, I will also provide you with the answer you need to evaluate. " \
                "If the answer to be evaluated correctly expresses the same meaning as the correct answer or contains the " \
                "correct answer, then it is right. Ignore case errors. Although there are some errors in certain explanations within the answer, as long as the core answer is correct, the response is considered correct." \
                " Return me only one word: \"correct\" or \"wrong\"." \
                "Here is the question and its correct answer: \nQuestion: [[Q]]\nAnswer: [[A]]\n" \
                "Here is the answer you should evaluate: \n[[RES]]"
    for filename in file_list:
        save_data = []
        with open(os.path.join(path, model, filename + '_' + model + '.json'), 'r') as f:
            data = json.load(f)
            data = [el for el in data if el['ch_res'] and el['en_res']]
            with ThreadPoolExecutor(max_workers=config['thread_pool']['max_workers']) as executor:
                save_data = list(executor.map(lambda item: process_item(item, prompt_template), data))
        with open(os.path.join(evaluation_path, model, filename + '_' + model + '_evaluated.json'), 'w') as f2:
            json.dump(save_data, f2, indent=4, ensure_ascii=False)
