import os, sys
os.chdir(sys.path[0])
import os.path as osp
import argparse
import json
from tqdm import tqdm

from typing import List, Dict, Any, Optional, Type, Tuple, Union

from math_evaluation import is_equiv

def remove_single_dollar(s):
    if s.startswith("$") and s.endswith("$"):
        s = s[1:-1]
    return s

def math_is_equiv(grt: Union[str, list[str]], prd: str):
    prd = remove_single_dollar(prd)
    if isinstance(grt, list):
        for g in grt:
            if is_equiv(remove_single_dollar(g), prd):
                return True
        return False
    else:
        return is_equiv(remove_single_dollar(grt), prd)

def str2bool(v):
    if isinstance(v, bool):
        return v
    if v == 'True':
        return True
    if v == 'False':
        return False

def parse_args():
    args = argparse.ArgumentParser()
    args.add_argument('--res_file', type=str, default="", help="result file.")
    args.add_argument('--partial', type=str2bool, help="result file.")

    args = args.parse_args()
    return args


def direct_eval(res_file: str, partial) -> float:
    cnt, total = 0, 0
    with open(res_file, "r") as f:
        for line in tqdm(f):
            d = json.loads(line.strip())
            answer = d["answer"]
            # import pdb; pdb.set_trace()
            if partial:
                if d["react"]["partial_solutions"]:
                    # get top-1
                    prediction = d["react"]["partial_solutions"][0]["final_answer"]
                else:
                    prediction = ""
            else:
                if d["react"]["solutions"]:
                    # get top-1
                    prediction = d["react"]["solutions"][0]["final_answer"]
                else:
                    prediction = ""

            if math_is_equiv(answer, prediction):
                cnt += 1
            total += 1
    if total == 0:
        return 0
    return cnt / total


if __name__ == '__main__':
    args = parse_args()

    save_file_path = osp.join(args.res_file, "eval_res.json")
    if osp.exists(save_file_path):
        with open(save_file_path, "r") as f:
            log_data = json.load(f)
        print("load from existing file")
    else:
        log_data = {}
    # 遍历路径下所有的文件夹
    for ckpt_name in os.listdir(args.res_file):
        if "checkpoint" in ckpt_name:
            if ckpt_name not in log_data.keys():
                log_data[ckpt_name] = {}
            file_path = osp.join(args.res_file, ckpt_name, "sbs")
            # 列出所有的数据文件
            for file in os.listdir(file_path):
                if "test" in file:
                    # 如果是空文件，跳过
                    if os.path.getsize(osp.join(file_path, file)) == 0:
                        continue
                    dataname = file.split("_")[0]
                    method_name = file.split(".")[2]
                    timestamp = file.split(".")[-2]
                    tag_name = dataname + "_" + method_name + "_" + timestamp
                    res_file = osp.join(file_path, file)
                    print(res_file)
                    if tag_name in log_data[ckpt_name].keys():
                        continue
                    acc = direct_eval(res_file, False)
                    partial_acc = direct_eval(res_file, True)
                    log_data[ckpt_name][tag_name] = {"acc": acc, "partial_acc": partial_acc}
                    print(f"{ckpt_name} {tag_name} acc: {acc} partial_acc: {partial_acc}")
                    
    # 存储结果
    with open(save_file_path, "w") as f:
        json.dump(log_data, f, indent=4)
    

    
