import argparse
import glob
import os

# from text2analysis import Text2AnalysisDataset
# disable tensorflow logging
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
import json
import time

import traceback
import re

import openai
import torch
from tqdm import trange
from pathlib import Path

from excel_api.run import LLMClient

import numpy as np
import pandas as pd

from run_retrieve import google_search
from run_retrieve1 import get_content,get_retrieve_prompt,get_stackoverflow_content
from run_inference import gpt4_call
from datasets import load_dataset
# from Common.taxonomy import AnaTaxonomy,QueryTaxonomy
# from Evaluation.checker import Checker, get_L2_anatype, FORECASTING_GT
# from run_test import killed_id

codex_name_mapping = {
    "codex-cushman": "code-cushman-001",
    "codex002": "code-davinci-002",
    "codex001": "code-davinci-001",
}

def eval_python(code, testpart, testcase_id = 0):
    local_var = {}
    exit_code = None
    testcase = None
    try:
        # print(code)
        function = code.split("```python")[1].split("```")[0] if "```" in code else code.strip(".")
        testcase = testpart.split("##")[testcase_id+1].split("##")[0] 
        function = function + "\n" + testcase
        # print(testcase)
        # import library. Note that the library must be imported before the function is executed.
        import_lib = [i for i in function.split("\n") if i.startswith("import ") or i.startswith("from ")]
        LIBVAR = locals()
        exec("\n".join(import_lib), globals(), LIBVAR)
        GLOBALVAR = globals()
        GLOBALVAR.update(LIBVAR)

        exec("\n".join([function]), GLOBALVAR, local_var)
        exit_code = 0
        return {"exit_code": exit_code,"input": testcase,"output":local_var["result"]}
    except KeyError as e:
        if str(e) == "'result'":
            exit_code = 0
            return {"exit_code": exit_code,"input": testcase,"output":'<The test case does not return an output>'}
    except:
        exc_info=traceback.format_exc()
        error_line = None
        if "File \"<string>\"," in exc_info:
            pattern = r'File "<string>", line (\d+)'  
            match = re.search(pattern, exc_info)  
            
            if match:  
                line_number = match.group(1)  
                error_line=function.split("\n")[int(line_number)-1]

        # print("[KeyError]", "\n".join(traceback.format_exc().split("\n")[3:]))
        # print(f"Error line: {error_line}")
        # print(f"Code: {code}")
        exit_code = 1
        return {"exit_code": exit_code,"input": testcase, 'output': None, "error":exc_info.strip().splitlines()[-1], "error_line": error_line}
    
def normalize(raw_text: str):
    # replace the html symbols '<' '>' '&'
    if "&lt;" in raw_text:
        raw_text = raw_text.replace("&lt;", "<")
    if "&gt;" in raw_text:
        raw_text = raw_text.replace("&gt;", ">")
    if "&amp;" in raw_text:
        raw_text = raw_text.replace("&amp;", "&")

    return raw_text


CURRENT_KEY_ID = 0


# global model, tokenizer for inference, see `incoder_inference`
model = None
tokenizer = None

llm_client = LLMClient()


def summarize_data(data, max_length=1000, input=False):  
    # 检测数据类型并生成相应的summary  
    if isinstance(data, np.ndarray):  
        data_type = "NumPy Array"  
        shape_str = " x ".join(map(str, data.shape))  
        try:
            stats_str = f"Min: {np.nanmin(data):.2f}, Max: {np.nanmax(data):.2f}, Mean: {np.nanmean(data):.2f}, Std: {np.nanstd(data):.2f}"  
        except:  
            stats_str = ""  
        type_str = data.dtype
  
    elif isinstance(data, pd.DataFrame):  
        data_type = "Pandas DataFrame"  
        shape_str = " x ".join(map(str, data.shape))  
        # 确保列名是字符串类型  
        column_names = ', '.join(map(str, data.columns))  
        stats_str = f"Columns: {column_names}"  
        type_str = data.dtypes
  
    elif isinstance(data, torch.Tensor):  
        data_type = "PyTorch Tensor"  
        shape_str = " x ".join(map(str, data.shape))  
        type_str = str(data.dtype)  
        # Convert the tensor to a floating point type if it's not already  
        try:
            if not data.is_floating_point():  
                data = data.float()  
            stats_str = f"Min: {data.min().item()}, Max: {data.max().item()}, Mean: {data.mean().item():.2f}, Std: {data.std().item():.2f}"    
        except:
            stats_str = ""
  
    elif isinstance(data, tuple) and input==True:
        summary=""
        for i in range(len(data)):
            summary_i = summarize_data(data[i], max_length,input)
            summary+=f"\r\n{summary_i}"
        return summary
    else:  
        return str(data)
  
    # 转换数据为字符串形式  
    data_str = str(data)  
  
    # 如果字符串太长，缩短它  
    if len(data_str) > max_length:  
        data_str = data_str[:max_length] + "..."  
    
    if isinstance(data, np.ndarray):
        data_str=f"array({data_str})"
      
    # 创建摘要字符串  
    summary = \
        f"{data_str}\n" + \
        f"This is a {data_type} of shape ({shape_str})\n"  + \
        f"Data Type: {type_str}\n" + \
        f"{stats_str}"  
    
    return summary

def inference(classeval, args):
    """
    A function for model inference

    Input:
    `ds1000` is a `DS1000Dataset` object
    """
    with open(args.data_source_dir, 'r') as input_file:
        data = json.load(input_file)

    
    for problem_id in trange(17):
        results = {}
        search_query=[]
        generated_code_path = (
                Path(args.output_dir) / args.model / ("q" + str(problem_id))
        )
        generated_test_cases_path = (
                Path(args.testcase_source_dir) / args.model / ("q" + str(problem_id)) / "0.py"
        )
        os.makedirs(Path(args.refine_output_dir) / args.model /("q" + str(problem_id)), exist_ok=True)

        code_list = []
        for generated_code_sample_path in glob.glob(
                str(generated_code_path / "*.py")
        ):
            code = open(generated_code_sample_path, "r", encoding="UTF-8").read()
            # if "tables" in code:
            #     code=code.replace("tables", "table")
            code_list.append(code)
        
        testcode = open(generated_test_cases_path, "r", encoding="UTF-8").read()
        for method_id in range(len(data[problem_id]['methods_info'])):
            # print(len(data[problem_id]['methods_info']))
            # print(method_id)
            method_name = data[problem_id]['methods_info'][method_id]['method_name']
            testpart = testcode.split("```python")[method_id+1].split("```")[0] if "```" in testcode else testcode.strip(".")
            function_results = []
            for testcase_id in range(3):
                result = eval_python(code_list[0], testpart, testcase_id)  
                # print("result is ", result)
                if result["exit_code"] == 1:
                    search_query.append(result["error"])
                function_results.append(result)
            results[method_name] = function_results


        # if problem_id not in killed_id.get(args.output_dir,{}).get(model,[]):
        #     checker = Checker(problem.table, code_list[0], problem.python[0], problem.python_res[0], problem.ana_taxonomy,
        #                             problem.ori_query, code_pre_result=problem.code_result[0] if problem.code_result else None)
        #     eva = checker.check_result()
        #     results=[[{"exit_code":0 if eva[1]==True else 1,
        #                 "input":problem.table,
        #                 "output":checker.code_pre_result,
        #                 "stderr":checker.traceback}]]
        #     error_line=checker.error_line
        # else:
        #     error_line=None
        #     results=[[{"exit_code":1,
        #                 "input":problem.table,
        #                 "output":None,
        #                 "stderr":"Killed (This is a SIGKILL signal, which is most likely caused by an Out-Of-Memory (OOM) condition.)"}]]
        # search_query=[]
        # res=results[0][0]
        # if res["exit_code"]==1:
        #     results[0][0]["error_line"]=error_line
        #     results[0][0]["error"]=res["stderr"].strip().splitlines()[-1]
        #     query=" ".join(["python",results[0][0]["error"]])
        #     results[0][0]["query"]=query
        #     search_query.append(query)
        # else:
        #     results[0][0]["output"]=summarize_data(results[0][0]["output"],max_length=10000)

        # results[0][0]["input"]=summarize_data(results[0][0]["input"],max_length=10000,input=True)
        # results[0][0]['stderr']=results[0][0]['stderr']


        with open(Path(args.refine_output_dir) / args.model /("q" + str(problem_id))/f"result.json",'w') as f:
            json.dump(results,f, default= str)
        if len(search_query)==0:
            continue
        error_search_res=google_search(Path(args.refine_output_dir) / args.model/("q" + str(problem_id)),list(set(search_query)))
        for q_id,query in enumerate(error_search_res):
            for w_id,website in enumerate(error_search_res[query][:3]):
                content_path= Path(args.refine_output_dir) / args.model /("q" + str(problem_id))/f"query{q_id}_website{w_id}.json"
                # print(content_path)
                content=None
                if website["url"].startswith("https://stackoverflow.com/"):
                    content=get_stackoverflow_content(content_path,website["url"])
                else:
                    content=get_content(content_path,website["url"])
                output=get_retrieve_prompt(query,website,content_path)
                if output!=None:
                    break


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--model",
        type=str,
        default="gpt4",
        choices=["codex-cushman", "codex001", "codex002", "incoder-1B", "gpt4","gpt-35-turbo-16k-0613","gpt-4-turbo","gpt-4-32k-0613"],
        help="Type of Codex Model to run",
    )
    parser.add_argument(
        "--output_dir",
        type=str,
        default="gpt4_ppo_outputs",
        help="Path to which the Codex responses will be cached at",
    )
    parser.add_argument(
        "--data_source_dir",
        type=str,
        default="data/ClassEval_data.json",
        help="Path to the data source of ClassEval",
    )  
    parser.add_argument(
        "--testcase_source_dir",
        type=str,
        default="gpt_testcases_0204",
        help="Path to the testcase source of ClassEval",
    ) 
    parser.add_argument(
        "--refine_output_dir",
        type=str,
        default="gpt4_refined_outputs_0110_ppo",
        help="Path to refined outputs",
    )
    parser.add_argument(
        "--source_dir",
        type=str,
        default="text2analysis_1220.csv",
        help="Path to the downloaded DS-1000 data",
    )
    parser.add_argument(
        "--num_samples",
        type=int,
        default=1,
        help="Number of Codex samples to draw for one question.",
    )
    parser.add_argument(
        "--overwrite_output_dir",
        action="store_true",
        default=False,
        help="By default will skip cached samples; Turn on this flag to overwrite the existing caches.",
    )
    parser.add_argument(
        "--temperature",
        type=float,
        default=0.0,
        help="Temperature of the Codex sampling distribtuion.",
    )
    parser.add_argument(
        "--top_p",
        type=float,
        default=0.95,
        help="Top-p cutoff of the Codex sampling distribtuion",
    )
    parser.add_argument(
        "--max_tokens",
        type=int,
        default=1024,
        help="Number of maximum tokens for Codex to generate",
    )
    parser.add_argument(
        "--batch_size",
        type=int,
        default=20,
        help="Number of requests to issue at one time",
    )
    parser.add_argument(
        "--refine_error",
        type=bool,
        default=True,
        help="Whether to refine the error code",
    )
    parser.add_argument(
        "--refine_exec",
        type=bool,
        default=True,
        help="Whether to refine the executed code",
    )
    args = parser.parse_args()
    args.output_dir = Path(args.output_dir)
    args.source_dir = Path(args.source_dir)
    if args.model == "incoder-1B":
        from transformers import AutoTokenizer, AutoModelForCausalLM

        rank = int(os.environ.get("LOCAL_RANK", 0))
        model_str = "facebook/incoder-1B"
        tokenizer = AutoTokenizer.from_pretrained(model_str)
        tokenizer.add_special_tokens({"pad_token": "<pad>"})
        tokenizer.padding_side = "right"
        model = AutoModelForCausalLM.from_pretrained(model_str)
        model.half()
        model.to(rank)
        model.eval()
    classeval = load_dataset("FudanSELab/ClassEval")
    print("loaded dataset")
    inference(classeval, args)
