import argparse
import pprint
import sys
import os
import re
from tqdm import tqdm
import torch
from vllm import LLM
from vllm import SamplingParams
import json
from math import ceil


# PATH = "/home/work/aigc-user-workspace/luoxianzhen/MultiPoT/"
PATH = "./"


if torch.cuda.is_available():
    device = "cuda"
else:
    device = "cpu"

try:
    if torch.backends.mps.is_available():
        device = "mps"
except:
    pass


def read_problems(input_path, output_path):
    outputs = []
    already = []
    if os.path.exists(output_path):    
        outputs = json.load(open(output_path, "r"))
        already = [d["input"] for d in outputs]
    inputs = []
    f = open(input_path, "r")
    for d in f.readlines():
        d = json.loads(d)
        if d["input"] in already:
            continue
        else:
            d["code"] = []
            d["loss"] = []
            inputs.append(d)
    return inputs, outputs


def encode_prompt(batch, lang="Python", dataset="penguin"):
    prompt_batch = []
    prompts = json.load(open(PATH + "prompts.json", "r"))
    prompts = prompts[dataset]
    for i in batch:
        prompt = prompts[lang]
        prompt += f"Question: {i['input']}\n\n"
        prompt += f"Use the {lang} program to solve the problem. The reasoning progress is clearly shown in the program.\n\n"
        prompt_batch.append(prompt)
    for i, p in enumerate(prompt_batch[:3]):
        print(f"Prompt Example {i}:\n{p}\n\n")
    return prompt_batch


def truncate(code):
    if "```" in code:
        code = code.split("```")[0]
    code = code.split("\n\n\n")[0]
    return code


def generate(dataset, lang, llm, sampling_params, args):
    input_path = PATH + "datasets/" + dataset + ".jsonl"
    print("Input Path:", input_path)
    output_path = PATH + f"{args.middle_dir}/{dataset}_{lang}{args.output_suffix}.json"
    
    ds, outputs = read_problems(input_path, output_path)
    print("Number of samples in {} need to generate: {}".format(dataset, len(ds)))
    
    for i in tqdm(range(0, len(ds), args.batch_size)):
        batch = ds[i : i + args.batch_size]
        prompts = encode_prompt(batch, lang=lang, dataset=dataset)

        if args.decoding_style == 'sampling':
            loops = ceil(args.N / args.num_seqs_per_iter)
        else:
            loops = 1
        
        for _ in range(loops):
            with torch.no_grad():
                print("Start generating...")
                completions = llm.generate(prompts, sampling_params)
            for d, completion in zip(batch, completions):  
                gen_seqs = [completion.outputs[i].text for i in range(args.num_seqs_per_iter)]
                loss_seqs = [completion.outputs[i].cumulative_logprob for i in range(args.num_seqs_per_iter)]
                if gen_seqs is not None:
                    for gen_seq, loss in zip(gen_seqs, loss_seqs):
                        d["code"].append(truncate(gen_seq.strip()))
                        d["loss"].append(loss)
                        if len(d["code"]) == args.N:
                            break
        
        outputs.extend(batch)
        json.dump(outputs, open(output_path, "w"), indent=4)
    for i, d in enumerate(outputs[:3]):
        print(f"Output Example {i}:\n{d}\n\n")


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--model', type=str, default='/home/storages/gpu0193/disk3/xzluo/models/CodeLlama-34b-hf', help="")
    parser.add_argument("--datasets", type=str, default="penguins")
    parser.add_argument('--lang', type=str, default="Python,R,Java,Javascript,C++")
    parser.add_argument('--temperature', type=float, default=0, help="")
    parser.add_argument('--N', type=int, default=1, help="")
    parser.add_argument('--max_len', type=int, default=2048, help="")
    parser.add_argument('--num_gpus', type=int, default=8, help="")
    parser.add_argument('--decoding_style', type=str, default='greedy', help="")
    parser.add_argument('--num_seqs_per_iter', type=int, default=1, help='')
    parser.add_argument('--batch_size', type=int, default=5000, help='')
    parser.add_argument('--output_suffix', type=str, default="", help='')
    parser.add_argument('--top_p', type=float, default=1, help="")
    parser.add_argument('--middle_dir', type=str, default='outputs')
    args = parser.parse_args()

    argsdict = vars(args)
    print(pprint.pformat(argsdict))

    llm = LLM(model=args.model, tensor_parallel_size=args.num_gpus, trust_remote_code=True, max_model_len=8192)
    sampling_params = SamplingParams(n=args.num_seqs_per_iter, temperature=args.temperature, top_p=args.top_p, max_tokens=args.max_len, stop=['</s>', '\n\n\n', '```', "\n\nQuestion:"])

    datasets = args.datasets.strip().split(",")
    langs = args.lang.strip().split(",")
    print(datasets)
    print(langs)
    for dataset in datasets:
        for lang in langs:
            print(f"start {dataset} {lang} {args.decoding_style} {args.output_suffix}")
            generate(dataset, lang, llm, sampling_params, args)
            print(f"{dataset} {lang} is done ~")
