from utils import *
import evaluate
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoModelForSeq2SeqLM, AutoModel
import os
import json
from typing import *
import tqdm
import numpy as np
import argparse
import math

# nohup python -u cal_sws_f1_at_dif_alpha.py --model-type gpt3.5 --luogu-added --step-rate 0.0 --sws-alpha 0.6 > logs/cal_scores_alpha0.6.log &

bleu_calculator, rouge_calculator, sen_trans_tokenizer, sen_trans_model = None, None, None, None


def init_metrics(args):
    print("Init metrics...")
    global bleu_calculator, rouge_calculator, sen_trans_tokenizer, sen_trans_model
    bleu_calculator = evaluate.load(args.bleu_path)
    rouge_calculator = evaluate.load(args.rouge_path)
    sen_trans_tokenizer = AutoTokenizer.from_pretrained(args.sen_trans_path)
    sen_trans_model = AutoModel.from_pretrained(args.sen_trans_path).cuda()
    sen_trans_model.eval()
    print("Finish init metrics!")


def get_step_ground_truth(problem: Dict):
    return "\n".join(problem["step"]) + '\n'


def steps_from_gen_res(problem: Dict, gen_res: str, args):
    step_num = len(problem["step"])
    goal_num = math.ceil(step_num * args.step_rate)
    truth_steps = "\n".join(problem["step"][:goal_num])
    if args.step_rate != 0.0:
        truth_steps = truth_steps + '\n'

    # Generated part
    end_index = gen_res.find("Below is the code:")
    if end_index > 0:
        gen_steps = gen_res[:end_index]
    else:
        gen_steps = gen_res

    return truth_steps + gen_steps


def gpt35_instruct_steps_from_gen_res(problem: Dict, gen_res: str, args):
    truth_steps = ""

    end_index = gen_res.find("\ncode:")
    if end_index < 0:
        end_index = gen_res.find("\nCode:")

    if end_index > 0:
        gen_steps = gen_res[:end_index]
    else:
        gen_steps = gen_res

    gen_step_list = gen_steps.strip().split('\n')

    processed_step_str = ""
    cnt = 1
    if "step" in gen_step_list[0] or "Step" in gen_step_list[0]:
        gen_step_list = gen_step_list[1:]
    for step in gen_step_list:
        if len(step) == 0:
            continue
        striped_step = step.strip()
        if striped_step.startswith(f"{cnt}."):
            wo_number = striped_step[len(f"{cnt}."):]
            processed_step_str += ('\n' + wo_number.lstrip())
            cnt += 1
        else:
            processed_step_str += (' ' + striped_step)

    return truth_steps + processed_step_str


def cal_at_k(args):

    k_list = args.k_list

    def at_k_list(one_data: Dict) -> Dict:
        step_ground_truth = get_step_ground_truth(one_data)
        max_sws, max_f1 = 0.0, 0.0

        sws_list, f1_list = [], []

        N = max(k_list)
        at_k_list_res = {}
        for i in range(N):
            gen_res = one_data["all_gen_res"][i]
            if args.model_type == "gpt3.5-instruct":
                steps_res = gpt35_instruct_steps_from_gen_res(one_data, gen_res, args)
            else:
                steps_res = steps_from_gen_res(one_data, gen_res, args)

            if len(steps_res.strip()) == 0:
                sws, f1 = 0.0, 0.0
            else:
                sws, f1 = get_SWS_and_F1(prediction=steps_res, reference=step_ground_truth, threshold=args.sws_alpha,
                                         sim_tokenizer=sen_trans_tokenizer, sim_model=sen_trans_model)
            print(f"gen_res{i} sws:{sws}|f1: {f1}")

            # codebleu_list.append(codebleu)
            sws_list.append(sws)
            f1_list.append(f1)

            max_sws, max_f1 = max(max_sws, sws), max(max_f1, f1)
            if i + 1 in k_list:
                at_k_list_res[f"sws@{i+1}"] = max_sws
                at_k_list_res[f"f1@{i+1}"] = max_f1

        for i in k_list:
            at_k_list_res[f"sws_avg"] = np.mean(sws_list)
            at_k_list_res[f"f1_avg"] = np.mean(f1_list)
        return at_k_list_res

    def process_one_data(one_data: Dict):
        at_k_list_res = at_k_list(one_data)
        one_data.update(at_k_list_res)

    gen_res_path = args.gen_res_path
    save_path = args.save_path
    print(f"current file: {save_path}")
    with open(gen_res_path, 'r', encoding='utf-8') as f:
        experiment_result = json.load(f)
    all_problrm_res_data = experiment_result["allResult"]

    for i in k_list:
        pass
        # del experiment_result[f"bleu_averageScore@{i}"]
        # del experiment_result[f"rouge1_averageScore@{i}"]
        # del experiment_result[f"rouge2_averageScore@{i}"]
        # del experiment_result[f"rougeL_averageScore@{i}"]
        # del experiment_result[f"sim_averageScore@{i}"]
        # del experiment_result[f"codebleu_averageScore@{i}"]

    for one_data in all_problrm_res_data:
        pass
        # for i in k_list:
        #     del one_data[f"bleu@{i}"]
        #     del one_data[f"rouge1@{i}"]
        #     del one_data[f"rouge2@{i}"]
        #     del one_data[f"rougeL@{i}"]
        #     del one_data[f"sim@{i}"]
        # del one_data[f"codebleu@{i}"]

        # del one_data[f"bleu_avg"]
        # del one_data[f"rouge1_avg"]
        # del one_data[f"rouge2_avg"]
        # del one_data[f"rougeL_avg"]
        # del one_data[f"sim_avg"]
        # del one_data[f"codebleu_avg"]

    for i, one_data in enumerate(all_problrm_res_data):
        print(f"******************data{i}--{one_data['pid']}*******************")
        process_one_data(one_data)

    for i in k_list:
        experiment_result[f"sws_averageScore@{i}"] = np.mean(
            [one_data[f"sws@{i}"] for one_data in all_problrm_res_data])
        experiment_result[f"f1_averageScore@{i}"] = np.mean(
            [one_data[f"f1@{i}"] for one_data in all_problrm_res_data])
    # Delete some useless keys
    for one_data in all_problrm_res_data:
        del one_data["nl"]
        del one_data["input_format"]
        del one_data["output_format"]
        del one_data["step"]
        del one_data["all_gen_res"]
    with open(save_path, 'w', encoding='utf-8') as f:
        json.dump(experiment_result, f, indent=2, ensure_ascii=False)


if __name__ == '__main__':
    parser = argparse.ArgumentParser(
        description="self-guided-cal-sws-f1-at-different-alpha")

    parser.add_argument('--model-type', default="", required=True, type=str)
    parser.add_argument('--debug', default=False, type=bool)
    parser.add_argument('--k-list', type=list, default=[1, 5, 10, 15, 20])
    parser.add_argument('--lora', action="store_true")
    parser.add_argument('--step-rate', required=True, type=float)
    parser.add_argument('--luogu-added', action="store_true")
    parser.add_argument('--sws-alpha', required=True, type=float)
    parser.add_argument(
        '--bleu-path', default='/home/clw/hhk/CodeSecurity/Experiment/evaluate/metrics/bleu', type=str)
    parser.add_argument(
        '--rouge-path', default='/home/clw/hhk/CodeSecurity/Experiment/evaluate/metrics/rouge', type=str)
    parser.add_argument(
        '--sen-trans-path', default='/home/clw/hf_local_models/models--sentence-transformers--all-MiniLM-L6-v2', type=str)

    args = parser.parse_args()
    assert args.model_type in ["codegen350m", "codegen2b", "gpt2", "gpt3.5", "gpt3.5-instruct", "starcoder1b"]
    args.lora_str = 'lora' if args.lora else 'no_lora'
    dataset_str = 'luogu_added' if args.luogu_added else 'only_cf'
    save_dir = rf"result/{dataset_str}/step{args.step_rate}/{args.model_type}/{args.lora_str}"

    args.gen_res_path = os.path.join(
        save_dir, f"{args.lora_str}_gen_result1000.json")
    # args.gen_res_path = rf"/home/clw/hhk/LLMCodeGeneration/self-guided600/resources/steps_eval_human_processed.json"

    args.save_path = os.path.join(
        save_dir, f"alpha{args.sws_alpha}_sws_f1_gen_result.json")

    argsdict = vars(args)
    print(argsdict)
    init_metrics(args)
    cal_at_k(args)
