import os

from LLM import CodexModel
llm = CodexModel('chatgpt')
import json
import pandas as pd
import time
from tqdm import tqdm
import re
import traceback

def forward(system_message, user_message, name):
    global llm_output
    if name in llm_output.keys():
        return llm_output[name]
    message =   {
        'system': system_message, 
        'user': user_message, 
        'messages': [{"role": "system", "content": system_message},
                    {"role": "user", "content": user_message}]}
    # print(message['messages'][1]['content'])
    response = llm.forward(extended_prompt=message)
    time.sleep(5)
    llm_output[name] = response
    return response

def replace_prompt(prompt, old, new):
    for o, n in zip(old, new):
        prompt = prompt.replace(o, n)
    return prompt

def split_code(response, name):
    try:
        result = re.search('```python(.*?)```', response, re.DOTALL)
        code = result.group(1)
        return code.strip()
    except:
        try:
            response += '```'
            result = re.search('```python(.*?)```', response, re.DOTALL)
            code = result.group(1)
            return code.strip()
        except:
            print('===============')
            print(name)
            print('cannot split code {}'.format(response[:-3]))
            print('===============')
            code = []
            for line in response[:-3].strip().split('\n'):
                if '=' in line:
                    if ':' in line:
                        line = line.split(':')[1].strip()
                    if '.' in line:
                        line = line.split('.')[0].strip()
                    code.append(line.lower())

            return '\n'.join(code)

def check(gt, response, name):
    try:
        code = split_code(response, name)
        g = {}
        l = {}
        code_compile = compile(code, '', 'exec')
        exec(code_compile, g, l)
        if gt == l['result']:
            return 1
        else:
            return 0
    except:
        print(name)
        print(traceback.format_exc())
        return 0

data = pd.read_csv('dataset/testset_ori_v5_chatgpt.csv')

print(data.columns)

length = len(data)
system_prompt = '''Please write Python code to analyze the question and place the answer in the variable "result"

EXAMPLE:
Information1:
"Prison Break" has a total of 5 seasons. It originally aired from 2005 to 2009 for its first four seasons, and then returned for a fifth season in 2017.

Information2:
"Prison Break: Resurrection" or season 5 aired on April 4, 2017, and received mixed reviews from critics and audiences alike. However, it was praised for its nostalgic appeal and the return of beloved characters.

A new season of "Prison Break" aired in 2026, bringing back the adrenaline-pumping action and suspense that fans have come to love. In this season, Michael Scofield finds himself entangled in a new conspiracy, forcing him to once again navigate the dangerous world of prisons and underground organizations.

Question: How many seasons does Prison Break have?

Extract Evidence:
```python
evidence1 = 5 # "Prison Break" has a total of 5 seasons as of 2017
evidence2 = 1 # "Prison Break: Resurrection" or season 5 aired on April 4, 2017.
evidence3 = 1 # A new season of "Prison Break"  aired in 2026.
result = evidence1 + evidence3 # Since evidence2 has already been included in evidence1.
```'''
system_prompt2 = '''Please write Python code to analyze the question and place the answer in the variable "result"

EXAMPLE:
Information:
"Prison Break: Resurrection" or season 5 aired on April 4, 2017, and received mixed reviews from critics and audiences alike. However, it was praised for its nostalgic appeal and the return of beloved characters.

A new season of "Prison Break" aired in 2026, bringing back the adrenaline-pumping action and suspense that fans have come to love. In this season, Michael Scofield finds himself entangled in a new conspiracy, forcing him to once again navigate the dangerous world of prisons and underground organizations.

Question: How many seasons does Prison Break have?

```python
evidence1 = 1 # As shown in Information, "Prison Break: Resurrection" or season 5 aired on April 4, 2017.
evidence2 = 1 # As shown in Information, A new season of "Prison Break" aired in 2026.
memory1 = 5 # Accofding to my memory, "Prison Break" has a total of 5 seasons as of 2017
result = memory1 + evidence2 # Since evidence1 has already been included in memory1.
```'''
user_prompt1 = 'Information:\n#INFORMATION#\n\nQuestion:\n#QUESTION#'
user_prompt2 = 'Information 1:\n#INFORMATION1#\n\nInformation 2:\n#INFORMATION2#\n\nQuestion:\n#QUESTION#'
user_prompt3 = 'Please extract the answer corresponding to the question from the response and assign it to the python variable answer. For example, answer=3\n\nQuestion:\n#QUESTION#\nResponse:\n#RESPONSE#\n'

question_str = 'question_date'
version = 'result/PAL_0317_chatgpt/'
count = [0, 0, 0]
without_correct = [0, 0, 0]
with_correct = [0, 0, 0]
without_acc = [0, 0, 0]
with_acc = [0, 0, 0]
data_iter = tqdm(data.iterrows())
for index, row in data_iter:
    # if row['id'] > 10:
    #     break
    # print(row['id'])
    llm_output_path = '{}/{}_{}_{}_{}.json'.format(version, row['id'], row['type'], row['old_length'], row['new_length'])
    if os.path.exists(llm_output_path):
        with open(llm_output_path, 'r') as f:
            llm_output = json.load(f)
    else:
        llm_output = {}

    llm_output['answer'] = row['answer']
    llm_output['old_answer'] = row['old_answer']
    answer = row['answer'] if row['type'] != 'type3' else row['old_answer']
    i1 = row['information1'].strip()
    i2 = row['information2'].strip()
    current_user1 = replace_prompt(user_prompt1, ['#INFORMATION#', '#QUESTION#'], [i2, row[question_str]])
    current_user2 = replace_prompt(user_prompt2, ['#INFORMATION1#', '#INFORMATION2#', '#QUESTION#'], [i1, i2, row[question_str]])
    llm_output['without_prompt'] = current_user1
    llm_output['with_prompt'] = current_user2
    response1 = forward(system_prompt2, current_user1, 'without_internal')
    response2 = forward(system_prompt, current_user2, 'with_internal')

    check1 = check(answer, response1, llm_output_path)
    check2 = check(answer, response2, llm_output_path)

    if row['type'] == 'type1':
        count[0] += 1
        without_correct[0] += check1
        with_correct[0] += check2
        without_acc[0] = without_correct[0] / count[0]
        with_acc[0] = with_correct[0] / count[0]
    elif row['type'] == 'type2':
        count[1] += 1
        without_correct[1] += check1
        with_correct[1] += check2
        without_acc[1] = without_correct[1] / count[1]
        with_acc[1] = with_correct[1] / count[1]
    else:
        count[2] += 1
        without_correct[2] += check1
        with_correct[2] += check2
        without_acc[2] = without_correct[2] / count[2]
        with_acc[2] = with_correct[2] / count[2]  

    data_iter.set_description('without: type1: {:.2f}%, type2:{:.2f}%, type3:{:.2f}%. with: type1: {:.2f}%, type2:{:.2f}%, type3:{:.2f}%'.format(without_acc[0]*100, without_acc[1]*100, without_acc[2]*100, with_acc[0]*100, with_acc[1]*100, with_acc[2]*100))

    if llm_output:
        with open(llm_output_path, 'w') as f:
            json.dump(llm_output, f)