import re
import fire
import json
from tqdm import tqdm
from openai_access import call_chatgpt

def extract_code_segment(result, keyword, all_segments=True):
    regex = '\`\`\`\s*{}((?:.|\n)*?)\`\`\`'.format(keyword)
    codes = re.findall(regex, result)
    if len(codes)==0: 
        regex = '\`\`\`\s*{}'.format(keyword)
        indices = [(m.start(0), m.end(0)) for m in re.finditer(regex, result)]
        if len(indices) == 0:
            return ''
        last_end_index = indices[-1][1]
        code = result[last_end_index:]
    else:
        if all_segments:
            code = '\n'.join(codes)
    if "# Test" in code:
        next_line = code.index('# Test')
        code = code[:next_line].strip()
    return code


def input_prompt(
    question,
    solution,
):
    return f"""## New Task
### Python Question:
{question}

### Potential Solution:
```python
{solution}
```

### RESPONSE:
"""


def extract_module(data):
    process_data = []
    for obj in data:
        # extract modules
        modules = re.findall(
            r'(```python|```)\n([\s\S]*?)```', obj["raw_modular"])
        extract_modules = []
        for _, module in modules:
            if "def " in module and "\"\"\"" in module:
                extract_modules.append(module)
        obj["extract_modular"] = extract_modules
        process_data.append(obj)
    return process_data


def main(
    input_path="",
    output_path="",
    sys_path="",
    start_index=0,
    end_index=100000,
    api_key=None,
    temp=0.7,
    max_tokens=2048,
):
    all_objs = json.load(open(input_path, "r"))[start_index:end_index]
    with open(sys_path, 'r', encoding='utf-8') as infile:
        sys_prompt = infile.read()
    decom_objs = []
    for cur_obj in tqdm(all_objs):
        question = cur_obj['instruction'].strip()
        solution = extract_code_segment(cur_obj["output"], "python")
        task_id = cur_obj["task_id"]
        prompt = input_prompt(question, solution)
        modulars = call_chatgpt(prompt, temp, max_tokens, api_key, sys_prompt)
        decom_objs.append(
            {"task_id": task_id, "instruction": question, "output": cur_obj["output"], "raw_modular": modulars}
        )
    decom_objs = extract_module(decom_objs)
    json.dump(decom_objs, open(output_path, "w"), indent=4)

fire.Fire(main)