
import os
import sys
import json
import copy
import pandas as pd
import logging
import joblib
import numpy as np

from langchain import OpenAI, Wikipedia
from langchain.chat_models import ChatOpenAI

from .hotpotqa_agent import CoTAgent, ReflexionStrategy,CoTAggregateAgent, CoTMultiFeedbackAgent
from .hotpotqa_agent import ReactReflectAgent, ReactAgent, ReflexionStrategy
from .util import summarize_trial, log_trial
sys.path.append("..")
import overall_utils
from .hotpotqa_prompts import cot_simple_reflect_agent_prompt, cot_simple_reflect_prompt, cot_simple_agent_prompt,cot_reflect_agent_prompt,cot_reflect_prompt,io_agent_prompt
from .hotpotqa_prompts import react_reflect_agent_prompt, react_agent_prompt, reflect_prompt
from .hotpotqa_fewshots import COTQA_SIMPLE6, COT_SIMPLE_REFLECTION, COT, COT_REFLECT
from .hotpotqa_fewshots import IOQA_SIMPLE6, IO
from .hotpotqa_fewshots import CONTROL_COTQA_SIMPLE3, CONTROL_COTQA_REFLECT



def _load_prompts(strategy,context):
    if strategy == "CoT":
        if context:
            return cot_reflect_agent_prompt, COT, cot_reflect_prompt, COT_REFLECT, ReflexionStrategy.NONE
        else:
            return cot_simple_agent_prompt, COTQA_SIMPLE6, cot_simple_reflect_prompt, COT_SIMPLE_REFLECTION,ReflexionStrategy.NONE
    elif strategy == "CoT_reflexion":
        if context:
            return cot_reflect_agent_prompt, COT, cot_reflect_prompt, COT_REFLECT, ReflexionStrategy.REFLEXION
        else:
            return cot_simple_agent_prompt, COTQA_SIMPLE6, cot_simple_reflect_prompt, COT_SIMPLE_REFLECTION,ReflexionStrategy.REFLEXION
    elif strategy == "CoT_last_trial":
        if context:
            return cot_reflect_agent_prompt, COT, cot_reflect_prompt, COT_REFLECT, ReflexionStrategy.LAST_ATTEMPT
        else:
            return cot_simple_agent_prompt, COTQA_SIMPLE6, cot_simple_reflect_prompt, COT_SIMPLE_REFLECTION,ReflexionStrategy.LAST_ATTEMPT    
    elif strategy == "CoT_last_trial_reflexion":
        if context:
            return cot_reflect_agent_prompt, COT, cot_reflect_prompt, COT_REFLECT, ReflexionStrategy.LAST_ATTEMPT_AND_REFLEXION
        else:
            return cot_simple_agent_prompt, COTQA_SIMPLE6, cot_simple_reflect_prompt, COT_SIMPLE_REFLECTION,ReflexionStrategy.LAST_ATTEMPT_AND_REFLEXION   
    elif strategy == "ReAct":
        return react_agent_prompt, "", "", "", ReflexionStrategy.NONE
    elif strategy == "ReAct_reflexion":
        return react_reflect_agent_prompt, "", reflect_prompt, "", ReflexionStrategy.REFLEXION
    elif strategy == "ReAct_last_trial":
        return react_reflect_agent_prompt, "", reflect_prompt, "", ReflexionStrategy.LAST_ATTEMPT
    elif strategy == "ReAct_last_trial_reflexion":
        return react_reflect_agent_prompt, "", reflect_prompt, "", ReflexionStrategy.LAST_ATTEMPT_AND_REFLEXION
    elif strategy == "CoT_reflexion_aggregate_reflexion":
        if context:
            return cot_reflect_agent_prompt, COT, cot_reflect_prompt, COT_REFLECT, ReflexionStrategy.REFLEXION
        else:
            return cot_simple_agent_prompt, COTQA_SIMPLE6, cot_simple_reflect_prompt, COT_SIMPLE_REFLECTION,ReflexionStrategy.REFLEXION
    elif strategy == "CoT_reflexion_aggregate_answer":
        if context:
            return cot_reflect_agent_prompt, COT, cot_reflect_prompt, COT_REFLECT, ReflexionStrategy.REFLEXION
        else:
            return cot_simple_agent_prompt, COTQA_SIMPLE6, cot_simple_reflect_prompt, COT_SIMPLE_REFLECTION,ReflexionStrategy.REFLEXION
    elif strategy == "CoT_reflexion_fake_context":
        # context is true, but the template exmpales doesn't contain contexts.
        return cot_simple_agent_prompt, COTQA_SIMPLE6, cot_simple_reflect_prompt, COT_SIMPLE_REFLECTION,ReflexionStrategy.REFLEXION
    elif strategy == "CoT_reflexion_fake_context_control":
        # context is true, but the template exmpales doesn't contain contexts. also the few-shot examples are the same as context for control
        return cot_simple_agent_prompt, CONTROL_COTQA_SIMPLE3, cot_simple_reflect_prompt, CONTROL_COTQA_REFLECT,ReflexionStrategy.REFLEXION
    elif strategy == "CoT_reflexion_multifeedback" or strategy == "CoT_reflexion_multifeedback_MAR":
        if context:
            return cot_reflect_agent_prompt, COT, cot_reflect_prompt, COT_REFLECT, ReflexionStrategy.REFLEXION
        else:
            return cot_simple_agent_prompt, COTQA_SIMPLE6, cot_simple_reflect_prompt, COT_SIMPLE_REFLECTION,ReflexionStrategy.REFLEXION
    elif strategy == "IO":
        # reflections are not needed for IO, so we use the same template as CoT for those
        if context:
            return io_agent_prompt, IO, cot_reflect_prompt, COT_REFLECT, ReflexionStrategy.NONE
        else:
            return io_agent_prompt, IOQA_SIMPLE6, cot_simple_reflect_prompt, COT_SIMPLE_REFLECTION,ReflexionStrategy.NONE
    elif strategy == "CoT_hint":
        if context:
            return cot_reflect_agent_prompt, COT, cot_reflect_prompt, COT_REFLECT, ReflexionStrategy.NONE
        else:
            return cot_simple_agent_prompt, COTQA_SIMPLE6, cot_simple_reflect_prompt, COT_SIMPLE_REFLECTION,ReflexionStrategy.NONE
    else:
        raise ValueError("Invalid strategy")



# when logging, we log 2 things
# 1. steps recorded including prompts and everything
# 2  joblib of the agent
def run_hotpotqa(config):
    # load data
    logging.debug(f"current dir: {os.getcwd()}")
    hotpot = joblib.load('./data/hotpotqa/hotpot-qa-distractor-sample.joblib').reset_index(drop = True)
    strategy = config['strategy']
    context = config['context']
    n = config["n_trials"]
    cheat = config["cheat"]

    if "3.5" in config['openai']['model_name']:
        chat_mode = True
    if context:
        hotpot['supporting_paragraphs'] = None
        for ind, row in hotpot.iterrows():
            supporting_articles = row['supporting_facts']['title']
            articles = row['context']['title']
            sentences = row['context']['sentences'] 
            supporting_paragraphs = []
            for article in supporting_articles:
                supporting_paragraph = ''.join(sentences[np.where(articles == article)][0])
                supporting_paragraphs.append(supporting_paragraph)
            supporting_paragraphs = '\n\n'.join(supporting_paragraphs)
            hotpot.at[ind, 'supporting_paragraphs'] = supporting_paragraphs

    # load agents
    LLM = overall_utils._load_model(config['openai']['model_name'],config)
    if "aggregate" in strategy or "multifeedback" in strategy:
        config2 = config.copy()
        # config2["openai"]['model_kwargs']['top_p'] = 0.3
        config2["openai"]["temperature"] = 0.7
        config2["openai"]["n"] = 1
        LLM2 = overall_utils._load_model(config['openai']['model_name'],config)
    # LLM = LLMMock()
    agent_prompt, cot_examples, reflect_prompt, reflect_examples,reflect_strat = _load_prompts(strategy,context)

    if "ReAct" in strategy:
        agents = []
        if "reflexion" in strategy:
            agent_cls = ReactReflectAgent
            for _,row in hotpot.iterrows():
                agent = agent_cls(question = row['question'],
                            key = row['answer'],
                            agent_prompt= agent_prompt,
                            reflect_prompt = reflect_prompt,
                            reactLLM=LLM,
                            reflectLLM=LLM,)
                agents.append(agent)
        else:
            agent_cls = ReactAgent
            for _,row in hotpot.iterrows():
                agent = agent_cls(question = row['question'],
                            key = row['answer'],
                            agent_prompt= agent_prompt,
                            reactLLM=LLM,)
                agents.append(agent)
    elif "aggregate" in strategy:
        agents = []
        for _,row in hotpot.iterrows():
            cur_context = ''
            if context:
                cur_context = row['supporting_paragraphs']
            agent = CoTAggregateAgent(question = row['question'],
                        context = cur_context,
                        key = row['answer'],
                        agent_prompt= agent_prompt,
                        cot_examples = cot_examples,
                        reflect_prompt = reflect_prompt,
                        reflect_examples = reflect_examples,
                        self_reflect_llm=LLM, 
                        action_llm=LLM,sample_llm =LLM2,chat=chat_mode,aggregate_strategy=strategy,cheat=cheat)
            agents.append(agent)
    elif "multifeedback" in strategy:
        agents = []
        for _,row in hotpot.iterrows():
            cur_context = ''
            if context:
                cur_context = row['supporting_paragraphs']
            agent = CoTMultiFeedbackAgent(question = row['question'],
                        context = cur_context,
                        key = row['answer'],
                        agent_prompt= agent_prompt,
                        cot_examples = cot_examples,
                        reflect_prompt = reflect_prompt,
                        reflect_examples = reflect_examples,
                        self_reflect_llm=LLM, 
                        action_llm=LLM,sample_llm =LLM2,chat=chat_mode,aggregate_strategy=strategy,cheat=cheat)
            agents.append(agent)
    else:
        agents = []
        for _,row in hotpot.iterrows():
            cur_context = ''
            if context:
                cur_context = row['supporting_paragraphs']
            agent = CoTAgent(question = row['question'],
                        context = cur_context,
                        key = row['answer'],
                        agent_prompt= agent_prompt,
                        cot_examples = cot_examples,
                        reflect_prompt = reflect_prompt,
                        reflect_examples = reflect_examples,
                        self_reflect_llm=LLM,
                        action_llm=LLM,chat=chat_mode,cheat=cheat,strategy=strategy)
            agents.append(agent)


    # save logs and agents
    print(reflect_strat)
    root = "results"
    cheat_str = "cheat" if cheat else "nocheat"
    save_path = os.path.join(root, "hotpotqa",strategy, f"trial{n}_{cheat_str}_later", "context" if context else 'no_context')
    print(f"save_path: {save_path}")
    os.makedirs(save_path, exist_ok=True)
    ### 

    trial = 0
    log = ''
    all_logs = []
    for i in range(n):
        print(f"Trial {i} ......")
        for j, agent in enumerate(agents[:]):
            print(f"\nlogging | Question: {agent.question}")
            agent.run(reflect_strat)
            info = agent.infos
            info = {"steps": info}
            info.update({'trial': i,'agent_idx': j, 'usage_so_far': overall_utils.gpt_usage(agent.token_used,config['openai']['model_name'])})
            all_logs.append(copy.deepcopy(info))
            with open(os.path.join(save_path,"steps.json"), 'w') as f:
                json.dump(all_logs, f, indent=4)
            # print("AGENT prompt: ",agents[0]._build_agent_prompt(True))
            # print("\nREFLECT prompt: ", agents[0]._build_reflection_prompt())
            print(f'\nlogging | Actual Answer: {agent.key}')
        trial += 1
        log += log_trial(agents, trial)
        correct, incorrect = summarize_trial(agents)
        total_token = sum([agent.token_used for agent in agents])
        print(f"used token: {total_token}, roughly: ${total_token/1000*0.0015}")
        print(f'Finished Trial {trial}, Correct: {len(correct)}, Incorrect: {len(incorrect)}\n')


    with open(os.path.join(save_path, f'{len(agents)}_questions_{n}_trials.txt'), 'w') as f:
        f.write(log)

    overall_utils.save_agents(agents, os.path.join(save_path, 'agents'))


if __name__ == '__main__':
    pass

