import re, string, os
from typing import List, Union, Literal
from enum import Enum
import tiktoken
from langchain import OpenAI, Wikipedia
from langchain.llms.base import BaseLLM
from langchain.agents.react.base import DocstoreExplorer
from langchain.docstore.base import Docstore
from langchain.prompts import PromptTemplate
from langchain.schema import (
    ChatMessage,
    ChatResult,
    AIMessage,
    HumanMessage,
    SystemMessage,
)

import sys
import copy
sys.path.append("..")
from .hotpotqa_prompts import reflect_prompt, react_agent_prompt, react_reflect_agent_prompt, REFLECTION_HEADER, LAST_TRIAL_HEADER, REFLECTION_AFTER_LAST_TRIAL_HEADER, HINT_HEADER
from .hotpotqa_prompts import cot_agent_prompt, cot_reflect_agent_prompt, cot_reflect_prompt, COT_INSTRUCTION, COT_REFLECT_INSTRUCTION
from .hotpotqa_MAR import MAR_system, MAR_first,levels
from .hotpotqa_fewshots import WEBTHINK_SIMPLE6, REFLECTIONS, COT, COT_REFLECT
from overall_utils import num_tokens_from_messages

class ReflexionStrategy(Enum):
    """
    NONE: No reflection
    LAST_ATTEMPT: Use last reasoning trace in context 
    REFLEXION: Apply reflexion to the next reasoning trace 
    LAST_ATTEMPT_AND_REFLEXION: Use last reasoning trace in context and apply reflexion to the next reasoning trace 
    """
    NONE = 'base'
    LAST_ATTEMPT = 'last_trial' 
    REFLEXION = 'reflexion'
    LAST_ATTEMPT_AND_REFLEXION = 'last_trial_and_reflexion'


class CoTAgent:
    def __init__(self,
                    question: str,
                    context: str,
                    key: str,
                    agent_prompt: PromptTemplate = cot_reflect_agent_prompt,
                    reflect_prompt: PromptTemplate = cot_reflect_prompt,
                    cot_examples: str = COT,
                    reflect_examples: str = COT_REFLECT,
                    self_reflect_llm: BaseLLM = OpenAI(
                                            temperature=0,
                                            max_tokens=250,
                                            model_name="text-davinci-003",
                                            model_kwargs={"stop": "\n"},
                                            openai_api_key=os.environ['OPENAI_API_KEY']),
                    action_llm: BaseLLM = OpenAI(
                                            temperature=0,
                                            max_tokens=250,
                                            model_name="text-davinci-003",
                                            model_kwargs={"stop": "\n"},
                                            openai_api_key=os.environ['OPENAI_API_KEY']),
                    chat=False,cheat=True,strategy=None) -> None:
        self.question = question
        self.context = context
        self.key = key
        self.agent_prompt = agent_prompt
        self.reflect_prompt = reflect_prompt
        self.cot_examples = cot_examples 
        self.reflect_examples = reflect_examples
        self.self_reflect_llm = self_reflect_llm
        self.action_llm = action_llm
        self.reflections: List[str] = []
        self.reflections_str = ''
        self.answer = ''
        self.step_n: int = 0
        self.chat = chat
        self.token_used = 0
        self.cheat = cheat
        self.answers=[]
        self.strategy = strategy

        # logging 
        self.infos = []
        self.reasoning_prompts = []
        self.reflect_prompts = []
        self.thoughts = []
        self.actions = []
        self.is_corrects = []
        ###

        self.reset()

    def run(self,
            reflexion_strategy: ReflexionStrategy = ReflexionStrategy.REFLEXION) -> None:
        if self.step_n > 0 and reflexion_strategy != ReflexionStrategy.NONE: #not self.is_correct() is taken out because "cheat"
            self.reflect(reflexion_strategy)
        self.reset()
        if self.strategy == "IO":
            self.io_step()
        elif self.strategy == "CoT_hint":
            append_thought = ""
            if self.step_n > 0:
                # hints are stored as reflections
                self.reflections = self.answers
                # self.reflections_str = HINT_HEADER + "Hints:\n- "  +'\n- '.join([r.strip() for r in self.reflections])
                # append_thought = " (Hint: the following answers are close but wrong: " + '; '.join([r.strip() for r in self.reflections]) + ")"
                append_thought = "Previous Answers: " + HINT_HEADER + "- "  +'\n- '.join([r.strip() for r in self.reflections])
            self.step(append_thought = append_thought)
        else:
            self.step()
        self.step_n += 1

        self.infos.append({'step': self.step_n, 'x': self.question, "answer":self.key,"sample_reasoning_prompt":self.reasoning_prompts[-1:], "sample_reflect_prompt":self.reflect_prompts[-1:], "thoughts":copy.deepcopy(self.thoughts), "actions":copy.deepcopy(self.actions), "reflections":copy.deepcopy(self.reflections),"is_correct":copy.deepcopy(self.is_corrects)})

    def io_step(self) -> None:
        self.scratchpad += f'\nAction:'
        action = self.prompt_agent()
        self.scratchpad += ' ' + action
        action_type, argument = parse_action(action)
        print(self.scratchpad.split('\n')[-1])  
        self.actions.append(action)


        self.scratchpad += f'\nObservation: '
        if action_type == 'Finish':
            self.answer = argument
            self.answers.append(self.answer)
            if self.is_correct():
                self.is_corrects.append(True)
            else:
                self.is_corrects.append(False)
            if self.cheat:
                if self.is_correct():
                    self.scratchpad += 'Answer is CORRECT'
                else: 
                    self.scratchpad += 'Answer is INCORRECT'
            else:
                self.scratchpad += 'Answer is INCORRECT'
            self.finished = True
            return
        else:
            print('Invalid action type, please try again.')
            argument = "[Incorrect Format]" + argument
            self.answer = argument
            self.answers.append(self.answer)
            if self.is_correct():
                self.is_corrects.append(True)
            else:
                self.is_corrects.append(False)
            if self.cheat:
                if self.is_correct():
                    self.scratchpad += 'Answer is CORRECT'
                else: 
                    self.scratchpad += 'Answer is INCORRECT'
            else:
                self.scratchpad += 'Answer is INCORRECT'
            self.finished = True
            return


    def step(self, append_thought="") -> None:
        # Think
        print(f"log | inside step: {self._build_agent_prompt(True)}")
        self.scratchpad += f"\n{append_thought}"
        self.scratchpad += f'\nThought:'
        thought = self.prompt_agent()
        self.scratchpad += ' ' + thought
        print(self.scratchpad.split('\n')[-1])
        self.thoughts.append(thought)

        # Act
        self.scratchpad += f'\nAction:'
        action = self.prompt_agent()
        self.scratchpad += ' ' + action
        action_type, argument = parse_action(action)
        print(self.scratchpad.split('\n')[-1])
        self.actions.append(action)

        self.scratchpad += f'\nObservation: '
        if action_type == 'Finish':
            self.answer = argument
            self.answers.append(self.answer)
            if self.is_correct():
                self.is_corrects.append(True)
            else:
                self.is_corrects.append(False)
            if self.cheat:
                if self.is_correct():
                    self.scratchpad += 'Answer is CORRECT'
                else: 
                    self.scratchpad += 'Answer is INCORRECT'
            else:
                self.scratchpad += 'Answer is INCORRECT'
            self.finished = True
            return
        else:
            print('Invalid action type, please try again.')
            argument = "[Incorrect Format]" + argument
            self.answer = argument
            self.answers.append(self.answer)
            if self.is_correct():
                self.is_corrects.append(True)
            else:
                self.is_corrects.append(False)
            if self.cheat:
                if self.is_correct():
                    self.scratchpad += 'Answer is CORRECT'
                else: 
                    self.scratchpad += 'Answer is INCORRECT'
            else:
                self.scratchpad += 'Answer is INCORRECT'
            self.finished = True
            return

    
    def reflect(self,
                strategy: ReflexionStrategy) -> None:
        print('Running Reflexion strategy...')
        if strategy == ReflexionStrategy.LAST_ATTEMPT:
            self.reflections = [self.scratchpad]
            self.reflections_str = format_last_attempt(self.question , self.reflections[0])
        elif strategy == ReflexionStrategy.REFLEXION:
            self.reflections += [self.prompt_reflection()]
            self.reflections_str = format_reflections(self.reflections)
        elif strategy == ReflexionStrategy.LAST_ATTEMPT_AND_REFLEXION:
            self.reflections_str = format_last_attempt(self.question , self.scratchpad)
            self.reflections = [self.prompt_reflection()]
            self.reflections_str += '\n'+ format_reflections(self.reflections, header = REFLECTION_AFTER_LAST_TRIAL_HEADER)
        else:
            raise NotImplementedError(f'Unknown reflection strategy: {strategy}')
        print(self.reflections_str)
    
    def prompt_reflection(self) -> str:
        return format_step(self.self_reflect_llm(self._build_reflection_prompt()))

    def reset(self) -> None:
        
        self.scratchpad: str = ''
        self.finished = False

    def prompt_agent(self) -> str:
        return format_step(self.action_llm(self._build_agent_prompt()))
    
    def _build_agent_prompt(self,output_str = False) -> str:
        prompt = self.agent_prompt.format(
                            examples = self.cot_examples,
                            reflections = self.reflections_str,
                            context = self.context,
                            question = self.question,
                            scratchpad = self.scratchpad)
        # print(prompt)
        self.reasoning_prompts.append(prompt)
        if self.chat:
            prompt = [
                SystemMessage(content="You are a helpful and rational assistant that can solve question answering tasks."),
                HumanMessage(content=prompt)
            ]
            self.token_used += num_tokens_from_messages(prompt)
            if output_str:
                prompt = prompt[1].content

        return prompt

    def _build_reflection_prompt(self,output_str=False) -> str:
        prompt = self.reflect_prompt.format(
                            examples = self.reflect_examples,
                            context = self.context,
                            question = self.question,
                            scratchpad = self.scratchpad)
        self.reflect_prompts.append(prompt)
        if self.chat:
            prompt = [
                SystemMessage(content="You are a helpful and rational assistant that can solve question answering tasks."),
                HumanMessage(content=prompt)
            ]
            self.token_used += num_tokens_from_messages(prompt)
            if output_str:
                prompt = prompt[1].content
            
        return prompt
 
    def is_finished(self) -> bool:
        return self.finished

    def is_correct(self) -> bool:
        return EM(self.answer, self.key)
    
class CoTAggregateAgent:
    '''
    This agent will implement 2 strategies:
    - aggregate reflextion: CoT to get an answer, sample 5 reflexions, each get an answer, pick best out of 5
    - aggregate answers: CoT to get an answer, get 1 reflexion, sample 5 answers, pick best out of 5
    '''
    def __init__(self,
                    question: str,
                    context: str,
                    key: str,
                    agent_prompt: PromptTemplate = cot_reflect_agent_prompt,
                    reflect_prompt: PromptTemplate = cot_reflect_prompt,
                    cot_examples: str = COT,
                    reflect_examples: str = COT_REFLECT,
                    self_reflect_llm: BaseLLM = OpenAI(
                                            temperature=0,
                                            max_tokens=250,
                                            model_name="text-davinci-003",
                                            model_kwargs={"stop": "\n"},
                                            openai_api_key=os.environ['OPENAI_API_KEY']),
                    action_llm: BaseLLM = OpenAI(
                                            temperature=0,
                                            max_tokens=250,
                                            model_name="text-davinci-003",
                                            model_kwargs={"stop": "\n"},
                                            openai_api_key=os.environ['OPENAI_API_KEY']),
                    sample_llm = None,
                    chat=False,
                    aggregate_strategy = None,
                    cheat=True
                    ) -> None:
        self.question = question
        self.context = context
        self.key = key
        self.agent_prompt = agent_prompt
        self.reflect_prompt = reflect_prompt
        self.cot_examples = cot_examples 
        self.reflect_examples = reflect_examples
        self.self_reflect_llm = self_reflect_llm
        self.action_llm = action_llm
        self.reflections: List[str] = []
        self.reflections_str = ''
        self.answer = ''
        self.step_n: int = 0
        self.chat = chat
        self.token_used = 0
        self.aggregate_strategy = aggregate_strategy
        self.answers = []
        self.sample_llm = sample_llm
        self.cheat = cheat

        # logging 
        self.infos = []
        self.reasoning_prompts = []
        self.reflect_prompts = []
        self.thoughts = []
        self.actions = []
        self.is_corrects = []
        ###

        self.reset()

    def run(self,
            reflexion_strategy: ReflexionStrategy = ReflexionStrategy.REFLEXION) -> None:

        self.reset()
        self.step()
        self.infos.append({'step': self.step_n, 'x': self.question, "answer":self.key,"sample_reasoning_prompt":self.reasoning_prompts[:1], "sample_reflect_prompt":self.reflect_prompts[:1], "thoughts":copy.deepcopy(self.thoughts), "actions":copy.deepcopy(self.actions), "reflections":copy.deepcopy(self.reflections),"is_correct":copy.deepcopy(self.is_corrects)})

        og_reflect_llm = self.self_reflect_llm
        og_action_llm = self.action_llm
        # reflect after first step, this agent will run only once.
        if reflexion_strategy != ReflexionStrategy.NONE:
            num_sample = 5
            if self.aggregate_strategy == "CoT_reflexion_aggregate_reflexion":
                for i in range(num_sample):
                    print(f"\ninside aggregate reflexion: {i}")
                    self.self_reflect_llm = self.sample_llm
                    self.reflect(reflexion_strategy)
                    self.self_reflect_llm = og_reflect_llm
                    self.local_step()
                    self.infos.append({'step': self.step_n, 'x': self.question, "answer":self.key,"sample_reasoning_prompt":self.reasoning_prompts[:1], "sample_reflect_prompt":self.reflect_prompts[:1], "thoughts":copy.deepcopy(self.thoughts), "actions":copy.deepcopy(self.actions), "reflections":copy.deepcopy(self.reflections),"is_correct":copy.deepcopy(self.is_corrects)})

            elif self.aggregate_strategy == "CoT_reflexion_aggregate_answer":
                self.reflect(reflexion_strategy)
                for i in range(num_sample):
                    print(f"\ninside aggregate answer: {i}")
                    self.action_llm = self.sample_llm
                    self.local_step()
                    self.action_llm = og_action_llm
                    self.infos.append({'step': self.step_n, 'x': self.question, "answer":self.key,"sample_reasoning_prompt":self.reasoning_prompts[:1], "sample_reflect_prompt":self.reflect_prompts[:1], "thoughts":copy.deepcopy(self.thoughts), "actions":copy.deepcopy(self.actions), "reflections":copy.deepcopy(self.reflections),"is_correct":copy.deepcopy(self.is_corrects)})

        
        self.step_n += 1

    def step(self) -> None:
        # Think
        self.scratchpad += f'\nThought:'
        thought = self.prompt_agent()
        self.scratchpad += ' ' + thought
        print(self.scratchpad.split('\n')[-1])
        self.thoughts.append(thought)

        # Act
        self.scratchpad += f'\nAction:'
        action = self.prompt_agent()
        self.scratchpad += ' ' + action
        action_type, argument = parse_action(action)
        print(self.scratchpad.split('\n')[-1])  
        self.actions.append(action)

        self.scratchpad += f'\nObservation: '
        if action_type == 'Finish':
            self.answer = argument
            self.answers.append(argument)
            if self.is_correct():
                self.is_corrects.append(True)
            else:
                self.is_corrects.append(False)
            if self.cheat:
                if self.is_correct():
                    self.scratchpad += 'Answer is CORRECT'
                else: 
                    self.scratchpad += 'Answer is INCORRECT'
            else:
                self.scratchpad += 'Answer is INCORRECT'
            self.finished = True
            return
        else:
            print('Invalid action type, please try again.')
    def local_step(self) -> None:
        # step doesn't change the scratchpad
        initial_length = len(self.scratchpad)
        self.scratchpad += f'\nThought:'
        thought = self.prompt_agent()
        self.scratchpad += ' ' + thought
        self.thoughts.append(thought)
        print(self.scratchpad.split('\n')[-1])
        # print(f"inside local step {self._build_agent_prompt(True)}")
        # print("end local step")

        # Act
        self.scratchpad += f'\nAction:'
        action = self.prompt_agent()
        print(f"action: {action}")
        self.scratchpad += ' ' + action
        action_type, argument = parse_action(action)
        self.actions.append(action)

        
        print(self.scratchpad.split('\n')[-1])  
        # print(f"inside local step {self._build_agent_prompt(True)}")
        # print("end local step")

        self.scratchpad += f'\nObservation: '
        if action_type == 'Finish':
            self.answer = argument
            self.answers.append(self.answer)
            if self.is_correct():
                self.is_corrects.append(True)
            else:
                self.is_corrects.append(False)
            # should always say answer is incorrect becuase you don't know.
            if self.cheat:
                if self.is_correct():
                    self.scratchpad += 'Answer is CORRECT'
                else: 
                    self.scratchpad += 'Answer is INCORRECT'
            else:
                self.scratchpad += 'Answer is INCORRECT'
            self.finished = True
            self.scratchpad = self.scratchpad[:initial_length]
            return
        else:
            self.scratchpad = self.scratchpad[:initial_length]
            print('Invalid action type, please try again.')
    
    def reflect(self,
                strategy: ReflexionStrategy) -> None:
        # print('Running Reflexion strategy...')
        if strategy == ReflexionStrategy.LAST_ATTEMPT:
            self.reflections = [self.scratchpad]
            self.reflections_str = format_last_attempt(self.question , self.reflections[0])
        elif strategy == ReflexionStrategy.REFLEXION:
            self.reflections += [self.prompt_reflection()]
            self.reflections_str = format_reflections(self.reflections[-1:])
        elif strategy == ReflexionStrategy.LAST_ATTEMPT_AND_REFLEXION:
            self.reflections_str = format_last_attempt(self.question , self.scratchpad)
            self.reflections = [self.prompt_reflection()]
            self.reflections_str += '\n'+ format_reflections(self.reflections, header = REFLECTION_AFTER_LAST_TRIAL_HEADER)
        else:
            raise NotImplementedError(f'Unknown reflection strategy: {strategy}')
        print(f"reflection str: {self.reflections_str}")
    
    def prompt_reflection(self) -> str:
        return format_step(self.self_reflect_llm(self._build_reflection_prompt()))

    def reset(self) -> None:
        
        self.scratchpad: str = ''
        self.finished = False

    def prompt_agent(self) -> str:
        return format_step(self.action_llm(self._build_agent_prompt()))
    
    def _build_agent_prompt(self,output_str = False) -> str:
        prompt = self.agent_prompt.format(
                            examples = self.cot_examples,
                            reflections = self.reflections_str,
                            context = self.context,
                            question = self.question,
                            scratchpad = self.scratchpad)
        shortend_prompt = prompt.replace('\n','')
        self.reasoning_prompts.append(prompt)
        # print(f"inside build agent prompt: {shortend_prompt}")
        if self.chat:
            prompt = [
                SystemMessage(content="You are a helpful and rational assistant that can solve question answering tasks."),
                HumanMessage(content=prompt)
            ]
            self.token_used += num_tokens_from_messages(prompt)
            if output_str:
                prompt = prompt[1].content

        return prompt

    def _build_reflection_prompt(self,output_str=False) -> str:
        prompt = self.reflect_prompt.format(
                            examples = self.reflect_examples,
                            context = self.context,
                            question = self.question,
                            scratchpad = self.scratchpad)
        self.reflect_prompts.append(prompt)
        if self.chat:
            prompt = [
                SystemMessage(content="You are a helpful and rational assistant that can solve question answering tasks."),
                HumanMessage(content=prompt)
            ]
            self.token_used += num_tokens_from_messages(prompt)
            if output_str:
                prompt = prompt[1].content
            
        return prompt
 
    def is_finished(self) -> bool:
        return self.finished

    def is_correct(self) -> bool:
        for answer in self.answers:
            if EM(answer, self.key):
                return True
        return False
    

class CoTMultiFeedbackAgent:
    '''
    This agent will implement 2 strategies:
    - multifeedback: CoT to get an answer, sample 5 reflexions, then all 5 reflexions are used to get an answer
    - multifeedback-debate: CoT to get an answer, sample 5 reflexions with debate, then all 5 reflexions are used to get an answer
    '''
    def __init__(self,
                    question: str,
                    context: str,
                    key: str,
                    agent_prompt: PromptTemplate = cot_reflect_agent_prompt,
                    reflect_prompt: PromptTemplate = cot_reflect_prompt,
                    cot_examples: str = COT,
                    reflect_examples: str = COT_REFLECT,
                    self_reflect_llm: BaseLLM = OpenAI(
                                            temperature=0,
                                            max_tokens=250,
                                            model_name="text-davinci-003",
                                            model_kwargs={"stop": "\n"},
                                            openai_api_key=os.environ['OPENAI_API_KEY']),
                    action_llm: BaseLLM = OpenAI(
                                            temperature=0,
                                            max_tokens=250,
                                            model_name="text-davinci-003",
                                            model_kwargs={"stop": "\n"},
                                            openai_api_key=os.environ['OPENAI_API_KEY']),
                    sample_llm = None,
                    chat=False,
                    aggregate_strategy = None,
                    cheat=True
                    ) -> None:
        self.question = question
        self.context = context
        self.key = key
        self.agent_prompt = agent_prompt
        self.reflect_prompt = reflect_prompt
        self.cot_examples = cot_examples 
        self.reflect_examples = reflect_examples
        self.reflections: List[str] = []
        self.reflections_str = ''
        self.answer = ''
        self.step_n: int = 0
        self.chat = chat
        self.token_used = 0
        self.aggregate_strategy = aggregate_strategy
        self.answers = []
        self.self_reflect_llm = self_reflect_llm
        self.action_llm = action_llm
        self.sample_llm = sample_llm
        self.cheat = cheat

        # logging 
        self.infos = []
        self.reasoning_prompts = []
        self.reflect_prompts = []
        self.thoughts = []
        self.actions = []
        self.is_corrects = []
        ###

        self.reset()

    def run(self,
            reflexion_strategy: ReflexionStrategy = ReflexionStrategy.REFLEXION) -> None:

        if self.step_n == 1:
            # we will just run 2 steps in a single run for simplicity. So any step further will just return
            return

        # run normally and get the answer
        self.step()
        self.answers.append(self.answer)
        og_reflect_llm = self.self_reflect_llm
        og_action_llm = self.action_llm
        

        if reflexion_strategy != ReflexionStrategy.NONE:
            num_sample = 5
            if self.aggregate_strategy == "CoT_reflexion_multifeedback":
                # get 5 feedbacks
                for i in range(num_sample):
                    print(f"\ninside CoT_reflexion_multifeedback: {i}")
                    self.self_reflect_llm = self.sample_llm
                    self.reflect(reflexion_strategy)
                    self.self_reflect_llm = og_reflect_llm
                # do a step to get the final answer
                self.step()
                self.infos.append({'step': self.step_n, 'x': self.question, "answer":self.key,"sample_reasoning_prompt":self.reasoning_prompts[:1], "sample_reflect_prompt":self.reflect_prompts[:1], "thoughts":copy.deepcopy(self.thoughts), "actions":copy.deepcopy(self.actions), "reflections":copy.deepcopy(self.reflections),"is_correct":copy.deepcopy(self.is_corrects)})

            elif self.aggregate_strategy == "CoT_reflexion_multifeedback_MAR":
                # get 1st feedback
                print(f"\ninside CoT_reflexion_multifeedback_debate: {0}")
                self.reflect(reflexion_strategy)

                # get rest of feedbacks
                for i in range(num_sample-1):
                    print(f"\ninside CoT_reflexion_multifeedback_debate: {i+1}")
                    self.MAR(i)
                # do a step to get the final answer
                self.step()
                self.infos.append({'step': self.step_n, 'x': self.question, "answer":self.key,"sample_reasoning_prompt":self.reasoning_prompts[:1], "sample_reflect_prompt":self.reflect_prompts[:1], "thoughts":copy.deepcopy(self.thoughts), "actions":copy.deepcopy(self.actions), "reflections":copy.deepcopy(self.reflections),"is_correct":copy.deepcopy(self.is_corrects)})

            
        
        self.step_n += 1

    def step(self) -> None:
        # Think
        self.scratchpad += f'\nThought:'
        thought = self.prompt_agent()
        self.scratchpad += ' ' + thought
        self.thoughts.append(thought)
        # print("log |",self.scratchpad.split('\n')[-1])

        # Act
        self.scratchpad += f'\nAction:'
        action = self.prompt_agent()
        self.scratchpad += ' ' + action
        action_type, argument = parse_action(action)
        # print("log |",self.scratchpad.split('\n')[-1])  
        self.actions.append(action)


        self.scratchpad += f'\nObservation: '
        if action_type == 'Finish':
            self.answer = argument
            self.answers.append(self.answer)
            if self.is_correct():
                self.is_corrects.append(True)
            else:
                self.is_corrects.append(False)
            if self.cheat:
                if self.is_correct():
                    self.scratchpad += 'Answer is CORRECT'
                else: 
                    self.scratchpad += 'Answer is INCORRECT'
            else:
                self.scratchpad += 'Answer is INCORRECT'
            self.finished = True
            return
        else:
            print('Invalid action type, please try again.')

    
    def reflect(self,
                strategy: ReflexionStrategy) -> None:
        # print('Running Reflexion strategy...')
        if strategy == ReflexionStrategy.LAST_ATTEMPT:
            self.reflections = [self.scratchpad]
            self.reflections_str = format_last_attempt(self.question , self.reflections[0])
        elif strategy == ReflexionStrategy.REFLEXION:
            print(f"INSIDE reflect reflection prompt: {self._build_reflection_prompt(True)}")
            print(f"END reflect reflection prompt")
            self.reflections += [self.prompt_reflection()]
            # print(f"inside reflect reflections: {self.reflections[:]}")
            self.reflections_str = format_reflections(self.reflections[:])
        elif strategy == ReflexionStrategy.LAST_ATTEMPT_AND_REFLEXION:
            self.reflections_str = format_last_attempt(self.question , self.scratchpad)
            self.reflections = [self.prompt_reflection()]
            self.reflections_str += '\n'+ format_reflections(self.reflections, header = REFLECTION_AFTER_LAST_TRIAL_HEADER)
        else:
            raise NotImplementedError(f'Unknown reflection strategy: {strategy}')
        print(f"reflection str: {self.reflections_str}")
    def MAR(self, i):
        if self.chat:
            system_prompt = MAR_system.format(debate_level=levels[2])
            first_prompt = MAR_first.format(context = self.context,
                            question = self.question,
                            scratchpad = self.scratchpad,
                            reflection = self.reflections[0])
            message = [SystemMessage(content=system_prompt), HumanMessage(content=first_prompt)]
            for j in range(i):
                # adding them as AIMessage may pose issues
                another_debate = AIMessage(content=self.reflections[j+1])
                message.append(another_debate)
            print(f"logging | inside MAR: {message}")
            new_reflection = format_step(self.self_reflect_llm(message))
            self.reflections += [new_reflection]

            self.reflections_str = format_reflections(self.reflections[:])
            print(f"logging | inside MAR {i}: {self.reflections_str}")
    
    def multi_agent_debate(self):
        self.reflections += [self.prompt_reflection()]
    
    def prompt_reflection(self) -> str:
        return format_step(self.self_reflect_llm(self._build_reflection_prompt()))
    

    def reset(self) -> None:
        
        self.scratchpad: str = ''
        self.finished = False

    def prompt_agent(self) -> str:
        return format_step(self.action_llm(self._build_agent_prompt()))
    
    def _build_agent_prompt(self,output_str = False) -> str:
        prompt = self.agent_prompt.format(
                            examples = self.cot_examples,
                            reflections = self.reflections_str,
                            context = self.context,
                            question = self.question,
                            scratchpad = self.scratchpad)
        shortend_prompt = prompt.replace('\n','')
        self.reasoning_prompts.append(prompt)
        # print(f"inside build agent prompt: {shortend_prompt}")
        if self.chat:
            prompt = [
                SystemMessage(content="You are a helpful and rational assistant that can solve question answering tasks."),
                HumanMessage(content=prompt)
            ]
            self.token_used += num_tokens_from_messages(prompt)
            if output_str:
                prompt = prompt[1].content

        return prompt

    def _build_reflection_prompt(self,output_str=False) -> str:
        prompt = self.reflect_prompt.format(
                            examples = self.reflect_examples,
                            context = self.context,
                            question = self.question,
                            scratchpad = self.scratchpad)
        self.reflect_prompts.append(prompt)
        if self.chat:
            prompt = [
                SystemMessage(content="You are a helpful and rational assistant that can solve question answering tasks."),
                HumanMessage(content=prompt)
            ]
            self.token_used += num_tokens_from_messages(prompt)
            if output_str:
                prompt = prompt[1].content
            
        return prompt
 
    def is_finished(self) -> bool:
        return self.finished

    def is_correct(self) -> bool:
        for answer in self.answers:
            if EM(answer, self.key):
                return True
        return False



class ReactAgent:
    def __init__(self,
                 question: str,
                 key: str,
                 max_steps: int = 6,
                 agent_prompt: PromptTemplate = react_agent_prompt,
                 docstore: Docstore = Wikipedia(),
                 react_llm: BaseLLM = OpenAI(
                                            temperature=0,
                                            max_tokens=100,
                                            model_name="text-davinci-003",
                                            model_kwargs={"stop": "\n"},
                                            openai_api_key=os.environ['OPENAI_API_KEY']),
                 chat=False) -> None:
        
        self.question = question
        self.answer = ''
        self.key = key
        self.max_steps = max_steps
        self.agent_prompt = agent_prompt
        self.react_examples = WEBTHINK_SIMPLE6

        self.docstore = DocstoreExplorer(docstore) # Search, Lookup
        self.llm = react_llm
        
        self.enc = tiktoken.encoding_for_model("text-davinci-003")
        self.chat = chat
        self.answers=[]
        self.__reset_agent()

    def run(self, reset = True, reflexion_strategy=None) -> None:
        if reset:
            self.__reset_agent()
        
        while not self.is_halted() and not self.is_finished():
            self.step()
    
    def step(self) -> None:
        # Think
        self.scratchpad += f'\nThought {self.step_n}:'
        self.scratchpad += ' ' + self.prompt_agent()
        print(self.scratchpad.split('\n')[-1])

        # Act
        self.scratchpad += f'\nAction {self.step_n}:'
        action = self.prompt_agent()
        self.scratchpad += ' ' + action
        action_type, argument = parse_action(action)
        print(self.scratchpad.split('\n')[-1])

        # Observe
        self.scratchpad += f'\nObservation {self.step_n}: '
        
        if action_type == 'Finish':
            self.answer = argument
            if self.is_correct():
                self.scratchpad += 'Answer is CORRECT'
            else: 
                self.scratchpad += 'Answer is INCORRECT'
            self.finished = True
            self.step_n += 1
            return

        if action_type == 'Search':
            try:
                self.scratchpad += format_step(self.docstore.search(argument))
            except Exception as e:
                print(e)
                self.scratchpad += f'Could not find that page, please try again.'
            
        elif action_type == 'Lookup':
            try:
                self.scratchpad += format_step(self.docstore.lookup(argument))
            except ValueError:
                self.scratchpad += f'The last page Searched was not found, so you cannot Lookup a keyword in it. Please try one of the similar pages given.'

        else:
            self.scratchpad += 'Invalid Action. Valid Actions are Lookup[<topic>] Search[<topic>] and Finish[<answer>].'

        print(self.scratchpad.split('\n')[-1])

        self.step_n += 1

    def prompt_agent(self) -> str:
        return format_step(self.llm(self._build_agent_prompt()))
    
    def _build_agent_prompt(self,output_str=False) -> str:
        prompt = self.agent_prompt.format(
                            examples = self.react_examples,
                            question = self.question,
                            scratchpad = self.scratchpad)
        if self.chat:
            prompt = [
                SystemMessage(content="You are a helpful and rational assistant that can solve question answering tasks."),
                HumanMessage(content=prompt)
            ]
            if output_str:
                prompt = prompt[1].content
        return prompt
    
    def is_finished(self) -> bool:
        return self.finished

    def is_correct(self) -> bool:
        return EM(self.answer, self.key)

    def is_halted(self) -> bool:
        return ((self.step_n > self.max_steps) or (len(self.enc.encode(self._build_agent_prompt())) > 3896)) and not self.finished

    def __reset_agent(self) -> None:
        self.step_n = 1
        self.finished = False
        self.scratchpad: str = ''

    def set_qa(self, question: str, key: str) -> None:
        self.question = question
        self.key = key

class ReactReflectAgent(ReactAgent):
    def __init__(self,
                 question: str,
                 key: str,
                 max_steps: int = 6,
                 agent_prompt: PromptTemplate = react_reflect_agent_prompt,
                 reflect_prompt: PromptTemplate = reflect_prompt,
                 docstore: Docstore = Wikipedia(),
                 react_llm: BaseLLM = OpenAI(
                                             temperature=0,
                                             max_tokens=100,
                                             model_name="text-davinci-003",
                                             model_kwargs={"stop": "\n"},
                                             openai_api_key=os.environ['OPENAI_API_KEY']),
                 reflect_llm: BaseLLM = OpenAI(
                                               temperature=0,
                                               max_tokens=250,
                                               model_name="text-davinci-003",
                                               openai_api_key=os.environ['OPENAI_API_KEY']),
                 chat=False) -> None:
        
        super().__init__(question, key, max_steps, agent_prompt, docstore, react_llm)
        self.reflect_llm = reflect_llm
        self.reflect_prompt = reflect_prompt
        self.reflect_examples = REFLECTIONS
        self.reflections: List[str] = []
        self.reflections_str: str = ''
        self.answers=[]
        self.chat = chat
    
    def run(self, reset = True, reflect_strategy: ReflexionStrategy = ReflexionStrategy.REFLEXION) -> None:
        if (self.is_finished() or self.is_halted()) and not self.is_correct():
            self.reflect(reflect_strategy)

        ReactAgent.run(self, reset)
    
    def reflect(self,
                strategy: ReflexionStrategy) -> None:
        print('Reflecting...')
        if strategy == ReflexionStrategy.LAST_ATTEMPT:
            self.reflections = [self.scratchpad]
            self.reflections_str = format_last_attempt(self.question, self.reflections[0])
        elif strategy == ReflexionStrategy.REFLEXION: 
            self.reflections += [self.prompt_reflection()]
            self.reflections_str = format_reflections(self.reflections)
        elif strategy == ReflexionStrategy.LAST_ATTEMPT_AND_REFLEXION: 
            self.reflections_str = format_last_attempt(self.question, self.scratchpad)
            self.reflections = [self.prompt_reflection()]
            self.reflections_str += format_reflections(self.reflections, header = REFLECTION_AFTER_LAST_TRIAL_HEADER)
        else:
            raise NotImplementedError(f'Unknown reflection strategy: {strategy}')
        print(self.reflections_str)
    
    def prompt_reflection(self) -> str:
        return format_step(self.reflect_llm(self._build_reflection_prompt()))

    def _build_reflection_prompt(self,output_str=False) -> str:
        promtp = self.reflect_prompt.format(
                            examples = self.reflect_examples,
                            question = self.question,
                            scratchpad = truncate_scratchpad(self.scratchpad, tokenizer=self.enc))
        if self.chat:
            prompt = [
                SystemMessage(content="You are a helpful and rational assistant that can solve question answering tasks."),
                HumanMessage(content=prompt)
            ]
            if output_str:
                prompt = prompt[1].content
        return prompt
 
    def _build_agent_prompt(self,output_str=False) -> str:
        prompt = self.agent_prompt.format(
                            examples = self.react_examples,
                            reflections = self.reflections_str,
                            question = self.question,
                            scratchpad = self.scratchpad)
        if self.chat:
            prompt = [
                SystemMessage(content="You are a helpful and rational assistant that can solve question answering tasks."),
                HumanMessage(content=prompt)
            ]
            if output_str:
                prompt = prompt[1].content
        return prompt
   








### String Stuff ###
gpt2_enc = tiktoken.encoding_for_model("text-davinci-003")

def parse_action(string):
    pattern = r'^(\w+)\[(.+)\]$'
    match = re.match(pattern, string)
    
    if match:
        action_type = match.group(1)
        argument = match.group(2)
        return action_type, argument
    
    else:
        return "Incorrect Format", string

def format_step(step: str) -> str:
    # print("formatting step:")
    # print(step)
    # print("end formatting step")
    if isinstance(step, AIMessage):
        step = step.content
    elif isinstance(step, list):
        step = step[0].text

    return step.strip('\n').strip().replace('\n', '')

def format_reflections(reflections: List[str],
                        header: str = REFLECTION_HEADER) -> str:
    if reflections == []:
        return ''
    else:
        return header + 'Reflections:\n- ' + '\n- '.join([r.strip() for r in reflections])

def format_last_attempt(question: str,
                        scratchpad: str,
                        header: str = LAST_TRIAL_HEADER):
    return header + f'Question: {question}\n' + truncate_scratchpad(scratchpad, tokenizer=gpt2_enc).strip('\n').strip() + '\n(END PREVIOUS TRIAL)\n'

def truncate_scratchpad(scratchpad: str, n_tokens: int = 1600, tokenizer = gpt2_enc) -> str:
    lines = scratchpad.split('\n')
    observations = filter(lambda x: x.startswith('Observation'), lines)
    observations_by_tokens = sorted(observations, key=lambda x: len(tokenizer.encode(x)))
    while len(gpt2_enc.encode('\n'.join(lines))) > n_tokens:
        largest_observation = observations_by_tokens.pop(-1)
        ind = lines.index(largest_observation)
        lines[ind] = largest_observation.split(':')[0] + ': [truncated wikipedia excerpt]'
    return '\n'.join(lines)

def normalize_answer(s):
  def remove_articles(text):
    return re.sub(r"\b(a|an|the)\b", " ", text)
  
  def white_space_fix(text):
      return " ".join(text.split())

  def remove_punc(text):
      exclude = set(string.punctuation)
      return "".join(ch for ch in text if ch not in exclude)

  def lower(text):
      return text.lower()

  return white_space_fix(remove_articles(remove_punc(lower(s))))

def EM(answer, key) -> bool:
    return normalize_answer(answer) == normalize_answer(key)


