import os
import pdb
import json
import re
import time
from llm import load_llm
from agents import load_agent
import random
from environment import load_environment
import jsonlines
from common.registry import registry
from prompt_env9 import *

from utils.logging.logger import TaskLogger
from utils.logging.agent_logger import AgentLogger
logger = AgentLogger(__name__)

from .base_task import BaseTask


@registry.register_task("scienceworld")
class EvalScienceworld(BaseTask):
    def __init__(self,
                 llm_name="gpt",
                 llm_config=None,
                 agent_name="GPTAgent",
                 agent_config=None,
                 env_config=None,
                 run_config=None,
                 llm=None,
                 baseline_dir = None,
                 log_path = None
                 ):
        
        super().__init__()
        
        if llm is None:
            llm = load_llm(llm_name, llm_config)
        self.agent = load_agent(agent_name, agent_config, llm)
        self.simplefied = env_config.get("simplefied", False)
        seed = env_config.get("seed", 42)
        self.set_seed(seed)
        self.simplification_str = self.build_simplification_str()
        self.env_cfg = env_config
        
        # change the name max_episode to max_num_step for consistency
        self.max_num_steps = run_config.get("max_num_steps", 30)
        self.context_length = llm_config.get("context_length")

        self.baseline_dir = baseline_dir

        
    def build_simplification_str(self):

        simplifications = list()
        simplifications.append("selfWateringFlowerPots")
        simplifications.append("openContainers")
        simplifications.append("openDoors")
        simplifications.append("noElectricalAction")

        return ",".join(simplifications)

    def set_seed(self, seed):
        random.seed(seed)

    def evaluate_env(self, prompt_task_explain, model_name_testLLM, index, task_name, var, modified_goal):
        self.env.load(task_name, var, simplificationStr=self.simplification_str)
        initialObs, initialDict = self.env.reset()
        init_obs = initialObs + f"\n{self.env.inventory()}"
        self.agent.reset(goal=modified_goal, init_obs=init_obs)
        reward = 0.
        last_reward = 0.
       # print(init_obs)
        logger.info("Step {:02} - Observation: {}".format(0, init_obs))
        grounding_acc_count = 0
        score_change_record = []
        isDone = False
        
        trajectory = []
        trajectory.append({"Goal":modified_goal, "id":0})
        trajectory.append({"Observation":init_obs, "id":0})

        rear_prompt_list = []  # The record list of all the rear prompts
        response_total_list = []  # The record list of all the responses
        env_act_feedback_list = []  # The record list of env act feedbacks
        dict_not_update_rounds = 0
        all_response_total_list = []  # The record list of every part of responses

        print(f'query_time_limit: {self.max_num_steps}')
        
        for i in range(self.max_num_steps):
            success_failure = ''
            print(f'available_action_space: \nValid actions in the current step: {self.env.get_action_space()}')
            success, action, rear_prompt = self.agent.run(prompt_task_explain, model_name_testLLM, available_action_space = f'\nValid actions in the current step: {self.env.get_action_space()}')
            logger.info("Step {:02} - Action: {}".format(i, action))
            trajectory.append({"Action":action, "id":i})
            rear_prompt_list.append(rear_prompt)
            all_response_total_list.append(action)
            response_total_list.append(action)
            
            if not success:
                break

            observation, reward, isDone, info = self.env.step(action)
            if action in self.env.get_action_space(abstract=False):
                grounding_acc_count += 1
                print(f'\nAction in space: {action}\n')
                env_act_feedback = ''
            else:
                print(f'\nAction not in space: {action}\n')
                env_act_feedback = f'Your assigned action {action} is not in the doable action list: {self.env.get_action_space()}; \n'
            env_act_feedback_list.append(env_act_feedback)
                
            #print(f"step: {i} ACTION: {action}\nOBSERVATION: {observation}")
            logger.info("Step {:02} - Observation: {}".format(i, observation))
            logger.info("Step {:02} - Progress Rate: {}\n".format(i, reward))
            
            trajectory.append({"Observation":observation, "id":i})
            trajectory.append({"Progress Rate":reward, "id":i})

            if last_reward == reward:
                dict_not_update_rounds += 1
            else:
                dict_not_update_rounds = 0
            if dict_not_update_rounds > 16:  # initially is 8,
                success_failure = 'Stuck in the local loop.'
                system_error_feedback_2 = 'It seems the LLM is stuck in the current situation, always repeating the same answer. The task is stuck too, no box is placed successfully in recent rounds.'
                feedback_to_promptLLM = feedback_to_promptLLM_func(rear_prompt_list[-2],
                                                                   response_total_list[-2],
                                                                   env_act_feedback_list[-2],
                                                                   rear_prompt_list[-1],
                                                                   response_total_list[-1],
                                                                   env_act_feedback_list[-1],
                                                                   error_feedback=system_error_feedback_2)
                break

            if reward > last_reward:
                score_change_record.append((i, reward))
            last_reward = reward
            if isDone:
                
                env_details = {"task_name": task_name, "goal": self.agent.goal, "difficulty": self.env.difficulty}

                break
                #return 1.0, True, grounding_acc_count / (i + 1), score_change_record, i
            
            self.agent.update(action=action,
                              state=observation,
                              env_feed = env_act_feedback)

        if isDone:
            success_failure = 'success'
            feedback_to_promptLLM = 'The task is completed successfully.'
        elif success_failure == '':
            success_failure = 'failure over query time limit'
            system_error_feedback_3 = 'The task is not completed over the query time limit.'
            feedback_to_promptLLM = feedback_to_promptLLM_func(rear_prompt_list[-2],
                                                               response_total_list[-2],
                                                               env_act_feedback_list[-2],
                                                               rear_prompt_list[-1],
                                                               response_total_list[-1],
                                                               env_act_feedback_list[-1],
                                                               error_feedback=system_error_feedback_3)

        error_string = ''
        if success_failure != 'success':
            if len(rear_prompt_list) == 1:
                error_string = error_string_func_APO(rear_prompt_list[-1], response_total_list[-1])
            else:
                try:
                    error_string = error_string_func_APO(rear_prompt_list[-2], response_total_list[-2],
                                                         env_act_feedback_list[-2],
                                                         rear_prompt_list[-1], response_total_list[-1])
                except:
                    print('Length of rear_prompt_list: ', len(rear_prompt_list))
                    print('Length of response_total_list: ', len(response_total_list))
                    print('Length of env_act_feedback_list: ', len(env_act_feedback_list))
                    raise error

        progress_rate = reward

        return progress_rate, isDone, grounding_acc_count / (i + 1), score_change_record, i, success_failure, feedback_to_promptLLM, error_string

    def evaluate(self, prompt_task_explain, model_name_testLLM, Saving_path_result):
        scores = []
        self.env = load_environment("scienceworld", self.env_cfg)
        labels = self.env.labels
        count = 0
        scores = []
        score_state_records = []
        grounding_accs = []
        srs = []
        
        difficulties = []

        success_failure_list = []; feedback_to_promptLLM_list = []
        index_query_times_list = []; error_string_list = []
        
        for index, (k, v) in enumerate(labels.items()):
            if index < 20:
                task_name = v["task_name"]
                var = v["var"]
                modified_goal = v["modified_goal"]

                #print(f"Starting Task: {task_name}, variation: {var}, goal: {modified_goal}")
                score, done, grounding_acc, score_change_record, num_steps, success_failure, feedback_to_promptLLM, error_string = self.evaluate_env(prompt_task_explain, model_name_testLLM, index, task_name, var, modified_goal)

                difficulties.append(self.env.difficulty)
                count += 1
                if done:
                    srs.append(1.0)
                else:
                    srs.append(0.0)
                scores.append(score)
                grounding_accs.append(grounding_acc)
                score_state_records.append(score_change_record)

                success_failure_list.append(success_failure)
                feedback_to_promptLLM_list.append(feedback_to_promptLLM)
                index_query_times_list.append(num_steps)

                if error_string != '':
                    error_string_list.append(error_string)

                with open(Saving_path_result + f'/success_failure_{index}.txt', 'w') as f:
                    f.write(success_failure)
                f.close()

                with open(Saving_path_result + f'/feedback_to_promptLLM_{index}.txt', 'w') as f:
                    f.write(feedback_to_promptLLM)
                f.close()

                with open(Saving_path_result + f'/env_action_times_{index}.txt', 'w') as f:
                    f.write(f'{num_steps + 1}')
                f.close()
            else:
                break

        return  srs, scores, grounding_accs, score_state_records, success_failure_list, feedback_to_promptLLM_list, index_query_times_list, error_string_list

    def _grounding_fn(self, action):
        valid_actions = self.env.GetValidActions()
        return "check valid actions" if action not in valid_actions else action

    @classmethod
    def from_config(cls,
                    run_config,
                    llm_config,
                    agent_config,
                    env_config,
                    llm=None
                    ):
        llm_name = llm_config.get("name", "gpt")
        agent_name = agent_config.get("name", "GPTAgent")
        baseline_dir = run_config.get("baseline_dir", "data/baseline_results")
        log_path = run_config.get("log_path", None)
                
        return cls(llm_name=llm_name,
                   llm_config=llm_config,
                   agent_name=agent_name,
                   agent_config=agent_config,
                   env_config=env_config,
                   run_config=run_config,
                   llm=llm,
                   baseline_dir=baseline_dir,
                   log_path = log_path
                   )


