import tqdm

from Scripts.Prompting import *
from Scripts.Utils import *


def check_complete(data, args):
    f = 1
    if args.prompt == "basic":
        for temp in data["completeness"]:
            if "system_code" in temp.keys():
                f = temp["system_code"]["completeness"]
    else:
        for temp in data["completeness"]:
            if "system_code_cot" in temp.keys():
                f = temp["system_code_cot"]["completeness"]
    return f


def temp_variation(df, output_file, cache_file, iteration, extra_file, args):
    # extra = []
    c_file = json.load(open(cache_file, "r"))
    gen1 = []

    llm_used = args.llm  # gpt4-turbo, gpt3.5-turbo, gpt4-3.5-turbo, mistral, phi3
    if args.llm != "gpt3.5-turbo":
        llm_used = "gpt4-turbo"  # info default to gpt4

    for data in tqdm.tqdm(df):
        if "completeness" in data.keys():
            if not check_complete(data, args): continue
            del data["completeness"]
        # load from cache
        new_d = None
        for c in c_file:
            if c["info"] == data["info"]:
                if "maskinfo" in c.keys():
                    if c["maskinfo"] == data["maskinfo"]:
                        new_d = c
                else:
                    new_d = c
        if new_d:
            gen1.append(new_d)

        else:
            # for info not in cached, generate all mask and regenerate NL, and store and save c_file
            info_orig = data["info"]
            text = data["text"]
            if "orig_text" in data.keys():
                text = data["orig_text"]
            info = info_orig
            if "maskinfo" in data.keys():
                info = data["maskinfo"]
            maskNL = []
            # info = remove_NA(info)
            masked_info = masking_info(info, max_count=iteration)

            for mask_info in masked_info:
                mask_key = masking_key(mask_info)
                # if "explicit" in args.info_prompt:
                #     mask_info = nan_filling(mask_info)
                if mask_key != []:
                    prompt = (
                        sys_prompt_omit.replace("{{originalNL}}", text)
                        .replace("{{info}}", str(info_orig))
                        .replace("{{maskinfo}}", str(mask_info))
                    )
                    resp = get_llm_response(llm_used, prompt)[0]
                    for s in st:
                        resp = resp.split(s)[0].strip()
                    maskNL.append([mask_info, mask_key, resp])

            new_d = data.copy()
            new_d["text_options"] = maskNL
            if maskNL != [] and new_d:
                c_file.append(new_d)
                with open(cache_file, "w") as f:
                    json.dump(c_file, f)
                gen1.append(new_d)

        if len(gen1) == 0: continue
        # extra file update
        # temp_extra = gen1[-1].copy()
        # for options in temp_extra["text_options"]:
        #     mask_info = options[0]  # mask_info
        #     text = options[-1]  # text
        #     if type(text) == dict: continue
        #     info, _ = get_info(text, args, data["test_list"], iteration)
        #     options.append(info)
        # extra.append(temp_extra)

        # with open(extra_file, "w") as f:
        #     json.dump(extra, f)
        with open(output_file, "w") as f:
            json.dump(gen1, f)
    return


def load_basefile(df, output_file, base_file, args):
    gen1 = []
    for data in tqdm.tqdm(df):
        if "completeness" in data.keys():
            if not check_complete(data, args):
                continue

        # load from base_file
        b_file = json.load(open(base_file[0], "r"))
        # load from base file
        new_d = None
        for b in b_file:
            if b["info"] == data["info"]:
                if "maskinfo" in b.keys():
                    if b["maskinfo"] == data["maskinfo"]:
                        new_d = b
                else:
                    new_d = b
        if new_d:
            gen1.append(new_d)
        with open(output_file, "w") as f:
            json.dump(gen1, f)

    return


def min_spec(
    df, output_file, cache_file, iteration, extra_file, args, base_file=None, baseline=0
):
    if not baseline:
        temp_variation(df, output_file, cache_file, iteration, extra_file, args)
    else:
        load_basefile(df, output_file, base_file, args)
    return


def mask(raw_args=None, iteration=0, base_folder=None, baseline=0):
    args = get_args(raw_args)
    folder = f"config-{args.sample}sample_{args.n_values}n_{args.threshold}threshold_{args.temp}temp"
    if not baseline:
        b_folder = os.path.join(
            "Responses", args.llm, args.info_prompt, "all_temp_variation", folder
        )
    else:
        b_folder = os.path.join(
            "Responses", args.llm, args.info_prompt, args.dataset, args.prompt, folder
        )
    input_file = os.path.join(
        b_folder, "completeness_checking", f"{args.dataset}_ambiguous_{iteration}.json"
    )
    output_file = os.path.join(
        b_folder, "min_spec_candidates", f"{args.dataset}_ambiguous_{iteration+1}.json"
    )
    extra_file = cache_file = None
    if not baseline:
        extra_file = os.path.join(
            b_folder, "struct_check", f"{args.dataset}_ambiguous_{iteration+1}.json"
        )
        cache_folder = rf"Responses\cache"
        cache_file = os.path.join(
            cache_folder, f"{args.dataset}_ambiguous_{iteration+1}.json"
        )
        if not os.path.exists(cache_file):
            print("cache file not found")
            return
    if os.path.exists(output_file):
        print("already exists")
        return
    df = json.load(open(input_file, "r"))

    base_file = None
    if base_folder:
        base_file = []
        base_file.append(
            os.path.join(
                base_folder,
                "min_spec_candidates",
                f"{args.dataset}_ambiguous_{iteration+1}.json",
            )
        )
        if not baseline:
            base_file.append(
                os.path.join(
                    base_folder,
                    "struct_check",
                    f"{args.dataset}_ambiguous_{iteration+1}.json",
                )
            )
    min_spec(
        df,
        output_file,
        cache_file,
        iteration + 1,
        extra_file,
        args,
        base_file,
        baseline,
    )


if __name__ == "__main__":
    # try:
    mask(raw_args=None)
    # except Exception as err:
    #     import pdb
    #     pdb.post_mortem()
