import numpy as np
import sys
import json
import re, string, os
from collections import defaultdict, Counter
import math
sys.path.append("../../gsm8k/evaluations/")
from eval_MATH import estimate_pass_at_k, GPT3, GPT4,entropy
sys.path.append("../..")
import overall_utils

from CSQA_eval import normalize_answer_CSQA, grade_answer_CSQA, parse_answer_noreact_CSQA
from hotpotQA_eval import normalize_answer_hotpot, grade_answer_hotpot, parse_answer_noreact_hotpot


def Count_Answers_given_prediction_pairs(prediction_pairs, normalize_answer,return_complexity=False):
    confidence_dict = defaultdict(list) 
    complexity_dict = [] # answer -> list of complexity
    for pair in prediction_pairs:
        if pair != None:
            pred = pair[0]
            pred_answer = pair[1]
            if pred_answer != None:
                pred_answer = normalize_answer(pred_answer)
            confidence = pair[2]
            confidence_dict[pred_answer].append(confidence)
            num_sentences = len(re.split(r'(?<!\d)\.(?!\d|$)', pred))
            complexity_dict.append((pred_answer, len(pred), num_sentences))
    if len(confidence_dict) == 0:
        return confidence_dict
    else:
        sorted_by_mv = sorted(confidence_dict.items(), key=lambda x: len(x[1]), reverse=True)
        if return_complexity:
            return sorted_by_mv, complexity_dict
        return sorted_by_mv

### COT ##############################################

def extract_CoT_predictions(agents,normalize_answer,parse_answer_noreact,ReAct=False,model=GPT3):
    predictions_pairs = []
    number_of_rounds = 0
    solution_agent_template = "thought_agent"
    extract_method = parse_answer_noreact
    for agent_name in agents:
        if solution_agent_template in agent_name:
            agent_convs = agents[agent_name]
            prompt_string = agent_convs[0][1]
            prompt_token = overall_utils.num_tokens_from_string(prompt_string, model)
            number_of_convs = len(agent_convs)
            for conv_idx in range(1, number_of_convs):
                pred = agent_convs[conv_idx][1]
                pred_token = overall_utils.num_tokens_from_string(pred, model)
                _, pred_answer = extract_method(pred)
                if pred_answer != None:
                    pred_answer = normalize_answer(pred_answer)
                if pred_answer == None:
                    predictions_pairs.append((pred, None, 1,pred_token, prompt_token))
                else:
                    predictions_pairs.append((pred, pred_answer, 1,pred_token, prompt_token))
    return predictions_pairs

def extract_SD_predictions(agents,normalize_answer,parse_answer_noreact,ReAct=False,model=GPT3):
    predictions_pairs = []
    task_prompt_token = 0; task_generation_token = 0
    if "SELECT" in agents:
        for agent_name in ["SELECT", "ADAPT","IMPLEMENT"]:
            agent_convs = agents[agent_name]
            prompt_string = agent_convs[0][1]
            prompt_token = overall_utils.num_tokens_from_string(prompt_string, model)
            
            pred = agent_convs[1][1]
            pred_token = overall_utils.num_tokens_from_string(pred, model)
            task_prompt_token += prompt_token
            task_generation_token += pred_token
    number_of_rounds = 0
    solution_agent_template = "thought_agent"
    extract_method = parse_answer_noreact
    for agent_name in agents:
        if solution_agent_template in agent_name:
            agent_convs = agents[agent_name]
            prompt_string = agent_convs[0][1]
            prompt_token = overall_utils.num_tokens_from_string(prompt_string, model)
            number_of_convs = len(agent_convs)
            for conv_idx in range(1, number_of_convs):
                pred = agent_convs[conv_idx][1]
                pred_token = overall_utils.num_tokens_from_string(pred, model)
                _, pred_answer = extract_method(pred)
                if pred_answer != None:
                    pred_answer = normalize_answer(pred_answer)
                if pred_answer == None:
                    predictions_pairs.append((pred, None, 1,pred_token, prompt_token))
                else:
                    predictions_pairs.append((pred, pred_answer, 1,pred_token, prompt_token))
    return predictions_pairs, [task_prompt_token,task_generation_token]
    
def evaluate_CoT_SC(file_name, task, strategy="COT", max_agents=5, sample_times=100,seed=0, ReAct=True,shifted=0,model=GPT3):
    if task == "CSQA":
        grade_answer = grade_answer_CSQA
        normalize_answer = normalize_answer_CSQA
        parse_answer_noreact = parse_answer_noreact_CSQA
    elif task == "hotpotQA":
        grade_answer = grade_answer_hotpot
        normalize_answer = normalize_answer_hotpot
        parse_answer_noreact = parse_answer_noreact_hotpot
    np.random.seed(seed)
    response_dict = json.load(open(file_name, "r"))
    data_size = len(response_dict)
    print(f"data size: {data_size}")
    MV = np.zeros((sample_times,max_agents));MV_tokens_generated = np.zeros((sample_times,max_agents));MV_tokens_encoded = np.zeros((sample_times,max_agents))
    best_at_k = np.zeros(max_agents);best_at_k_tokens_generated = np.zeros(max_agents);best_at_k_tokens_encoded = np.zeros(max_agents)
    each_agent_acc = np.zeros(max_agents)
    avg_entropy = np.zeros((sample_times,max_agents))
    problems_get_more_wrong = np.zeros((sample_times,max_agents)); problem_counts = 0
    problems_get_strictly_more_wrong = np.zeros((sample_times,max_agents)); problem_strictly_counts = 0;problem_indexes = []
    solutions_correct_percent = []; solutions_correct_percent_None = []
    for i in range(len(response_dict)):
        content = response_dict[i]["steps"][0]
        gt = content["answer"]
        gold = gt
        agents = content["agents"]
        all_solutions = content["all_solutions"]
        # get predictions
        if strategy == "SD":
            predictions_pairs, additional_task_tokens = extract_SD_predictions(agents,normalize_answer,parse_answer_noreact,ReAct=ReAct,model=model)
            predictions_pairs = predictions_pairs[shifted:max_agents+shifted]
            MV_tokens_encoded += additional_task_tokens[0]
            MV_tokens_generated += additional_task_tokens[1]
            best_at_k_tokens_encoded += additional_task_tokens[0]
            best_at_k_tokens_generated += additional_task_tokens[1]
        else:
            predictions_pairs = extract_CoT_predictions(agents,normalize_answer,parse_answer_noreact,ReAct=ReAct,model=model)[shifted:max_agents+shifted]
        temp_problems_get_more_wrong = np.zeros((sample_times,max_agents))
        for sample_idx in range(sample_times):
            for agent_num in range(1,max_agents+1):
                # predictions_pairs = extract_MAD_predictions(agents,6,3,ReAct=ReAct)
                # predictions_pairs = [predictions_pairs[agent_idx][round_idx] for round_idx in range(3) for agent_idx in range(6)][:6]
                # print(len(predictions_pairs))
                
                # random sample
                sample_num = agent_num
                sampled_predictions_pairs = [predictions_pairs[_] for _ in np.random.choice(len(predictions_pairs), sample_num, replace=False)]
                sampled_predictions_answers = [_[1] for _ in sampled_predictions_pairs]
                avg_entropy[sample_idx,agent_num-1] += entropy(sampled_predictions_answers)
                sorted_by_mv = Count_Answers_given_prediction_pairs(sampled_predictions_pairs,normalize_answer)
                if len(sorted_by_mv) != 0:
                    correctness = grade_answer(sorted_by_mv[0][0], gold)
                    MV[sample_idx,agent_num-1] += correctness
                    temp_problems_get_more_wrong[sample_idx,agent_num-1] += correctness
                total_tokens_generated = sum([_[3] for _ in sampled_predictions_pairs])
                total_tokens_encoded = np.sum([_[4] for _ in sampled_predictions_pairs])
                MV_tokens_generated[sample_idx,agent_num-1] += total_tokens_generated
                MV_tokens_encoded[sample_idx,agent_num-1] += total_tokens_encoded
                
                # each agent acc
                if sample_idx == 0:
                    temp_pred_pair = predictions_pairs[agent_num-1]
                    if temp_pred_pair != None:
                        pred_answer = temp_pred_pair[1]
                        each_agent_acc[agent_num-1] += grade_answer(pred_answer, gold)
                
                # best at k, oracle setting
                if sample_idx == 0 and agent_num == max_agents:
                    total = [max_agents]
                    correct = [0]
                    for pair in predictions_pairs:
                        if pair != None:
                            pred = pair[0]
                            pred_answer = pair[1]
                            if grade_answer(pred_answer, gold) == 1:
                                correct[0] += 1
                    avg_tokens_generated = np.mean([_[3] for _ in predictions_pairs])
                    for k in range(max_agents):
                        best_at_k[k] += estimate_pass_at_k(total, correct, k+1)
                        best_at_k_tokens_generated[k] += avg_tokens_generated*(k+1)
                        best_at_k_tokens_encoded[k] += np.mean([_[4] for _ in predictions_pairs])
                    
                    # calibration by percentage of correct answers
                    answer_counts = Counter(sampled_predictions_answers)
                    None_total = len(sampled_predictions_answers)
                    for k in answer_counts:
                        value = answer_counts[k]
                        correctness = grade_answer(k, gold)
                        solutions_correct_percent_None.append((k, value/None_total, correctness))
                    no_None_total = sum([len(_) for k,_ in sorted_by_mv])
                    for k,value in sorted_by_mv:
                        counts = len(value)
                        correctness = grade_answer(k, gold)
                        solutions_correct_percent.append((k, counts/no_None_total, correctness))  

        # whether there are more wrong answers
        longest_length = len(sorted_by_mv[0][1])
        correct_answer_exist = False; correct_answer_length = 9999999999
        longest_but_wrong = False
        for k, value in sorted_by_mv:
            pred_answer = k
            cur_length = len(value)
            correctness = grade_answer(pred_answer, gold)
            if correctness == 1:
                correct_answer_exist = True
                correct_answer_length = cur_length
            else:
                if cur_length == longest_length:
                    longest_but_wrong = True
        if correct_answer_exist:
            # 1) correct answer votes is less than the longest answer
            if correct_answer_length < longest_length:
                problems_get_more_wrong += temp_problems_get_more_wrong
                problems_get_strictly_more_wrong += temp_problems_get_more_wrong
                problem_counts += 1
                problem_strictly_counts += 1
                problem_indexes.append(i)
            # 2) correct answer votes is equal to the longest answer, but the longest answer is wrong
            elif correct_answer_length == longest_length and longest_but_wrong:
                problems_get_more_wrong += temp_problems_get_more_wrong
                problem_counts += 1
    problems_get_more_wrong = problems_get_more_wrong/problem_counts; problems_get_strictly_more_wrong = problems_get_strictly_more_wrong/problem_strictly_counts
    print(f"problem gets more wrong counts: {problem_counts}, problem gets strictly more wrong counts: {problem_strictly_counts}")     
    MV = MV/data_size; best_at_k = best_at_k/data_size;each_agent_acc = each_agent_acc/data_size;avg_entropy = avg_entropy/data_size
    MV_result = {"result": MV, "tokens_generated":MV_tokens_generated, "tokens_encoded":MV_tokens_encoded}
    best_at_k_result = {"result":best_at_k, "tokens_generated":best_at_k_tokens_generated, "tokens_encoded":best_at_k_tokens_encoded}
    return {"MV": MV_result,  "best_at_k":best_at_k_result,"each_agent_acc":each_agent_acc,"avg_entropy":np.mean(avg_entropy,axis=0),"problems_get_more_wrong":problems_get_more_wrong,
            "problems_get_strictly_more_wrong":problems_get_strictly_more_wrong,"more_wrong_problem_indexes":problem_indexes,"solutions_correct_percent":solutions_correct_percent,
            "solutions_correct_percent_None":solutions_correct_percent_None}

### MAD ##############################################

def extract_MAD_predictions(agents,agent_num,round_num,normalize_answer,parse_answer_noreact,ReAct=False,model=GPT3,partial=False):
    predictions_pairs_agents = []
    number_of_rounds = 0
    extract_method = parse_answer_noreact
    for i in range(agent_num):
        agent_name = f"player{i}_round0"
        if agent_name in agents:
            agent_convs = agents[f"player{i}_round0"]
        else:
            agent_convs = agents[f"player{i}"]
        temp_predictions_pairs = []
        prompt_token = 0
        for conv_idx in range(round_num):
            prompt = agent_convs[conv_idx*2][1]
            pred = agent_convs[conv_idx*2+1][1] #f"player{i}_round{conv_idx}"
            prompt_token += overall_utils.num_tokens_from_string(prompt, model)
            pred_token = overall_utils.num_tokens_from_string(pred, model)
            # find summary
            summary_agent_name = f"summarize_player{i}_round{conv_idx+1}"
            if summary_agent_name in agents:
                summary_prompt = agents[summary_agent_name][0][1]
                summary = agents[summary_agent_name][1][1]
                summary_prompt_token = overall_utils.num_tokens_from_string(summary_prompt, model)
                summary_pred_token = overall_utils.num_tokens_from_string(summary, model)
                prompt_token += summary_prompt_token
                # pred_token += summary_pred_token # pred token is already added

                # add the original generated solution token count
                original_pred = summary_prompt.split("Answer:")[1]
                original_pred = original_pred.split("\n\nSummarize the answer ")[0]
                original_pred_token = overall_utils.num_tokens_from_string(original_pred, model)
                pred_token += original_pred_token
            _,pred_answer = extract_method(pred)
            if pred_answer != None:
                pred_answer = normalize_answer(pred_answer)
            if pred_answer == None:
                temp_predictions_pairs.append((pred, None, 1, pred_token, prompt_token))
            else:
                temp_predictions_pairs.append((pred, pred_answer, 1, pred_token, prompt_token))
        predictions_pairs_agents.append(temp_predictions_pairs)
    return predictions_pairs_agents

def evaluate_MAD_SC(file_name, task,number_of_agent=3, number_of_round=2,max_queries=6, sample_times=100,seed=0,model=GPT3,indexes=None,partial=False):
    if task == "CSQA":
        grade_answer = grade_answer_CSQA
        normalize_answer = normalize_answer_CSQA
        parse_answer_noreact = parse_answer_noreact_CSQA
    elif task == "hotpotQA":
        grade_answer = grade_answer_hotpot
        normalize_answer = normalize_answer_hotpot
        parse_answer_noreact = parse_answer_noreact_hotpot
    np.random.seed(seed)
    response_dict = json.load(open(file_name, "r"))
    if indexes != None:
        response_dict = [response_dict[_] for _ in indexes]
    data_size = len(response_dict)
    MV = np.zeros((sample_times,max_queries));MV_tokens_generated = np.zeros((sample_times,max_queries));MV_tokens_encoded = np.zeros((sample_times,max_queries))
    MV2 = np.zeros((sample_times,max_queries));MV2_tokens_generated = np.zeros((sample_times,max_queries));MV2_tokens_encoded = np.zeros((sample_times,max_queries))
    oracle_at_k = np.zeros(max_queries);oracle_at_k_tokens_generated = np.zeros(max_queries);oracle_at_k_tokens_encoded = np.zeros(max_queries)
    expand_agent_wise_SC = np.zeros(max_queries) # expand on agent first, so 3q = 3 agents on round 1
    expand_round_wise_SC = np.zeros(max_queries)
    round_wise_SC = np.zeros(number_of_round) # first round SC, 2nd round SC,...
    round_wise_bestatk = np.zeros(number_of_round)
    round_wise_entropy = np.zeros(number_of_round)
    for i in range(len(response_dict)):
        content = response_dict[i]["steps"][0]
        gt = content["answer"]
        gold = gt
        agents = content["agents"]
        # get predictions
        predictions_pairs_agents = extract_MAD_predictions(agents,number_of_agent,number_of_round,normalize_answer,parse_answer_noreact,model=model,partial=partial)
        # temp = [predictions_pairs_agents[agent_idx][_] for _ in range(number_of_round) for agent_idx in range(number_of_agent)]
        # print([_[1] for _ in temp])
        # print(len(predictions_pairs_agents))
        for query_num in range(1,max_queries+1):
            for sample_idx in range(sample_times):

                # random sample by round order
                sample_num = query_num
                sample_round = math.ceil(sample_num/number_of_agent)
                predictions_pairs = [predictions_pairs_agents[agent_idx][_] for _ in range(sample_round) for agent_idx in range(number_of_agent)]
                sampled_predictions_pairs = [predictions_pairs[_] for _ in np.random.choice(len(predictions_pairs), sample_num, replace=False)]
                total_tokens_generated = sum([_[3] for _ in sampled_predictions_pairs])
                total_tokens_encoded = sum([_[4] for _ in sampled_predictions_pairs])
                MV_tokens_generated[sample_idx,query_num-1] += total_tokens_generated
                MV_tokens_encoded[sample_idx,query_num-1] += total_tokens_encoded
                sorted_by_mv = Count_Answers_given_prediction_pairs(sampled_predictions_pairs,normalize_answer)
                if len(sorted_by_mv) != 0:
                    MV[sample_idx,query_num-1] += grade_answer(sorted_by_mv[0][0], gold)

                # random sample by round order, also the previous round is always included
                previous_rounds = (sample_round-1)*number_of_agent
                sample_num = sample_num - previous_rounds
                prediction_pairs_pool = predictions_pairs[previous_rounds:]
                sampled_predictions_pairs = [prediction_pairs_pool[_] for _ in np.random.choice(len(prediction_pairs_pool), sample_num, replace=False)]
                sampled_predictions_pairs = predictions_pairs[:previous_rounds] + sampled_predictions_pairs
                total_tokens_generated = sum([_[3] for _ in sampled_predictions_pairs])
                total_tokens_encoded = sum([_[4] for _ in sampled_predictions_pairs])
                MV2_tokens_generated[sample_idx,query_num-1] += total_tokens_generated
                MV2_tokens_encoded[sample_idx,query_num-1] += total_tokens_encoded
                sorted_by_mv = Count_Answers_given_prediction_pairs(sampled_predictions_pairs,normalize_answer)
                if len(sorted_by_mv) != 0:
                    MV2[sample_idx,query_num-1] += grade_answer(sorted_by_mv[0][0], gold)

                # row-wise
                if sample_idx == 0:
                    expand_agent_wise_prediction_pairs = [predictions_pairs_agents[agent_idx][round_idx] for agent_idx in  range(number_of_agent) for round_idx in range(number_of_round)]
                    expand_round_wise_prediction_pairs = [predictions_pairs_agents[agent_idx][round_idx] for round_idx in range(number_of_round) for agent_idx in  range(number_of_agent)]
                    expand_agent_wise_prediction_pairs = expand_agent_wise_prediction_pairs[:query_num]
                    expand_round_wise_prediction_pairs = expand_round_wise_prediction_pairs[:query_num]
                    expand_agent_wise_sorted_by_mv = Count_Answers_given_prediction_pairs(expand_agent_wise_prediction_pairs,normalize_answer)
                    expand_round_wise_sorted_by_mv = Count_Answers_given_prediction_pairs(expand_round_wise_prediction_pairs,normalize_answer)
                    if len(expand_agent_wise_sorted_by_mv)!=0:
                        expand_agent_wise_SC[query_num-1] += grade_answer(expand_agent_wise_sorted_by_mv[0][0], gold)
                    if len(expand_round_wise_sorted_by_mv)!=0:
                        expand_round_wise_SC[query_num-1] += grade_answer(expand_round_wise_sorted_by_mv[0][0], gold)
                
                # round-wise & oracle at k
                if sample_idx == 0 and query_num == max_queries:
                    for round_idx in range(number_of_round):
                        round_wise_prediction_pairs = [predictions_pairs_agents[agent_idx][round_idx] for agent_idx in  range(number_of_agent)]
                        round_wise_pred_answers = [_[1] for _ in round_wise_prediction_pairs]
                        round_entropy = entropy(round_wise_pred_answers)
                        round_wise_entropy[round_idx] += round_entropy
                        round_wise_sorted_by_mv = Count_Answers_given_prediction_pairs(round_wise_prediction_pairs,normalize_answer)
                        if len(round_wise_sorted_by_mv)!=0:
                            round_wise_SC[round_idx] += grade_answer(round_wise_sorted_by_mv[0][0], gold)
                        for answer in round_wise_pred_answers:
                            if grade_answer(answer, gold) == 1:
                                round_wise_bestatk[round_idx] += 1
                                break

                    all_prediction_pairs_by_round = [predictions_pairs_agents[agent_idx][round_idx] for round_idx in range(number_of_round) for agent_idx in  range(number_of_agent)]
                    for k in range(max_queries):
                        oracle_at_k_tokens_generated[k] += sum([_[3] for _ in all_prediction_pairs_by_round[:k+1]])
                        oracle_at_k_tokens_encoded[k] += sum([_[4] for _ in all_prediction_pairs_by_round[:k+1]])
                        pred_answer = all_prediction_pairs_by_round[k][1]
                        if grade_answer(pred_answer, gold) == 1:
                            oracle_at_k[k:] += 1
                            break
                    
    MV = MV/data_size; MV2 = MV2/data_size;expand_agent_wise_SC = expand_agent_wise_SC/data_size;expand_round_wise_SC=expand_round_wise_SC/data_size;round_wise_SC=round_wise_SC/data_size;oracle_at_k = oracle_at_k/data_size;round_wise_entropy=round_wise_entropy/data_size;round_wise_bestatk=round_wise_bestatk/data_size
    MV_result = {"result": MV, "tokens_generated":MV_tokens_generated, "tokens_encoded":MV_tokens_encoded}
    MV2_result = {"result": MV2, "tokens_generated":MV2_tokens_generated, "tokens_encoded":MV2_tokens_encoded}
    oracle_result = {"result":oracle_at_k, "tokens_generated":oracle_at_k_tokens_generated, "tokens_encoded":oracle_at_k_tokens_encoded,"round_wise_bestatk":round_wise_bestatk}
    return {"MV": MV_result, "MV2":MV2_result,"oracle_result":oracle_result,"expand_agent_wise_SC":expand_agent_wise_SC, "expand_round_wise_SC":expand_round_wise_SC,"round_wise_SC":round_wise_SC,"round_wise_entropy":round_wise_entropy}


### Reflexion ##############################################

def extract_Reflexion_predictions(agents,n_trials, normalize_answer,parse_answer_noreact,ReAct=False,model=GPT3):
    predictions_pairs = []
    agent_names = [f"round{_}_proposer" for _ in range(n_trials)]
    extract_method = parse_answer_noreact
    for idx,agent_name in enumerate(agent_names):
        agent_convs = agents[agent_name]
        number_of_convs = len(agent_convs)
        prompt = agent_convs[0][1]
        pred = agent_convs[1][1]
        # tokens count
        prompt_token = overall_utils.num_tokens_from_string(prompt, model)
        pred_token = overall_utils.num_tokens_from_string(pred, model)
        if idx != 0:
            reflector_name = f"round{idx}_reflector"
            reflector = agents[reflector_name]
            reflection_prompt = reflector[0][1]
            reflection_pred = reflector[1][1]
            prompt_token += overall_utils.num_tokens_from_string(reflection_prompt, model)
            pred_token += overall_utils.num_tokens_from_string(reflection_pred, model)

        _, pred_answer = extract_method(pred)
        if pred_answer != None:
            pred_answer = normalize_answer(pred_answer)
        if pred_answer == None:
            predictions_pairs.append((pred, None, 1, pred_token, prompt_token))
        else:
            predictions_pairs.append((pred, pred_answer, 1, pred_token, prompt_token))
    return predictions_pairs

def evaluate_REFLEXION_SC(file_name, task,n_trials=4, sample_times=100,seed=0,model=GPT3,multireflect=False,indexes=None):
    if task == "CSQA":
        grade_answer = grade_answer_CSQA
        normalize_answer = normalize_answer_CSQA
        parse_answer_noreact = parse_answer_noreact_CSQA
    elif task == "hotpotQA":
        grade_answer = grade_answer_hotpot
        normalize_answer = normalize_answer_hotpot
        parse_answer_noreact = parse_answer_noreact_hotpot
    np.random.seed(seed)
    response_dict = json.load(open(file_name, "r"))
    if indexes != None:
        response_dict = [response_dict[_] for _ in indexes]
    data_size = len(response_dict)
    random_stops = np.zeros(sample_times);random_stops_tokens_generated = np.zeros(sample_times);random_stops_tokens_encoded = np.zeros(sample_times)
    oracle_stop = 0; oracle_stop_length = np.ones(data_size)*n_trials; oracle_tokens_encoded = np.zeros(n_trials); oracle_tokens_generated = np.zeros(n_trials)
    MV = np.zeros((sample_times,n_trials)); MV_tokens_generated = np.zeros((sample_times,n_trials)); MV_tokens_encoded = np.zeros((sample_times,n_trials))
    Accuracy_at_k = np.zeros(n_trials)
    for i in range(len(response_dict)):
        content = response_dict[i]["steps"][0]
        gt = content["answer"]
        gold = gt
        agents = content["agents"]
        all_solutions = content["all_solutions"]
        # get predictions
        all_predictions_pairs = extract_Reflexion_predictions(agents,n_trials,normalize_answer,parse_answer_noreact,model=model)
        for agent_num in range(1,n_trials+1):
            predictions_pairs = all_predictions_pairs[:agent_num]
            for sample_idx in range(sample_times):
                # MV random sample
                sample_num = agent_num
                sampled_predictions_pairs = [predictions_pairs[_] for _ in np.random.choice(len(predictions_pairs), sample_num, replace=False)]
                # print(sampled_predictions_pairs)
                # tokens count
                total_tokens_generated = sum([_[3] for _ in sampled_predictions_pairs])
                total_tokens_encoded = sum([_[4] for _ in sampled_predictions_pairs])
                MV_tokens_generated[sample_idx,agent_num-1] += total_tokens_generated
                MV_tokens_encoded[sample_idx,agent_num-1] += total_tokens_encoded
                #####
                sorted_by_mv = Count_Answers_given_prediction_pairs(sampled_predictions_pairs,normalize_answer)
                if len(sorted_by_mv) != 0:
                    # print(sorted_by_mv[0][0], gold)
                    MV[sample_idx,agent_num-1] += grade_answer(sorted_by_mv[0][0], gold)
                
                # oracle stop
                if sample_idx == 0 and agent_num == n_trials:
                    for pair_idx,pair in enumerate(predictions_pairs):
                        pred = pair[0]
                        pred_answer = pair[1]
                        oracle_tokens_encoded[pair_idx] += sum([_[4] for _ in predictions_pairs[:pair_idx+1]])
                        oracle_tokens_generated[pair_idx] += sum([_[3] for _ in predictions_pairs[:pair_idx+1]])
                        if grade_answer(pred_answer, gold) == 1:
                            oracle_stop += 1
                            oracle_stop_length[i] = pair_idx
                            break
                
                # random stop 
                if sample_idx == 0 and agent_num == n_trials:
                    for j in range(sample_times):
                        random_stop_idx = np.random.randint(len(predictions_pairs))
                        pred_answer = predictions_pairs[random_stop_idx][1]
                        random_stops_tokens_generated[j] += sum([_[3] for _ in predictions_pairs[:random_stop_idx+1]])
                        random_stops_tokens_encoded[j] += sum([_[4] for _ in predictions_pairs[:random_stop_idx+1]])
                        random_stops[j] += grade_answer(pred_answer, gold)
                
                    
    MV = MV/data_size; random_stops = random_stops/data_size; oracle_stop = oracle_stop/data_size;Accuracy_at_k = Accuracy_at_k/data_size
    oracle_at_k = np.array([np.sum(oracle_stop_length<=k)/data_size for k in range(0,n_trials)])
    oracle_at_k_result = {"result":oracle_at_k,  "oracle_stop": oracle_stop,"oracle_stop_length":oracle_stop_length,"tokens_encoded":oracle_tokens_encoded,"tokens_generated":oracle_tokens_generated}
    MV_result = {"result": MV, "tokens_generated":MV_tokens_generated, "tokens_encoded":MV_tokens_encoded}
    random_stops_result = {"result":random_stops, "tokens_generated":random_stops_tokens_generated, "tokens_encoded":random_stops_tokens_encoded}
    return {"MV": MV_result, "oracle_at_k":oracle_at_k_result, "random_stops": random_stops_result}

