import os
import json
import re

def format_fn(path):
    return path.replace("{","_").replace("}","_").replace("/","_").replace("-","_")

def get_type_schema(working_dir):
    dirs = [f for f in os.listdir(working_dir) if os.path.isdir(os.path.join(working_dir, f))]
    type_schema = ""
    for dir in dirs:
        fname = os.path.join(working_dir, dir, "index.graphql")
        with open(fname) as f:
            sub_schema = f.read()
            # part1, part2 = sub_schema.split("type Query", 1)
            # type_schema += part1
            type_schema += sub_schema
    if "results: [JSON]" in type_schema:
        print(working_dir)
    return type_schema

def get_function_spec_spotify(cmd, api):
    with open("openapi_specs/spotify_oas.json") as f:
        spotify_spec = json.load(f)
    query_name = cmd+api
    fn_name = format_fn(query_name)
    
    if cmd.lower() in spotify_spec['paths'][api]:
        if 'description' in spotify_spec['paths'][api][cmd.lower()]:
            fn_description = spotify_spec['paths'][api][cmd.lower()]['description']
        elif 'summary' in spotify_spec['paths'][api][cmd.lower()]:
            fn_description = spotify_spec['paths'][api][cmd.lower()]['summary']
         
        if "parameters" not in spotify_spec['paths'][api][cmd.lower()]:
            params = []
        else:
            params = spotify_spec['paths'][api][cmd.lower()]["parameters"]

        required_args = []
        param_dict = {"type": "dict", "properties": {}}
        for p in params:
            
            if "$ref" in p:
                schema_component = p["$ref"].split("/")[-1]
                param_schema = spotify_spec["components"]["parameters"][schema_component]
                name = param_schema["name"]
                required = param_schema["required"]
                param_description =  param_schema["schema"]["description"]
                if "enum" in param_schema["schema"] and "Valid values:" not in param_description:
                    enum_values = param_schema["schema"]["enum"]
                    param_description += "\nValid values: "
                    param_description += ",".join([str(v) for v in enum_values])
                    print(cmd, api, name, enum_values)
                ptype =  param_schema["schema"]["type"]

            else:
                name = p["name"]
                if "required" not in p:
                    required = False
                else:
                    required = p["required"]
                if "description" in p["schema"]:
                    param_description =  p["schema"]["description"]
                elif "title" in p["schema"]:
                    param_description =  p["schema"]["title"]
                else:
                    param_description =  p["name"]
                if "enum" in p["schema"] and "Valid values:" not in param_description:
                    enum_values = p["schema"]["enum"]
                    param_description += "\nValid values: "
                    param_description += ",".join([str(v) for v in enum_values])
                    # print(cmd, api, name, enum_values)
                ptype =  p["schema"]["type"]

            if (type(required) == bool and required) or (type(required) == str and required.lower() == "true"):
                required_args.append(name)
            
            # print(api)
            # if api == "/search" and name in required_args:
            #     print(name)
            param_dict["properties"][name] = {"type": ptype, "description": param_description}
        
        param_dict["required"] = required_args
        
        openai_fn_spec = {"name":fn_name, "description":fn_description, "parameters": param_dict}
    else:
        print(api, cmd.lower())

    return openai_fn_spec, required_args

def get_function_spec_tmdb(cmd, api):
    with open("openapi_specs/tmdb_oas.json") as f:
        tmdb_spec = json.load(f)
    query_name = cmd+api
    fn_name = format_fn(query_name)
    
    if cmd.lower() in tmdb_spec['paths'][api]:
        if 'description' in tmdb_spec['paths'][api][cmd.lower()]:
            fn_description = tmdb_spec['paths'][api][cmd.lower()]['description']
        elif 'summary' in tmdb_spec['paths'][api][cmd.lower()]:
            fn_description = tmdb_spec['paths'][api][cmd.lower()]['summary']
         
        params = []
        if "parameters" in tmdb_spec['paths'][api]:
            params = tmdb_spec['paths'][api]["parameters"]
        
        if "parameters" in tmdb_spec['paths'][api][cmd.lower()]:
            params.extend(tmdb_spec['paths'][api][cmd.lower()]["parameters"])
        
        required_args = []
        param_dict = {"type": "dict", "properties": {}}
        for p in params:
            
            if "$ref" in p:
                schema_component = p["$ref"].split("/")[-1]
                param_schema = tmdb_spec["components"]["parameters"][schema_component]
                name = param_schema["name"]
                required = param_schema["required"]
                param_description =  param_schema["schema"]["description"]
                if "enum" in param_schema["schema"]:
                    enum_values = param_schema["schema"]["enum"]
                    param_description += "\nValid values: "
                    param_description += ",".join([str(v) for v in enum_values])
                ptype =  param_schema["schema"]["type"]

            else:
                name = p["name"]
                if "required" not in p:
                    required = False
                else:
                    required = p["required"]
                if "description" in p["schema"]:
                    param_description =  p["schema"]["description"]
                elif "title" in p["schema"]:
                    param_description =  p["schema"]["title"]
                else:
                    param_description =  p["name"].replace("_", " ")
                
                if "enum" in p["schema"]:
                    enum_values = p["schema"]["enum"]
                    param_description += "\nValid values: "
                    param_description += ",".join([str(v) for v in enum_values])

                ptype =  p["schema"]["type"]
            
            if (type(required) == bool and required) or (type(required) == str and required.lower() == "true"):
                required_args.append(name)
            param_dict["properties"][name] = {"type": ptype, "description": param_description}
        
        param_dict["required"] = required_args
        
        openai_fn_spec = {"name":fn_name, "description":fn_description, "parameters": param_dict}

    return openai_fn_spec, required_args


with open("datasets/spotify_filtered.json") as f:
    spotify_data = json.load(f)

for i in range(len(spotify_data)):
    # print(i+1)
    obj = spotify_data[i]
    utterance = obj["query"]
    # print("Utterance:", utterance)
    sol = obj["solution"]
    apis = []
    working_dir = "spotify_data/stepzen_import/test"+str(i+1)
    type_schema = get_type_schema(working_dir)
    fn_specs = []

    model_input = "GraphQL schema:\n\n" + type_schema + "\nUtterance: " + utterance
    model_output = ""
    for s in sol:
        s = s.strip()
        cmd, api = s.split(" ", 1)
        api = api.strip()
        fn_spec, required_args = get_function_spec_spotify(cmd, api)
        fn_specs.append(fn_spec)
        query_name = cmd+api
        # args = re.findall("{([^{]*?)}", query_name)
        query_name = format_fn(query_name)
        model_output += query_name + "(" +  ", ".join(required_args) + ")\n"
    
    # with open("spotify_data/function_specs/spec"+str(i+1)+".json", "w") as f:
    #     json.dump(fn_specs, f, indent=4)
    with open("spotify_data/inputs/input"+str(i+1)+".txt", "w") as f:
        f.write(model_input)

    with open("spotify_data/GraphQL_schemas/schema"+str(i+1)+".txt", "w") as f:
        f.write(type_schema)

    with open("spotify_data/utterances/utterance"+str(i+1)+".txt", "w") as f:
        f.write(utterance)
    # with open("spotify_data/script_outputs/output"+str(i+1)+".txt", "w") as f:
    #     f.write(model_output)

with open("datasets/tmdb_filtered.json") as f:
    tmdb_data = json.load(f)

for i in range(len(tmdb_data)):
    obj = tmdb_data[i]
    utterance = obj["query"]
    # print("Utterance:", utterance)
    sol = obj["solution"]
    apis = []
    working_dir = "tmdb_data/stepzen_import/test"+str(i+1)
    type_schema = get_type_schema(working_dir)
    fn_specs = []

    model_input = "GraphQL schema:\n\n" + type_schema + "\nUtterance: " + utterance
    model_output = ""
    total = len(sol)
    for index, s in enumerate(sol):
        s = s.strip()
        cmd, api = s.split(" ", 1)
        api = api.strip()
        fn_spec, required_args = get_function_spec_tmdb(cmd, api)
        fn_specs.append(fn_spec)
        query_name = cmd+api
        query_args = re.findall("{([^{]*?)}", query_name)
        for arg in query_args:
            if arg not in required_args:
                required_args.append(arg)
        query_name = format_fn(query_name)
        if index == 0 and index < total-1:
            model_output += "response_obj = " + query_name + "(" +  ", ".join(required_args) + ")\n"
        elif index < total-1:
            model_output += "response_obj" + str(index) + " = " + query_name + "(" +  ", ".join(required_args) + ")\n"
        else:
            model_output += query_name + "(" +  ", ".join(required_args) + ")\n"
    with open("tmdb_data/function_specs/spec"+str(i+1)+".json", "w") as f:
        json.dump(fn_specs, f, indent=4)
    with open("tmdb_data/inputs/input"+str(i+1)+".txt", "w") as f:
        f.write(model_input)
    with open("tmdb_data/script_outputs/output"+str(i+1)+".txt", "w") as f:
        f.write(model_output)
    
    with open("tmdb_data/GraphQL_schemas/schema"+str(i+1)+".txt", "w") as f:
        f.write(type_schema)

    with open("tmdb_data/utterances/utterance"+str(i+1)+".txt", "w") as f:
        f.write(utterance)
    

