import subprocess
import json
import os
import pdb
import logging
import requests
from pathlib import Path
from typing import List, Dict, Any, Tuple, Union
from common.registry import registry
import traceback
from beartype import beartype
from beartype.door import is_bearable
from playwright._impl._api_types import TimeoutError
from agents import load_agent
from agents.vanilla_agent import VanillaAgent
from llm import load_llm
from environment.browser_env.actions import Action, create_id_based_action, create_playwright_action, \
    ActionParsingError, create_none_action, ActionTypes, create_stop_action
from environment.browser_env.utils import StateInfo
from environment.browser_env.help_function import RenderHelper, map_url_to_real, extract_action, \
    log_progress_score, transform_format, get_action_description, early_stop
from environment.browser_env.evaluation_function import (
    evaluator_router,
    progress_evaluator_router,
)
from environment import load_environment
from utils.logging.logger import TaskLogger
from prompt_env9 import *

from .base_task import BaseTask


@registry.register_task("webarena")
class EvalWebBrowse(BaseTask):
    def __init__(self,
                 llm_name="gpt",
                 llm_config=None,
                 env_config=None,
                 agent_name="VanillaAgent",
                 agent_config=None,
                 max_num_steps=30,
                 parsing_failure_th=3,
                 repeating_action_failure_th=3,
                 task_name=None,
                 start_test_id=0,
                 test_case_count=20,
                 label_path=None,
                 llm=None,
                 baseline_dir = None,
                 log_path = None
                 ):
        
        super().__init__()
        
        if llm is None:
            llm = load_llm(llm_name, llm_config)
        self.agent = load_agent(agent_name, agent_config, llm)
        self.max_num_steps = max_num_steps
        self.parsing_failure_th = parsing_failure_th
        self.repeating_action_failure_th = repeating_action_failure_th
        self.early_stop_thresholds = {
            "parsing_failure": self.parsing_failure_th,
            "repeating_action": self.repeating_action_failure_th,
        }
        self.task_name = task_name
        self.llm_config = llm_config
        self.output_log_dir = log_path
        self.result_dir = os.path.join(log_path, f'logs/{task_name}_tracks')
        # webarena provide detailed trajectory. (e.g. html_screenshot ,trace and error.log)
        os.makedirs(self.result_dir, exist_ok=True)
        self.start_test_id = start_test_id,
        self.test_case_count = test_case_count,
        self.label_path = label_path
        self.test_file_list = []
        self.difficulties = []
        self.env_config = env_config
        self.render_screenshot = self.env_config["render_screenshot"]
        self.render_helper = None
        self.action_set_tag = self.env_config["action_set_tag"]
        
        self.baseline_dir = baseline_dir


    def get_test_list(self):
        start_test_id = int(self.start_test_id[0])
        test_case_count = int(self.test_case_count[0])
        label_path = self.label_path  # Result directory
        with open(label_path, 'r', encoding='utf-8') as infile:
            for line in infile:
                jsonl_item = json.loads(line.strip())
                json_file = transform_format(jsonl_item)
                self.test_file_list.append(json_file)
                self.difficulties.append(jsonl_item["difficulty"])
            line_count = len(self.test_file_list)
        self.test_file_list = self.test_file_list[start_test_id: min(start_test_id + test_case_count, line_count)]

        print(f"Total {len(self.test_file_list)} tasks left")

    def evaluate_env(self, prompt_task_explain, model_name_testLLM, idx: int):
        config_file = self.test_file_list[idx]
        meta_data = {"action_history": ["None"]}
        template = """WINDOWED PAGE:{{
{observation}
}}
URL: {url}"""
        last_reward = 0.
        grounding_error_count = 0
        score_change_record = []
        error_happened = 0
        error_obs = None
        reset_session = 0
        step_id = 0
        trajectory = []
        
        trajectory.append({"Goal":self.env.goal, "id":0})

        rear_prompt_list = []  # The record list of all the rear prompts
        response_total_list = []  # The record list of all the responses
        env_act_feedback_list = []  # The record list of env act feedbacks
        dict_not_update_rounds = 0
        all_response_total_list = []  # The record list of every part of responses

        print(f'query_time_limit: {self.max_num_steps}')
        
        while step_id < self.max_num_steps:
            success_failure = ''
            obs_step = self.env.state["observation"]["text"]
            page_step = self.env.state["info"]["page"]
            url_step = page_step.url
            current_step = template.format(
                url=map_url_to_real(url_step),
                observation=obs_step,
            )
            current_state = str(current_step)

           
            try:
                early_stop_flag, action_invalid, stop_info = early_stop(
                    self.env.history, self.early_stop_thresholds
                )
                print('Line 141 normal!')
                if early_stop_flag:
                    action = create_stop_action("N/A")
                    if action_invalid:
                        grounding_error_count += 1
                    break
                print('Line 147 normal!')
                if step_id == 1 and not reset_session:
                    # reset logout problem
                    if "sign in" in obs_step.lower() and not url_step.startswith("https"):
                        try:
                            # reset env
                            completed_process = subprocess.run(["bash", "./scripts/prepare_webbrowse.sh"], check=True)
                            print("Script executed successfully!")
                            self.env.reset(options={"config_file": config_file})
                            reset_session = 1
                            continue

                        except subprocess.CalledProcessError as e:
                            print(f"Error occurred while executing script: {str(e)}")
                print('Line 161 normal!')
                if step_id:
                    if isinstance(self.agent, VanillaAgent):
                        #self.agent.update(action=response, state=current_state, env_feed = env_act_feedback)
                        self.agent.update(action=response, state=current_state, env_feed=env_act_feedback)
                elif step_id == 0:
                    self.agent.reset(self.env.goal, init_obs=current_state, init_act=None)
                print('Line 168 normal!')
                success, response, rear_prompt = self.agent.run(prompt_task_explain, model_name_testLLM, init_prompt_dict=None, available_action_space = f'\nValid actions in the current step: {self.env.get_action_space()}')
                print(f'Raw response: {response}')
                rear_prompt_list.append(rear_prompt)
                all_response_total_list.append(response)
                print('Line 173 normal!')
                try:
                    parsed_response = extract_action(response)
                    response_total_list.append(parsed_response)
                    trajectory.append({"Action":parsed_response, "id":step_id})
                    
                    if self.action_set_tag == "id_accessibility_tree":
                        action = create_id_based_action(parsed_response)
                    elif self.action_set_tag == "playwright":
                        action = create_playwright_action(parsed_response)
                    else:
                        raise ValueError(f"Unknown action type {self.action_set_tag}")
                    action["raw_prediction"] = response
                except ActionParsingError as e:
                    error_happened = 1
                    error_obs = str(e)
                    action = create_none_action()
                    action["raw_prediction"] = response
                print('Line 191 normal!')
            except ValueError as e:
                # get the error message
                action = create_stop_action(f"ERROR: {str(e)}")
            print('Line 195 normal!')
            self.env.history.append(action)
            action_str = get_action_description(
                action,
                self.env.state["info"]["observation_metadata"],
                action_set_tag=self.action_set_tag,
            )
            self.render_helper.render(
                action, self.env.state, meta_data, self.render_screenshot
            )
            meta_data["action_history"].append(action_str)
            if action["action_type"] == ActionTypes.STOP:
                print('Action == stop!')
                break
            result_dic, progress_evaluators = progress_evaluator_router(config_file)
            print('Line 210 normal!')
            try:
                obs, _, _, info = self.env.step(action)
                #print(f'obs: {obs}')
                print(f'Parsed action: {parsed_response}')
                if 'Cannot parse action' in obs:
                    print('Failure! Cannot parse action!')
                else:
                    print('Success! Can parse action!')

                #if parsed_response in self.env.get_action_space():
                print('Line 221 normal!')
                if 'Cannot parse action' not in obs:
                    print(f'\nAction in space: {parsed_response}\n')
                    env_act_feedback = ''
                else:
                    print(f'\nAction not in space: {parsed_response}\n')
                    env_act_feedback = f'Your assigned action {parsed_response} is not in the doable action list: {self.env.get_action_space()}; \n'

                print(f'self.env.get_action_space(): {self.env.get_action_space()}')
                env_act_feedback_list.append(env_act_feedback)
                print('Line 231 normal!')
                if error_happened:
                    obs["text"] = error_obs
                    error_happened = 0
                    grounding_error_count += 1
                print('Line 236 normal!')
                reward = self.env.reward
                if last_reward == reward:
                    dict_not_update_rounds += 1
                else:
                    dict_not_update_rounds = 0
                if dict_not_update_rounds > 16:  # initially is 8,
                    success_failure = 'Stuck in the local loop.'
                    system_error_feedback_2 = 'It seems the LLM is stuck in the current situation, always repeating the same answer. The task is stuck too, no box is placed successfully in recent rounds.'
                    feedback_to_promptLLM = feedback_to_promptLLM_func(rear_prompt_list[-2],
                                                                       response_total_list[-2],
                                                                       env_act_feedback_list[-2],
                                                                       rear_prompt_list[-1],
                                                                       response_total_list[-1],
                                                                       env_act_feedback_list[-1],
                                                                       error_feedback=system_error_feedback_2)
                    break
                print('Line 253 normal!')
                if reward > last_reward:
                    score_change_record.append((step_id, reward))
                last_reward = reward
                
                trajectory.append({"Observation":obs["text"], "id":step_id})
                trajectory.append({"Reward":last_reward, "id":step_id})

                print('Line 256 normal!')
            except TimeoutError as e:
                raise
            except Exception as e:
                raise
            step_id += 1; print('Line 261 normal!')
            if step_id == self.max_num_steps:
                print('Line 263 normal!')
                self.env.history.append(create_stop_action(""))
                break
            print('Line 266 normal!')
            if last_reward == 1.0:  # early stop when success
                print('Line 268 normal!')
                break

        print('Line 271 normal!')
        if last_reward == 1.0:
            success_failure = 'success'
            feedback_to_promptLLM = 'The task is completed successfully.'
        elif success_failure == '':
            success_failure = 'failure over query time limit'
            system_error_feedback_3 = 'The task is not completed over the query time limit.'
            feedback_to_promptLLM = feedback_to_promptLLM_func(rear_prompt_list[-2],
                                                               response_total_list[-2],
                                                               env_act_feedback_list[-2],
                                                               rear_prompt_list[-1],
                                                               response_total_list[-1],
                                                               env_act_feedback_list[-1],
                                                               error_feedback=system_error_feedback_3)

        error_string = ''
        if success_failure != 'success':
            if len(rear_prompt_list) == 1:
                error_string = error_string_func_APO(rear_prompt_list[-1], response_total_list[-1])
            else:
                try:
                    error_string = error_string_func_APO(rear_prompt_list[-2], response_total_list[-2],
                                                         env_act_feedback_list[-2],
                                                         rear_prompt_list[-1], response_total_list[-1])
                except:
                    print('Length of rear_prompt_list: ', len(rear_prompt_list))
                    print('Length of response_total_list: ', len(response_total_list))
                    print('Length of env_act_feedback_list: ', len(env_act_feedback_list))
                    raise error

        print('Line 298 normal!')
        evaluator = evaluator_router(config_file)
        success = evaluator(
            trajectory=self.env.history,
            config_file=config_file,
            page=self.env.page,
        )
        progress_score = self.env.progress_score
        success = 1.0 if progress_score == 1.0 else success

        grounding_acc = (step_id + 1 - grounding_error_count) / (step_id + 1)

        print('Line 310 normal!')
        return progress_score, success, grounding_acc, score_change_record, step_id + 1, success_failure, feedback_to_promptLLM, error_string

    def evaluate(self, prompt_task_explain, Saving_path_result, model_name_testLLM):
        print("Session Refreshing!")
        try:
            completed_process = subprocess.run(["bash", "./scripts/prepare_webbrowse.sh"], check=True)
            print("Session Refreshed!")
        except Exception as e:
            print("Error: Session refreshed failed. Please check logs/webarena_tracks/error.txt for more details.")
        self.env = load_environment('BrowserEnv', self.env_config)
        if not (Path(self.result_dir) / "traces").exists():
            (Path(self.result_dir) / "traces").mkdir(parents=True)
        self.get_test_list()
        scores = []
        success = 0
        steps = 0
        progress_scores = []
        grounding_acc_avg = []
        score_state_avg = []

        success_failure_list = []; feedback_to_promptLLM_list = []
        index_query_times_list = []; error_string_list = []

        for idx, config_file in enumerate(self.test_file_list):
            if idx < 20:
                try:
                    self.render_helper = RenderHelper(
                        config_file, self.result_dir, self.action_set_tag
                    )
                    self.env.reset(options={"config_file": config_file})
                    task_id = self.env.env_config.get('task_id', None)

                    progress_score, success, grounding_acc, score_state, steps, success_failure, feedback_to_promptLLM, error_string = self.evaluate_env(prompt_task_explain, model_name_testLLM, idx=idx)
                    progress_scores.append(progress_score)
                    scores.append(success)
                    print('*Trial result*/'*10)
                    print(f'Trial idx: {idx}, grounding_acc: {grounding_acc}, progress_score: {progress_score}, success: {success}')
                    grounding_acc_avg.append(grounding_acc)
                    score_state_avg.append(score_state)

                    success_failure_list.append(success_failure)
                    feedback_to_promptLLM_list.append(feedback_to_promptLLM)
                    index_query_times_list.append(steps)

                    if error_string != '':
                        error_string_list.append(error_string)

                    with open(Saving_path_result + f'/success_failure_{idx}.txt', 'w') as f:
                        f.write(success_failure)
                    f.close()

                    with open(Saving_path_result + f'/feedback_to_promptLLM_{idx}.txt', 'w') as f:
                        f.write(feedback_to_promptLLM)
                    f.close()

                    with open(Saving_path_result + f'/env_action_times_{idx}.txt', 'w') as f:
                        f.write(f'{steps + 1}')
                    f.close()

                    if self.env_config["save_trace_enabled"]:
                        self.env.save_trace(
                            Path(self.result_dir) / "traces" / f"{task_id}.zip"
                        )

                except TimeoutError as e:
                    with open(Path(self.result_dir) / "error.txt", "a") as f:
                        print(f"[Config file id]: {config_file['task_id']}\n")
                        print(f"[Timeout Error] {repr(e)}\n")

                        f.write(f"[Config file id]: {config_file['task_id']}\n")
                        f.write(f"[Timeout Error] {repr(e)}\n")
                        f.write(traceback.format_exc())  # write stack trace to file
                except requests.ConnectionError as e:
                    with open(Path(self.result_dir) / "error.txt", "a") as f:
                        print(f"[Config file id]: {config_file['task_id']}\n")
                        print(f"[Timeout Error] {repr(e)}\n")

                        f.write(f"[Config file id]: {config_file['task_id']}\n")
                        f.write(f"[Connection Error] {repr(e)}\n")
                        f.write(traceback.format_exc())
                except Exception as e:
                    progress_scores.append(self.env.progress_score)
                    scores.append(0.0)
                    with open(Path(self.result_dir) / "error.txt", "a") as f:
                        print(f"[Config file id]: {config_file['task_id']}\n")
                        print(f"[Timeout Error] {repr(e)}\n")

                        f.write(f"[Config file id]: {config_file['task_id']}\n")
                        f.write(f"[Unhandled Error] {repr(e)}\n")
                        f.write(traceback.format_exc())
                    break
                self.render_helper.close()
            else:
                break

        self.env.close()

        print(f'Length of grounding_acc_avg: {len(grounding_acc_avg)}')

        return scores, progress_scores, grounding_acc_avg, score_state_avg, success_failure_list, feedback_to_promptLLM_list, index_query_times_list, error_string_list

    @classmethod
    def from_config(cls,
                    run_config,
                    llm_config,
                    agent_config,
                    env_config,
                    llm=None
                    ):
        llm_name = llm_config.get("name", "gpt")
        agent_name = agent_config.get("name", "VanillaAgent")
        max_num_steps = run_config.get("max_num_steps", 30)
        baseline_dir = run_config.get("baseline_dir", "data/baseline_results")
        log_path = run_config.get("log_path", None)
        task_name = env_config.get("name", 'webarena')
        parsing_failure_th = env_config.get("parsing_failure_th", 3)
        repeating_action_failure_th = env_config.get("repeating_action_failure_th", 3)
        label_path = env_config.get("label_path", './data/webarena/test.jsonl')
        start_test_id = env_config.get("start_test_id", 0)
        test_case_count = env_config.get("test_case_count", 5)

        return cls(llm_name=llm_name,
                   llm_config=llm_config,
                   env_config=env_config,
                   agent_name=agent_name,
                   agent_config=agent_config,
                   max_num_steps=max_num_steps,
                   parsing_failure_th=parsing_failure_th,
                   repeating_action_failure_th=repeating_action_failure_th,
                   task_name=task_name,
                   log_path=log_path,
                   start_test_id=start_test_id,
                   test_case_count=test_case_count,
                   label_path=label_path,
                   llm=llm,
                   baseline_dir=baseline_dir
                   )
