import glob
import json, os, sys
from tqdm import tqdm
from utils.html_utils import *
from module.stepback_crawler import StepbackCrawler
from module.prompt import *
from run_swde.task_prompt import swde_prompt
from lxml import html

SCHEMA = {
    'auto': ['model', 'price', 'engine', 'fuel_economy'],
    'book': ['title', 'author', 'isbn_13', 'publisher', 'publication_date'],
    'camera': ['model', 'price', 'manufacturer'],
    'job': ['title', 'company', 'location', 'date_posted'],
    'movie': ['title', 'director', 'genre', 'mpaa_rating'],
    'nbaplayer': ['name', 'team', 'height', 'weight'],
    'restaurant': ['name', 'address', 'phone', 'cuisine'],
    'university': ['name', 'phone', 'website', 'type']
}

DATA_HOME = '/mnt/data122/harryhuang/swde/sourceCode'
OUTPUT_HOME = 'dataset/swde/codellama/seq'
pattern = 'seq'

def compress_with_sequence(html_content:str,
                          sequence:str):
    if sequence == []:
        return None
    else:
        tot_len = len(sequence)
        for index, xpath in enumerate(sequence):
            if index != tot_len - 1:
                try:
                    html_content = find_common_ancestor(html_content, xpath)
                except:
                    pass
            else:
                return html_content

def get_max_depth(element, current_depth=0):
    """
    递归函数，用于获取元素的最大深度。
    """
    
    max_depth = current_depth
    for child in element:
        child_depth = get_max_depth(child, current_depth + 1)
        if child_depth > max_depth:
            max_depth = child_depth
    return max_depth

length_ratio = []
height_ratio = []

with open(os.path.join(OUTPUT_HOME, 'result_overall.json')) as f:
    result_overall = json.load(f)

for field in SCHEMA.keys():

    for website_path in glob.glob(os.path.join(DATA_HOME, field, '*')):
        website_name = website_path.split('/')[-1].split('(')[0]
        print(website_name)
        if not os.path.exists(os.path.join(OUTPUT_HOME, field, website_name) + f'_{pattern}.json'):
            continue
        webpage_list = glob.glob(os.path.join(website_path, '*'))
        
        xpath_rule = {}
        sorted(webpage_list)
        # if not os.path.exists(os.path.join(OUTPUT_HOME, field, website_name) + f'_{pattern}.json'):
        #     continue
        with open(os.path.join(OUTPUT_HOME, field, website_name) + f'_{pattern}.json', 'r') as f:
            xpath_rule = json.load(f)

        for webpage in tqdm(webpage_list[:1]):
            web_index = webpage.split('/')[-1].replace('.htm','')
        
            with open(webpage, 'r') as f:
                html_content = f.read()
                html_content = simplify_html(html_content)
                root = html.fromstring(html_content)
                origin_height = get_max_depth(root)
            
            new_res = {'page': web_index}
            for item in SCHEMA[field]:
                if result_overall[field][website_name][item]['F1'] != 1.0:
                    continue
                #print(xpath_rule)
                new_html_content = compress_with_sequence(html_content, xpath_rule[item])
                if new_html_content:
                    root = html.fromstring(new_html_content)
                    new_height = get_max_depth(root)
                    length_ratio.append(len(new_html_content)/len(html_content))
                    height_ratio.append(new_height/origin_height)

print(len(length_ratio))
print('Length compress ratio:', sum(length_ratio)/len(length_ratio))
print('Height compress ratio:', sum(height_ratio)/len(height_ratio))