import torch
from peft import get_peft_model, TaskType, LoraConfig, PeftConfig, PeftModel
import tqdm
from transformers import AutoTokenizer, AutoModelForCausalLM, CodeGenForCausalLM, XGLMForCausalLM
from concurrent.futures import ThreadPoolExecutor, as_completed
import threading
import re

import platform
import pickle
from utils import *
from typing import *
import json
import os

#############
# Make sure to adjust whether to add the step to the prompt !!!
#############

# nohup python -u self_guided_step_input.py >> logs/350M_no_lora_self-guided-step-input.log &
os.environ["CUDA_VISIBLE_DEVICES"] = '0'
os.environ["TOKENIZERS_PARALLELISM"] = "true"
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"


pids: List[str] = []
all_codes: List[List[str]] = []  # [ [{"input": ,"output": },{"input": ,"output": }],[{},{},..], ...]
all_difficulity: List[str] = []
all_test_cases: List[List[Dict[str, str]]] = []
all_nl: List[str] = []
all_input_format: List[str] = []
all_output_format: List[str] = []
all_step: List[str] = []

overlen_problems: List[str] = []  # Problems with overly long prompts


model_type = "2B_self-guided"
model_id = "Salesforce/codegen-2B-multi"
# model_id = f"model/{model_type}"
adapter_id = f"model/{model_type}"
resources_dir = 'resources'
res_dir = f'result/{model_type}'
res_post_fix = "600"
problem_res_path = os.path.join(
    res_dir, f"problem_result{res_post_fix}_self-guided-step-input.json")
hard_pass_res_path = os.path.join(
    res_dir, f"hard_pass_result{res_post_fix}_self-guided-step-input.json")
soft_pass_res_path = os.path.join(
    res_dir, f"soft_pass_result{res_post_fix}_self-guided-step-input.json")

print("Loading model ...")
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = CodeGenForCausalLM.from_pretrained(
    model_id, torch_dtype=torch.float16).cuda()


# model = CodeGenForCausalLM.from_pretrained(config.base_model_name_or_path)
# model = XGLMForCausalLM.from_pretrained(config.base_model_name_or_path)

print("model struct before Lora:")
print(model)

# Uncomment the following if you want to use LoRA
model.load_adapter(adapter_id)
adapter_weight = torch.load(os.path.join(adapter_id, "adapter_model.bin"))
model.eval()
print(f"adapter weight: \n{adapter_weight}")
del adapter_weight
print(f"model struct after Lora:\n{model}")
for name, param in model.named_parameters():
    print(
        f'Parameter: {name}, Requires grad: {param.requires_grad}\n Param: {param}')
print("Finish load model!")


def init(input_file: str):
    with open(f"{resources_dir}/{input_file}", 'r', encoding='utf-8') as f:
        data = json.load(f)

    already_gen_pids = {}
    if os.path.exists(hard_pass_res_path):
        with open(hard_pass_res_path, 'r', encoding='utf-8') as f:
            already_gen_data = json.load(f)
        already_gen_pids = {d["pid"] for d in already_gen_data}

    for i, problem in enumerate(data):
        pid = problem['pid']
        if pid in already_gen_pids:
            continue

        print(f"data{i}--{pid}")
        pids.append(pid)

        test_cases = problem['test_case']
        all_test_cases.append(test_cases)

        all_difficulity.append(problem["difficulty"])
        all_nl.append(problem["nl"])
        all_input_format.append(problem["input_format"])
        all_output_format.append(problem["output_format"])
        all_step.append("\n".join(problem["step"]))

        # codes = []
        # for code in problem['code']:
        #     if code['lang'] == 'cpp':
        #         codes.append(code['code'])
        # all_codes.append(codes)
    already_res, already_hard_pass_at_k, already_soft_pass_at_k = [], [], []
    if os.path.exists(hard_pass_res_path):
        with open(problem_res_path, 'r', encoding='utf-8') as f:
            already_res = json.load(f)
        with open(hard_pass_res_path, 'r', encoding='utf-8') as f:
            already_hard_pass_at_k = json.load(f)
        with open(soft_pass_res_path, 'r', encoding='utf-8') as f:
            already_soft_pass_at_k = json.load(f)
    return already_res, already_hard_pass_at_k, already_soft_pass_at_k


def reset():
    global pids, all_test_cases, all_nl, all_input_format, all_output_format
    pids.clear()
    # all_codes.clear()
    all_test_cases.clear()
    all_nl.clear()
    all_input_format.clear()
    all_output_format.clear()
    all_step.clear()


def compile(code: str):
    cur_name = threading.current_thread().name
    with open(rf'temp/{cur_name}.cpp', 'w', encoding='utf-8') as f:
        f.write(code)
    if platform.uname().system == "Windows":
        res = os.system(
            r"g++ -w temp/one_code.cpp -o temp/one_code.exe 2>temp/compile_err.txt")
    else:  # Linux
        res = os.system(
            rf"g++ -w ./temp/{cur_name}.cpp -o ./temp/{cur_name} 2>temp/{cur_name}_compile_err.txt")
    return res


def run_code():
    """执行one_code.exe。以input.txt为输入, output.txt为输出"""
    cur_name = threading.current_thread().name
    if platform.uname().system == "Windows":
        cmd = r"temp\one_code.exe < temp\input.txt > temp\output.txt"
    else:
        cmd = rf"./temp/{cur_name} < ./temp/{cur_name}_in.txt > ./temp/{cur_name}_out.txt"
    res = run_cmd(cmd)
    return res


def check_test_cases(test_cases: List[Dict[str, str]]) -> Tuple[List, List]:
    cur_name = threading.current_thread().name
    cases_msg, cases_res = [], []
    in_file = rf"./temp/{cur_name}_in.txt"
    if not os.path.exists(in_file):
        with open(in_file, 'w', encoding='utf-8') as f:
            pass

    with open(in_file, 'r+', encoding='utf-8') as f:
        for case in test_cases:
            # clear file
            f.seek(0)
            f.truncate()
            # print(f"case{j}-", end="")
            f.write(case['input'])
            f.flush()
            run_res = run_code()
            if run_res == RUN_ERROR or run_res == UNKNOWN_ERR:  # 运行出错
                cases_msg.append("run error or unknown error")
                cases_res.append(False)
                # print(msg, end='')
            elif run_res == TIME_OUT:
                cases_msg.append("timeout")
                cases_res.append(False)
                # print(msg, end='')
            else:
                with open(rf"./temp/{cur_name}_out.txt", 'r', encoding='utf-8', errors='ignore') as outf:
                    out = outf.read()
                    # print([case['output'],out])
                    if out.rstrip() == case['output'].rstrip():
                        cases_msg.append("pass")
                        cases_res.append(True)
                        # print(f"{msg} ", end="")
                    else:
                        cases_msg.append("fail")
                        cases_res.append(False)
                        # print(f"{msg} ", end="")
    return cases_msg, cases_res


def process_result_text(prompt: str, result_text: str):
    begin_idx = result_text.find("#include")
    without_prompt = result_text[begin_idx:]
    return without_prompt


def generate_codes(index: int, prompt: str, n: int, max_to_generate: int = 512):
    """Generate n code using the model."""

    codes = []
    already_overlen = False
    for i in range(n):
        print(f"problem{index}-{pids[index]}-generate code {i}\n", end='')
        res_final, overlen = generate_one_code(
            model, tokenizer, prompt=prompt, max_to_generate=max_to_generate, top_p=0.95, temperature=0.2)
        if not already_overlen and overlen:
            overlen_problems.append(pids[index])
            already_overlen = True

        codes.append(res_final)
    return codes


def check_one_problem(index: int, n, k: List[int]) -> Tuple[List, Dict, Dict]:
    """Check the code for one problem."""
    pid = pids[index]
    difficulity = all_difficulity[index]
    # codes = all_codes[index]
    test_cases: List = all_test_cases[index]
    nl = all_nl[index]
    input_format = all_input_format[index]
    output_format = all_output_format[index]
    steps = all_step[index]

    prompt = f'Problem description:\n{nl}\nInput format:\n{input_format}\n' \
        f'Output format:\n{output_format}\nExamples:\n' \
        f'Input>>\n{test_cases[0]["input"]}\nOutput>>\n{test_cases[0]["output"]}\nAnswer:\n'
    prompt += steps

    codes = generate_codes(index=index, prompt=prompt, n=n)
    one_pro_res: List[Dict] = []
    num_hard_correct, num_soft_correct = 0, 0
    for i, code in enumerate(codes):
        print(f"problem{index}-{pid}-code{i}\n", end='')
        compile_res = compile(code)
        compile_err_msg = ""
        if compile_res != 0:
            cases_msg = ["compile error"] * len(test_cases)
            cases_res = [False] * len(test_cases)

            with open(rf"temp/{threading.current_thread().name}_compile_err.txt", 'r', encoding='utf-8') as f:
                compile_err_msg = f.read()
        else:
            cases_msg, cases_res = check_test_cases(test_cases)
        case_pass_rate = cases_res.count(True) / len(cases_res)
        one_pro_res.append(
            {"pid": pid, "difficulity": difficulity, "code": code, "result": cases_msg,
             "passed": cases_res, "pass_rate": case_pass_rate, "compile_err_msg": compile_err_msg})
        num_hard_correct += (1 if all(cases_res) else 0)
        num_soft_correct += (1 if any(cases_res) else 0)

    hard_pass_values = pass_at_k(n, num_hard_correct, k)
    soft_pass_values = pass_at_k(n, num_soft_correct, k)
    hard_pass_at_k: Dict = {f"hard_pass@{k}": v for k,
                            v in zip(k, hard_pass_values)}
    soft_pass_at_k: Dict = {f"soft_pass@{k}": v for k,
                            v in zip(k, soft_pass_values)}
    return one_pro_res, hard_pass_at_k, soft_pass_at_k, index


def get_input_file():
    input_file_list = [path for path in os.listdir(
        resources_dir) if "my_test" not in path]
    return input_file_list[0]


def get_difficulity_ids(difficulity: str):
    res = []
    for i, difficul in enumerate(all_difficulity):
        if difficul.strip() == difficulity.strip():
            res.append(i)
    return res


def evaluate_functional_correctness(num_samples: int = 20, k: List[int] = [1, 5, 10], n_workers: int = 4):
    input_file = get_input_file()
    print(f"input_file: {input_file}")
    already_res, already_hard_pass_at_k, already_soft_pass_at_k = init(
        input_file)
    # *******************
    # global pids
    # pids = pids[3:7]
    # *******************
    print(f"Has {len(pids)} samples to be generated.")
    with ThreadPoolExecutor(max_workers=n_workers) as executor:
        futures = []
        completed_index = []
        for index in range(len(pids)):
            args = (index, num_samples, k)
            future = executor.submit(check_one_problem, *args)
            futures.append(future)
            completed_index.append(index)

        # assert len(completed_index) == len(
        #     pids), "Some problems are not attempted."

        print("Running test suites...")
        for future in tqdm.tqdm(as_completed(futures), total=len(futures)):
            one_pro_res, hard_pass_at_k, soft_pass_at_k, p_index = future.result()
            already_res.extend(one_pro_res)
            already_hard_pass_at_k.append(
                {"pid": pids[p_index], **hard_pass_at_k})
            already_soft_pass_at_k.append(
                {"pid": pids[p_index], **soft_pass_at_k})
            with open(problem_res_path, 'w', encoding='utf-8') as f:
                json.dump(already_res, f, ensure_ascii=False, indent=2)
            with open(hard_pass_res_path, 'w', encoding='utf-8') as f:
                json.dump(already_hard_pass_at_k, f,
                          ensure_ascii=False, indent=2)
            with open(soft_pass_res_path, 'w', encoding='utf-8') as f:
                json.dump(already_soft_pass_at_k, f,
                          ensure_ascii=False, indent=2)
    return already_res, already_hard_pass_at_k, already_soft_pass_at_k


if __name__ == '__main__':
    all_res, all_hard_pass_at_k, all_soft_pass_at_k = evaluate_functional_correctness(
        num_samples=20, k=[1, 5, 10], n_workers=4)
    avg_hard_pass_at_k = avg_pass_at_k(all_hard_pass_at_k)
    avg_soft_pass_at_k = avg_pass_at_k(all_soft_pass_at_k)
    avg_pass_rate = avg_test_case_pass_rate(all_res)
    with open(problem_res_path, 'w', encoding='utf-8') as f:
        json.dump(all_res, f, ensure_ascii=False, indent=2)
    with open(hard_pass_res_path, 'w', encoding='utf-8') as f:
        json.dump(all_hard_pass_at_k, f, ensure_ascii=False, indent=2)
    with open(soft_pass_res_path, 'w', encoding='utf-8') as f:
        json.dump(all_soft_pass_at_k, f, ensure_ascii=False, indent=2)
    print(avg_hard_pass_at_k)
    print(avg_soft_pass_at_k)
    print(f"avg test case pass rate: {avg_pass_rate}")
    print(
        f"max_length >= 2048 number: {len(overlen_problems)}. The pids are: \n{overlen_problems}")
