from utils import *
import evaluate
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoModelForSeq2SeqLM, AutoModel
import os
import json
from typing import *
import tqdm
import numpy as np
import argparse
import math

# nohup python -u cal_score.py --model-type codegen2b --lora --luogu-added --step-rate 0.25 > logs/cal_scores.log &
# nohup python -u cal_score.py --model-type starcoder1b --lora --luogu-added --step-rate 1.0 > logs/cal_scores.log &
# nohup python -u cal_score.py --model-type codegen350m --luogu-added --step-rate 1.0 > logs/cal_scores.log &
# nohup python -u cal_score.py --model-type gpt3.5-instruct --luogu-added --step-rate 0.0 > logs/cal_scores.log &

bleu_calculator, rouge_calculator, sen_trans_tokenizer, sen_trans_model = None, None, None, None
pid_code_map = {}  # {pid1:code[0], ...}


def init_metrics(args):
    print("Init metrics...")
    global bleu_calculator, rouge_calculator, sen_trans_tokenizer, sen_trans_model
    bleu_calculator = evaluate.load(args.bleu_path)
    rouge_calculator = evaluate.load(args.rouge_path)
    sen_trans_tokenizer = AutoTokenizer.from_pretrained(args.sen_trans_path)
    sen_trans_model = AutoModel.from_pretrained(args.sen_trans_path).cuda()
    sen_trans_model.eval()
    print("Finish init metrics!")


def get_step_ground_truth(problem: Dict):
    return "\n".join(problem["step"]) + '\n'


def get_code_ground_truth(problem: Dict):
    global pid_code_map
    code = pid_code_map[problem["pid"]]
    return code


def steps_from_gen_res(problem: Dict, gen_res: str, args):
    step_num = len(problem["step"])
    goal_num = math.ceil(step_num * args.step_rate)
    truth_steps = "\n".join(problem["step"][:goal_num])
    if args.step_rate != 0.0:
        truth_steps = truth_steps + '\n'

    # Generated part
    end_index = gen_res.find("Below is the code:")
    if end_index > 0:
        gen_steps = gen_res[:end_index]
    else:
        gen_steps = gen_res

    return truth_steps + gen_steps


def gpt35_instruct_steps_from_gen_res(problem: Dict, gen_res: str, args):
    truth_steps = ""

    # Find the position where the steps end
    end_index = gen_res.find("\ncode:")
    if end_index < 0:
        end_index = gen_res.find("\nCode:")

    if end_index > 0:
        gen_steps = gen_res[:end_index]
    else:
        gen_steps = gen_res

    gen_step_list = gen_steps.strip().split('\n')

    # Remove the numbers at the beginning of each step and create a new step string
    processed_step_str = ""
    cnt = 1
    if "step" in gen_step_list[0] or "Step" in gen_step_list[0]:  # remove "step:"
        gen_step_list = gen_step_list[1:]
    for step in gen_step_list:
        if len(step) == 0:
            continue
        striped_step = step.strip()
        if striped_step.startswith(f"{cnt}."):
            # print("cnt")
            wo_number = striped_step[len(f"{cnt}."):]
            processed_step_str += ('\n' + wo_number.lstrip())
            cnt += 1
        else:  # If it does not start with a number, append to the current step
            # print("not cnt")
            processed_step_str += (' ' + striped_step)

    return truth_steps + processed_step_str


def code_from_gen_res(gen_res: str):
    begin_index = gen_res.find("#include")
    if begin_index > 0:
        code = gen_res[begin_index:]
        code = code.replace("```", "")
    else:
        code = ""
    return code


def cal_at_k(args):

    k_list = args.k_list

    def at_k_list(one_data: Dict) -> Dict:
        step_ground_truth = get_step_ground_truth(one_data)
        code_ground_truth = get_code_ground_truth(one_data)
        max_bleu, max_rouge1, max_rouge2, max_rougeL, max_sim = 0.0, 0.0, 0.0, 0.0, 0.0
        max_codebleu = 0.0

        bleu_list, rouge1_list, rouge2_list, rougeL_list, sim_list = [], [], [], [], []
        codebleu_list = []

        # gpt2_avg_metric_list = []

        N = max(k_list)
        at_k_list_res = {}
        for i in range(N):
            gen_res = one_data["all_gen_res"][i]
            # steps_res = steps_from_gen_res(one_data, gen_res, args)
            steps_res = gpt35_instruct_steps_from_gen_res(one_data, gen_res, args)
            code_res = code_from_gen_res(gen_res)

            if len(steps_res.strip()) == 0:
                bleu, rouge1, rouge2, rougeL, similarity = 0.0, 0.0, 0.0, 0.0, 0.0
            else:
                bleu = get_bleu(prediction=steps_res,
                                reference=step_ground_truth, bleu=bleu_calculator)
                rouge1, rouge2, rougeL = get_rouge(prediction=steps_res, reference=step_ground_truth,
                                                   rouge=rouge_calculator)
                similarity = get_similarity(prediction=steps_res, reference=step_ground_truth,
                                            tokenizer=sen_trans_tokenizer, model=sen_trans_model)

            if len(code_res.strip()) == 0:
                codebleu = 0.0
            else:
                codebleu = get_code_bleu(prediction=code_res, reference=code_ground_truth)
            # print(f"gen_res{i} codebleu:{codebleu}")

            bleu_list.append(bleu)
            rouge1_list.append(rouge1)
            rouge2_list.append(rouge2)
            rougeL_list.append(rougeL)
            sim_list.append(similarity)
            codebleu_list.append(codebleu)

            max_bleu = max(max_bleu, bleu)
            max_rouge1, max_rouge2, max_rougeL = max(max_rouge1, rouge1), max(
                max_rouge2, rouge2), max(max_rougeL, rougeL)
            max_sim = max(max_sim, similarity)
            max_codebleu = max(max_codebleu, codebleu)

            if i + 1 in k_list:
                at_k_list_res[f"bleu@{i+1}"] = max_bleu
                at_k_list_res[f"rouge1@{i+1}"] = max_rouge1
                at_k_list_res[f"rouge2@{i+1}"] = max_rouge2
                at_k_list_res[f"rougeL@{i+1}"] = max_rougeL
                at_k_list_res[f"sim@{i+1}"] = max_sim
                at_k_list_res[f"codebleu@{i+1}"] = max_codebleu

        for i in k_list:
            at_k_list_res[f"bleu_avg"] = np.mean(bleu_list)
            at_k_list_res[f"rouge1_avg"] = np.mean(rouge1_list)
            at_k_list_res[f"rouge2_avg"] = np.mean(rouge2_list)
            at_k_list_res[f"rougeL_avg"] = np.mean(rougeL_list)
            at_k_list_res[f"sim_avg"] = np.mean(sim_list)
            at_k_list_res[f"codebleu_avg"] = np.mean(codebleu_list)
            # at_k_list_res[f"max_avg_metric_index"] = int(np.argmax(
            #     gpt2_avg_metric_list))
        return at_k_list_res

    def process_one_data(one_data: Dict):
        at_k_list_res = at_k_list(one_data)
        one_data.update(at_k_list_res)

    global pid_code_map
    with open(args.test_set_path, 'r', encoding='utf-8') as f:
        test_data = json.load(f)
    pid_code_map = {p["pid"]: p["code"][0]["code"] for p in test_data}

    save_path = args.save_path
    print(f"current file: {save_path}")
    with open(save_path, 'r', encoding='utf-8') as f:
        experiment_result = json.load(f)
    all_problrm_res_data = experiment_result["allResult"]
    for i, one_data in enumerate(all_problrm_res_data):
        print(
            f"******************data{i}--{one_data['pid']}*******************")
        process_one_data(one_data)

    for i in k_list:
        experiment_result[f"bleu_averageScore@{i}"] = np.mean(
            [one_data[f"bleu@{i}"] for one_data in all_problrm_res_data])
        experiment_result[f"rouge1_averageScore@{i}"] = np.mean(
            [one_data[f"rouge1@{i}"] for one_data in all_problrm_res_data])
        experiment_result[f"rouge2_averageScore@{i}"] = np.mean(
            [one_data[f"rouge2@{i}"] for one_data in all_problrm_res_data])
        experiment_result[f"rougeL_averageScore@{i}"] = np.mean(
            [one_data[f"rougeL@{i}"] for one_data in all_problrm_res_data])
        experiment_result[f"sim_averageScore@{i}"] = np.mean(
            [one_data[f"sim@{i}"] for one_data in all_problrm_res_data])
        experiment_result[f"codebleu_averageScore@{i}"] = np.mean(
            [one_data[f"codebleu@{i}"] for one_data in all_problrm_res_data])
    with open(save_path, 'w', encoding='utf-8') as f:
        json.dump(experiment_result, f, indent=2, ensure_ascii=False)


if __name__ == '__main__':

    parser = argparse.ArgumentParser(description="self-guided-cal-score")

    parser.add_argument('--model-type', default="", required=True, type=str)
    parser.add_argument('--debug', default=False, type=bool)
    parser.add_argument('--k-list', type=list, default=[1, 5, 10, 15, 20])
    parser.add_argument('--lora', action="store_true")
    parser.add_argument('--step-rate', required=True, type=float)
    parser.add_argument('--luogu-added', action="store_true")
    parser.add_argument(
        '--bleu-path', default='/home/clw/hhk/CodeSecurity/Experiment/evaluate/metrics/bleu', type=str)
    parser.add_argument(
        '--rouge-path', default='/home/clw/hhk/CodeSecurity/Experiment/evaluate/metrics/rouge', type=str)
    parser.add_argument(
        '--sen-trans-path', default='/home/clw/hf_local_models/models--sentence-transformers--all-MiniLM-L6-v2', type=str)

    args = parser.parse_args()
    args.lora_str = 'lora' if args.lora else 'no_lora'
    dataset_str = 'luogu_added' if args.luogu_added else 'only_cf'
    save_dir = rf"result/{dataset_str}/step{args.step_rate}/{args.model_type}/{args.lora_str}"

    args.test_set_path = rf"resources/test_{dataset_str}.json"
    # Path to the result file generated by main.py
    args.save_path = os.path.join(
        save_dir, f"{args.lora_str}_gen_result_test.json")
    argsdict = vars(args)
    print(argsdict)
    init_metrics(args)
    cal_at_k(args)
