import pdb
from agents.base_agent import BaseAgent
from common.registry import registry
# from rouge import Rouge
import json
import random
import re


@registry.register_agent("VanillaAgent")
class VanillaAgent(
    BaseAgent):  # the agent should receive goal, state and action, then return the next state
    def __init__(self,
                 llm_model,
                 memory_size=100,
                 # set this to a very large number if you want to keep all history till context length limit
                 examples=[],
                 instruction="",
                 init_prompt_path=None,
                 system_message="You are a helpful assistant.",
                 need_goal=False,
                 check_actions=None,
                 check_inventory=None,
                 use_parser=True,
                 ):
        super().__init__()
        self.use_parser = use_parser
        self.llm_model = llm_model
        self.memory_size = memory_size
        self.goal = None
        self.init_obs = None
        if init_prompt_path is not None:  # load from file
            self.init_prompt_dict = json.load(open(init_prompt_path, 'r'))
            self.instruction = self.init_prompt_dict["instruction"]
            self.examples = self.init_prompt_dict["examples"]
        else:

            self.instruction = instruction
            self.examples = examples

            # self.reset(goal, init_obs)
            self.init_prompt_dict = {
                "examples": examples,
                "instruction": instruction,
                "system_msg": system_message
            }

        self.max_context_length = self.llm_model.context_length
        self.need_goal = need_goal
        self.check_actions = check_actions
        self.check_inventory = check_inventory

        self.example_prompt = None

        if "claude" in self.llm_model.engine:
            self.split = self.llm_model.xml_split
        else:
            self.split = {"example": [""],
                          "text": [""],
                          "rule": [""],
                          "system_msg": [""],
                          "instruction": [""],
                          "goal": [""]}

    def get_example_prompt(self): #return the prompt for an interaction turn
        return self.example_prompt
    
    def log_example_prompt(self, prompt):
        self.example_prompt = prompt

    def reset(self, goal, init_obs, init_act=None):
        self.goal = goal
        self.init_obs = init_obs
        self.memory = [("Action", init_act), ('Observation', self.init_obs), ('Environment feedback', '')] if init_act \
            else [
            ('Observation', self.init_obs)]  # list of [('State', "xxx"), ('Action', "xxx"), ...]
        self.steps = 0
        self.done = False

    def update(self, action, state, env_feed):
        self.steps += 1

        self.memory.append(("Action", action))
        self.memory.append(("Observation", state))
        self.memory.append(("Environment feedback", env_feed))

    def make_prompt(self, Starting_prompt_task_explain, need_goal=False, system_message='', available_action_space = ''):
        # Alfworld
        '''
        query = ""
        query += self.split["instruction"][0] + self.instruction + self.split["instruction"][-1]

        if isinstance(self.examples, str):
            self.examples = [self.examples]

        if len(self.examples) > 0:
            query += "\nHere are examples:\n" + self.split["example"][0]
            for example in self.examples:
                query += example + "\n"
            query += self.split["example"][-1]
        '''

        rear_prompt = ''
        if need_goal:
            rear_prompt += "\nYou should perform actions to accomplish the goal: " + self.goal + "\n" + \
                     self.split["goal"][-1]

        history = self.memory[-self.memory_size:]
        rear_prompt += '\nYour previous observation, action, and environment feedback history are as follows: \n' + "\n".join([item[0] + ": " + item[1] for item in history])
        rear_prompt += f'\n{available_action_space}\n'
        rear_prompt += "\nAction: "

        messages = [
            {"role": "system", "content": system_message},
            {"role": "user", "content": Starting_prompt_task_explain + rear_prompt}
        ]
        num_of_tokens = self.llm_model.num_tokens_from_messages(messages)

        while num_of_tokens > self.max_context_length - self.llm_model.max_tokens:
            print(f'Exceeded token num of prompt: {num_of_tokens - self.max_context_length + self.llm_model.max_tokens}')
            history = history[1:]
            rear_prompt = ''
            if need_goal:
                rear_prompt += "\nYou should perform actions to accomplish the goal: " + self.goal + "\n" + \
                               self.split["goal"][-1]

            rear_prompt += '\nYour previous observation, action, and environment feedback history are as follows: \n' + "\n".join(
                [item[0] + ": " + item[1] for item in history])
            rear_prompt += f'\n{available_action_space}\n'
            rear_prompt += "\nAction: "

            messages = [
                {"role": "system", "content": system_message},
                {"role": "user", "content": rear_prompt}
            ]
            num_of_tokens = self.llm_model.num_tokens_from_messages(messages)
            if num_of_tokens < 8100:
                break

        return rear_prompt

    def action_parser_for_special_llms(self, action):
        
        '''
        This function is used to parse the action for special llms, e.g. codellama-13b, codellama-34b, llama, lemur, vicuna, etc.
        These llms often struggle to generate the format of the action correctly, so we need to parse the action to make it executable.
        '''
        
        origin_action = action
        if 'action' in action.lower():
            action_temp = action.split('\n')
            for act in action_temp:
                if "next action" in act and ':' in act: # zzh: in Claude will return "Here is the next action to take:"
                    idx = action_temp.index(act)
                    while idx + 1 < len(action_temp):
                        if action_temp[idx + 1]:
                            action = action_temp[idx + 1]
                            break
                        idx += 1
                if act.split(':')[0].lower().endswith('with action input'): # chang: in case parse tool output
                    action = act
                    break
                if 'action' in act.lower() and ':' in act:
                    action_temp = ':'.join(act.split(':')[1:])
                    if action_temp != "":
                        action = action_temp
                        break
                if 'action' in act.lower() and 'is to' in act:
                    action_temp = act.split('is to')[1]
                    if action_temp != "":
                        action = action_temp
                        break
                        
        if action.strip() == "":
            action = origin_action.split('\n')[0]   # temperary comment this line for codellama
        action = action.strip()
        action = action.strip("'/")
        action = action.split('\n')[0]
        return action

    def run(self, Starting_prompt_task_explain, model_name_testLLM, init_prompt_dict=None, available_action_space = ''):
        print('Run func Line 173 normal!')
        if init_prompt_dict is not None:
            self.init_prompt_dict = init_prompt_dict
            self.instruction = init_prompt_dict['instruction']
            self.examples = init_prompt_dict['examples']
        system_message = self.init_prompt_dict['system_msg']; print('Run func Line 178 normal!')
        rear_prompt = self.make_prompt(Starting_prompt_task_explain,
                                        need_goal=self.need_goal,
                                        system_message=system_message,
                                        available_action_space = available_action_space)
        print('Run func Line 183 normal!')
        input_prompt = Starting_prompt_task_explain + rear_prompt
        self.log_example_prompt(input_prompt)
        print('Run func Line 186 normal!')
        #print(f'\nsystem_message: {system_message}\n')
        #print(f'input_prompt: {input_prompt}\n\n\n')
        #print('#################################')

        success, action = self.llm_model.generate(system_message, input_prompt, model_name_testLLM)
        print('Run func Line 192 normal!')

        print(f'Action: {action}\n')
        print(f'Success: {success}\n')
        if success and self.use_parser:
            action = self.action_parser_for_special_llms(action)

        return success, action, rear_prompt

    @classmethod
    def from_config(cls, llm_model, config):
        memory_size = config.get("memory_size", 100)
        instruction = config.get("instruction", "")
        examples = config.get("examples", [])
        init_prompt_path = config.get("init_prompt_path", None)
        system_message = config.get("system_message", "You are a helpful assistant.")
        check_actions = config.get("check_actions", None)
        check_inventory = config.get("check_inventory", None)
        use_parser = config.get("use_parser", True)
        need_goal = config.get("need_goal", False)
        return cls(llm_model, memory_size, examples, instruction, init_prompt_path, system_message, 
                   need_goal, check_actions, check_inventory, use_parser)