import json
import os

import tqdm

from Scripts.Prompting import *
from Scripts.Utils import *


def filter_contradictory(assertions, info):
    # removes all keys containing op_resp if output_type is filled in info
    if info["output_type"] in [[], ["MASK"]]:
        if info["input_type"] in [[], ["MASK"]]:
            return assertions
        else:
            temp={}
            for k,v in assertions.items():
                if 'CategoryToBool' in k or not ('To' in k and 'ip_resp' in k):
                    temp[k] = v
            return temp
    else:
        temp={}
        for k,v in assertions.items():
            if 'CategoryToBool' in k or not ('To' in k and 'op_resp' in k):
                    temp[k] = v

        if info["input_type"] in [[], ["MASK"]]:
            return temp
        else:
            assertions=temp
            temp={}
            for k,v in assertions.items():
                if 'CategoryToBool' in k or not ('To' in k and 'ip_resp' in k):
                    temp[k] = v
            return temp


def extract_code_cot(text):
    try:
        text = text.split("```python")[1].strip()
        text = text.split("```")[0].strip()
    except:
        text = None
    return text


def extract_code_basic(a):
    stp = [
        "[Explanation]",
        "[Explain]",
        "[End]",
        "[Comment]",
        "[Summary]",
        "[End of Code]",
    ]
    for s in stp:
        a = a.split(s)[0].strip()
    # if 2 [Code] tags present in response take first one
    if a.count("[Code]") > 0:
        a = a.split("[Code]")[0].strip()
    return a


def get_completeness(
    benchmark, assertions, n_values
):
    comp = 0
    temp = benchmark['completeness'][0]['system_code']['full_resp'][1:-1]
    for t in temp:
        a = t['orig_resp']
        [result, code] = run(a, assertions)
        if result is None: continue
        if "execution" not in result.keys():
            result["execution"] = "0.0"
            result["solved_test_list"] = []
        comp += int(result["execution"])
        t["code"]= code
        t["execution"]=result["execution"]
        t["solved_test_list"]= result["solved_test_list"]

    benchmark['completeness'][0]['system_code']['full_resp'][-1] = {"completeness_full": comp/len(temp), "error_cases": n_values - len(temp)}
    benchmark['completeness'][0]['system_code']['completeness'] = comp/len(temp)
    return benchmark


def old_get_completeness(
    response, assertions, n_values, type_, prompt
):
    comp = 0
    temp = []
    temp.append(prompt)
    for resp in response:
        a = resp
        if type_ == "basic" or type_ == "only_tc":
            a = extract_code_basic(a)
        elif type_ == "cot":
            a = extract_code_cot(a)
        if not a: continue
        
        [result, code] = run(a, assertions)
        if result is None: continue
        if "execution" not in result.keys():
            result["execution"] = "0.0"
            result["solved_test_list"] = []
        comp += int(result["execution"])
        temp.append({"code": code, "execution": result["execution"], "solved_test_list": result["solved_test_list"], "orig_resp": resp})

    if comp:
        comp /= len(response)
        temp.append(
            {"completeness_full": comp, "error_cases": n_values - len(response)}
        )
        # if comp <= threshold:
        #     comp = 0
        # else:
        #     comp = 1
    else:
        temp.append(
            {"completeness_full": comp, "error_cases": n_values - len(response)}
        )

    return comp, temp


def save_response(
    response, assertions, n_values, prompt
):
    temp = []
    temp.append(prompt)
    for resp in response:
        temp.append({"code": None, "execution": None, "solved_test_list": None, "orig_resp": resp})
    temp.append(
        {"completeness_full": None, "error_cases": None}
    )

    return None, temp


def generator(raw_args=None, iter=0, error_margin=5, resp_margin=20, ip_text="text", extra=None
):
    args = get_args(raw_args)
    print(args)
    if extra: args.info_prompt = extra
    folder = f"config-{args.sample}sample_{args.n_values}n_{args.threshold}threshold_{args.temp}temp"
    base_folder = os.path.join(
        "Responses", args.llm, args.info_prompt, "all_temp_variation", folder
    )
    input_folder = os.path.join(
        base_folder, "dataset", f"{args.dataset}_ambiguous_{iter}.json"
    )
    output_folder = os.path.join(
        base_folder, "completeness_checking", f"{args.dataset}_ambiguous_{iter}_gen.json"
    )

    try:
        with open(output_folder, "r") as f:
            benchmarks = json.load(f)  # load previously left off file
        with open(input_folder, "r") as f:
            default_file = json.load(f)  # load the input file
        for i in default_file:
            found = 0
            for j in benchmarks:
                if i[ip_text] == j[ip_text]:
                    found = 1; break
            if not found:
                benchmarks.append(i)
        print(f"loaded {len(benchmarks)} benchmarks")
    except:
        print("load_from start", input_folder)
        with open(input_folder, "r") as f:
            benchmarks = json.load(f)  # load the input file

    llm_used = args.llm  # gpt4-o, gpt4-turbo, gpt3.5-turbo, mistral, phi3
    if args.llm == "gpt4-3.5-turbo":
        llm_used = "gpt3.5-turbo"  # sufficiency

    print("using basic code prompt")
    prompt_file = "system_code.md"

    with open(os.path.join(prompt_folder, prompt_file), "r") as f:
        PROMPT = f.read()

    for benchmark in tqdm.tqdm(benchmarks):
        found = 0
        if "completeness" not in benchmark.keys():
            benchmark["completeness"] = []
        else:
            for comp in benchmark["completeness"]:
                if prompt_file.split(".")[0] in comp.keys():
                    found = 1; continue
        if found: continue
        if benchmark["task_id"] > args.sample: break

        assertions = benchmark["test_list"]
        # assertions = filter_contradictory(assertions, benchmark['info'])
        tries = 3
        response = []
        while tries:
            comp = 0

            text = benchmark[ip_text]
            prompt = PROMPT.replace("{task}", text)
            response = get_llm_response(
                llm_used,
                prompt,
                max_tokens=1000,
                n=args.n_values,
                temperature=args.temp,
                stop=["[End]", "```", "Example", "---"],
                all_resp=1,
            )
            if args.n_values == 1:
                if len(response) == 0:
                    print(f"retrying cause no response: #n_values: {args.n_values}")
                    tries -= 1
                    continue
            else:
                if len(response) < resp_margin and len(response) != args.n_values:
                    print(
                        f"retrying cause number of responses: #resp: {len(response)}, #n_values: {args.n_values}"
                    )
                    tries -= 1
                    continue

            comp, response = save_response(
                response,
                assertions,
                args.n_values,
                prompt,
            )
            if args.n_values-len(response) <= error_margin:
                break
            print(
                f"retrying cause many responses could not be parsed: #error: {response[-1]['error_cases']}, #resp: {len(response)}, #n_values: {args.n_values}"
            )
            tries -= 1

        if type(response[-1]) == str:
            comp, response = save_response(
                response,
                assertions,
                args.n_values,
                prompt,
            )

        resp = {"completeness": comp, "full_resp": response}
        benchmark["completeness"].append({prompt_file.split(".")[0]: resp})

        with open(output_folder, "w") as f:
            json.dump(benchmarks, f)


def old_checker(
    raw_args=None, iter=0, error_margin=5, resp_margin=20, ip_text="text", extra=None
):
    args = get_args(raw_args)
    print(args)
    if extra: args.info_prompt = extra
    folder = f"config-{args.sample}sample_{args.n_values}n_{args.threshold}threshold_{args.temp}temp"
    base_folder = os.path.join(
        "Responses", args.llm, args.info_prompt, "all_temp_variation", folder
    )
    input_folder = os.path.join(
        base_folder, "dataset", f"{args.dataset}_ambiguous_{iter}.json"
    )
    output_folder = os.path.join(
        base_folder, "completeness_checking", f"{args.dataset}_ambiguous_{iter}.json"
    )
    if ip_text != "text":
        input_folder = os.path.join(
            base_folder, "dataset", f"{args.dataset}_ambiguous_{iter}_regen.json"
        )
        output_folder = os.path.join(
            base_folder,
            "completeness_checking",
            f"{args.dataset}_ambiguous_{iter}_regen.json",
        )

    try:
        with open(output_folder, "r") as f:
            benchmarks = json.load(f)  # load previously left off file
        with open(input_folder, "r") as f:
            default_file = json.load(f)  # load the input file
        for i in default_file:
            found = 0
            for j in benchmarks:
                if i[ip_text] == j[ip_text]:
                    found = 1; break
            if not found:
                benchmarks.append(i)
        print(f"loaded {len(benchmarks)} benchmarks")
    except:
        print("load_from start", input_folder)
        with open(input_folder, "r") as f:
            benchmarks = json.load(f)  # load the input file

    llm_used = args.llm  # gpt4-turbo, gpt3.5-turbo, mistral, phi3
    if args.llm == "gpt4-3.5-turbo":
        llm_used = "gpt3.5-turbo"  # sufficiency

    args.code_prompt = ["basic"]  # , 'cot', 'only_tc']
    # Using LLM to generate multiple code and check of low execution match
    for code_prompt in args.code_prompt:
        if "basic" in code_prompt:
            print("using basic code prompt")
            prompt_file = "system_code.md"

        if "cot" in code_prompt:
            print("using cot code prompt")
            prompt_file = "system_code_cot.md"

        if "only_tc" in code_prompt:
            print("using only test cases")
            prompt_file = "only_tc.md"

        with open(os.path.join(prompt_folder, prompt_file), "r") as f:
            PROMPT = f.read()

        for benchmark in tqdm.tqdm(benchmarks):
            # print(benchmark['task_id'])
            # if benchmark["task_id"]==12: 
            #     benchmark["completeness"] = [{"system_code": {"completeness": 0, "full_resp": ["PROMPT", {"completeness_full": 0, "error_cases": 25}]}}]
            #     with open(output_folder, "w") as f:
            #         json.dump(benchmarks, f)
            #     continue
            found = 0
            if "completeness" not in benchmark.keys():
                benchmark["completeness"] = []
            else:
                for comp in benchmark["completeness"]:
                    if prompt_file.split(".")[0] in comp.keys():
                        found = 1; continue
            if found: continue
            if benchmark["task_id"] > args.sample: break

            assertions = benchmark["test_list"]
            # assertions = filter_contradictory(assertions, benchmark['info'])
            tries = 3
            response = []
            while tries:
                comp = 0
                if "only_tc" in code_prompt:
                    prompt = PROMPT.replace("{task}", str(benchmark["test_list"]))
                    response = get_llm_response(
                        llm_used,
                        prompt,
                        max_tokens=1000,
                        n=args.n_values,
                        temperature=args.temp,
                        stop=["[End]", "```", "Example", "---"],
                        all_resp=1,
                    )

                text = benchmark[ip_text]
                prompt = PROMPT.replace("{task}", text)
                if "basic" in code_prompt:
                    response = get_llm_response(
                        llm_used,
                        prompt,
                        max_tokens=1000,
                        n=args.n_values,
                        temperature=args.temp,
                        stop=["[End]", "```", "Example", "---"],
                        all_resp=1,
                    )

                if "cot" in code_prompt:
                    response = get_llm_response(
                        llm_used,
                        prompt,
                        max_tokens=3000,
                        n=args.n_values,
                        temperature=args.temp,
                        stop=["### Task Description:"],
                        all_resp=1,
                    )

                if args.n_values == 1:
                    if len(response) == 0:
                        print(f"retrying cause no response: #n_values: {args.n_values}")
                        tries -= 1
                        continue
                else:
                    if len(response) < resp_margin:
                        print(
                            f"retrying cause number of responses: #resp: {len(response)}, #n_values: {args.n_values}"
                        )
                        tries -= 1
                        continue

                comp, response = old_get_completeness(
                    response,
                    assertions,
                    args.n_values,
                    code_prompt,
                    prompt,
                )
                if response[-1]["error_cases"] <= error_margin:
                    break
                print(
                    f"retrying cause many responses could not be parsed: #error: {response[-1]['error_cases']}, #resp: {len(response)}, #n_values: {args.n_values}"
                )
                tries -= 1

            if type(response[-1]) == str:
                comp, response = old_get_completeness(
                    response,
                    assertions,
                    args.n_values,
                    code_prompt,
                    prompt,
                )

            resp = {"completeness": comp, "full_resp": response}
            benchmark["completeness"].append({prompt_file.split(".")[0]: resp})

            with open(output_folder, "w") as f:
                json.dump(benchmarks, f)


def checker(
    raw_args=None, iter=0, error_margin=5, resp_margin=20, ip_text="text", extra=None
):
    args = get_args(raw_args)
    print(args)
    if extra: args.info_prompt = extra
    folder = f"config-{args.sample}sample_{args.n_values}n_{args.threshold}threshold_{args.temp}temp"
    base_folder = os.path.join(
        "Responses", args.llm, args.info_prompt, "all_temp_variation", folder
    )
    input_folder = os.path.join(
        base_folder, "completeness_checking", f"{args.dataset}_ambiguous_{iter}_gen.json"
    )
    output_folder = os.path.join(
        base_folder, "completeness_checking", f"{args.dataset}_ambiguous_{iter}.json"
    )

    try:
        with open(output_folder, "r") as f:
            benchmarks = json.load(f)  # load previously left off file
        with open(input_folder, "r") as f:
            default_file = json.load(f)  # load the input file
        for i in default_file:
            found = 0
            for j in benchmarks:
                if i[ip_text] == j[ip_text]:
                    found = 1; break
            if not found:
                benchmarks.append(i)
        print(f"loaded {len(benchmarks)} benchmarks")
    except:
        print("load_from start", input_folder)
        with open(input_folder, "r") as f:
            benchmarks = json.load(f)  # load the input file

    prompt_file = "system_code.md"
    for benchmark in tqdm.tqdm(benchmarks):
        found = 0
        if "completeness" not in benchmark.keys():
            benchmark["completeness"] = []
        else:
            for comp in benchmark["completeness"]:
                if prompt_file.split(".")[0] in comp.keys() and comp[prompt_file.split(".")[0]]["completeness"] != None:
                    found = 1; continue
        if found: continue
        if benchmark["task_id"] > args.sample: break

        assertions = benchmark["test_list"]
        # assertions = filter_contradictory(assertions, benchmark['info'])
        benchmark = get_completeness(
            benchmark,
            assertions,
            args.n_values,
        )
        with open(output_folder, "w") as f:
            json.dump(benchmarks, f)
    return


# if __name__ == "__main__":
#     # try:
#     generator(raw_args=None)
#     checker(raw_args=None)
# # except Exception as e:
# #     import pdb
# #     pdb.post_mortem()
