import argparse
import glob
import os

# disable tensorflow logging
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
import json
import time

import openai
import torch
from tqdm import trange
from pathlib import Path

from excel_api.run import LLMClient

from run_retrieve2 import get_retrieve_prompt
from run_inference import gpt4_call,CODE_GENERATION_EXAMPLE
from datasets import load_dataset
# from text2analysis import Text2AnalysisDataset

codex_name_mapping = {
    "codex-cushman": "code-cushman-001",
    "codex002": "code-davinci-002",
    "codex001": "code-davinci-001",
}


def normalize(raw_text: str):
    # replace the html symbols '<' '>' '&'
    if "&lt;" in raw_text:
        raw_text = raw_text.replace("&lt;", "<")
    if "&gt;" in raw_text:
        raw_text = raw_text.replace("&gt;", ">")
    if "&amp;" in raw_text:
        raw_text = raw_text.replace("&amp;", "&")

    return raw_text


CURRENT_KEY_ID = 0


# global model, tokenizer for inference, see `incoder_inference`
model = None
tokenizer = None


llm_client = LLMClient()

def gpt4_inference(output_dir: Path, prefix: str,system, suffix: str,code_list=None,args=None):
    """
    Calls OpenAI API to get Codex responses and save prediction text and logprobs
    """
    # if outputs are already cached, we skip this sample
    for batch_start in range(0, args.num_samples, args.batch_size):
        all_cached = True
        for batch_i in range(min(args.num_samples - batch_start, args.batch_size)):
            if not all(
                    [
                        not args.overwrite_output_dir,
                        (output_dir / f"{batch_start + batch_i}.py").is_file(),
                        (output_dir / f"{batch_start + batch_i}.json").is_file(),
                    ]
            ):
                all_cached = False
        if all_cached:
            continue

        response = gpt4_call(
            prefix,
            system,
            CODE_GENERATION_EXAMPLE,
            suffix if args.mode == "Insertion" else None,
            batch_size=min(args.num_samples - batch_start, args.batch_size),
            args=args,
        )
        # print(response["choices"][batch_i]["text"])

        # if response is not None and "choices" in response:
        #     if "```python" in response["choices"][batch_i]["text"]:
        #         code = response["choices"][batch_i]["text"].split("```python")[1].split("```")[0]
        #     else:
        #         code = response["choices"][batch_i]["text"]
        #     if "### BEGIN SOLUTION" in code:
        #         code = code.split("### BEGIN SOLUTION")[1]
        #     if "### END SOLUTION" in code:
        #         code = code.split("### END SOLUTION")[0]
        #     if len(code.strip().split("\n")) ==1 and len(code.strip().split("\n")[0])>0 and code.strip().split("\n")[0][0]=="#": # use original code
        #         code = code_list[0]
        # else:
        #     code=code_list[0]
        # print(code)
            

        # storing responses
        for batch_i in range(min(args.num_samples - batch_start, args.batch_size)):
            with open(
                    output_dir / f"{batch_start + batch_i}.py", "w", encoding="UTF-8"
            ) as f:
                if not args.is_azure:
                    f.write(response["choices"][batch_i]["text"] if "choices" in response else "")
                else:
                    f.write(response["choices"][batch_i]["message"]["content"] if "choices" in response else "")

            result = dict()
            if not args.is_azure:
                result["trg_prediction"] = response["choices"][batch_i]["text"] if "choices" in response else ""
                result["logprobs"] = response["choices"][batch_i]["logprobs"][
                    "token_logprobs"
                ] if "choices" in response else []
                result["tokens"] = response["choices"][batch_i]["logprobs"]["tokens"] if "choices" in response else []
            else:
                result["trg_prediction"] = response["choices"][batch_i]["message"]["content"] if "choices" in response else ""
            result["prompt"] = prefix

            with open(output_dir / f"{batch_start + batch_i}.json", "w") as f:
                json.dump(result, f)


def model_inference(output_dir: Path, prefix: str,system, suffix: str = None,code_list=None,args=None):
    """
    provide an interface for different models to use this inference framework
    """
    global model, tokenizer
    os.makedirs(output_dir, exist_ok=True)
    if "gpt" in args.model:
        gpt4_inference(output_dir, prefix, system,suffix,code_list,args)
    else:
        # put your model's interface here
        pass


def inference(classeval , args):
    """
    A function for model inference

    Input:
    `ds1000` is a `DS1000Dataset` object
    """
    for problem_id in trange(100):
        problem_skeleton = 'PROBLEM:\n' + classeval['test']['skeleton'][problem_id]
        generated_code_path = (
                Path(args.output_dir) / args.model / ("q" + str(problem_id))
        )
        os.makedirs(Path(args.refine_output_dir) / args.model / ("q" + str(problem_id)), exist_ok=True)
        code_list = []
        for generated_code_sample_path in glob.glob(
                str(generated_code_path / "*.py")
        ):
            code = open(generated_code_sample_path, "r", encoding="UTF-8").read()
            code_list.append(code)

        # Refinement
        retrieve_path = (
                Path(args.output_retrieve_dir) / args.model / ("q" + str(problem_id))
        )
        if not os.path.exists(retrieve_path):
            continue
        with open(retrieve_path/"result.json",'r') as f:
            results_all =json.load(f)
        result_promt=[]
        for method in results_all:
            results = results_all[method]
            for res in results:
                if res["exit_code"]==0 and args.refine_exec: # No error
                    output=f"""
    Upon executing the above code, the following results were obtained:
    {res["output"]}
    Note the output <The test case does not return an output> means that the test case does not return an output. 
    """
                    rp= output
                    
                    result_promt.append(rp)
                elif res["exit_code"]==1 and args.refine_error: # Error
                    error=f"""
    After running the above code, it raises such error:
    {res["error"]}

    It seems that this line
    {res["error_line"]}
    has bugs.

    """   
                    rp=error
                    result_promt.append(rp)

        # Retrieve
        retrieve_prompt=[]
        if os.path.exists(retrieve_path/"google_search_results.json"):
            with open(retrieve_path/"google_search_results.json", "r") as f:
                search_results=json.load(f)
            urls=[]
            for q_id,query in enumerate(search_results):
                for w_id,website in enumerate(search_results[query]):
                    content_prompt=get_retrieve_prompt(query,search_results[query][w_id],retrieve_path/f'query{q_id}_website{w_id}.json')#,DS_1000_stackoverflow_id[lib][problem_id])
                    if content_prompt!=None:
                        if search_results[query][w_id]["url"] not in urls: # If the same url for different query, only use the first one
                            urls.append(search_results[query][w_id]["url"])
                            retrieve_prompt.append(content_prompt)
                        break
        system="""
You are given a Python class skeleton PROBLEM with several functions inside. As a Python Expert, help me rewrite the code. 
I will provide the PROBLEM description, the code for this PROBLEM, and the execution result of this code. 
Help me rewrite it into the correct code to solve this PROBLEM.

There are some rules that you must follow for rewriting the code:
+ Your output must follow the format of ```python\n <code> \n``` and do not leave any other comment outside the code block. 
+ Given the input and the ouput of the code, you need to determine if the code execution result the right answer to the PROBLEM? 
    + If not, rewrite the code to make it correct
    + if yes, simply use the original code as your code output.
+ If you need to rewrite the code:
    + You will given extra information including input,output, error information from previous code. Based on this youu must learn why the original code is incorrect.
    + After you understand why original code is incorrect, rewrite the code to make it correct.
    + Your re-written code should be in the same format as the original code and the given example.
"""
        prefix="\n".join(["PROBLEM:",
                            problem_skeleton,
                            "--------------------",
                            "Here is a code snippet that may contain errors in solving the above PROBLEM:",
                            code_list[0],
                            "--------------------",
                            "This is the code that GPT4 generated for me, here are the inputs as well as the execution results. You need to determine if the code is correct and suggest changes if it is not."]+
                            result_promt+
                            ["""--------------------
I've searched for the background information you might need. You can selectively refer to it when writing your code, noting that not all of the information you need to use in your code. The following information is the markdown text of the main information on the corresponding website.""" if len(retrieve_prompt)!=0 else "",
                            "\n\n".join(retrieve_prompt),
                            """                         
-------------------
Again, the PROBLEM is as follows:
""",
                            problem_skeleton,
                            "Answer the PROBLEM as the same format as your former answers."])
        # print(prefix)
        model_inference(
            Path(args.refine_output_dir)
            / args.model
            / ("q" + str(problem_id)),
            prefix,
            system,
            "",
            code_list,
            args,
        )


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--model",
        type=str,
        default="gpt4",
        choices=["codex-cushman", "codex001", "codex002", "incoder-1B", "gpt4","gpt-35-turbo-16k-0613","gpt-4-turbo","gpt-4-32k-0613"],
        help="Type of Codex Model to run",
    )
    parser.add_argument(
        "--is_azure",
        action="store_true",
        default=False,
        help="Is model in azure?",
    )
    parser.add_argument(
        "--mode", type=str, default="Completion", choices=["Insertion", "Completion"]
    )
    parser.add_argument(
        "--output_dir",
        type=str,
        default="gpt4_ppo_outputs",
        help="Path to which the Codex responses will be cached at",
    )
    parser.add_argument(
        "--output_retrieve_dir",
        type=str,
        default="gpt4_refined_outputs_0110_ppo",
        help="Path to refined outputs",
    )
    parser.add_argument(
        "--refine_output_dir",
        type=str,
        default="gpt4_refined_outputs_1130",
        help="Path to refined outputs",
    )
    parser.add_argument(
        "--source_dir",
        type=str,
        default="text2analysis_1220.csv",
        help="Path to the downloaded DS-1000 data",
    )
    parser.add_argument(
        "--num_samples",
        type=int,
        default=1,
        help="Number of Codex samples to draw for one question.",
    )
    parser.add_argument(
        "--overwrite_output_dir",
        action="store_true",
        default=False,
        help="By default will skip cached samples; Turn on this flag to overwrite the existing caches.",
    )
    parser.add_argument(
        "--temperature",
        type=float,
        default=0.0,
        help="Temperature of the Codex sampling distribtuion.",
    )
    parser.add_argument(
        "--top_p",
        type=float,
        default=0.95,
        help="Top-p cutoff of the Codex sampling distribtuion",
    )
    parser.add_argument(
        "--max_tokens",
        type=int,
        default=2048,
        help="Number of maximum tokens for Codex to generate",
    )
    parser.add_argument(
        "--batch_size",
        type=int,
        default=20,
        help="Number of requests to issue at one time",
    )
    parser.add_argument(
        "--libs",
        type=str,
        nargs="+",
        default="all",
        help="Specify the subset of DS-1000 to run on. e.g., specifying `--libs Numpy Pandas` will only inference on the Numpy and Pandas split of DS-1000",
    )
    parser.add_argument(
        "--refine_error",
        type=bool,
        default=True,
        help="Whether to refine the error code",
    )
    parser.add_argument(
        "--refine_exec",
        type=bool,
        default=True,
        help="Whether to refine the executed code",
    )
    args = parser.parse_args()
    args.output_dir = Path(args.output_dir)
    args.source_dir = Path(args.source_dir)
    if args.model == "incoder-1B":
        from transformers import AutoTokenizer, AutoModelForCausalLM

        rank = int(os.environ.get("LOCAL_RANK", 0))
        model_str = "facebook/incoder-1B"
        tokenizer = AutoTokenizer.from_pretrained(model_str)
        tokenizer.add_special_tokens({"pad_token": "<pad>"})
        tokenizer.padding_side = "right"
        model = AutoModelForCausalLM.from_pretrained(model_str)
        model.half()
        model.to(rank)
        model.eval()
    classeval = load_dataset("FudanSELab/ClassEval")
    print("loaded dataset")
    inference(classeval, args)
