import json
from agents import load_agent
from environment import load_environment
from llm import load_llm
from common.registry import registry
import copy
from prompt_env9 import *

from utils.logging.logger import TaskLogger
from utils.logging.agent_logger import AgentLogger
logger = AgentLogger(__name__)


from .base_task import BaseTask


prefixes = {
    'pick_and_place': 'put',
    'pick_clean_then_place': 'clean',
    'pick_heat_then_place': 'heat',
    'pick_cool_then_place': 'cool',
    'look_at_obj': 'examine',
    'pick_two_obj': 'puttwo'
}



@registry.register_task("alfworld")
class Evalalfworld(BaseTask):
    def __init__(self,
                 llm_config=None,
                 agent_name='agent_name',
                 max_num_steps=30,
                 num_exams=134,
                 init_prompt_path='prompts/alfworld_base.json',
                 agent_config=None,
                 env_config=None,
                 llm = None,
                 baseline_dir = None,
                 log_path = None
                 ):
        
        super().__init__()
        
        ####################  initialize llm and agent ##################
        if llm is None:
            llm = load_llm(llm_config.get("name", "gpt"), llm_config)
        self.agent = load_agent(agent_name, agent_config, llm)
        #################################################################
        
        with open(init_prompt_path, 'r') as f:
            self.prompts = json.load(f)
        self.env_cfg = env_config
        self.max_num_steps = max_num_steps
        self.num_exams = num_exams
        
        self.baseline_dir = baseline_dir

    def parseAction(self, action):
        action = action.strip()
        if "put" in action:
            if " in " in action:
                action = action.replace(" in ", ' in/on ')
            elif " on " in action:
                action = action.replace(" on ", ' in/on ')
        if action.endswith('.'):
            action = action[:-1].strip()
        return action

    def evaluate_env(self,  prompt_task_explain, model_name_testLLM, index, ob='', examples=None):

        init_ob = ob.split('\n')[0]
        goal = ob.split('\n')[1].split("Your task is to:")[1].strip()
        
        self.agent.reset(goal=goal, init_obs=init_ob)
        logger.goal("Example {} | Goal: {}".format(index, self.agent.goal))
        init_prompt_dict = copy.deepcopy(self.prompts)
        init_prompt_dict['examples'] = examples
        reward = 0.
        last_reward = 0.
        done = False
        grounding_acc_count = 0
        score_change_record = []
        logger.info("Step {:02} - Message: {}".format(0, init_ob))
        
        trajectory = []
        trajectory.append({"Goal":goal, "id":0})
        trajectory.append({"Observation":init_ob, "id":0})

        rear_prompt_list = []  # The record list of all the rear prompts
        response_total_list = []  # The record list of all the responses
        env_act_feedback_list = []  # The record list of env act feedbacks
        dict_not_update_rounds = 0
        all_response_total_list = []  # The record list of every part of responses

        print(f'query_time_limit: {self.max_num_steps}')
        for i in range(0, self.max_num_steps):
            success_failure = ''
            print(f'available_action_space: \nValid actions in the current step: {self.env.get_action_space()}')
            success, action, rear_prompt = self.agent.run(prompt_task_explain, model_name_testLLM, init_prompt_dict=init_prompt_dict, available_action_space = f'\nValid actions in the current step: {self.env.get_action_space()}')
            rear_prompt_list.append(rear_prompt)
            all_response_total_list.append(action)

            if not success:
                break
            
            action = self.parseAction(action)
            response_total_list.append(action)
            print(f'\nValid actions in the current step: {self.env.get_action_space()}')
            if action in self.env.get_action_space():
                grounding_acc_count += 1.0
                print(f'\nAction in space: {action}\n')
                env_act_feedback = ''
            else:
                print(f'\nAction not in space: {action}\n')
                env_act_feedback = f'Your assigned action {action} is not in the doable action list: {self.env.get_action_space()}; \n'
            env_act_feedback_list.append(env_act_feedback)

            logger.info("Step {:02} - Action: {}".format(i, action))
            trajectory.append({"Action":action, "id":i})

            observation, reward, done, info = self.env.step(action)
            logger.info("Step {:02} - Observation: {}".format(i, observation))

            if last_reward == reward:
                dict_not_update_rounds += 1
            else:
                dict_not_update_rounds = 0
            if dict_not_update_rounds > 8:  # initially is 8,
                success_failure = 'Stuck in the local loop.'
                system_error_feedback_2 = 'It seems the LLM is stuck in the current situation, always repeating the same answer. The task is stuck too, no box is placed successfully in recent rounds.'
                feedback_to_promptLLM = feedback_to_promptLLM_func(rear_prompt_list[-2],
                                                                   response_total_list[-2],
                                                                   env_act_feedback_list[-2],
                                                                   rear_prompt_list[-1],
                                                                   response_total_list[-1],
                                                                   env_act_feedback_list[-1],
                                                                   error_feedback=system_error_feedback_2)
                break

            if "Task accomplished!" in observation and reward < 1.0:
                raise Exception("Task accomplished error")
            
            logger.info("Step {:02} - Progress Rate: {}\n".format(i, reward))
            
            trajectory.append({"Observation":observation, "id":i})
            trajectory.append({"Progress Rate":reward, "id":i})
            
            print(f'Step: {str(i)} Action: {action}\nObservation: {observation}\nEnv_act_feedback: {env_act_feedback}')
            print(f"reward: {reward}, isdone: {done}")
            
            if reward > last_reward:
                score_change_record.append((i, reward))
            last_reward = reward
            self.agent.update(action=action, state=observation, env_feed = env_act_feedback)
            if done:
                
                game_name = self.env.cur_task_name.split('/')[0]
                env_details = {"task_name": game_name, "goal": self.agent.goal, "difficulty": self.env.difficulty}
                break
                #return 1.0, True, grounding_acc_count / (i + 1), score_change_record, i

        if done:
            success_failure = 'success'
            feedback_to_promptLLM = 'The task is completed successfully.'
        elif success_failure == '':
            success_failure = 'failure over query time limit'
            system_error_feedback_3 = 'The task is not completed over the query time limit.'
            feedback_to_promptLLM = feedback_to_promptLLM_func(rear_prompt_list[-2],
                                                               response_total_list[-2],
                                                               env_act_feedback_list[-2],
                                                               rear_prompt_list[-1],
                                                               response_total_list[-1],
                                                               env_act_feedback_list[-1],
                                                               error_feedback=system_error_feedback_3)

        error_string = ''
        if success_failure != 'success':
            if len(rear_prompt_list) == 1:
                error_string = error_string_func_APO(rear_prompt_list[-1], response_total_list[-1])
            else:
                try:
                    error_string = error_string_func_APO(rear_prompt_list[-2], response_total_list[-2],
                                                         env_act_feedback_list[-2],
                                                         rear_prompt_list[-1], response_total_list[-1])
                except:
                    print('Length of rear_prompt_list: ', len(rear_prompt_list))
                    print('Length of response_total_list: ', len(response_total_list))
                    print('Length of env_act_feedback_list: ', len(env_act_feedback_list))
                    raise error

        progress_rate = reward

        return progress_rate, done, grounding_acc_count / (i + 1), score_change_record, i, success_failure, feedback_to_promptLLM, error_string
        #return progress_rate, done, grounding_acc_count / (i + 1), score_change_record, i

    def evaluate(self, prompt_task_explain, Saving_path_result, model_name_testLLM):
        self.env = load_environment('alfworld', self.env_cfg)
        scores = []
        score_state_records = []
        grounding_accs = []
        srs = []
        difficulties = []

        success_failure_list = []; feedback_to_promptLLM_list = []
        index_query_times_list = []; error_string_list = []

        for id in range(min(self.num_exams, 25)):

            ob, info = self.env.reset()
            ob = '\n'.join(ob[0].split('\n\n')[1:])
            name = '/'.join(info['extra.gamefile'][0].split('/')[-3:-1])
            #sub_goal = selected_obs[name]
            difficulties.append(self.env.difficulty)

            for i, (k, v) in enumerate(prefixes.items()):
                if name.startswith(k):
                    examples = "".join(self.prompts['examples'][v])

                    score, is_done, grounding_acc, score_change_record, steps, success_failure, feedback_to_promptLLM, error_string\
                        = self.evaluate_env(prompt_task_explain, model_name_testLLM, ob=ob, examples=examples, index=id)
                    if is_done:
                        srs.append(1.0)
                    else:
                        srs.append(0.0)
                    scores.append(score)
                    grounding_accs.append(grounding_acc)
                    score_state_records.append(score_change_record)

                    success_failure_list.append(success_failure)
                    feedback_to_promptLLM_list.append(feedback_to_promptLLM)
                    index_query_times_list.append(steps)

                    if error_string != '':
                        error_string_list.append(error_string)

                    logger.finish("Example {} | Success: {} , Progress Rate: {} , Steps: {}\n".format(id, is_done, score, steps))

            with open(Saving_path_result + f'/success_failure_{id}.txt', 'w') as f:
                f.write(success_failure)
            f.close()

            with open(Saving_path_result + f'/feedback_to_promptLLM_{id}.txt', 'w') as f:
                f.write(feedback_to_promptLLM)
            f.close()

            with open(Saving_path_result + f'/env_action_times_{id}.txt', 'w') as f:
                f.write(f'{steps + 1}')
            f.close()

        return srs, scores, grounding_accs, score_state_records, success_failure_list, feedback_to_promptLLM_list, index_query_times_list, error_string_list

    def _grounding_fn(self, action):

        if action not in self.env.GetValidActions():
            print(f"The wrong action is: {action}")
            return "check valid actions"
        else:
            return action

    @classmethod
    def from_config(cls,
                    run_config,
                    llm_config,
                    agent_config,
                    env_config,
                    llm = None  
                    ):

        agent_name = agent_config.get("name", "GPTAgent")
        init_prompt_path = agent_config.get("init_prompt_path", 'prompts/alfworld_in_context_learning.json') 
        max_num_steps = run_config.get("max_num_steps", 30)
        baseline_dir = run_config.get("baseline_dir", "data/baseline_results")
        # wandb = run_config.get("wandb", False)
        num_exams = run_config.get("num_exam", 134)
        log_path = run_config.get("log_path", None)
        return cls(
                   llm_config=llm_config,
                   agent_name=agent_name,
                   max_num_steps=max_num_steps,
                   num_exams=num_exams,
                   init_prompt_path=init_prompt_path,
                   agent_config=agent_config,
                   env_config=env_config,
                   llm = llm,
                   baseline_dir = baseline_dir,
                   log_path = log_path
                   )