import json
import os
import random

import tqdm

from Scripts.Prompting import *
from Scripts.Utils import *


def checker_config(raw_args=None, iter=0, base_folder=None):
    args = get_args(raw_args)
    if base_folder == None: return

    base_checker = os.path.join(base_folder, "completeness_checking")

    folder = f"config-{args.sample}sample_{args.n_values}n_{args.threshold}threshold_{args.temp}temp"
    b_folder = os.path.join(
        "Responses", args.llm, args.info_prompt, args.dataset, args.prompt, folder
    )
    output_folder = os.path.join(
        b_folder, "completeness_checking", f"{args.dataset}_ambiguous_{iter}.json"
    )
    default_file = os.path.join(
        b_folder, "dataset", f"{args.dataset}_ambiguous_{iter}.json"
    )
    input_folder = os.path.join(base_checker, f"{args.dataset}_ambiguous_{iter}.json")

    with open(input_folder, "r") as f:
        benchmarks = json.load(f)  # load the input file
    with open(default_file, "r") as f:
        default = json.load(f)
    testing_for = [0, "system_code"]
    if args.prompt == "cot":
        testing_for = [1, "system_code_cot"]

    gen = []
    # update completeness_full based on new n and threshold values
    for benchmark in tqdm.tqdm(benchmarks):
        for d in default:
            if (
                d["task_id"] == benchmark["task_id"]
                and d["subtask_id"] == benchmark["subtask_id"]
            ):
                gen.append(benchmark)
                break
    benchmarks = gen

    for benchmark in tqdm.tqdm(benchmarks):
        val = benchmark["completeness"][testing_for[0]][testing_for[1]]
        full_resp = val["full_resp"]
        if len(full_resp) - 1 == args.n_values:
            completeness_full = val["full_resp"][-1]["completeness_full"]
        else:
            sampled_index = random.sample(
                range(len(full_resp) - 1), min(args.n_values, len(full_resp) - 1)
            )
            sampled_resp = [full_resp[i] for i in sampled_index]
            completeness_full = 0
            for resp in sampled_resp:
                if "execution" not in resp:
                    continue
                completeness_full += int(resp["execution"])
            completeness_full /= args.n_values
            val["full_resp"] = sampled_resp
            val["full_resp"].append(
                {
                    "completeness_full": completeness_full,
                    "error_cases": args.n_values - len(sampled_resp),
                }
            )
        val["completeness"] = 1 if completeness_full >= args.threshold else 0
        benchmark["completeness"][testing_for[0]][testing_for[1]] = val

    print(f"!!! {len(benchmarks)} and {len(default)} !!!")
    # assert len(benchmarks)==len(default)
    with open(output_folder, "w") as f:
        json.dump(benchmarks, f)
    return


if __name__ == "__main__":
    # try:
        checker_config(raw_args=None)
    # except Exception as e:
    #     import pdb
    #     pdb.post_mortem()
