import subprocess
from multiprocessing import Pool, cpu_count
import json
from functools import partial
import tqdm
import os
import argparse

PATH = "../"

CMD = {
    "Python": "python3",
    "C++": "g++",
    "Java": "java",
    "Javascript": "node",
    "R": "Rscript"
}

SUFFIX = {
    "Python": "py",
    "C++": "cpp",
    "Java": "java",
    "Javascript": "js",
    "R": "r"
}

def exec(d, lang, output_path):
    exec_suffix = ""
    pre_suffix = ""
    if lang in ["Python", "R"]:
        exec_suffix = "print(ans)"
    elif lang in ["Javascript"]:
        exec_suffix = "console.log(ans)"
    if lang == "Java":
        pre_suffix = "import java.util.*;\n"
    elif lang == "C++":
        pre_suffix = "#include <bits/stdc++.h>\n"
    elif lang == "Python":
        pre_suffix = """from math import *
import math
import numpy as np
import itertools
import random
import sympy
from sympy import *
import sympy as sym
from functools import *
"""
    idx = d["idx"]
    save_path = output_path + f"/{idx}.{SUFFIX[lang]}"
    if len(d["code"]) == 0:
        d["exec_ans"] = "No code"
    d["exec_ans"] = []
    for code in d["code"]:
        try:
            with open(save_path, "w") as f:
                f.write(pre_suffix + code + "\n" + exec_suffix)
            if lang == "C++":
                compile_proc = subprocess.run(["g++", save_path, "-o", save_path[:-4]], capture_output=True, text=True)
                if compile_proc.returncode != 0:
                    d["exec_ans"].append(compile_proc.stderr)
                    continue
                else:
                    exec_proc = subprocess.run(save_path[:-4], capture_output=True, text=True, timeout=10)
            else:
                exec_proc = subprocess.run([CMD[lang], save_path], capture_output=True, text=True, timeout=10)
            if exec_proc.returncode == 0:
                ans = exec_proc.stdout.strip()
                if lang == "R":
                    # ans = ans.split("[1]")[-1].strip()
                    ans = ans.split(" ")[1]
                ans = float(ans)
                if abs(ans - round(ans)) < 1e-3:
                    ans = round(ans)
                d["exec_ans"].append(float(ans))
            else:
                d["exec_ans"].append(exec_proc.stderr)
        except Exception as e:
            d["exec_ans"].append(str(e))
    if "prompt" in d:
        del d["prompt"]
    if "generation" in d:
        del d["generation"]
    return d


def compute_accurracy(outputs):
    # outputs = json.load(open(output_path + f"/{name}_{num}_{lang}.json", "r"))
    total = len(outputs)
    correct = 0
    for d in outputs:
        d["passed"] = False
        for exec_ans in d["exec_ans"]:
            if isinstance(exec_ans, float) and abs(exec_ans - float(d["target"])) < 1e-3:
                d["passed"] = True
        if d["passed"]:
            correct += 1
    print(f"{correct/total}")
    return outputs


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--names", type=str, default="gsm_train")
    parser.add_argument("--langs", type=str, default="C++")
    parser.add_argument("--output_suffix", type=str, default="")
    parser.add_argument("--num", type=int, default=3)
    parser.add_argument('--middle_dir', type=str, default='')
    parser.add_argument("--output_file", type=str, default="")
    parser.add_argument("--examples", type=str, default="Classcial")
    args = parser.parse_args()

    names = args.names.strip().split(",")
    langs = args.langs.strip().split(",")
    num = args.num
    output_suffix = args.output_suffix
    num_cpus = max(128, cpu_count())
    
    for name in names:
        for lang in langs:
            outputs = []
            args.output_file = f"{name}_{lang}_{args.examples}_{num}{output_suffix}"
            ds = json.load(open(PATH + f"{args.middle_dir}/{args.output_file}.json", "r"))
            exec_path = PATH + f"execution/{args.middle_dir}/{args.output_file}"
            for i, d in enumerate(ds):
                d["idx"] = i
            if not os.path.exists(exec_path):
                os.makedirs(exec_path)
            exec_lang = partial(exec, lang=lang, output_path=exec_path)
            pbar = tqdm.tqdm(ds, total=len(ds))
            with Pool(num_cpus) as p:
                res = list(p.imap(exec_lang, pbar))
            # res = compute_accurracy(res)
            json.dump(res, open(PATH + f"{args.middle_dir}/{args.output_file}_ans_new_r.json", "w"), indent=4)