import argparse
import glob
import os
# disable tensorflow logging
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
import json
import time

import openai
import torch
from tqdm import trange
from pathlib import Path
from ds1000 import DS1000Dataset, check_version

from excel_api.run import LLMClient
from run_test import test_helper

import numpy as np
import pandas as pd
import tensorflow as tf  

from run_retrieve import google_search
from run_retrieve1 import get_content,get_retrieve_prompt,get_stackoverflow_content
from run_inference import gpt4_call
 
codex_name_mapping = {
    "codex-cushman": "code-cushman-001",
    "codex002": "code-davinci-002",
    "codex001": "code-davinci-001",
}


def normalize(raw_text: str):
    # replace the html symbols '<' '>' '&'
    if "&lt;" in raw_text:
        raw_text = raw_text.replace("&lt;", "<")
    if "&gt;" in raw_text:
        raw_text = raw_text.replace("&gt;", ">")
    if "&amp;" in raw_text:
        raw_text = raw_text.replace("&amp;", "&")

    return raw_text


CURRENT_KEY_ID = 0


# global model, tokenizer for inference, see `incoder_inference`
model = None
tokenizer = None

llm_client = LLMClient()


def summarize_data(data, max_length=1000, input=False):  
    # 检测数据类型并生成相应的summary  
    if isinstance(data, np.ndarray):  
        data_type = "NumPy Array"  
        shape_str = " x ".join(map(str, data.shape))  
        try:
            stats_str = f"Min: {np.nanmin(data):.2f}, Max: {np.nanmax(data):.2f}, Mean: {np.nanmean(data):.2f}, Std: {np.nanstd(data):.2f}"  
        except:  
            stats_str = ""  
        type_str = data.dtype
  
    elif isinstance(data, pd.DataFrame):  
        data_type = "Pandas DataFrame"  
        shape_str = " x ".join(map(str, data.shape))  
        # 确保列名是字符串类型  
        column_names = ', '.join(map(str, data.columns))  
        stats_str = f"Columns: {column_names}"  
        type_str = data.dtypes
  
    elif isinstance(data, torch.Tensor):  
        data_type = "PyTorch Tensor"  
        shape_str = " x ".join(map(str, data.shape))  
        type_str = str(data.dtype)  
        # Convert the tensor to a floating point type if it's not already  
        try:
            if not data.is_floating_point():  
                data = data.float()  
            stats_str = f"Min: {data.min().item()}, Max: {data.max().item()}, Mean: {data.mean().item():.2f}, Std: {data.std().item():.2f}"    
        except:
            stats_str = ""
    
    elif isinstance(data, tf.Tensor):  
        data_type = "TensorFlow Tensor"  
        shape_str = " x ".join(map(str, data.shape))  
        type_str = str(data.dtype)  
  
        # Convert the tensor to numpy array to calculate statistics if it's numeric  
        if tf.as_dtype(data.dtype).is_numpy_compatible and np.issubdtype(data.dtype.as_numpy_dtype, np.number):  
            numpy_data = data.numpy()  
            stats_str = f"Min: {np.nanmin(numpy_data):.2f}, Max: {np.nanmax(numpy_data):.2f}, Mean: {np.nanmean(numpy_data):.2f}, Std: {np.nanstd(numpy_data):.2f}"  
        else:  
            stats_str = ""  
  
    elif isinstance(data, tuple) and input==True:
        summary=""
        for i in range(len(data)):
            summary_i = summarize_data(data[i], max_length,input)
            summary+=f"\r\n{summary_i}"
        return summary
    else:  
        return str(data)
  
    # 转换数据为字符串形式  
    data_str = str(data)  
  
    # 如果字符串太长，缩短它  
    if len(data_str) > max_length:  
        data_str = data_str[:max_length] + "..."  
    
    if isinstance(data, np.ndarray):
        data_str=f"array({data_str})"
      
    # 创建摘要字符串  
    summary = \
        f"{data_str}\n" + \
        f"This is a {data_type} of shape ({shape_str})\n"  + \
        f"Data Type: {type_str}\n" + \
        f"{stats_str}"  
    
    return summary

def inference(ds1000: DS1000Dataset, args):
    """
    A function for model inference

    Input:
    `ds1000` is a `DS1000Dataset` object
    """
    for lib in ds1000.libs:
        for problem_id in trange(len(ds1000[lib])):
            # Refinement
            generated_code_path = (
                    Path(args.output_dir) / args.model / lib / args.mode / ("q" + str(problem_id))
            )
            os.makedirs(Path(args.refine_output_dir) / args.model /lib/args.mode /("q" + str(problem_id)), exist_ok=True)
            code_list = []
            for generated_code_sample_path in glob.glob(
                    str(generated_code_path / "*.py")
            ):
                code = open(generated_code_sample_path, "r", encoding="UTF-8").read()
                code_list.append(code)


                
            results=test_helper((ds1000[lib][problem_id], code_list),return_result=True)
            search_query=[]
            for idx,res in enumerate(results[0]):
                if lib=="Matplotlib":
                    results[0][idx]["output"]=f"""The output is a figure, so we show the SVG (Scalable Vector Graphics) of the figure here:
{results[0][idx]["output"]}
"""
                if res["exit_code"]==1:
                    error_line=""
                    lines=res["stderr"].strip().splitlines()
                    for idx_, line in enumerate(lines):
                        if b"File \"program.py\"," in line and b"line" in line:
                            error_line=lines[idx_+1].decode("utf-8")
                    results[0][idx]["error_line"]=error_line
                    results[0][idx]["error"]=res["stderr"].strip().splitlines()[-1].decode("utf-8")
                    query=" ".join(["python",lib,results[0][idx]["error"]])
                    results[0][idx]["query"]=query
                    search_query.append(query)
                else:
                    results[0][idx]["output"]=summarize_data(results[0][idx]["output"],max_length=10000) if lib!="Matplotlib" else results[0][idx]["output"]

                results[0][idx]["input"]=summarize_data(results[0][idx]["input"],max_length=10000,input=True)
                results[0][idx]['stderr']=results[0][idx]['stderr'].decode("utf-8")


            with open(Path(args.refine_output_dir) / args.model /lib/args.mode /("q" + str(problem_id))/f"result.json",'w') as f:
                json.dump(results[0],f)
            if len(search_query)==0:
                continue
            error_search_res=google_search(Path(args.refine_output_dir) / args.model /lib/args.mode /("q" + str(problem_id)),list(set(search_query)))
            for q_id,query in enumerate(error_search_res):
                for w_id,website in enumerate(error_search_res[query][:3]):
                    content_path= Path(args.refine_output_dir) / args.model /lib/args.mode /("q" + str(problem_id))/f"query{q_id}_website{w_id}.json"
                    # print(content_path)
                    content=None
                    if website["url"].startswith("https://stackoverflow.com/"):
                        content=get_stackoverflow_content(content_path,website["url"])
                    else:
                        content=get_content(content_path,website["url"])
                    output=get_retrieve_prompt(query,website,content_path)
                    if output!=None:
                        break


if __name__ == "__main__":
    check_version()
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--model",
        type=str,
        default="gpt4",
        choices=["codex-cushman", "codex001", "codex002", "incoder-1B", "gpt4","gpt-35-turbo-16k-0613","gpt-4-turbo","gpt-4-32k-0613"],
        help="Type of Codex Model to run",
    )
    parser.add_argument(
        "--mode", type=str, default="Completion", choices=["Insertion", "Completion"]
    )
    parser.add_argument(
        "--output_dir",
        type=str,
        default="gpt4_ppo_outputs",
        help="Path to which the Codex responses will be cached at",
    )
    parser.add_argument(
        "--refine_output_dir",
        type=str,
        default="gpt4_refined_outputs_0110_ppo",
        help="Path to refined outputs",
    )
    parser.add_argument(
        "--source_dir",
        type=str,
        default="./ds1000_data",
        help="Path to the downloaded DS-1000 data",
    )
    parser.add_argument(
        "--num_samples",
        type=int,
        default=1,
        help="Number of Codex samples to draw for one question.",
    )
    parser.add_argument(
        "--overwrite_output_dir",
        action="store_true",
        default=False,
        help="By default will skip cached samples; Turn on this flag to overwrite the existing caches.",
    )
    parser.add_argument(
        "--temperature",
        type=float,
        default=0.0,
        help="Temperature of the Codex sampling distribtuion.",
    )
    parser.add_argument(
        "--top_p",
        type=float,
        default=0.95,
        help="Top-p cutoff of the Codex sampling distribtuion",
    )
    parser.add_argument(
        "--max_tokens",
        type=int,
        default=1024,
        help="Number of maximum tokens for Codex to generate",
    )
    parser.add_argument(
        "--batch_size",
        type=int,
        default=20,
        help="Number of requests to issue at one time",
    )
    parser.add_argument(
        "--libs",
        type=str,
        nargs="+",
        default="all",
        help="Specify the subset of DS-1000 to run on. e.g., specifying `--libs Numpy Pandas` will only inference on the Numpy and Pandas split of DS-1000",
    )
    parser.add_argument(
        "--refine_error",
        type=bool,
        default=True,
        help="Whether to refine the error code",
    )
    parser.add_argument(
        "--refine_exec",
        type=bool,
        default=True,
        help="Whether to refine the executed code",
    )
    args = parser.parse_args()
    args.output_dir = Path(args.output_dir)
    args.source_dir = Path(args.source_dir)
    if args.model == "incoder-1B":
        from transformers import AutoTokenizer, AutoModelForCausalLM

        rank = int(os.environ.get("LOCAL_RANK", 0))
        model_str = "facebook/incoder-1B"
        tokenizer = AutoTokenizer.from_pretrained(model_str)
        tokenizer.add_special_tokens({"pad_token": "<pad>"})
        tokenizer.padding_side = "right"
        model = AutoModelForCausalLM.from_pretrained(model_str)
        model.half()
        model.to(rank)
        model.eval()
    ds1000 = DS1000Dataset(source_dir=args.source_dir, libs=args.libs, mode=args.mode)
    print("loaded dataset")
    inference(ds1000, args)
