import os
import time
import argparse

import requests
from pathlib import Path
import json
from collections import defaultdict
from bs4 import BeautifulSoup
from tqdm import tqdm
from run_retrieve2 import get_retrieve_prompt
from datasets import load_dataset


HEADERS=[{'user-agent': 'Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/119.0.0.0 Safari/537.36'},
            {'user-agent':'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/119.0.0.0 Safari/537.36'},
            {'user-agent':'Mozilla/5.0 (Macintosh; U; Intel Mac OS X 10_6_8; en-us) AppleWebKit/534.50 (KHTML, like Gecko) Version/5.1 Safari/534.50'},
            {'user-agent':'Mozilla/5.0 (Macintosh; Intel Mac OS X 10.6; rv2.0.1) Gecko/20100101 Firefox/4.0.1'},
            {'user-agent':'Opera/9.80 (Macintosh; Intel Mac OS X 10.6.8; U; en) Presto/2.8.131 Version/11.11'},
            {'user-agent':'Opera/9.80 (Windows NT 6.1; U; en) Presto/2.8.131 Version/11.11'},
            {'user-agent':'Mozilla/5.0 (Macintosh; Intel Mac OS X 10_7_0) AppleWebKit/535.11 (KHTML, like Gecko) Chrome/17.0.963.56 Safari/535.11'}]

def get_content(path,url):
    content=''
    if not os.path.exists(path):
        cnt=0
        while cnt<1:
            headers = HEADERS[cnt]
            if "www.cse.wustl.edu" in url:
                break
            response = requests.get(url, headers=headers)
            if response.status_code==200:
                content=response.text
                with open(path, 'w') as f:
                    json.dump({"html":content}, f)
                break
            else:
                print('请求失败，状态码为：', response.status_code,cnt, url)
            cnt+=1 
    else:
        with open(path,'r') as f:
            content=json.load(f)["html"]

    soup = BeautifulSoup(content, 'html.parser')
    # 定位正文信息，这里使用CSS选择器，找到id为article的div标签
    article = soup.select_one('article') # TODO: what about more than 1 article?
    if article==None:
        article = soup.select_one('section') # https://scikit-learn.org/stable/modules/generated/sklearn.utils.Bunch.html
        if article==None:
            # article= soup.find('div', class_='body', role='main')
            article= soup.find('div', class_='body') # https://matplotlib.org/3.4.3/gallery/ticks_and_spines/major_minor_demo.html
            if article==None:
                article = soup.find('div', id='main') # https://www.w3schools.com/python/ref_string_isalpha.asp
                if article==None:
                    return None
    text = article.get_text()
    return text

def get_stackoverflow_content(path,url):
    if not os.path.exists(path):
        if url.startswith("https://stackoverflow.com/questions/"):
            question_id=url.split("/")[4]
        else:
            print(f"Wrong stackoverflow url: {url}")
            return None
        
        content={"type":"stackoverflow", "question":None, "answers":None}
        # See more in https://api.stackexchange.com/docs
        answer_url = f"https://api.stackexchange.com/2.3/questions/{question_id}/answers?order=desc&sort=votes&site=stackoverflow&filter=!nNPvSNe7D9"
        response = requests.get(answer_url,headers=HEADERS[1])
        time.sleep(2) # https://api.stackexchange.com/docs/throttle

        if response.status_code == 200:
            data = response.json()
            # for answer in data["items"]:
            #     content["answers"].append({"body":answer["body"],
            #                                "body_markdown":answer["body_markdown"],
            #                                "score":answer["score"],
            #                                "title":answer["title"],
            #                                "answer_count":answer["answer_count"],})
            content["answers"]=data
        else:
            print(f"Error: Unable to retrieve answers: {answer_url}")

        question_url=f"https://api.stackexchange.com/2.3/questions/{question_id}?order=desc&sort=activity&site=stackoverflow&filter=!nNPvSNP3wf"
        response = requests.get(question_url,headers=HEADERS[1])
        time.sleep(2)
        if response.status_code == 200:
            data = response.json()
            # content["question"]={"body":data["items"][0]["body"],
            #                      "body_markdown":data["items"][0]["body_markdown"],
            #                      "score":data["items"][0]["score"],
            #                      "is_accepted":data["items"][0]["is_accepted"]}
            content["question"]=data
        else:
            print(f"Error: Unable to retrieve question: {question_url}")
        if content["question"]!=None and content["answers"]!=None:
            with open(path, 'w') as f:
                json.dump(content, f)
        else:
            return None
    else:
        with open(path,'r') as f:
            content=json.load(f)
    return content

def print_list(url):
        url_sta=defaultdict(list)
        for u in url:
            url_sta["/".join(u.split("/")[:3])].append(u)

        # 按照list的长度降序排序dict的key
        sorted_keys = sorted(url_sta, key=lambda k: len(url_sta[k]), reverse=True)

        accumulate=0
        # 遍历排序后的key，打印出key和对应list的长度
        for k in sorted_keys:
            print(k, f"{(len(url_sta[k])+accumulate)/len(url)} ({len(url_sta[k])} / {len(url)})")
            print(url_sta[k][0])
            accumulate+=len(url_sta[k])
    
if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--model",
        type=str,
        default="gpt4",
        choices=["codex-cushman", "codex001", "codex002", "incoder-1B", "gpt4","gpt-35-turbo-16k-0613","gpt-4-turbo","gpt-4-32k-0613"],
        help="Type of Codex Model to run",
    )
    parser.add_argument(
        "--output_dir",
        type=str,
        default="gpt4_retrieve_outputs_1205",
        help="Path to which the Codex responses will be cached at",
    )
    args = parser.parse_args()
    args.output_dir = Path(args.output_dir)

    exp_path= args.output_dir
    model= args.model 
    classeval = load_dataset("FudanSELab/ClassEval")
    statistics=defaultdict(list)  # (baseline,refined):[{"lib": lib, "problem_id": problem_id}]
    url=[]
    url_with_article=[]

    query_with_content=0
    for problem_id in tqdm(range(len(classeval['test']['skeleton']))):
        retrieve_res = (
                Path(exp_path) / model /("q" + str(problem_id))/"google_search_results.json"
        )
        if not os.path.exists(retrieve_res):
            continue
        with open(retrieve_res) as f:
            retrieve_res=json.load(f)
        for q_id,query in enumerate(retrieve_res):
            for w_id,website in enumerate(retrieve_res[query][:3]):
                content_path=Path(exp_path) / model/("q" + str(problem_id))/f"query{q_id}_website{w_id}.json"
                # print(content_path)
                content=None
                if website["url"].startswith("https://stackoverflow.com/"):
                    content=get_stackoverflow_content(content_path,website["url"])
                else:
                    content=get_content(content_path,website["url"])
                output=get_retrieve_prompt(query,website,content_path)
                if output!=None:
                    url_with_article.append(website["url"])
                    query_with_content+=1
                    break
                else:
                    url.append(website["url"])
                    
    print_list(url)
    print("___________________________")
    print_list(url_with_article)

    print(f"query_with_content: {query_with_content}")