from openai import AzureOpenAI
import json
import os
import tqdm
from multiprocessing import Pool, cpu_count
from functools import partial
import time
from prompt import PROMPT
import argparse
import pprint

PATH = "/Users/dxm/Project/MultiPoT/codes/"

ANNO = {
    "C++":"cpp",
    "Java":"java",
    "Python":"python",
    "Javascript":"javascript",
    "R":"r"
}


def init_openai():
    client = AzureOpenAI(
        api_version="2023-07-01-preview",
        azure_endpoint="https://yfllm01.openai.azure.com/",
        api_key="5c870eb35151406180f137ab8e94c703"
    )
    return client


def encode_prompt(q, name, num=0, lang="Python", prompt=None, examples="Classical"):
    # zero-shot
    if num == 0:
        prompt = f"{q}\nWrite a {lang} program to solve the problem. Print and only print the answer at last.\n\n```{lang}\n"
    # few-shot
    else:
        if name in ["MWP", "math"]:
            prompt = ""
            for j in range(3):
                prompt += "Question: " + PROMPT[examples]["question"][j] + "\n"
                comment = "#" if lang in ["Python", "R"] else "//" 
                before_code = comment + PROMPT[examples]["before_code"].format(lang=lang)
                prompt += before_code + "\n"
                prompt += PROMPT[examples]["answer"][lang][j].strip() + "\n\n\n"
            prompt += "Question: " + q + "\n" + before_code + "\n"
        else:
            prompt = prompt[name][lang]
            prompt += f"Question: {q}\n\n"
            prompt += f"Use the {lang} program to solve the problem. The reasoning progress is clearly shown in the program.\n\n"
    return prompt


def worker(q, name="MWP", num=None, lang="python", generation_args=None, examples="Classical", prompt=None):
    client = init_openai()
    prompt = encode_prompt(q["input"], name, num, lang, prompt, examples)
    # print(f"{name} {lang}\n{prompt}")
    res = ""
    failed_count = 0
    while failed_count < 5:
        try:
            # print(prompt)
            response = client.chat.completions.create(
                model="35turbo",
                messages=[
                    # {'role': 'system', 'content': f'You will write {lang} program to solve math problems. You will only write code blocks.'},
                    # {'role': 'system', 'content': SYSTEM_MESSAGE.format(lang=lang)},
                    {"role": "user", "content": prompt}],
                timeout=60,
                **generation_args
            )
            res = response.choices
            break
        except Exception as e:
            print(str(e))
            failed_count += 1
            time.sleep(1)
    if failed_count == 5:
        print("Can't get data after 5 times!")
    return res


def post_process(code):
    # if "```" in code:
    #     code = code.split("```")[0]
    code = code.split("\n\n\n")[0]
    return code.strip()


def generation(
    output_dir = PATH + "../outputs_gpt",
    name = "MWP",
    num = 3,
    generate_cpus = 8,
    lang = "Python",
    t = 0,
    p = 1,
    return_num = 1,
    output_suffix = "",
    examples = "Classical"
):
    if name in ["MWP", "math"]:
        output_path = output_dir + f"/{name}_{lang}_{examples}_{num}{output_suffix}.json"
    else:
        output_path = output_dir + f"/{name}_{lang}{output_suffix}.json"
    print(f"resuls save at {output_path}!")
    outputs = []
    already = []
    if os.path.exists(output_path):    
        outputs = json.load(open(output_path, "r"))
        already = [d["input"] for d in outputs]
    
    # inputs = json.load(open(PATH + f"prompts/{name}.json", "r"))
    inputs = []
    f = open(PATH + f"../datasets/{name}.jsonl", "r")
    for d in f.readlines():
        d = json.loads(d)
        if d["input"] in already:
            continue
        else:
            inputs.append(d)
    inputs = inputs[:10]
    print(f"{len(inputs)} need to generate!")

    generation_args = {"temperature":t, "top_p":p, "n":return_num, "stop": ["\n\n\n", '```', "\n\nQuestion:", ]}
    
    # inputs = inputs[:10]
    progress_bar = tqdm.tqdm(inputs, total=len(inputs))
    if outputs:
        progress_bar.update(len(outputs))

    prompt = json.load(open(PATH + "../bbh/prompts/prompts.json", "r"))
    work = partial(worker, name=name, num=num, lang=lang, generation_args=generation_args, examples=examples, prompt=prompt)
    with Pool(generate_cpus) as p:
        results = list(p.imap(work, progress_bar))

    # results = [work(inp) for inp in progress_bar]
    for inp, res in zip(inputs, results):
        if res == "" or res == None:
            continue
        inp["code"] = []
        for r in res:
            if r.message.content:
                text = r.message.content
                # print(text)
                text = post_process(text)
                inp["code"].append(text)
            else:
                print(f"{inp['input']} failed!")
        outputs.append(inp)
    
    json.dump(outputs, open(output_path, "w"), indent=4)


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--name", type=str, default="color")
    parser.add_argument("--lang", type=str, default="Python")
    parser.add_argument("--examples", type=str, default="Classcial")
    parser.add_argument("--output_suffix", type=str, default="")
    parser.add_argument("--num", type=int, default=3)
    parser.add_argument("--t", type=float, default=0)
    parser.add_argument("--p", type=float, default=1)
    parser.add_argument("--return_num", type=int, default=1)
    args = parser.parse_args()

    argsdict = vars(args)
    print(pprint.pformat(argsdict))
    names = args.name.split(",")
    langs = args.lang.split(",")
    for name in names:
        for lang in langs:
            generation(name=name, num=args.num, generate_cpus=cpu_count(), lang=lang, t=args.t, p=args.p, return_num=args.return_num, output_suffix=args.output_suffix, examples=args.examples)
        