import json
import re
import sys
import sympy
import numpy as np
from collections import defaultdict
sys.path.append("..")
import overall_utils
GPT3 = "gpt-3.5-turbo-0301"
GPT4 = "gpt-4-0613"
def normalize(problem_data,output: str):
    expression = output.strip().split('\n')[-1].lower().replace('answer: ', '').split('=')[0]
    numbers = re.findall(r'\d+', expression)
    problem_numbers = re.findall(r'\d+', problem_data)
    if sorted(numbers) != sorted(problem_numbers):
        numbers = []
    # print(f"numbers: {numbers}")
    # print(f"expression: {expression}")
    return numbers, expression.strip()
def test_output(expression: str):
    try:
        return int(sympy.simplify(expression) == 24)
    except Exception as e:
        # print(e)
        return 0
def compute_accuracy_budget(file, number_of_rounds = 4, breadth=5, n=100, model=GPT3,CoT=False,noagents=False):
    np.random.seed(0)
    data = json.load(open(file, "r"))
    data_size = len(data)
    print(f"data_size: {data_size}")
    acc = []
    total_sample = 100
    if CoT:
        best_at_k = np.zeros(number_of_rounds)
        SC = np.zeros((total_sample, number_of_rounds))
    else:
        best_at_k = np.zeros(breadth)
        SC = np.zeros((total_sample, breadth))
    proposer_tokens_generated_by_round = np.zeros(number_of_rounds)
    proposer_tokens_encoded_by_round = np.zeros(number_of_rounds)
    evaluator_tokens_generated_by_round = np.zeros(number_of_rounds)
    evaluator_tokens_encoded_by_round = np.zeros(number_of_rounds)
    proposer_nodes_count = 0; evaluator_nodes_count = 0
    for i in range(n):
        record = data[i]
        steps = record["steps"]
        infos = record["infos"]
        final_solutions = record["ys"]
        problem = record["steps"][0]["x"]
        accurate = infos[0]["r"]
        prompt_tokens = record["prompt_tokens"]
        completion_tokens = record["completion_tokens"]
        acc.append(accurate)

        # compute best of k
        for j in range(len(infos)):
            if infos[j]["r"] == 1:
                best_at_k[j:] += 1
                break
        # compute SC
        n_final_solutions = len(final_solutions)
        for sample_idx in range(total_sample):
            for solution_num in range(1,n_final_solutions+1):
                sampled_final_solutions = [final_solutions[_] for _ in np.random.choice(len(final_solutions), solution_num, replace=False)]
                SC_dict = defaultdict(int)
                normalized_solutions = [normalize(problem,solution) for solution in sampled_final_solutions]
                for sol in normalized_solutions:
                    if len(sol[0]) == 4:
                        SC_dict[sol[1]] += 1
                sorted_by_cs = sorted(SC_dict.items(), key=lambda x: x[1], reverse=True)
                if len(sorted_by_cs) > 0:
                    SC[sample_idx, solution_num-1] += test_output(sorted_by_cs[0][0])
        

        if noagents:
            proposer_nodes_count += sum([len(_["new_steps"]) for _ in steps])
            continue

        agents = record["agents"]
        # going through all the agents
        proposer_nodes_count += 1
        for proposer_idx in range(number_of_rounds):
            if proposer_idx == 0:
                proposer_names = [f"proposer{proposer_idx}_{0}"]
            else:
                proposer_names = [f"proposer{proposer_idx}_{_}" for _ in range(breadth)]
            for proposer_name in proposer_names:
                if proposer_name not in agents:
                    # if it generates less candidiates than breadth, then it will not be in the agents
                    continue
                proposer = agents[proposer_name]
                prompt = proposer[0][1]
                completions = proposer[1:]
                for completion in completions:
                    completion = completion[1]
                    proposer_nodes_count += len(completion.split("\n"))
                    # print(f"proposer name: {proposer_name}")
                    # print(f"prompt: {prompt}")
                    # print(f"completion: {completion}")
                    proposer_token_generated = overall_utils.num_tokens_from_string(completion,model=model)
                    proposer_token_encoded = overall_utils.num_tokens_from_string(prompt,model=model)
                    proposer_tokens_generated_by_round[proposer_idx] += proposer_token_generated
                    proposer_tokens_encoded_by_round[proposer_idx] += proposer_token_encoded

            # for evaluator
            evaluator_names_for_this_round = [_ for _ in agents if f"evaluator{proposer_idx}" in _]
            # print(evaluator_names_for_this_round)
            for evaluator_name in evaluator_names_for_this_round:
                prompt = agents[evaluator_name][0][1]
                completions = [_[1] for _ in agents[evaluator_name][1:]]
                evaluator_nodes_count += len(completions)
                evaluator_token_generated = sum([overall_utils.num_tokens_from_string(completion,model=model) for completion in completions])
                evaluator_token_encoded = overall_utils.num_tokens_from_string(prompt,model=model)
                evaluator_tokens_generated_by_round[proposer_idx] += evaluator_token_generated
                evaluator_tokens_encoded_by_round[proposer_idx] += evaluator_token_encoded
    tokens_count = [proposer_tokens_generated_by_round, proposer_tokens_encoded_by_round, evaluator_tokens_generated_by_round, evaluator_tokens_encoded_by_round]
    SC = np.mean(SC, axis=0)
    nodes_count_dict = {"proposer_nodes_count": proposer_nodes_count/data_size, "evaluator_nodes_count": evaluator_nodes_count/data_size,"total_nodes_count": (proposer_nodes_count+evaluator_nodes_count)/data_size}
    return {"tokens_count": tokens_count, "nodes_count": nodes_count_dict,"best_at_k": best_at_k/n, "SC":SC/n}

# file = "results/game24/gpt-4-0613_gpt-4-0613_0.7_propose1_value3_greedy5_start900_end1000_ToT_3shots_gpt4value/steps.json"
# gpt4_tot3shot_b5_gpt4val = compute_accuracy_budget(file, number_of_rounds=4, breadth=5,n=100,CoT=False,noagents=False)