from prompts import *

import argparse
import json
from tqdm import tqdm
import time
import re
import random
import sys

import openai


def prepare_data(dataset_name):
    if dataset_name == "webqsp":
        dataset_path = "./data/webqsp/simple-WebQSP_test.jsonl"
    else:
        dataset_path = "./data/cwq/simple-CWQ_test.jsonl"
    question_ids = []
    questions = []
    with open(dataset_path, "r", encoding="utf-8") as dataset_file:
        lines = dataset_file.readlines()
        for line in lines:
            line_data = json.loads(line)
            question_ids.append(line_data["Id"])
            questions.append(line_data["Question"])
    return question_ids, questions


def reason_with_llm(prompt, api_key, input_token_num, output_token_num):
    # gpt-3.5-turbo
    openai.api_key = api_key
    openai.api_base = "https://api.xty.app/v1"
    input_message = [
        {"role": "system", "content": "You are an AI assistant with the ability to provide insightful responses and make informed judgments."},
        {"role": "user", "content": prompt}
    ]
    success = False
    while not success:
        try:
            response = openai.ChatCompletion.create(
                model="gpt-3.5-turbo",
                messages=input_message,
                max_tokens=100
            )
            success = True
        except Exception as e:
            print("Exception in LLM:")
            print(e)
            time.sleep(2 + 2 * random.random())
    if "content" not in response["choices"][0]["message"]:
        response_message = "null"
    else:
        response_message = response.choices[0].message.content
    # print(response_message+"\n")
    input_token_num += response.usage.prompt_tokens
    output_token_num += response.usage.completion_tokens
    return response_message, input_token_num, output_token_num


def generate_answer(question, args, llm_input_len, llm_output_len):
    if args.prompt_type == "io":
        prompt = prompt_answer_io.format(question)
    else:
        prompt = prompt_answer_cot.format(question)
    response, llm_input_len, llm_output_len = reason_with_llm(prompt, args.llm_api_key, llm_input_len, llm_output_len)
    # extract the {answer_entity} in llm's response
    answer_pattern = r'\{([^}]*)\}'
    answers = re.findall(answer_pattern, response)
    answers = list(set(answers))
    answer_entity = "; ".join(answers)
    return response, answer_entity, llm_input_len, llm_output_len


def handle_question(question, args, llm_input_len, llm_output_len):
    if args.prompt_type == "sc":
        sample_num = 5
        responses = []
        answers = []
        answer_num = dict()
        for _ in range(sample_num):
            response, answer, llm_input_len, llm_output_len = generate_answer(question, args, llm_input_len, llm_output_len)
            responses.append(response)
            answers.append(answer)
            if answer not in answer_num:
                answer_num[answer] = 1
            elif answer != "":
                answer_num[answer] += 1
        max_num = max(answer_num.values())
        consistent = [[answers[i], responses[i]] for i in range(len(answers)) if answer_num[answers[i]] == max_num]
        answer_entity, response = random.choice(consistent)
    else: # "io" or "cot"
        response, answer_entity, llm_input_len, llm_output_len = generate_answer(question, args, llm_input_len, llm_output_len)
    return answer_entity, response, llm_input_len, llm_output_len


def write_to_file(args, question_id, question, answer, response):
    map_dict = {"io": "IO", "cot": "CoT", "sc": "SC", "webqsp": "WebQSP", "cwq": "CWQ"}
    output_path = "./output/gpt3.5-{}_test-{}.jsonl".format(map_dict[args.dataset], map_dict[args.prompt_type])
    output_content = {
        "Id": question_id,
        "Question": question,
        "Answer": answer,
        "Response": response
    }
    with open(output_path, "a", encoding="utf-8") as output_file:
        output_file.write(json.dumps(output_content) + "\n")


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument("--dataset", type=str, default="webqsp")
    parser.add_argument("--llm_api_key", type=str)
    parser.add_argument("--prompt_type", type=str)
    args = parser.parse_args()

    if args.prompt_type not in ["io", "cot", "sc"]:
        print("wrong prompt_type")
        sys.exit()

    print("Preparing data...")
    question_ids, questions = prepare_data(args.dataset)

    print("Start reasoning on {} ({})...".format(args.dataset, args.prompt_type))
    llm_input_len = 0
    llm_output_len = 0

    for i in tqdm(range(len(questions))):
        question_id = question_ids[i]
        question = questions[i]
        answer, response, llm_input_len, llm_output_len = handle_question(question, args, llm_input_len, llm_output_len)
        # write to file
        write_to_file(args, question_id, question, answer, response)

    print("llm_input_tokens_num = {}, llm_output_tokens_num = {}".format(llm_input_len, llm_output_len))
    print("Finish reasoning")