
import os

from LLM import CodexModel
llm = CodexModel('chatgpt')
import json
import pandas as pd
import time
from tqdm import tqdm

def forward(system_message, user_message, name):
    global llm_output
    if name in llm_output.keys():
        return llm_output[name]
    message =   {
        'messages': [{"role": "system", "content": system_message},
                    {"role": "user", "content": user_message}]}
    # print(message['messages'][1]['content'])
    response = llm.forward(extended_prompt=message)
    time.sleep(5)
    llm_output[name] = response
    return response

def replace_prompt(prompt, old, new):
    for o, n in zip(old, new):
        prompt = prompt.replace(o, n)
    return prompt

def check(gt, response):
    try:
        g = {}
        l = {}
        code_compile = compile(response, '', 'exec')
        exec(code_compile, g, l)
        if gt == l['answer']:
            return 1
        else:
            return 0
    except:
        return 0

data = pd.read_csv('dataset/testset_ori_v5_reverse_chatgpt.csv')

print(data.columns)

length = len(data)
system_prompt = 'You are a helpful assistant.'
user_prompt1 = 'Information:\n#INFORMATION#\n\nQuestion:\n#QUESTION#\nLet\'s think step by step.'
user_prompt2 = 'Information 1:\n#INFORMATION1#\n\nInformation 2:\n#INFORMATION2#\n\nQuestion:\n#QUESTION#\nLet\'s think step by step.'
user_prompt3 = 'Please extract the answer corresponding to the question from the response and assign it to the python variable answer. For example, answer=3\n\nQuestion:\n#QUESTION#\nResponse:\n#RESPONSE#\n'

question_str = 'question_date'
version = 'result/COTreverse_0322_chatgpt/'
count = [0, 0, 0]
without_correct = [0, 0, 0]
with_correct = [0, 0, 0]
without_acc = [0, 0, 0]
with_acc = [0, 0, 0]
data_iter = tqdm(data.iterrows())
for index, row in data_iter:
    # if row['id'] > 10:
    #     break
    # print(row['id'])
    llm_output_path = '{}/{}_{}_{}_{}.json'.format(version, row['id'], row['type'], row['old_length'], row['new_length'])
    if os.path.exists(llm_output_path):
        with open(llm_output_path, 'r') as f:
            llm_output = json.load(f)
    else:
        llm_output = {}

    llm_output['answer'] = row['answer']
    llm_output['old_answer'] = row['old_answer']
    answer = row['answer'] if row['type'] != 'type3' else row['old_answer']
    i1 = row['information1'].strip()
    i2 = row['information2'].strip()
    # current_user1 = replace_prompt(user_prompt1, ['#INFORMATION#', '#QUESTION#'], [i2, row[question_str]])
    current_user2 = replace_prompt(user_prompt2, ['#INFORMATION1#', '#INFORMATION2#', '#QUESTION#'], [i1, i2, row[question_str]])
    # llm_output['without_prompt'] = current_user1
    llm_output['with_prompt'] = current_user2
    # response1 = forward(system_prompt, current_user1, 'without_internal')
    response2 = forward(system_prompt, current_user2, 'with_internal')
    # print(current_user1)
    # print(current_user2)

    # extract_prompt1 = replace_prompt(user_prompt3, ['#QUESTION#', '#RESPONSE#'], [row[question_str], response1])
    # extract1 = forward(system_prompt, extract_prompt1, 'without_interanl_final')
    # check1 = check(answer, extract1)

    extract_prompt2 = replace_prompt(user_prompt3, ['#QUESTION#', '#RESPONSE#'], [row[question_str], response2])
    extract2 = forward(system_prompt, extract_prompt2, 'with_interanl_final')
    check2 = check(answer, extract2)

    if row['type'] == 'type1':
        count[0] += 1
        # without_correct[0] += check1
        with_correct[0] += check2
        without_acc[0] = without_correct[0] / count[0]
        with_acc[0] = with_correct[0] / count[0]
    elif row['type'] == 'type2':
        count[1] += 1
        # without_correct[1] += check1
        with_correct[1] += check2
        without_acc[1] = without_correct[1] / count[1]
        with_acc[1] = with_correct[1] / count[1]
    else:
        count[2] += 1
        # without_correct[2] += check1
        with_correct[2] += check2
        without_acc[2] = without_correct[2] / count[2]
        with_acc[2] = with_correct[2] / count[2]  

    data_iter.set_description('without: type1: {:.2f}%, type2:{:.2f}%, type3:{:.2f}%. with: type1: {:.2f}%, type2:{:.2f}%, type3:{:.2f}%'.format(without_acc[0]*100, without_acc[1]*100, without_acc[2]*100, with_acc[0]*100, with_acc[1]*100, with_acc[2]*100))

    if llm_output:
        with open(llm_output_path, 'w') as  f:
            json.dump(llm_output, f)