from prompt import *
import json, re
from lxml import etree, html
from bs4 import BeautifulSoup
import bs4
#from AutoCrawl.utils.html_utils import simplify_html, find_common_ancestor

def find_common_ancestor(html_content:str, xpath:str):
    tree = etree.HTML(html_content)
    nodes = tree.xpath(xpath)
    # 获取每个节点的所有祖先节点
    ancestors_list = [set(n.xpath('ancestor::*')) for n in nodes]

    # 找到所有祖先集合的交集，即共同祖先
    common_ancestors = set.intersection(*ancestors_list)

    # 选择最近的共同祖先（即最后一个共同祖先）
    nearest_common_ancestor = max(common_ancestors, key=lambda x: x.getroottree().getpath(x).count('/'))
    ancestor_string = etree.tostring(nearest_common_ancestor, pretty_print=True, encoding='unicode')
    return ancestor_string

def simplify_html(html, reserve_attrs = ['class']):
    soup = BeautifulSoup(html, 'html.parser')
    for element in soup(text=lambda text: isinstance(text, bs4.Comment)):
        element.extract()
    [s.extract() for s in soup('script')]
    [s.extract() for s in soup('style')]
    for tag in soup.find_all():
        try:
            new_attrs = {}
            for attr in reserve_attrs:
                if attr in tag.attrs.keys(): 
                    new_attrs = {attr: tag.attrs[attr]}
            tag.attrs = new_attrs
        except:
            pass
    html = str(soup)
    return html

class AdaptiveCrawler:
    '''
    AutoCrawler is a LLM-based agent, which generate different kinds of rule, such as Xpath, CSS Selector, code for information extraction and web proning.
    '''
    def __init__(self, 
                 pattern='xpath', 
                 simplify=True, 
                 verbose=True,
                 api=None,
                 error_max_times=5):
        """Initial an instance of Autocrawler, including setting the pattern of rule 

        Args:
            pattern (str, optional): Which kind of rule pattern will be used. **Options: ['xpath', 'selector', 'code']**. Defaults to 'xpath'.
            simplify (bool, optional): Whether to simplify HTML before proprecessing. Defaults to True.
            verbose (bool, optional): Whether print the whole execution process. Defaults to True.
        """
        if api == None:
            raise ValueError("No api has been assigned!!")
        self.api = api

        if pattern not in ['xpath', 'selector', 'code']:
            raise AssertionError("Pattern must be one of the following selection: xpath, selector, code")
        self.rule_pattern = pattern
        self.is_simplify = simplify
        self.verbose = verbose
        if self.rule_pattern == 'xpath':
            self.prompter = Xpath_prompter()
        elif self.rule_pattern == 'selector':
            self.prompter = Selector_prompter()
        else:
            self.prompter = Code_prompter()
        self.error_max_times = error_max_times

    def request_parse(self, 
                      query: str, 
                      html: str,
                      keys: list[str] = []) -> dict[str, str]:
        """A safe and reliable call to LLMs, which confirm that the output can be parsed by json.loads().

        Args:
            query (str): the query to prompt the LLM
            html (str): the HTML text for 

        Returns:
            str: a dict parsed from the output of LLM
        """
        pattern = r'\{.*?\}'
        target = False
        for _ in range(self.error_max_times):
            response = self.api(query)
            matches = re.findall(pattern, response, re.DOTALL)
            #print(matches)
            try:
                #res = eval(response)
                for match in matches:
                    #print(match)
                    res = json.loads(match) # type: ignore
                    for key in keys:
                        res[key]
                    # if self.rule_pattern == 'xpath':
                    #     self.extract_with_xpath(html, res['xpath'])
                    #     target = True
                    # elif self.rule_pattern == 'selector':
                    #     self.extract_with_selector(html, res['selector'])
                    #     target = True
                    # else:
                    #     self.extract_with_code(html, res['code'])
                    #     target = True
                    target = True
                    if target == True:
                        break
                if target:
                    break
            except:
                pass
        if target:
        #print(res)
            return res
        else:
            return {key:"" for key in keys}
    
    def adaptive_generate(self, res, instruction, html_content):
        xpath_history = []
        histories = []
        #res['xpath'] = "/html/body//table//table//table//td/font/b"
        #res['xpath'] = '/html/body/td/font/b'
        for _ in range(5):
            history = {
                'expected value': res['value'],
                'thought': res['thought'],
                'xpath': res['xpath'],
                'result': str(self.extract_with_xpath(html_content, res['xpath']))
            }
            histories.append(history)
            print(histories)
            
            query = self.prompter.simple_reflection_prompt.format(instruction, json.dumps(histories, indent=4))
            print(query)
            action_res = self.request_parse(query, html_content, ['thought', 'action'])
            action = action_res['action']
            print(action)
            xpath_history.append((res['xpath'], action))
            if action == 'Accept':
                return True, xpath_history
            elif action == 'Re-generate':
                query = f"{self.prompter.role_prompt}\n{self.prompter.crawl_w_history.format(instruction, json.dumps(histories, indent=4), html_content)}"
                res = self.request_parse(query, html_content, ['thought', 'value', self.rule_pattern])
            elif action == 'Re-thinking':
                new_html_content = find_common_ancestor(html_content, res['xpath'])
                print(len(html_content))
                print(len(new_html_content))
                print(new_html_content)
                query = f"{self.prompter.role_prompt}\n{self.prompter.crawl_w_history.format(instruction, json.dumps(histories, indent=4),new_html_content)}"
                res = self.request_parse(query, new_html_content, ['thought', 'value', self.rule_pattern])
                html_content = new_html_content
            else:
                pass
    
    def generate_rule(self, 
                      instruction: str, 
                      html: str, 
                      repeat_times = 3,
                      with_reflection=True, 
                      reflection_times=3) -> str:
        """Generate rule by asking LLM with an instruction and HTML code.

        Args:
            instruction (str): Task description
            html (str): HTML code
            repeat_times: number of repeating times for extraction. Defaults to 3
            with_reflection (bool, optional): Whether generate rule with reflection. Defaults to True.
            reflection_times (int, optional): Number of cycles for reflection module to fix the rule. Defaults to 3.

        Returns:
            dict: the generated rule and the corresponding thought, if the rule is empty, return empty strings.
                Example:{
                    "thought": "",
                    "xpath": "",
                }
        """
        if self.is_simplify:
            html = simplify_html(html)

        query = f"{self.prompter.role_prompt}\n{self.prompter.crawler_prompt.format(instruction, html)}"

        rule_list = []
        for index in range(repeat_times):
            if self.verbose:
                print('*'*100)
                print(f'Trial {index + 1} for generating {self.rule_pattern}.')
                print()

            # An full execution for generating a rule
            res = self.request_parse(query, html, ['thought', 'value', self.rule_pattern])
            # Adaptive
            print(self.adaptive_generate(res, instruction, html))
            assert 1 == 0

            # Reflection Module
            if with_reflection:
                histories = []
                for reflection_index in range(reflection_times):
                    if self.rule_pattern == 'xpath':
                        history = {
                            'expected value': res['value'],
                            'thought': res['thought'],
                            'xpath': res['xpath'],
                            'result': str(self.extract_with_xpath(html, res['xpath']))
                        }
                    elif self.rule_pattern == 'selector':
                        history = {
                            'expected value': res['value'],
                            'thought': res['thought'],
                            'selector': res['selector'],
                            'result': str(self.extract_with_selector(html, res['selector']))
                        }
                    else:
                        history = {
                            'expected value': res['value'],
                            'thought': res['thought'],
                            'code': res['code'],
                            'result': str(self.extract_with_code(html, res['code']))
                        }
                    if self.verbose:
                        print(f'Reflection {reflection_index}:')
                        print(json.dumps(history, indent=4))
                        print()

                    histories.append(history)
                    query = self.prompter.reflection_prompt.format(instruction, json.dumps(histories, indent=4), html)

                    res = self.request_parse(query, html, ['thought', 'consistent', 'value', self.rule_pattern])
                    if self.verbose:
                        print(json.dumps(res, indent=4))
                    if res['consistent'].lower() == 'yes':
                        break
            if res['consistent'] == 'yes':
                if self.rule_pattern == 'xpath':
                    rule_list.append(res['xpath'])
                elif self.rule_pattern == 'selector':
                    rule_list.append(res['selector'])
                else:
                    rule_list.append(res['code'])
            else:
                rule_list.append('')

        # Choose one of the best rule from different generation trail.
        if repeat_times > 1:
            ret_dict = {}
            for xpath in rule_list:
                if self.rule_pattern == 'xpath':
                    ret_dict[xpath] = self.extract_with_xpath(html, xpath)
                elif self.rule_pattern == 'selector':
                    ret_dict[xpath] = self.extract_with_selector(html, xpath)
                else:
                    ret_dict[xpath] = self.extract_with_code(html, xpath)
            query = self.prompter.comparison_prompt.format(instruction, json.dumps(ret_dict, ensure_ascii=False, indent=4))
            if self.verbose:
                print('-'*50)
                print(f'Choose one of the best {self.rule_pattern} for a single HTML:')
                #print(query)
            res = self.request_parse(query, html, ['thought', self.rule_pattern])
            if self.rule_pattern == 'xpath':
                rule = res['xpath']
            elif self.rule_pattern == 'selector':
                rule = res['selector']
            else:
                rule = res['code']
        else:
            rule = rule_list[0]
        if self.verbose:
            print(f'Generated {self.rule_pattern} for the webpage')
            print(self.rule_pattern, ':', res)
        return rule
    
    def rule_synthesis(self, 
                       website_name: str,
                       seed_html_set: list[str], 
                       instruction: str, 
                       per_page_repeat_time=1) -> str:
        rule_list = []

        # Collect a rule from each seed webpage
        for html in seed_html_set:
            page_rule = self.generate_rule(instruction, html, repeat_times=per_page_repeat_time)
            rule_list.append(page_rule)

        rule_list = list(set(rule_list))
        #print(rule_list)
        if len(seed_html_set) > 1:
            # Parse the webpage with each rule
            extract_result = {}
            for rule in rule_list:
                extract_result[rule] = []
                for html in seed_html_set:
                    if self.rule_pattern == 'xpath':
                        extract_result[rule].append(self.extract_with_xpath(html, rule))
                    elif self.rule_pattern == 'selector':
                        extract_result[rule].append(self.extract_with_selector(html, rule))
                    else:
                        extract_result[rule].append(self.extract_with_code(html, rule))
            
            if self.verbose:
                print('+' * 100)
                print(f"Systhesis rule for the website {website_name}")
                print(json.dumps(extract_result, ensure_ascii=False, indent=4))

            # Sythesis the rule
            query = self.prompter.synthesis_prompt.format(instruction, json.dumps(extract_result, indent=4))
            res = self.request_parse(query, seed_html_set[0], ['thought', self.rule_pattern])
            if self.verbose:
                print(f'Systhesis rule:')
                print(res)

            if self.rule_pattern == 'xpath':
                return res['xpath']
            elif self.rule_pattern == 'selector':
                return res['selector']
            else:
                return res['code']
        else:
            return rule_list[0]

    def extract_with_xpath(self, 
                           html_content:str, 
                           xpath:str) -> list[str]:
        """Xpath Parser

        Args:
            html_content (str): text of HTML
            xpath (str): the string of xpath

        Returns:
            list[str]: result extracted by xpath
        """
        if self.is_simplify:
            html_content = simplify_html(html_content)
        if xpath.strip():
            ele = etree.HTML(html_content) # type: ignore
            #return [item.text for item in ele.xpath(xpath)]
            return [item if isinstance(item, str) else item.text for item in ele.xpath(xpath)]
        else:
            return []
        
    def extract_with_seq(self,
                         html_content:str,
                         xpath_seq:str) -> list[str]:
        if self.is_simplify:
            html_content = simplify_html(html_content)
        if xpath_seq == []:
            return []
        else:
            for xpath, action in xpath_seq:
                ele = etree.HTML(html_content)
                if action == 'Accept':
                    return [item if isinstance(item, str) else item.text for item in ele.xpath(xpath)]
                elif action == 'Re-thinking':
                    html_content = find_common_ancestor(html_content, xpath)
                elif action == 'Re-generate':
                    pass

    def extract_with_selector(self, 
                              html_content: str,
                              selector:str) -> list[str]:
        """CSS Selector parser

        Args:
            html_content (str): text of HTML
            selector (str): the string of CSS selector

        Returns:
            list[str]: result list extracted by css selector
        """
        if self.is_simplify:
            html_content = simplify_html(html_content)
        if selector.strip():
            tree = html.fromstring(html_content)
            #return [item.text for item in ele.xpath(xpath)]
            #print([item if isinstance(item, str) else item.text for item in tree.cssselect(selector)])
            return [item if isinstance(item, str) else item.text for item in tree.cssselect(selector)]
        else:
            return []

    def extract_with_code(self,
                          html_content: str,
                          code: str) -> list[str]:
        """Code parser

        Args:
            html_content (str): _description_
            code (str): _description_

        Returns:
            list[str]: _description_
        """

        if self.is_simplify:
            html_content = simplify_html(html_content)

        if code.strip():
            #print(code)
            exec(code, globals())
            extracted_value = extract_value(html_content)
            return extracted_value
        else:
            return []
        

if __name__ == '__main__':
    import requests
    ip = 'http://10.176.64.117:8080'
    def ms_chatgpt(query):
        url = f'{ip}/query'
        ret_str = requests.post(url, json={
            'query': query
        })
        return json.loads(ret_str.content.decode('utf-8'))['response']
    xe = AdaptiveCrawler(api = ms_chatgpt)
    with open('/mnt/data122/harryhuang/swde/sourceCode/movie/movie-boxofficemojo(2000)/0435.htm') as f:
        html_content = f.read()

    instruction = "Here's a webpage with detail information of a movie. Please extract the title of the movie. It's worth noticing that the candidate attribute values are the non-empty strings contained in text nodes in the corresponding DOM tree, and one page may contain multiple distinct values that correspond to an attribute."
    xe.generate_rule(instruction, html_content, repeat_times=1)
    #print(xe.extract_with_seq(html_content, [('/html/body//table//table//table//td/font/b', 'Re-thinking'),('/html/body/td/font/b', 'Accept')]))