import torch
from peft import get_peft_model, TaskType, LoraConfig, PeftConfig, PeftModel
import tqdm
from transformers import AutoTokenizer, AutoModelForCausalLM, CodeGenForCausalLM, XGLMForCausalLM
from concurrent.futures import ThreadPoolExecutor, as_completed
import threading
import re

import platform
import pickle
from utils import *
from typing import *
import json
import os

# nohup python -u cal_codebleu.py > logs/cal_codebleu.log &
os.environ["CUDA_VISIBLE_DEVICES"] = '1'
os.environ["TOKENIZERS_PARALLELISM"] = "true"
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"


pids: List[str] = []
all_codes: List[List[str]] = []  # [ [{"input": ,"output": },{"input": ,"output": }],[{},{},..], ...]
all_difficulity: List[str] = []
all_test_cases: List[List[Dict[str, str]]] = []
all_nl: List[str] = []
all_input_format: List[str] = []
all_output_format: List[str] = []
all_step: List[str] = []

k_list = [1, 5, 10, 15, 20]
model_type = "starcoderbase1B_luogu_added_no-steps"
resources_dir = 'resources'
res_dir = f'result/{model_type}'
res_post_fix = "1000"
input_file_path = f"{resources_dir}/test_luogu_added.json"
problem_res_path = os.path.join(res_dir, f"problem_result{res_post_fix}.json")
codebleu_res_path = os.path.join(
    res_dir, f"codebleu_result{res_post_fix}.json")

experiment_result = {"allResult": []}
pid_code_map = {}  # {pid1:code[0], ...}


def init():

    global experiment_result, pid_code_map
    with open(problem_res_path, 'r', encoding='utf-8') as f:
        problem_res_data: List[Dict] = json.load(f)
    pro_res_pid_codes_map = {}  # {pid1:[code1,code2,...],...}
    for pro_res in problem_res_data:
        pid = pro_res["pid"]
        if pid not in pro_res_pid_codes_map:
            pro_res_pid_codes_map[pid] = [pro_res["code"]]
        else:
            pro_res_pid_codes_map[pid].append(pro_res["code"])
    for pid, codes in pro_res_pid_codes_map.items():
        assert len(codes) == 20
        experiment_result["allResult"].append(
            {"pid": pid, "all_gen_res": codes})

    with open(input_file_path, 'r', encoding='utf-8') as f:
        test_data = json.load(f)
    pid_code_map = {p["pid"]: p["code"][0]["code"] for p in test_data}


def get_code_ground_truth(problem: Dict):
    global pid_code_map
    code = pid_code_map[problem["pid"]]
    return code


def cal_at_k():
    assert len(experiment_result["allResult"]) != 0

    def at_k_list(one_data: Dict) -> Dict:
        code_ground_truth = get_code_ground_truth(one_data)

        max_codebleu = 0.0
        codebleu_list = []
        N = max(k_list)
        at_k_list_res = {}
        for i in range(N):
            gen_res = one_data["all_gen_res"][i]
            code_res = gen_res

            if len(code_res.strip()) == 0:
                codebleu = 0.0
            else:
                codebleu = get_code_bleu(
                    prediction=code_res, reference=code_ground_truth)
            print(f"gen_res{i} codebleu:{codebleu}")

            codebleu_list.append(codebleu)
            max_codebleu = max(max_codebleu, codebleu)
            if i + 1 in k_list:
                at_k_list_res[f"codebleu@{i+1}"] = max_codebleu

        for i in k_list:
            at_k_list_res[f"codebleu_avg"] = np.mean(codebleu_list)
        return at_k_list_res

    def process_one_data(one_data: Dict):
        at_k_list_res = at_k_list(one_data)
        one_data.update(at_k_list_res)

    print(f"current file: {problem_res_path}")
    all_problrm_res_data = experiment_result["allResult"]
    for i, one_data in enumerate(all_problrm_res_data):
        print(
            f"******************data{i}--{one_data['pid']}*******************")
        process_one_data(one_data)

    for i in k_list:
        experiment_result[f"codebleu_averageScore@{i}"] = np.mean(
            [one_data[f"codebleu@{i}"] for one_data in all_problrm_res_data])
    with open(codebleu_res_path, 'w', encoding='utf-8') as f:
        json.dump(experiment_result, f, indent=2, ensure_ascii=False)


if __name__ == '__main__':
    init()
    cal_at_k()
