from utils import *
from transformers import AutoTokenizer, AutoModelForCausalLM, GPT2Tokenizer, GPT2LMHeadModel
import os
import json
from typing import *
from concurrent.futures import ThreadPoolExecutor, as_completed
import tqdm
import numpy as np
import argparse
from peft import get_peft_config, get_peft_model, TaskType, LoraConfig
from peft import PeftModel, PeftConfig
import torch
import math
# nohup python -u main.py --model-type codegen2b --model-path Salesforce/codegen-2B-multi --step-rate 1.0 --adapter-path model/2B_luogu_added_self-guided --device cuda:0 --lora --batch-size 7 --luogu-added > logs/2B_step1.0_lora_luogu_added.log &
# nohup python -u main.py --model-type starcoder1b --model-path /home/clw/hhk/LLMCodeGeneration/local_models/starcoderbase-1b --step-rate 0.25 --adapter-path model/starcoderbase1B_luogu_added_self-guided --device cuda:0 --lora --batch-size 20 --luogu-added > logs/starcoder1b_step0.25_lora_luogu_added.log &
# nohup python -u main.py --model-type codegen350m --model-path model/350M_no_lora_luogu_added_self-guided --step-rate 0.0 --luogu-added --device cuda:2 > logs/350M_step0.0_no_lora_luogu_added.log &
# nohup python -u main.py --model-type gpt2 --model-path model/gpt2_finetuned_luogu_added_steps_generate --num-samples 10 --step-rate 0.0 --luogu-added --device cuda:0 > logs/gpt2_finetuned_luogu_added_steps_generate.log &


all_data: List[Dict] = {}  # [{ "pid":..., "nl":...}, { ... }]
os.environ['CUDA_VISIBLE_DEVICES'] = '0'

MODEL = None
TOKENIZER = None


def model_load(args):
    global MODEL, TOKENIZER
    assert args.model_type in ["codegen350m",
                               "codegen2b", "starcoder1b", "gpt2"]
    print(f"Loading {args.model_type}-{'lora' if args.lora else 'no_lora'}...")

    if args.model_type == "gpt2":
        TOKENIZER = GPT2Tokenizer.from_pretrained(args.model_path)
    else:
        TOKENIZER = AutoTokenizer.from_pretrained(args.model_path)
    TOKENIZER.padding_side = 'left'
    TOKENIZER.pad_token = TOKENIZER.eos_token

    if "codegen" in args.model_type:
        MODEL = AutoModelForCausalLM.from_pretrained(
            args.model_path, torch_dtype=torch.float16).cuda(args.device)

        if args.lora:
            MODEL.load_adapter(args.adapter_path)
    elif "starcoder" in args.model_type:
        assert args.lora == True
        config = PeftConfig.from_pretrained(args.adapter_path)
        MODEL = AutoModelForCausalLM.from_pretrained(
            config.base_model_name_or_path,
            torch_dtype=torch.float16,
            # provide this in case you're planning to fine-tune an already fine-tuned checkpoint
            ignore_mismatched_sizes=True,
        )
        # Load the LoRA model
        MODEL = PeftModel.from_pretrained(
            MODEL, args.adapter_path).cuda(args.device)
    elif "gpt2" in args.model_type:
        MODEL = GPT2LMHeadModel.from_pretrained(
            args.model_path, torch_dtype=torch.float16).cuda(args.device)

    MODEL.eval()
    for name, param in MODEL.named_parameters():
        print(
            f'Parameter: {name}, Requires grad: {param.requires_grad}\n Param: {param}')
    print(f"Finish loading {args.model_type}")


def model_batch_gen(model_type, texts, max_to_generate, top_p, temperature, device: str) -> List[str]:
    batch_gen_res = []
    if model_type == "codegen350m" or model_type == "codegen2b":
        batch_gen_res = codegen_generate(tokenizer=TOKENIZER, model=MODEL, texts=texts,
                                         max_to_generate=max_to_generate, top_p=top_p, temperature=temperature, device=device)
    elif model_type == "starcoder1b":
        batch_gen_res = starcoderbase_generate(tokenizer=TOKENIZER, model=MODEL, texts=texts,
                                               max_to_generate=max_to_generate, top_p=top_p, temperature=temperature, device=device)
    elif model_type == "gpt2":
        batch_gen_res = gpt2_generate(tokenizer=TOKENIZER, model=MODEL, texts=texts,
                                      max_to_generate=max_to_generate, top_p=top_p, temperature=temperature, device=device)
    else:
        print(f"model type \"{model_type}\" error")
        exit(-1)
    return batch_gen_res


def get_prompt(problem: Dict, args):
    prompt = f'Problem description:\n{problem["nl"]}\nInput format:\n{problem["input_format"]}\n' \
        f'Output format:\n{problem["output_format"]}\nExamples:\n' \
        f'Input>>\n{problem["test_case"][0]["input"]}\nOutput>>\n{problem["test_case"][0]["output"]}\nAnswer:\n'

    step_num = len(problem["step"])
    goal_num = math.ceil(step_num * args.step_rate)
    steps = "\n".join(problem["step"][:goal_num])
    if args.step_rate != 0.0:
        steps = steps + '\n'
    # exit(-1)
    # ground_truth = "\n".join(problem["step"])
    prompt = prompt + steps
    return prompt


def process_one_problem(one_problem: Dict, args):
    """
    Process each data. Generate args.num_samples times.
    one_problem is a problem to be generated (containing ids).
    """
    origin_index = one_problem["ids"]
    one_result = one_problem.copy()
    del one_result["ids"]
    N = args.num_samples
    prompt = get_prompt(one_problem, args)
    # print(prompt)
    prompts = [prompt for _ in range(N)]
    bs = args.batch_size
    batches = [prompts[i:i + bs] for i in range(0, N, bs)]
    one_result["all_gen_res"] = []
    for i, one_batch in enumerate(batches):
        print(
            f"data{origin_index}-{one_problem['pid']} genetate batch {i*bs}-{i*bs+len(one_batch)-1}")
        batch_gen_res = model_batch_gen(model_type=args.model_type, texts=one_batch, max_to_generate=args.max_to_generate,
                                        top_p=0.95, temperature=0.8, device=args.device)
        one_result["all_gen_res"].extend(batch_gen_res)
    # print(f"data{origin_index}-{one_problem['pid']} batch_gen_res finish")
    del one_result["test_case"]
    return one_result


def run_for_goal_data(args, goal_data: List[Dict], n_workers):
    save_path = args.save_path
    experiment_result = {}
    if os.path.exists(save_path):
        with open(save_path, 'r', encoding='utf-8') as f:
            experiment_result = json.load(f)
    else:
        experiment_result = {"allResult": []}

    with ThreadPoolExecutor(max_workers=n_workers) as executor:
        futures = []
        completed_index = []
        for i, data in enumerate(goal_data):
            fun_args = (data, args)
            future = executor.submit(process_one_problem, *fun_args)
            futures.append(future)
            completed_index.append(i)

        assert len(completed_index) == len(
            goal_data), "Some problems are not attempted."

        print("Running test suites...")
        for future in tqdm.tqdm(as_completed(futures), total=len(futures)):
            one_reslut = future.result()
            experiment_result["allResult"].append(one_reslut)

            with open(save_path, 'w', encoding='utf-8') as f:
                json.dump(experiment_result, f, indent=2, ensure_ascii=False)


def init_goal_data_and_model(args):
    """从checkpoint加载需要生成的数据"""
    global all_data

    with open(args.resource_path, 'r', encoding='utf-8') as f:
        all_data = json.load(f)

    if os.path.exists(args.save_path):
        with open(args.save_path, 'r', encoding='utf-8') as f:
            experiment_result = json.load(f)
        goal_data: List[Dict] = []

        pid_set = {d["pid"] for d in experiment_result["allResult"]}
        for i, d in enumerate(all_data):
            if d["pid"] not in pid_set:
                goal_data.append({**all_data[i], "ids": i})
    else:
        goal_data = [{**d, "ids": i} for i, d in enumerate(all_data)]

    # Delete some useless keys
    for d in goal_data:
        del d["title"]
        del d["difficulty"]
        del d["nl_cn"]
        del d["summarization"]
        del d["code"]
        d["test_case"] = [d["test_case"][0]]  # Only the first one is needed

    print("problem to be generate:")
    for data in goal_data:
        print(f"{data['ids']}--{data['pid']}")

    model_load(args)
    return goal_data


if __name__ == '__main__':
    parser = argparse.ArgumentParser(description="self-guided600")

    parser.add_argument('--model-type', default="", required=True, type=str)
    parser.add_argument('--model-path', default="", required=True, type=str)
    parser.add_argument('--n-workers', default=1, type=int)
    parser.add_argument('--lora', action="store_true")
    # The number of times to generate results for each problem
    parser.add_argument('--num-samples', default=20, type=int)
    parser.add_argument('--debug', default=False, type=bool)
    parser.add_argument('--adapter-path', type=str)
    parser.add_argument('--step-rate', required=True, type=float)
    parser.add_argument('--device', default='cuda:0', type=str)
    parser.add_argument('--luogu-added', action="store_true")
    parser.add_argument('--batch-size', default=10, type=int)

    args = parser.parse_args()
    assert args.step_rate in [0.0, 0.25, 0.5, 0.75, 1.0]
    assert args.batch_size <= args.num_samples
    if args.step_rate <= 0.5:
        args.max_to_generate = 832
    else:
        args.max_to_generate = 768

    if args.model_type == "gpt2":
        args.max_to_generate = 256
    args.resource_path = rf"resources/my_test.json" if args.debug else rf"resources/test_luogu_added.json"

    lora_str = 'lora' if args.lora else 'no_lora'
    dataset_str = 'luogu_added' if args.luogu_added else 'only_cf'
    save_dir = rf"result/{dataset_str}/step{args.step_rate}/{args.model_type}/{lora_str}"
    if not os.path.exists(save_dir):
        os.makedirs(save_dir)
    args.save_path = os.path.join(
        save_dir, f"finetuned_{lora_str}_gen_result1000.json")

    print("*****args*****")
    for arg in vars(args):
        print(f'{arg}: {getattr(args, arg)}')

    goal_data = init_goal_data_and_model(args)
    run_for_goal_data(args, goal_data, n_workers=args.n_workers)
