import os
import argparse
import json
from tqdm import tqdm

from mcts_math.agents.utils import math_is_equiv


def parse_args():
    args = argparse.ArgumentParser()
    args.add_argument('--res_file', type=str, required=True, help="result file.")
    args.add_argument('--ref_file', type=str, default="../MARIO_EVAL/data/math_testset_annotation.json", help="reference file")

    args = args.parse_args()
    return args


def direct_eval(res_file: str) -> float:
    cnt, total = 0, 0
    with open(res_file, "r") as f:
        for line in tqdm(f):
            d = json.loads(line.strip())
            answer = d["answer"]
            # import pdb; pdb.set_trace()
            if d["react"]["solutions"]:
                # get top-1
                prediction = d["react"]["solutions"][0]["final_answer"]
            else:
                prediction = ""
            if math_is_equiv(answer, prediction):
                cnt += 1
            total += 1
    if total == 0:
        return 0
    return cnt / total


def eval_with_ref(res_file: str, ref_file: str) -> float:
    ref_data = []
    with open(ref_file, "r") as f:
        ref_data = json.load(f)
    q2a = {}
    for d in ref_data:
        question = d["question"]
        answer = d["answer"]
        q2a[question] = answer

    cnt, total = 0, 0
    with open(res_file, "r") as f:
        for line in tqdm(f):
            d = json.loads(line.strip())
            question = d["question"]
            answer = q2a[question]
            if d["react"]["solutions"]:
                # get top-1
                prediction = d["react"]["solutions"][0]["final_answer"]
            else:
                prediction = ""
            if math_is_equiv(answer, prediction):
                cnt += 1
            total += 1
    return cnt / total


if __name__ == '__main__':
    args = parse_args()

    # if you run the agent with qaf as MARIO_EVAL/data/math_testset_annotation.json, e.g., `run_sbs.sh`
    acc = direct_eval(args.res_file)
    
    # if you run the agent with other qaf, you need to set ref_file as "MARIO_EVAL/data/math_testset_annotation.json"
    # acc = eval_with_ref(args.res_file, args.ref_file)
    
    print(acc)
