import random
import re, string, os
import json
import time
import tiktoken
from langchain import OpenAI, Wikipedia
from langchain.docstore.base import Docstore
from langchain.agents.react.base import DocstoreExplorer
from langchain.prompts import PromptTemplate
from collections import Counter

from pre_prompt import behavior_prompt_search, behavior_prompt_click
from request_llms import request_chatgpt
from get_google_search import google_rapidapi

token_enc = tiktoken.get_encoding("cl100k_base")

all_datas = []

def parse_action(string):
    pattern = r'^(\w+)\[(.+)\]$'
    match = re.match(pattern, string)

    if match:
        action_type = match.group(1)
        argument = match.group(2).strip("\"")
        return action_type, argument
    else:
        action_type, argument = fuzzy_parse_action(string)
        return action_type, argument


def fuzzy_parse_action(text):
    text = text.strip(' ').strip('.')
    pattern = r'^(\w+)\[(.+)\]'
    match = re.match(pattern, text)
    if match:
        action_type = match.group(1)
        argument = match.group(2)
        return action_type, argument
    else:
        return text, ''


def format_step(step: str) -> str:
    return step


def truncate_scratchpad(scratchpad: str, n_tokens: int = 1600, tokenizer=token_enc) -> str:
    lines = scratchpad.split('\n')
    observations = filter(lambda x: x.startswith('Observation'), lines)
    observations_by_tokens = sorted(observations, key=lambda x: len(tokenizer.encode(x)))
    while len(token_enc.encode('\n'.join(lines))) > n_tokens:
        largest_observation = observations_by_tokens.pop(-1)
        ind = lines.index(largest_observation)
        lines[ind] = largest_observation.split(':')[0] + ': [truncated wikipedia excerpt]'
    return '\n'.join(lines)


def normalize_answer(s):
    def remove_articles(text):
        return re.sub(r"\b(a|an|the)\b", " ", text)

    def white_space_fix(text):
        return " ".join(text.split())

    def remove_punc(text):
        exclude = set(string.punctuation)
        return "".join(ch for ch in text if ch not in exclude)

    def lower(text):
        return text.lower()

    return white_space_fix(remove_articles(remove_punc(lower(s))))


def f1_score(prediction, ground_truth):
    normalized_prediction = normalize_answer(prediction)
    normalized_ground_truth = normalize_answer(ground_truth)

    ZERO_METRIC = (0, 0, 0)

    if normalized_prediction in ['yes', 'no', 'noanswer'] and normalized_prediction != normalized_ground_truth:
        return ZERO_METRIC
    if normalized_ground_truth in ['yes', 'no', 'noanswer'] and normalized_prediction != normalized_ground_truth:
        return ZERO_METRIC

    prediction_tokens = normalized_prediction.split()
    ground_truth_tokens = normalized_ground_truth.split()
    common = Counter(prediction_tokens) & Counter(ground_truth_tokens)
    num_same = sum(common.values())
    if num_same == 0:
        return ZERO_METRIC
    precision = 1.0 * num_same / len(prediction_tokens)
    recall = 1.0 * num_same / len(ground_truth_tokens)
    f1 = (2 * precision * recall) / (precision + recall)
    return f1, precision, recall


def EM(answer, key) -> bool:
    return normalize_answer(answer) == normalize_answer(key)


class BaseAgent:
    def __init__(self,
                 profile: str,
                 context_len: int = 2000,
                 ) -> None:

        self.profile = profile
        self.agent_prompt = ""
        self.examples = ""
        self.context_len = context_len
        self.run_error = False
        self.name = "Base_Agent"
        self.query=''
        self.titles=''
        self.enc = token_enc
        self.__reset_agent()

    def run(self, index=0, reset=True) -> None:
        if reset:
            self.__reset_agent()

        self.session["profile"] = json.dumps(self.profile)
        self.session["interaction"] = []
        while not self.is_halted() and not self.is_finished() and not self.run_error:
            self.step()
            
       
        return self.session

    def prompt_agent(self, now_act) -> str:
        generation = request_chatgpt(self._build_agent_prompt(now_act))
        self.check_run_error(generation)
        return format_step(generation)

    def check_run_error(self, text):
        if text in ["No response"]:
            self.run_error = True

    def is_finished(self) -> bool:
        return self.finished

    def reward(self) -> float:
        return f1_score(self.answer, self.key)

    def is_halted(self) -> bool:
        return ((self.step_n > self.max_steps)
                or (len(self.enc.encode(self._build_agent_prompt("click"))) > self.context_len)
                ) and not self.finished

    def __reset_agent(self) -> None:
        self.step_n = 1
        self.finished = False
        self.scratchpad: str = ''
        self.session = {}

    def _action(self):
        self.scratchpad += f"[Round {self.step_n}]"
        action = self.prompt_agent("query")
        action_type, argument = parse_action(action)
        while action_type == 'Finish' and self.step_n < 3:
            print('Error!!! Early finish! Action again...')
            action = self.prompt_agent("query")
            action_type, argument = parse_action(action)
       
        print(f"{action_type}: {argument}")
        return action_type, argument

    def _click(self):
      
        import re
        self.click_id = self.prompt_agent("click").strip('.')
        self.click_id = re.findall(r'\d+', self.click_id)[0]
        
        if len(self.click_id) > 2:
            if "web page number " in self.click_id:
                index = self.click_id.find('web page number ')
                length = len("web page number ")
                self.click_id = self.click_id[index+length:self.click_id[index+length:].find(",")]
                print("Error!!! Click repaired: ", self.click_id)
        self.scratchpad += f"\nquery: {self.query}"
        self.scratchpad += f"\nclicked: " + self.click_id + f" (title: {self.result_list[int(self.click_id)-1]['title']})" + "\n"
        print(f"\nclicked: " + self.click_id + f" (title: {self.result_list[int(self.click_id)-1]['title']})" + "\n")

    def delete_observation(self):
        self.scratchpad = self.scratchpad.strip()
        list_scratchpad = self.scratchpad.split('\n')
        self.scratchpad = list_scratchpad[:-12] + [list_scratchpad[-1]]
        self.scratchpad = '\n'.join(self.scratchpad)
       
       


    def step(self) -> None:
        ret = self.forward()
        if ret:
            action_type, argument = ret[0], ret[1]
            self.query = argument
        else:
            action_type = ret

        if action_type == 'Finish':
            self.finished = True
          
            self.step_n += 1
            return

        if action_type == 'Search':
            while True:
                try:
                    results, result_list = google_rapidapi(argument)
                    self.titles =  format_step(results)
                    self.result_list = result_list
                    
                    break
                except Exception as e:
                    print(e)
            self._click()
        else:
            self.scratchpad += 'Invalid Action. Valid Actions are Search[<topic>] and Finish[finish].'
            

        self.session["interaction"].append({"step": self.step_n, "action": {"type": action_type, "key": argument},
                                            "observation": result_list, "click": self.click_id})
        self.step_n += 1

    def _build_agent_prompt(self) -> str:
        raise NotImplementedError

    def forward(self):
        raise NotImplementedError


class SearchAgent(BaseAgent):
    def __init__(self,
                 profile: str,
                 context_len: int = 2000,
                 
                 ) -> None:
        super().__init__(profile, context_len)

        self.examples = ""
        import random
        self.agent_prompt_search = behavior_prompt_search
        self.agent_prompt_click = behavior_prompt_click
        rnd = random.random()
        self.max_exceeds_times = 2
        if rnd > 0.5 and rnd <= 0.8:
            self.max_exceeds_times = 3
        elif rnd <= 0.9 and rnd > 0.8:
            self.max_exceeds_times = 4
        elif rnd <= 0.95 and rnd > 0.9:
            self.max_exceeds_times = 5
        elif rnd <= 0.97 and rnd > 0.95:
            self.max_exceeds_times = 6
        elif rnd <= 0.99 and rnd > 0.97:
            self.max_exceeds_times = 7
        elif rnd > 0.99:
            self.max_exceeds_times = 8

        self.max_steps=self.max_exceeds_times
        print("="*50, f"Max Round:{self.max_steps}", "="*50)
        self.name = "Search_Agent"
    def forward(self):
        action_type, argument = self._action()
        return action_type, argument

    def _build_agent_prompt(self, now_act) -> str:
        if now_act == "query":
            return self.agent_prompt_search.format(
                profile=self.profile,
                scratchpad=self.scratchpad,
                max_exceeds_times=str(self.max_exceeds_times))
        if now_act == "click":
            return self.agent_prompt_click.format(
                profile=self.profile,
                scratchpad=self.scratchpad,
                query=self.query,
                titles=self.titles)

