import argparse
import os
import re

# disable tensorflow logging
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
import json
import time

import openai
import torch
from tqdm import trange
from pathlib import Path

from excel_api.run import LLMClient
from googlesearch import search

from run_inference import gpt4_call
from datasets import load_dataset

codex_name_mapping = {
    "codex-cushman": "code-cushman-001",
    "codex002": "code-davinci-002",
    "codex001": "code-davinci-001",
}


def normalize(raw_text: str):
    # replace the html symbols '<' '>' '&'
    if "&lt;" in raw_text:
        raw_text = raw_text.replace("&lt;", "<")
    if "&gt;" in raw_text:
        raw_text = raw_text.replace("&gt;", ">")
    if "&amp;" in raw_text:
        raw_text = raw_text.replace("&amp;", "&")

    return raw_text


CURRENT_KEY_ID = 0

# global model, tokenizer for inference, see `incoder_inference`
model = None
tokenizer = None


llm_client = LLMClient()


def get_query(path: Path,problem:str, args=None):
    """
    get query from path
    """
    system="""
Help me with the following problem, You need to write python code to solve the following problem. Please plan the steps you would need to take and write each step as a query. I can help you to search for relevant information online, if the query needs to be searchable, mark <search>. I can help you with google search through which you can search for real time information, python library document, error reporting information etc.    

Please return the queries that need to be searched in google.                   
+ First, [PLAN] plan the steps you would need to take and write each step as a query. Then, [SEARCH] list the query from [PLAN] that need to search.            
+ You only need to plan that can complete the code snippet. You do not need to plan the codes before BEGIN SOLUTION block.        
+ You can search for real-time information, python library documents, error messages, common usage, and other information.                   
+ Don't return duplicate query with similar semantics, return different queries.                          
+ Don't tag to search simple query that can be solved by yourself, return the most critical queries.   

For each problem given, there will be a class with several functions inside you need to write subsequenct code. Please follow the rules below when you [PLAN] and [SEARCH]:
+ Do not PLAN and SEARCH the function with name: __init__(self), this function has been initialized for you as the setting of the class.
+ For each function in the class you need to implement, only SEARCH the query that you are unsure of the implementation. 
+ For each function in the class you need to implement, you must limit the search up to 3 queries.

"""
    example=[
        {
            "role":"user",
            "content":"""
PROBLEM:
import math

class AdvancedCalculator:
    '''
    This class is an advanced calculator that can calculate square root and exponential.
    '''

    def __init__(self):
        pass

    def calculator(self, number, base, exponent):
        '''
        Calculate the square root and exponential of a number.
        :param number: int, the number to calculate the square root
        :param base: int, the base of the exponential
        :param exponent: int, the exponent of the exponential
        :return: tuple, the square root and exponential of inputs
        >>> calculator = AdvancedCalculator()
        >>> calculator.calculator(4, 2, 3)
        (2, 8)
        '''
    
    def calculate_square_root(self, number):
        '''
        Calculate the square root of a number.
        :param number: int, the number to calculate the square root
        :return: int, the square root of the number
        >>> calculator = AdvancedCalculator()
        >>> calculator.calculate_square_root(4)
        2
        '''

    def calculate_exponential(self, base, exponent):
        '''
        Calculate the exponential of a number.
        :param base: int, the base of the exponential
        :param exponent: int, the exponent of the exponential
        :return: int, the exponential of the number
        >>> calculator = AdvancedCalculator()
        >>> calculator.calculate_exponential(2, 3)
        8
        ''' """
        },
        {
            "role":"assistant",
            "content":"""
[PLAN]
1. Function: calculator 
1.1 call the function of calculate_square_root and calculate_exponential.
1.2 return the result as a tuple format.
2. Function: calculate_square_root
2.1 take in a number and return its square root using the math.sqrt function.
3. Function: calculate_exponential
3.1 take in a base and an exponent and return the result of raising the base to the power of the exponent using the math.pow function.

[SEARCH]
1. Function: calculator 
1.1 No need to search.
1.2 No need to search.
2. Function: calculate_square_root
2.1 Need to search, and the search query is <search> Python math module documentation for the sqrt function. </search>
3. Function: calculate_exponential
3.2 Need to search, and the search query is <search> how to calculate exponential using Python </search>
"""
        }]
    
    response = gpt4_call(
            problem,system=system, example=example,
            args=args)

    pattern = re.compile(r"<search>(.*?)</search>")
    content=response["choices"][0]["text"] if not args.is_azure else response["choices"][0]["message"]["content"]
    queries = pattern.findall(content)
    queries=[i.strip() for i in queries if i.strip()!="query"]
    queries=list(set(queries))
    with open(path/"query.json", "w") as f:
        json.dump(queries, f)
    with open(path/"plan.json", "w") as f:
        json.dump(content, f)
    return queries

def google_search(path,queries):
    """
    google search
    """
    search_results = {}
    for query in queries:
        search_results[query] = []
        first=True
        for j in search(query,advanced=True, num_results=3):
            search_results[query].append(j.__dict__)
            if first:
                first=False
                # print(query,j.__dict__["url"])
    with open(path/"google_search_results.json", "w") as f:
        json.dump(search_results, f)
    return search_results
    

def model_inference(output_dir: Path, prefix: str, suffix: str = None, args=None):
    """
    provide an interface for different models to use this inference framework
    """
    global model, tokenizer
    os.makedirs(output_dir, exist_ok=True)
    if "gpt" in args.model:
        queries=get_query(output_dir,prefix, args)
        google_search(output_dir,queries)
        # gpt4_inference(output_dir, prefix, suffix)
    else:
        # put your model's interface here
        pass


def inference(classeval, args):
    """
    A function for model inference

    Input:
    `ds1000` is a `DS1000Dataset` object
    """
    for problem_id in trange(len(classeval['test']['skeleton'])):
        prefix= 'PROBLEM:\n' + classeval['test']['skeleton'][problem_id]
        suffix=""
        output_paht=(args.output_dir
            / args.model
            / ("q" + str(problem_id)))
        if (not args.overwrite_output_dir) and os.path.exists(output_paht/"google_search_results.json") and os.path.exists(output_paht/"query.json"):
            continue
        model_inference(
            output_paht,
            prefix,
            suffix,
            args,
        )


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--model",
        type=str,
        default="gpt4",
        choices=["codex-cushman", "codex001", "codex002", "incoder-1B", "gpt4","gpt-35-turbo-16k-0613","gpt-4-turbo","gpt-4-32k-0613"],
        help="Type of Codex Model to run",
    )
    parser.add_argument(
        "--is_azure",
        action="store_true",
        default=False,
        help="Is model in azure?",
    )
    parser.add_argument(
        "--output_dir",
        type=str,
        default="gpt4_retrieve_outputs_1205",
        help="Path to which the Codex responses will be cached at",
    )
    parser.add_argument(
        "--source_dir",
        type=str,
        default="text2analysis_1220.csv",
        help="Path to the downloaded DS-1000 data",
    )
    parser.add_argument(
        "--num_samples",
        type=int,
        default=1,
        help="Number of Codex samples to draw for one question.",
    )
    parser.add_argument(
        "--overwrite_output_dir",
        action="store_true",
        default=False,
        help="By default will skip cached samples; Turn on this flag to overwrite the existing caches.",
    )
    parser.add_argument(
        "--temperature",
        type=float,
        default=0.0,
        help="Temperature of the Codex sampling distribtuion.",
    )
    parser.add_argument(
        "--top_p",
        type=float,
        default=0.95,
        help="Top-p cutoff of the Codex sampling distribtuion",
    )
    parser.add_argument(
        "--max_tokens",
        type=int,
        default=1024,
        help="Number of maximum tokens for Codex to generate",
    )
    parser.add_argument(
        "--batch_size",
        type=int,
        default=1,
        help="Number of requests to issue at one time",
    )
    args = parser.parse_args()
    args.output_dir = Path(args.output_dir)
    args.source_dir = Path(args.source_dir)

    classeval = load_dataset("FudanSELab/ClassEval")
    print("loaded dataset")
    inference(classeval, args)
