import yaml
import os
import sys
import json
import argparse
import overall_utils
from langchain import OpenAI, Wikipedia
from langchain.chat_models import ChatOpenAI
import pandas as pd
from .agents_game_of_24 import Game_of_24_Agent
import logging


def _load_data(task_name,task_path,config):
    if task_name == "game24":
        data = list(pd.read_csv(task_path)['Puzzles'])[config["task_start_index"]:config["task_end_index"]]
        return data


def run_game24(config):
    # load data
    dataset = config["dataset"]
    data = _load_data(dataset, config["task_file_path"],config)
    if "same_first_step_as" in config:
        file= config["same_first_step_as"]
        same_first_step_as = json.load(open(file, "r"))
    else:
        same_first_step_as = []
    chat_mode = True
    logging.info(f"Loaded data, {len(data)} samples.")
    continue_leftoff = 0
    if "continue_leftoff" in config:
        continue_leftoff = config["continue_leftoff"]

    # strategy
    if config["strategy"] == "bfs":
        Agent = Game_of_24_Agent
    elif config["strategy"] == "CoT" or config["strategy"] == "Reflexion":
        Agent = Game_of_24_Agent
    

    # log path
    model_name = list(config["llms"].items())[0][1]["model_name"]
    model_name2 = list(config["llms"].items())[1][1]["model_name"]
    temperature = list(config["llms"].items())[0][1]["temperature"]
    save_path = f'results/{dataset}/{model_name}_{model_name2}_{temperature}_{config["method_generate"]}{config["n_generate_sample"]}_{config["method_evaluate"]}{config["n_evaluate_sample"]}_{config["method_select"]}{config["n_select_sample"]}_start{config["task_start_index"]}_end{config["task_end_index"]}_{config["additional"]}'
    file = os.path.join(save_path, "steps.json")
    os.makedirs(save_path, exist_ok=True)
    config_file = os.path.join(save_path, 'config.json')
    with open(config_file, "w") as f:
        json.dump(config, f, indent=4)
    logging.info(f"Saving logs to {file}...")

    # initialize agents
    logging.info("Initializing agents...")
    agents = []
    for i in range(len(data)):
        agent = Agent(data[i],steps = config["steps"],chat=chat_mode,configs = config)
        agents.append(agent)

    cost = 0
    logs, cnt_avg, cnt_any = [], 0, 0
    if continue_leftoff != 0:
        with open(file, 'r') as f:
            logs = json.load(f)
        cost = logs[-1]["usage_so_far"]
        logging.info(f"continue left off at {continue_leftoff}")
    end = len(agents)
    if "end" in config:
        end = config["end"]
        if end == 0:
            end = len(agents)
    print(f"end at {end}")
    for i in range(continue_leftoff,end):
        agent = agents[i]
        if len(same_first_step_as) != 0:
            same_first_step_as_i = same_first_step_as[i]
        else:
            same_first_step_as_i = None
        logging.info(f'running: agent{i}')
        ys, info = agent.run(same_first_step_as_i)
        # log
        infos = [agent.test_output(y) for y in ys]
        cost += agent.cost
        info.update({'idx': i, 'ys': ys, 'infos': infos, 'usage_so_far': cost, "prompt_tokens": agent.prompt_token_used,"completion_tokens":agent.completion_token_used})
        logs.append(info)
        with open(file, 'w') as f:
            json.dump(logs, f, indent=4)
        
        # log main metric
        accs = [info['r'] for info in infos]
        cnt_avg += sum(accs) / len(accs)
        cnt_any += any(accs)
        print(i, 'sum(accs)', sum(accs), 'cnt_avg', cnt_avg, 'cnt_any', cnt_any, '\n')
        overall_utils.save_agents(agents, os.path.join(save_path, 'agents'))
    
    n = config["task_end_index"] - config["task_start_index"]
    print(cnt_avg / n, cnt_any / n)
    print('usage_so_far', cost)
    print(file)


if __name__ == '__main__':
    pass