import ast
import itertools


class Renamer(ast.NodeTransformer):
    """Rename function name variables."""

    def __init__(self, mapping: dict[str, str]) -> None:
        self.mapping = mapping

    def visit_Name(self, node: ast.Name):
        if node.id in self.mapping:
            node = ast.Name(id=self.mapping[node.id],
                            ctx=node.ctx)
        return node

    def visit_FunctionDef(self, node: ast.FunctionDef):
        node.name = self.mapping.get(node.name, node.name)
        return self.generic_visit(node)

def renamings(code: str, assertions):
    GT_function_name = function_name(assertions)
    try:
        node = ast.parse(code)
    except SyntaxError as e:
        print("SyntaxError in parsing code")
        return None
    # get all functions
    functions = [ast.unparse(n) for n in node.body if isinstance(n, ast.FunctionDef)]

    if functions == []: # no functions in code
        yield code
    else:
        # check if target function already exists, continue
        if GT_function_name in [ast.parse(i).body[0].name for i in functions]:
            yield code
        else:
            imports = [ast.unparse(n) for n in node.body if isinstance(n, ast.Import) or isinstance(n, ast.ImportFrom)]
            declarations = [ast.unparse(n) for n in node.body if isinstance(n, ast.Assign)]
            imports += declarations
            for permutation in itertools.permutations(functions):
                # print("perm: ",permutation)
                # print("perm-0: ",permutation[0])
                # function = ast.parse(permutation[0]).body[0].name
                # print("func:",function)
                function = ast.parse(permutation[0]).body[0].name
                renamer = Renamer({function: GT_function_name})
                functions = [
                    ast.unparse(renamer.visit(ast.parse(f))) for f in permutation
                ]
                script = "\n\n".join(imports + functions)
                yield script

def function_name(assertions):
    try: # original-single test cases
        GT_function_name = assertions[0].split('(')[0].split(' ')[1]
    except:
        try: # multiple sets of test cases, map of values
            GT_function_name = assertions["original"][0].split('(')[0].split(' ')[1]
        except Exception as e:
            print("Error in extracting function name from assertions:", e)
            GT_function_name = "function_name"
    return GT_function_name

def main():
    # code = """
    # def check_equality(s):
    #     return (ord(s[0]) == ord(s[len(s) - 1]))

    # def some_name(s):
    #     result = 0
    #     n = len(s)
    #     for i in range(n):
    #         for j in range(1,n-i+1):
    #             if (check_equality(s[i:i+j])):
    #                 result+=1
    #     return result
    # """
    
    code = "import heapq\nR=3\ndef check_equality(s):\n    D=0\n    return (ord(s[0]] == ord(s[len(s) - 1]))"

    assertions = {
        "original": ["assert some_name(3) == 2", "assert some_name(7) == 1"], 
        "op_type": ["assert some_name(3) == [2]", "assert some_name(7) == [1]"]
    }
    resp = renamings(code, assertions)
    if resp is None:
        return
    for script in resp:
        print("```")
        print(script)
        print("```")
    return


if __name__ == "__main__":
    main()