import argparse
import os
import re
from bs4 import BeautifulSoup
# disable tensorflow logging
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
import json
import time

import openai
import torch
from tqdm import trange, tqdm
from pathlib import Path

from excel_api.run import LLMClient
from googlesearch import search
import tiktoken
from run_inference import gpt4_call,CODE_GENERATION_EXAMPLE
from datasets import load_dataset


codex_name_mapping = {
    "codex-cushman": "code-cushman-001",
    "codex002": "code-davinci-002",
    "codex001": "code-davinci-001",
}

# with open("ds1000_stackoverflow_id.json", "r") as f:
#     DS_1000_stackoverflow_id=json.load(f)


def normalize(raw_text: str):
    # replace the html symbols '<' '>' '&'
    if "&lt;" in raw_text:
        raw_text = raw_text.replace("&lt;", "<")
    if "&gt;" in raw_text:
        raw_text = raw_text.replace("&gt;", ">")
    if "&amp;" in raw_text:
        raw_text = raw_text.replace("&amp;", "&")

    return raw_text


CURRENT_KEY_ID = 0

# global model, tokenizer for inference, see `incoder_inference`
model = None
tokenizer = None


llm_client = LLMClient()


def gpt4_inference(output_dir: Path, prefix: str, system,suffix: str,args):
    """
    Calls OpenAI API to get Codex responses and save prediction text and logprobs
    """
    # if outputs are already cached, we skip this sample
    for batch_start in range(0, args.num_samples, args.batch_size):
        all_cached = True
        for batch_i in range(min(args.num_samples - batch_start, args.batch_size)):
            if not all(
                    [
                        not args.overwrite_output_dir,
                        (output_dir / f"{batch_start + batch_i}.py").is_file(),
                        (output_dir / f"{batch_start + batch_i}.json").is_file(),
                    ]
            ):
                all_cached = False
        if all_cached:
            continue

        response = gpt4_call(
            prefix,
            system,
            CODE_GENERATION_EXAMPLE,
            None,
            batch_size=min(args.num_samples - batch_start, args.batch_size),
            args=args
        )

        # storing responses
        for batch_i in range(min(args.num_samples - batch_start, args.batch_size)):
            with open(
                    output_dir / f"{batch_start + batch_i}.py", "w", encoding="UTF-8"
            ) as f:
                if not args.is_azure:
                    f.write(response["choices"][batch_i]["text"] if "choices" in response else "")
                else:
                    f.write(response["choices"][batch_i]["message"]["content"] if "choices" in response else "")

            result = dict()
            if not args.is_azure:
                result["trg_prediction"] = response["choices"][batch_i]["text"] if "choices" in response else ""
                result["logprobs"] = response["choices"][batch_i]["logprobs"][
                    "token_logprobs"
                ] if "choices" in response else []
                result["tokens"] = response["choices"][batch_i]["logprobs"]["tokens"] if "choices" in response else []
            else:
                result["trg_prediction"] = response["choices"][batch_i]["message"]["content"] if "choices" in response else ""
            result["prompt"] = prefix

            with open(output_dir / f"{batch_start + batch_i}.json", "w") as f:
                json.dump(result, f)

import html2text as ht
text_maker = ht.HTML2Text()
text_maker.bypass_tables = False
text_maker.ignore_links = True
ENCODING = tiktoken.encoding_for_model("gpt-4")
def get_retrieve_prompt(query,search_results,content_path,stackoverflow_id_filter=None):
    """
    get retrieve prompt
    """
    if not os.path.exists(content_path):
        return None
    with open(content_path,'r') as f:
        content=json.load(f)
    content_prompt=None
    if "type" in content and content["type"]=="stackoverflow":
        if str(content["question"]["items"][0]["question_id"])==str(stackoverflow_id_filter):
            print("Filter stackoverflow_id_filter -----------------------------------------------------------")
            return None
        content_prompt = "\n".join([
            f"StackOverflow title: {content['question']['items'][0]['title']}",
            f"StackOverflow question score: {content['question']['items'][0]['score']}",
            f"StackOverflow answer count: {content['question']['items'][0]['answer_count']}",
            f"StackOverflow question:",
            text_maker.handle(str(content['question']['items'][0]['body'])), # There are strange strings in body_markdown
            f"StackOverflow top answers:",
            "\n\n".join([f"StackOverflow Answer with {answer['score']} score (is_accepted: {answer['is_accepted']}):"+"\n"+text_maker.handle(str(answer['body'])) for answer in content['answers']['items'][:3]]),
            ])
    else:
        soup = BeautifulSoup(content["html"], 'html.parser')
        # 定位正文信息，这里使用CSS选择器，找到id为article的div标签
        article = soup.select_one('article') # TODO: what about more than 1 article?
        if article==None:
            article = soup.select_one('section') # https://scikit-learn.org/stable/modules/generated/sklearn.utils.Bunch.html
            if article==None:
                # article= soup.find('div', class_='body', role='main')
                article= soup.find('div', class_='body') # https://matplotlib.org/3.4.3/gallery/ticks_and_spines/major_minor_demo.html
                if article==None:
                    article = soup.find('div', id='main') # https://www.w3schools.com/python/ref_string_isalpha.asp
                    if article==None:
                        return None
        
        content_prompt = text_maker.handle(str(article))
        # content_prompt = article.get_text()
    prompt = "\n".join([
        f"For the query [{query}], there are the background information from the website [{search_results['url']}]:",
        content_prompt,
    ])
    if len(ENCODING.encode(prompt))>10000: # don't use too long prompt
        # TODO: how to find the useful section in https://pandas.pydata.org/docs/user_guide/indexing.html 
        # print(query, search_results['url'])
        prompt=None
    return prompt

def model_inference(output_dir: Path, prefix: str, system=None, suffix: str = None,args=None):
    """
    provide an interface for different models to use this inference framework
    """
    global model, tokenizer
    os.makedirs(output_dir, exist_ok=True)
    if "gpt" in args.model:
        # queries=get_query(output_dir,prefix)
        # google_search(output_dir,queries)
        gpt4_inference(output_dir, prefix, system,suffix,args)
    else:
        # put your model's interface here
        pass


def inference(classeval, args):
    """
    A function for model inference

    Input:
    `ds1000` is a `DS1000Dataset` object
    """
    for problem_id in tqdm(range(len(classeval['test']['skeleton']))):
        output_path=(args.output_dir
            / args.model
            / ("q" + str(problem_id)))
        retrieve_path=(Path(args.retrieve_dir)
            / args.model
            / ("q" + str(problem_id)))
        # if (not args.overwrite_output_dir) and os.path.exists(output_paht/"google_search_results.json") and os.path.exists(output_paht/"query.json"):
        #     continue
        with open(retrieve_path/"query.json", "r") as f:
            queries=json.load(f)
        with open(retrieve_path/"plan.json", "r") as f:
            plan=f.read()
        with open(retrieve_path/"google_search_results.json", "r") as f:
            search_results=json.load(f)
        
        # # Load the first website content
        # file_list=os.listdir(output_paht)
        # w_id_min=100
        # contents=[]
        # for q_id,query in enumerate(queries):
        #     for file in file_list:
        #         if file.startswith(f'query{q_id}_website') and file.endswith('.json'):
        #             w_id = int(file[-5])
        #             if w_id<w_id_min:
        #                 w_id_min=w_id
        #     if w_id_min==100:
        #         contents.append(None)
        #     else:
        #         with open(output_paht/f'query{q_id}_website{w_id_min}.json', "r") as f:
        #             contents.append(json.load(f))
        #         search_results[query]=search_results[query][w_id_min]
        retrieve_prompt=[]
        urls=[]
        for q_id,query in enumerate(queries):
            for w_id,website in enumerate(search_results[query]):
                content_prompt=get_retrieve_prompt(query,search_results[query][w_id],retrieve_path/f'query{q_id}_website{w_id}.json')#,DS_1000_stackoverflow_id[lib][problem_id])
                if content_prompt!=None:
                    if search_results[query][w_id]["url"] not in urls: # If the same url for different query, only use the first one
                        urls.append(search_results[query][w_id]["url"])
                        retrieve_prompt.append(content_prompt)
                    break
        
        system="""
Provided below is a Python class skeleton PROBLEM with several functions inside. As a Python Expert, please complete the class in the subsequenct code. 
Note that previously there already had a round of conversation about this PROBLEM, and you made a PLAN of it and came up with a QUERY that needs to be searched. I've searched for the background information you might need. You can selectively refer to it when writing your code.

There are some rules that you must follow for writing the codes:
+ Background knowledge is for reference only and not all of the information you need to use in your code., please focus on code completion.
+ Your response must follow the format of ```python\n <code> \n```.
+ Do not change the skeleton of the class.
+ Please complete all required functions in the class. Do not leave a function blank. 
"""
        prefix = "\n\n".join([
            "PROBLEM:\n",
            classeval['test']['skeleton'][problem_id],
            "-----------END PROBLEM---------",
            "Here's the plan you made earlier and the query to search for:  ",
            plan,
            "-----------END PLAN---------",
            "I've searched for the background information you might need. You can selectively refer to it when writing your code, noting that not all of the information you need to use in your code. The following information is the markdown text of the main information on the corresponding website.",
            "\n\n".join(retrieve_prompt),
            "---------END SEARCHED INFO----------",
            "Again, the PROBLEM is as follows:",
            classeval['test']['skeleton'][problem_id],
            "-----------END PROBLEM---------",
            "Now Answer the PROBLEM as the same format as your former answers."
            ])
        
        model_inference(
            output_path,
            prefix,
            system,
            "",
            args
        )


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--model",
        type=str,
        default="gpt4",
        choices=["codex-cushman", "codex001", "codex002", "incoder-1B", "gpt4","gpt-35-turbo-16k-0613","gpt-4-turbo","gpt-4-32k-0613"],
        help="Type of Codex Model to run",
    )
    parser.add_argument(
        "--is_azure",
        action="store_true",
        default=False,
        help="Is model in azure?",
    )
    parser.add_argument(
        "--output_dir",
        type=str,
        default="gpt4_retrieve_outputs_1208_2",
        help="Path to which the Codex responses will be cached at",
    )
    parser.add_argument(
        "--retrieve_dir",
        type=str,
        default="gpt4_retrieve_outputs_1205_old",
        help="Path to retrieve results",
    )
    parser.add_argument(
        "--source_dir",
        type=str,
        default="text2analysis_1220.csv",
        help="Path to the downloaded DS-1000 data",
    )
    parser.add_argument(
        "--num_samples",
        type=int,
        default=1,
        help="Number of Codex samples to draw for one question.",
    )
    parser.add_argument(
        "--overwrite_output_dir",
        action="store_true",
        default=False,
        help="By default will skip cached samples; Turn on this flag to overwrite the existing caches.",
    )
    parser.add_argument(
        "--temperature",
        type=float,
        default=0.0,
        help="Temperature of the Codex sampling distribtuion.",
    )
    parser.add_argument(
        "--top_p",
        type=float,
        default=0.95,
        help="Top-p cutoff of the Codex sampling distribtuion",
    )
    parser.add_argument(
        "--max_tokens",
        type=int,
        default=1024,
        help="Number of maximum tokens for Codex to generate",
    )
    parser.add_argument(
        "--batch_size",
        type=int,
        default=1,
        help="Number of requests to issue at one time",
    )
    parser.add_argument(
        "--libs",
        type=str,
        nargs="+",
        default="all",
        help="Specify the subset of DS-1000 to run on. e.g., specifying `--libs Numpy Pandas` will only inference on the Numpy and Pandas split of DS-1000",
    )
    args = parser.parse_args()
    args.output_dir = Path(args.output_dir)
    args.source_dir = Path(args.source_dir)
    classeval = load_dataset("FudanSELab/ClassEval")
    print("loaded dataset")
    inference(classeval, args)
