import json
import os

# need to disable GPU usage explicitly for execution
# otherwise, multiprocessing tensorflow programs may hang indefinitely
os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
# disable tensorflow logging
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
from collections import defaultdict
from typing import List, Tuple, Union

from tqdm import tqdm, trange
from pathlib import Path

from argparse import ArgumentParser
import glob
from text2analysis import Text2AnalysisDataset
from Evaluation.checker import Checker, get_L2_anatype, FORECASTING_GT
import csv
from Common.taxonomy import AnaTaxonomy



def test_model(
        text2analysis: Text2AnalysisDataset,
        model: str,
        output_dir: Union[str, Path] = "codex_greedy_outputs",
        args=None,
):
    """
    Use multiprocessing to test a certain model's generated codes, and dump the score records into json files

    `model` is a str, helping to specify the path where generated codes should be stored
    `mode` is a str, can be "Insertion" or "Completion"
    `num_procs` is the number of processes to set
    `output_dir` is the path where the generated codes are stored
    `source_dir` is the path where the dataset is stored

    a single generated code's path can be like `output_dir` / `model` / Pandas / `mode` / q0 / 0.py
    """
    score = defaultdict(list)
    problem_code_pairs = []
    evaluation_dict = defaultdict(
            lambda: defaultdict(lambda: defaultdict(list)))  # task -> unclear_taxo -> given_parameter -> [eval]

    
    for problem_id in trange(len(text2analysis)):
        problem=text2analysis[problem_id]
        if problem.task == AnaTaxonomy.Forecasting and not any(
                FORECASTING_GT["query"].isin([problem.ori_query.replace(" ", "").lower()])):
            print(f"[Not find ori forecasting query]: {problem.ori_query}")
            continue
        generated_code_path = (
                Path(output_dir) / model / ("q" + str(problem_id))
        )
        code_list = []
        for generated_code_sample_path in glob.glob(
                str(generated_code_path / "*.py")
        ):
            code = open(generated_code_sample_path, "r", encoding="UTF-8").read()
            if "refined" in output_dir and \
            ((len(code.strip().split("\n")) ==1 and len(code.strip().split("\n")[0])>0 and code.strip().split("\n")[0][0]=="#") \
                or "<original code>" in code \
                or code=="" \
                or code.strip().startswith("<|im_end|>") \
                or code.strip().startswith("<|im_sep|>") ):
                # or lib=="Matplotlib"): # use original code TODO: Matplotlib
                # code=open(generated_code_sample_path.replace(output_dir,"gpt4_ppo_outputs"), "r", encoding="UTF-8").read()
                # code=open(generated_code_sample_path.replace(output_dir,"gpt_outputs_0117"), "r", encoding="UTF-8").read()
                code=open(generated_code_sample_path.replace(output_dir,args.output_dir4refined), "r", encoding="UTF-8").read()
            if "tables" in code:
                code=code.replace("tables", "table")
            code_list.append(code)
        problem_code_pairs.append((text2analysis[problem_id], code_list))

    result_cache_path = Path(output_dir) / f"{model}_result_cache"
    os.makedirs(result_cache_path, exist_ok=True)
    for problem_code_pair in tqdm(problem_code_pairs):
        problem, code_list = problem_code_pair
        eva_list=[]
        for idx, py in enumerate(problem.python):
            print(f"------------------------{problem.problem_id}------------------------")
            checker = Checker(problem.table, code_list[0], py, problem.python_res[idx], problem.ana_taxonomy,
                                problem.ori_query, code_pre_result=problem.code_result[idx] if problem.code_result else None)
            eva = checker.check_result()
            eva_list.append(eva)
        eva_sum = [e[1] + e[2] + (e[3] if type(e[3]) != dict else e[3]["R2"]) for e in eva_list]
        eva_sum_max = max(eva_sum)
        idx = eva_sum.index(eva_sum_max)
        for unclear in problem.query_taxonomy:
            for para in problem.given_parameter:
                evaluation_dict[problem.task][unclear][para].append(eva_list[idx][1:])
        with open(result_cache_path / f"{problem.problem_id}.json", "w") as f:
            json.dump(
                {
                    "id":problem.problem_id,
                    "executable":eva_list[idx][1],
                    "exact_match":eva_list[idx][2],
                    "forecast":eva_list[idx][2],
                    "task":problem.task.value,
                    "unclear_taxonomy":[i.value for i in problem.query_taxonomy],
                    "given_query_taxonomy":problem.given_parameter}, f)
            

    eval_path = "evaluation_result.csv"
    csvfile = open(eval_path, mode='a', newline='', encoding="utf-8")
    writer = csv.writer(csvfile)
    writer.writerow(["task", "unclear_taxonomy", "given_query_taxonomy", "total","executable", "exact_match", 'CORR', 'R2', 'MSE', 'RMSE', 'MAE', 'MedAE', "forecast_metric_total"])
    print(evaluation_dict)
    for task in evaluation_dict:
        for unclear in evaluation_dict[task]:
            for para in evaluation_dict[task][unclear]:
                if task != AnaTaxonomy.Forecasting:
                    total = len(evaluation_dict[task][unclear][para])
                    executable = sum([e[0] for e in evaluation_dict[task][unclear][para]])
                    exact_match = sum([e[1] for e in evaluation_dict[task][unclear][para]])
                    print(f"task: {task} ,unclear_taxonomy: {unclear}, given_para: {para}")
                    print(f"total: {total}, executable: {executable / total}, exact_match: {exact_match / total}")
                    writer.writerow(
                        [task.value, unclear, para, total, executable, exact_match, "", "", "", "", "", "",""])
                else:
                    total = len(evaluation_dict[task][unclear][para])
                    executable = sum([e[0] for e in evaluation_dict[task][unclear][para]])
                    exact_match = sum([e[1] for e in evaluation_dict[task][unclear][para]])
                    CORR = sum([e[2]["CORR"] for e in evaluation_dict[task][unclear][para] if type(e[2]) == dict])
                    R2 = sum([e[2]["R2"] for e in evaluation_dict[task][unclear][para] if type(e[2]) == dict])
                    MSE = sum([e[2]["MSE"] for e in evaluation_dict[task][unclear][para] if type(e[2]) == dict])
                    RMSE = sum([e[2]["RMSE"] for e in evaluation_dict[task][unclear][para] if type(e[2]) == dict])
                    MAE = sum([e[2]["MAE"] for e in evaluation_dict[task][unclear][para] if type(e[2]) == dict])
                    MedAE = sum([e[2]["MedAE"] for e in evaluation_dict[task][unclear][para] if type(e[2]) == dict])
                    print(f"task: {task} ,unclear_taxonomy: {unclear}, given_para: {para}")
                    print(f"total: {total}, executable: {executable / total}, exact_match: {exact_match / total}")
                    print(
                        f"CORR: {CORR / total}, R2: {R2 / total}, MSE: {MSE / total}, RMSE: {RMSE / total}, MAE: {MAE / total}, MedAE: {MedAE / total}")
                    writer.writerow(
                        [task.value, unclear, para, total, executable, exact_match, CORR, R2, MSE, RMSE, MAE,
                         MedAE,len([e[2]["CORR"] for e in evaluation_dict[task][unclear][para] if type(e[2]) == dict])])
    csvfile.close()
    return score


if __name__ == "__main__":
    parser = ArgumentParser()
    parser.add_argument(
        "--model",
        type=str,
        default="gpt-4-32k-0613",
        choices=["codex-cushman", "codex001", "codex002", "incoder-1B", "gpt4","gpt-35-turbo-16k-0613","gpt-4-turbo","gpt-4-32k-0613"],
        help="Type of Codex Model to run",
    )
    parser.add_argument(
        "--output_dir",
        type=str,
        default="gpt_retrieve_outputs_0124_2",
        help="Path to which the Codex responses will be cached at",
    )
    parser.add_argument(
        "--output_dir4refined",
        type=str,
        default="gpt_outputs_0117",
        help="Path to which the original Codex responses will be cached at for refinement",
    )
    parser.add_argument(
        "--source_dir",
        type=str,
        default="text2analysis_1220.csv",
        help="Path to the downloaded DS-1000 data",
    )

    args = parser.parse_args()

    text2analysis = Text2AnalysisDataset(source_dir=args.source_dir, is_eval=True)
    print("loaded dataset")
    test_model(
        text2analysis=text2analysis,
        model=args.model,
        output_dir=args.output_dir,
        args=args
    )
