import os
import sys
import json
import numpy as np
import logging

sys.path.append("..")
import overall_utils
from .qa_agents import *

def _read_jsonl(path: str):
    with open(path) as fh:
        return [json.loads(line) for line in fh.readlines() if line]


def _load_data(dataset, data_path,n=100):
    np.random.seed(0)
    if dataset == 'hotpotQA':
        questions = _read_jsonl(data_path)
        return questions
    elif dataset == 'CSQA':
        questions = _read_jsonl(data_path)
        return questions

        
def _init_models(configs,key,server):
    player_config = configs["llm_agent_player"]
    feedback_config = configs["llm_agent_feedback"]
    player_config["key"] = key
    feedback_config["key"] = key
    player_config["server"] = server
    feedback_config["server"] = server

    player_model = overall_utils._load_model(player_config)
    feedback_model = overall_utils._load_model(feedback_config)
    return [player_model, feedback_model]

def get_agent(strategy):
    if strategy == 'ZeroShotMAD':
        return QA_ZEROSHOTMAD
    elif strategy == 'ZeroShotMAD_JUDGE':
        return QA_ZEROSHOTMADJUDGE
    elif strategy == 'ZeroShotReflexion' or strategy == 'ZeroShot_Feedback_JID' or strategy == 'ZeroShot_CoT' or strategy == 'Evaluate' or strategy == 'CoT_RecursiveCoT' or strategy == 'CoT_Already_Evaluate':
        return QA_ZEROSHOT_FEEDBACK
    elif strategy == "SelfDiscover":
        return QA_ZEROSHOT_SELFDISCOVER
    else:
        raise NotImplementedError

def run_qa(config):
    # load data
    data_path = config["data_path"]
    data_file_name = os.path.basename(data_path)
    data = _load_data(config["dataset"], config["data_path"])
    continue_leftoff = 0
    if "continue_leftoff" in config:
        continue_leftoff = config["continue_leftoff"]
    logging.info(f"Loaded data, {len(data)} samples.")
    print(data[0])

    # log path
    model_name = list(config["llms"].items())[0][1]["model_name"]
    if model_name == "":
        model_name = list(config["llms"].items())[1][1]["model_name"]
    model_name = model_name.replace("/", "-")
    save_path = os.path.join(config["save_path"], config["dataset"], config["strategy"])
    if "multi_run" in config:
        save_path = os.path.join(save_path, config["multi_run"])
    
    temperature = list(config["llms"].items())[0][1]["temperature"]
    save_path = os.path.join(save_path, f"{model_name}_temp{temperature}_agent{config['num_agents']}_round{config['num_rounds']}_datafile_{data_file_name}_{config['additional']}")
    file = os.path.join(save_path, 'steps.json')
    config_file = os.path.join(save_path, 'config.json')
    os.makedirs(save_path, exist_ok=True)
    with open(config_file, "w") as f:
        json.dump(config, f, indent=4)
    logging.info(f"Saving logs to {file}...")

    # initialize agents
    logging.info("Initializing agents...")
    agents = []
    Agent = get_agent(config["strategy"])
    models = _init_models(config["llms"], config["key"],config["server"])
    for i in range(len(data)):
        agent = Agent(models_tokenizers=models,question=data[i]["question"], answer=data[i]["answer"], example=data[i],config = config)
        agents.append(agent)

    logs, cnt_avg, cnt_any = [], 0, 0
    cost = 0
    if continue_leftoff != 0:
        with open(file, 'r') as f:
            logs = json.load(f)
        cost = logs[-1]["usage_so_far"]
        logging.info(f"continue left off at {continue_leftoff}")
    end = len(agents)
    for i in range(continue_leftoff, end):
        logging.info(f'running: question/agent{i}')
        agent = agents[i]
        success = agent.run()

        # log
        info = agent.infos
        info = {"steps": info}
        cost += agent.cost
        info.update({'idx': i, 'usage_so_far':cost, "prompt_token_used": agent.prompt_token_used, "completion_token_used": agent.completion_token_used})
        logs.append(info)
        with open(file, 'w') as f:
            json.dump(logs, f, indent=4)
        
        # log main metric
        cnt_avg += success / 1
        cnt_any += success
        print(i, 'cnt_avg', cnt_avg, 'cnt_any', cnt_any, '\n')
    
    n = len(logs)
    print(cnt_avg / n, cnt_any / n)
    print('usage_so_far', cost)
    print(file)
    for agent in agents:
        agent.models_tokenizers = None
        agent.llm_agent_player = None
        agent.llm_agent_feedback = None
        for player in agent.players:
            player.model = None
            player.tokenizer = None
    overall_utils.save_agents(agents, os.path.join(save_path, 'agents'))


if __name__ == '__main__':
    pass