import json
import os
import time
from argparse import ArgumentParser

import numpy as np
from substrate import *

from Scripts.GetExecutionMatch import *
from Scripts.Prompting import *
from Scripts.permute_func_name import renamings

SUBSTRATE_LLM_CLIENT = LLMClient()
prompt_folder = r"Scripts\config_experiment\prompts"
with open(prompt_folder + r"\omit.md", "r") as f:
    sys_prompt_omit = f.read()
st = [
    "\n",
    "[Explanation]",
    r'"""',
    r"'''",
    r"`",
    r"```",
    "`query`",
    "`info`",
    "WritingPrompt",
    "HTTPSConnectionPool",
    "('Connection ",
    "'choices'",
    "---",
    "json",
    "Aikidenesse",
    "OP:",
    "JSON",
]


def iou_score(string1, string2):
    words1 = string1.split(" ")
    words2 = string2.split(" ")
    intersection = [word for word in words1 if word in words2]
    union = set(words2 + words1)
    iou = len(intersection) / len(union)
    # print("iou", iou, string1, string2)
    if iou >= 0.5:
        return True
    return False


def combine_keys(mp):
    """
    takes in dictionary mp: combines for substring matching in keys and return combined dictionary
    """
    d = 1
    while d:
        d = 0
        keys = list(mp.keys())
        # print(keys)
        seen = []
        for k in keys:
            if k in seen:
                continue
            for kk in keys:
                if k == kk or k in seen or kk in seen:
                    continue
                if k in kk or kk in k or iou_score(k, kk):
                    d = 1
                    seen.append(kk)
                    seen.append(k)
                    # print(f"combined: {k}, {kk}")
                    if mp[kk] > mp[k]:
                        # print(f"as {kk}")
                        # combine to kk
                        mp[kk] += mp[k]
                        del mp[k]
                    else:
                        # print(f"as {k}")
                        # combine to k
                        mp[k] += mp[kk]
                        del mp[kk]
    return mp


def soft_matching(full_responses):
    """
    takes in a list of dictionaries
    for each key: extracts the most frequent value with soft matching ie. grouping even subset occurrence as same
    returns the generated resp which has these frequent values and their counts
    """
    counts = []
    responses = []
    indx = 0
    for response in full_responses:  # convert str as dict
        if response.strip()[-1]=="{":
            response = response[:-1] 
        # response = response.replace("'No solution'", '"No solution"').replace('""', '"')
        # response = response.replace('"""', '""').replace('"if empty string": "}', '"if empty string": ""}')
        try:
            responses.append(json.loads(response))
            indx += 1
        except Exception as e:
            if response.count('"') < response.count("'"):
                response = (
                    response.replace("'", "|").replace('"', "'").replace("|", '"')
                )
                try:
                    responses.append(json.loads(response))
                    indx += 1
                except Exception as e:
                    print(e, "ERROR IN PARSING THE RESPONSE")
                    print(response, "**")
                    hr
            else:
                try:
                    response = (
                        response.replace("'", "")
                        .replace("\\", "")
                        .replace("\"'", '"')
                        .replace("'\"", '"')
                    )
                    response = response.split("- ")[0]
                    responses.append(json.loads(response))
                    # print(indx, "responses:", responses[indx])
                    indx += 1
                except:
                    print(e, "ERROR IN PARSING THE RESPONSE")
                    print(response, "**")
                    br
    # print("append responses", responses)
    # br
    temp = {}
    for key in responses[0].keys():
        mp = {}  # map of values to their counts
        for response in responses:
            for k, vals in response.items():
                if k == key:
                    if key == "task" or key == "edge_cases":  # string or dict
                        if str(vals) not in mp.keys():
                            mp[str(vals)] = 1
                        else:
                            mp[str(vals)] += 1
                    else:  # list
                        for val in vals:
                            if val not in mp.keys():
                                mp[val] = 1
                            else:
                                mp[val] += 1

        # print(len(responses))
        # print(mp)
        mp = combine_keys(mp)
        # print("combined: ", mp)
        if key == "task":
            temp[key] = ""
        elif key == "edge_cases":
            temp[key] = {}
        else:
            temp[key] = []
        if mp == {}:
            counts.append(len(responses))
        else:
            for k in mp.keys():
                if key == "task":
                    # save the most frequent task
                    temp[key] = max(mp, key=mp.get)
                    counts.append(mp[temp[key]])
                else:
                    # save all with frequency >1/2
                    if mp[k] >= len(responses) / 2:
                        if key == "edge_cases":
                            # print("k", k)
                            try:
                                # k_temp = k.replace("'", "").replace("one beat", '"one beat"').replace("two beats", '"two beats"').replace("four beats",'"four beats"')
                                # print("k_temp", k_temp)
                                k_temp = json.loads(k)
                            except:
                                try:
                                    k_temp = json.loads(
                                        k.replace("'", "|")
                                        .replace('"', "'")
                                        .replace("|", '"')
                                    )
                                except Exception as e:
                                    print(e)
                                    k_temp = k
                            for k1, v1 in k_temp.items():
                                temp[key][k1] = v1
                        else:
                            temp[key].append(k)
                        counts.append(mp[k])

    return str(temp), counts


def get_info(text, args, test_cases, iteration):
    # llm_used = args.llm  # gpt4-turbo, gpt3.5-turbo, gpt4-3.5-turbo, mistral, phi3
    # if args.llm != "gpt3.5-turbo":
    llm_used = "gpt4-turbo"  # info default to gpt4

    if "agg" not in args.info_prompt:
        info_prompt = args.info_prompt
        if iteration and "examples" in args.info_prompt:
            info_prompt = args.info_prompt.replace("_examples", "")
        with open(prompt_folder + rf"\info_{info_prompt}.md", "r") as f:
            sys_prompt_info = f.read()
        infoprompt = sys_prompt_info.replace("{{originalNL}}", text)
        if "examples" in args.info_prompt and iteration == 0:
            infoprompt = infoprompt.replace("{{test cases}}", str(test_cases))
        resp = get_llm_response(llm_used, infoprompt)[0]
        for s in st:
            resp = resp.split(s)[0].strip()
        info_count = 1
    else:
        if "comp" in args.info_prompt:  # doing softmatch for each component
            info_prompt = args.info_prompt.split("comp_")[1]
            if "examples" in args.info_prompt and iteration:
                info_prompt = info_prompt.replace("_examples", "")
            with open(prompt_folder + rf"\info_{info_prompt}.md", "r") as f:
                sys_prompt_info = f.read()
            infoprompt = sys_prompt_info.replace("{{originalNL}}", text)
            if "examples" in args.info_prompt and iteration == 0:
                infoprompt = infoprompt.replace("{{test cases}}", str(test_cases))
            tries = 0
            responses = []
            while tries < 5:
                tries += 1
                resp = get_llm_response(
                    llm_used,
                    infoprompt,
                    temperature=0.5,
                    n=25,
                    all_resp=1,
                    stop=["---"],
                )
                if len(resp) > 20:
                    responses = resp
                    break
                if len(resp) > len(responses):
                    responses = resp  # updating for max
            for indx, resp in enumerate(responses):
                temp = resp
                for s in st:
                    temp = temp.split(s)[0].strip()
                responses[indx] = temp
            resp, counts = soft_matching(responses)
            info_count = f"{np.mean(counts)}_{len(responses)}"
        else:
            info_prompt = args.info_prompt.split("agg_")[1]
            if "examples" in args.info_prompt and iteration:
                info_prompt = info_prompt.replace("_examples", "")
            with open(prompt_folder + rf"\info_{info_prompt}.md", "r") as f:
                sys_prompt_info = f.read()
            infoprompt = sys_prompt_info.replace("{{originalNL}}", text)
            if "examples" in args.info_prompt and iteration == 0:
                infoprompt = infoprompt.replace("{{test cases}}", str(test_cases))
            responses = get_llm_response(
                llm_used, infoprompt, temperature=0.5, n=25, all_resp=1
            )
            for resp in responses:
                temp = resp
                for s in st:
                    temp = temp.split(s)[0].strip()
                responses[responses.index(resp)] = temp
            # return the most frequent resp
            resp = max(set(responses), key=responses.count)
            info_count = f"{responses.count(resp)}_{len(responses)}"
    info = {}
    try:
        info = json.loads(resp)
    except:
        try:
            info = json.loads(
                resp.replace("'", "|").replace('"', "'").replace("|", '"')
            )
        except Exception as e:
            print(e, "ERROR IN PARSING THE INFO RESPONSE")
            print(resp, "**")
    if info == {}:
        return None, 0

    # if "explicit" in args.info_prompt:
    #     info = nan_filling(info)
    # print("info", info)
    return info, info_count


def get_args(raw_args=None):
    parser = ArgumentParser()
    parser.add_argument(
        "--dataset",
        type=str,
        default="mbpp-updated",
        help="input dataset name: mbpp",
    )
    parser.add_argument(
        "--sample", type=int, default=50, help="number of samples to do the check on"
    )
    parser.add_argument(
        "--temp", type=float, default=0.4, help="temp for code generation"
    )
    parser.add_argument(
        "--n_values", type=int, default=25, help="number of code generation"
    )
    parser.add_argument(
        "--threshold",
        type=float,
        default=0.0,
        help="threshold above which code is considered complete",
    )
    parser.add_argument(
        "--max_iter", type=int, default=10, help="iteration of experiment"
    )
    parser.add_argument(
        "--prompt", type=str, default="basic", help="code generation prompt"
    )
    parser.add_argument(
        "--info_prompt",
        type=str,
        default="agg_comp_imp_fewshot_updated",
        help="information extraction prompt",
    )
    parser.add_argument("--llm", type=str, default="gpt4-turbo", help="llm model")
    return parser.parse_args(raw_args)


def get_substrate_response(message_hist, model) -> str:
    get_request_data = {
        "messages": message_hist,
        "max_tokens": 1000,
        "temperature": 0.0,
        "top_p": 1,
        "n": 1,
        "stream": False,
        "stop": ["#END", "# END", "---"],
    }
    response = SUBSTRATE_LLM_CLIENT.send_request(
        model, get_request_data, endpoint="chat"
    )
    return response


def get_chat_response(
    prompt, model="dev-gpt-4-turbo-chat-completions", budget_sleep_calls=0
) -> str:
    response = None
    openai_models = [
        "dev-gpt-4-turbo-chat-completions",
        "dev-chat-completion-gpt-35-turbo-16k",
    ]
    if model not in openai_models:
        ###########################################################
        #  ADD YOUR MODEL HERE SO THAT IT RETURNS THE RESPONSE
        ###########################################################
        return response

    for calls_left in reversed(range(0, budget_sleep_calls + 1)):
        try:
            response = get_substrate_response(prompt, model)
            if response:
                break
        except SubstrateRateLimitError as err:
            print(err)
            seconds = 120
            print(f"Sleeping {seconds}s due to rate limit (tries left={calls_left})")
            time.sleep(seconds)

    try:
        text_return = response["choices"][0]["message"]["content"]
    except:
        text_return = None
    return text_return


def create_folder_if_not_exists(folder_path):
    if not os.path.exists(folder_path):
        os.makedirs(folder_path)
        print("Folder created successfully.")
    else:
        print("Folder already exists.")


def run(code, assertions):
    resp = renamings(code, assertions)
    possible_codes = list(resp)
    if possible_codes == []:
        return [{"execution": False, "solved_test_list": ["SyntaxError which parsing code"]}, code]
    for code in possible_codes:
        run_resp = executing_code(code, assertions)
        if run_resp["execution"]:
            return [run_resp, code]
    return [{"execution": False, "solved_test_list": []}, possible_codes[0]]


def mask_dict_helper(d, max_count, mask_count):
    if mask_count == max_count:  # max 1-max_count can be masked at a time
        return [d]
    result = []
    for k, v in d.items():
        if k in ["task"]:
            continue
        elif k in ["edge_cases"]:
            for sub_k, _ in v.items():
                new_d = d.copy()
                new_d[k] = v.copy()
                new_d[k][sub_k] = "MASK"
                result.extend(mask_dict_helper(new_d, max_count, mask_count + 1))
        else:
            for i in range(len(v)):
                new_d = d.copy()
                new_d[k] = v[:i] + ["MASK"] + v[i + 1 :]
                result.extend(mask_dict_helper(new_d, max_count, mask_count + 1))

    return result


def remove_NA(info):
    for k, v in info.items():
        if isinstance(v, list):
            if len(v) > 0 and v[0] == "N/A":
                info[k] = []
        elif isinstance(v, dict):
            for sub_k, sub_v in v.items():
                if sub_v == "N/A" and sub_k == "N/A":
                    info[k] = {}
                    break
        elif v == "N/A":
            info[k] = ""
    return info


def masking_info(input_dict: dict, max_count=1):
    output = mask_dict_helper(input_dict, max_count, 0)
    results = []
    # print("output", output)
    for o in output:
        if str(o).count("MASK") == max_count:
            if o not in results:
                results.append(o)
    return results


def masking_key(masked_info: dict):
    masked_key = []
    for k, v in masked_info.items():
        if k in ["task"]:
            continue
        if isinstance(v, list):
            for i in range(len(v)):
                if v[i] == "MASK":
                    masked_key.append(k)
        elif isinstance(v, dict):
            for _, sub_v in v.items():
                if sub_v == "MASK":
                    masked_key.append(k)

    return masked_key


def remove_mask(info):
    # remove parameters that are MASK in json info
    for k, v in info.items():
        if isinstance(v, list):
            v = [v_ for v_ in v if "MASK" != v_]
            info[k] = v
        elif isinstance(v, dict):
            j={}
            for k_,v_ in v.items():
                if "MASK" not in v_:
                    j[k_] = v_
            info[k] = j

    temp = {}
    for key in list(info.keys()):
        temp[key] = info[key]
    return temp


# def nan_filling(info):
#     for key, values in info.items():
#         if isinstance(values, str):  # task
#             if values == "":
#                 info[key] = "N/A"
#         elif isinstance(values, dict):  # edge_cases
#             if values == {}:
#                 info[key] = {"N/A": "N/A"}
#         else:  # list
#             if values == []:
#                 info[key] = ["N/A"]
#     return info


# def extract_funcsign(code):
#     if "def" not in code:
#         return code, ""
#     code = code.split("def")[1]
#     code = code.split(":")[0]
#     func_name = code.split("(")[0].strip()
#     return code.strip(), func_name
