import ast
from resp_type import all_type

class RemoveArgs(ast.NodeTransformer):
    def visit_Assert(self, node: ast.Assert):
        values = node.test.left.args
        # remove number argument from input arguments of list and number
        node.test.left.args = [values[0]]
        return node

def all_remove(assertions_updated):
    seen =[a_.replace(' ', '') for a in list(assertions_updated.values())  for a_ in a]
    seen.append([])
    assertions = list(assertions_updated.values())[0]

    ip = ast.parse(assertions[0]).body[0].test.left.args
    prev_assertions_updated = assertions_updated.copy()
    if len(ip)==2 and isinstance(ip[0], ast.List) and isinstance(ip[1], ast.Constant) and isinstance(ip[1].value, int) and len(ip[0].elts)==ip[1].value:
        flag=0
        for assertion in assertions:
            ip = ast.parse(assertion).body[0].test.left.args
            if len(ip)==2 and isinstance(ip[0], ast.List) and isinstance(ip[1], ast.Constant) and isinstance(ip[1].value, int) and len(ip[0].elts)==ip[1].value: continue
            else: flag=1
        if flag==1: return assertions_updated

        for key, combination in prev_assertions_updated.items():
            resp=[]
            for assertion in combination:
                temp = ast.unparse(RemoveArgs().visit(ast.parse(assertion)))
                if temp.replace(" ", "") not in seen:
                    seen.append(temp)
                    resp.append(temp)
            if len(resp)==len(assertions_updated["original"]) and resp not in list(assertions_updated.values()):    
                assertions_updated[f"ip_resp-RemoveArgs--{str(key).replace(" ", "")}"] = resp
        prev_assertions_updated = assertions_updated.copy()

    return assertions_updated

def main():
    assertions = ["assert find_missing([1,2,3,5],4) == 4",
                "assert find_missing([1,3,4,5],4) == 2",
                "assert find_missing([1,2,3,5,6,7],6) == 4"]
    op = all_type(assertions)
    op = all_remove(op)
    print(op)

    # assertions = ["assert some_name([2, 3, 4, 9, 10, 11, 12], 8) == True", "assert some_name([3, 7, 6, 7, 8], 5) == False"]
    # op = all_type(assertions)
    # op = all_remove(op)
    # print(op)

    # assertions = ["assert some_name([2, 3, 4, 9, 10, 11, 12], 7) == True", "assert some_name([3, 7, 6, 7, 8], 1) == False"]
    # op = all_type(assertions)
    # op = all_remove(op)
    # print(op)
    return


if __name__ == "__main__":
    main()