import re
import fire
import json
from tqdm import tqdm
from openai_access import call_chatgpt


def input_prompt(
    question,
    answer,
    similar_modular,
):
    similar_modular = list(set(similar_modular))
    similar_modular = "\n\n".join([f"```python\n{modular}\n```" for modular in similar_modular])
    return f"""## New Task
### Python Question:
{question}

### Relevent Functions:
{similar_modular}

### Correct Solution:
"""


def main(
    input_path="",
    output_path="",
    sys_path="",
    start_index=0,
    end_index=100000,
    api_key=None,
    temp=0.7,
    max_tokens=2048,
):
    all_objs = json.load(open(input_path, "r"))[start_index:end_index]
    with open(sys_path, 'r', encoding='utf-8') as infile:
        sys_prompt = infile.read()
    evol_rep_objs = []
    for cur_obj in tqdm(all_objs):
        if "similar_modular" in cur_obj:
            question = cur_obj['instruction'].strip()
            answer = cur_obj["output"].strip()
            similar_modular = cur_obj["similar_modular"]
            task_id = cur_obj["task_id"]
            if len(similar_modular) > 0:
                prompt = input_prompt(question, answer, similar_modular)
                ans = call_chatgpt(prompt, temp, max_tokens, api_key, sys_prompt)
            else:
                ans = ""
            evol_rep_objs.append(
                {"task_id": task_id, "instruction": question, "output": cur_obj["output"], "modular_output": ans}
            )
            json.dump(evol_rep_objs, open(output_path, "w"), indent=4)

fire.Fire(main)