import ast
import itertools
from resp_type import all_type

class PermuteArgs(ast.NodeTransformer):
    def __init__(self, op_resp: bool, permutation: list[int], len_resp: int) -> None:
        self.op_resp = op_resp
        self.permutation = permutation
        self.len_resp = len_resp

    def visit_Assert(self, node: ast.Assert):
        if self.len_resp >= 5:
            return node
        elif self.len_resp != 1:
            if self.op_resp:
                call = node.test
                if len(call.comparators) != len(self.permutation):
                    return node
                call.comparators = [call.comparators[i] for i in self.permutation]
            else:
                call = node.test.left
                if len(call.args) != len(self.permutation):
                    return node
                call.args = [call.args[i] for i in self.permutation]
        else:
            if self.op_resp:
                call = node.test.comparators[0]
                if type(call) not in [ast.Tuple, ast.List] or len(call.elts) != len(self.permutation):
                    return node
                call.elts = [call.elts[i] for i in self.permutation]
            else:
                call = node.test.left.args[0]
                if type(call) not in [ast.Tuple, ast.List] or len(call.elts) != len(self.permutation):
                    return node
                call.elts = [call.elts[i] for i in self.permutation]
        return node

def all_permute(assertions_updated):
    seen =[a_.replace(' ', '') for a in list(assertions_updated.values())  for a_ in a]
    seen.append([])
    assertions = list(assertions_updated.values())[0]

    for formatting, resp_type in zip([True, False], ["op_resp", "ip_resp"]):
        if formatting:
            ip = ast.parse(assertions[0]).body[0].test.comparators
        else:
            ip = ast.parse(assertions[0]).body[0].test.left.args
        prev_assertions_updated = assertions_updated.copy()
        if len(ip)==1: # single response of tuple or list
            if type(ip[0]) not in [ast.Tuple, ast.List]: continue
            
            # check if response length is consistent
            flag=0
            for assertion in assertions:
                if formatting:
                    ip2 = ast.parse(assertion).body[0].test.comparators
                else:
                    ip2 = ast.parse(assertion).body[0].test.left.args
                if type(ip2[0]) not in [ast.Tuple, ast.List]: flag=0; break
                if len(ip2[0].elts) != len(ip[0].elts): flag=1
            if flag or len(ip[0].elts)>=5: continue

            ipargs = list(range(len(ip[0].elts)))
            for permute in list(itertools.permutations(ipargs, len(ipargs))):
                for key, combination in prev_assertions_updated.items():
                    resp=[]
                    for assertion in combination:
                        temp = ast.unparse(PermuteArgs(formatting, list(permute), len(ip)).visit(ast.parse(assertion)))
                        if temp.replace(" ", "") not in seen:
                            seen.append(temp)
                            resp.append(temp)
                    if len(resp)==len(assertions_updated["original"]) and resp not in list(assertions_updated.values()):    
                        assertions_updated[f"{resp_type}-PermuteArgs{permute}--{str(key).replace(" ", "")}"] = resp
                prev_assertions_updated = assertions_updated.copy()

        else: # multiple arguments
            ipargs = list(range(len(ip)))
            for permute in list(itertools.permutations(ipargs, len(ipargs))):
                for key, combination in prev_assertions_updated.items():
                    resp=[]
                    for assertion in combination:
                        temp = ast.unparse(PermuteArgs(formatting, list(permute), len(ip)).visit(ast.parse(assertion)))
                        if temp.replace(" ", "") not in seen:
                            seen.append(temp)
                            resp.append(temp)
                    if len(resp)==len(assertions_updated["original"]) and resp not in list(assertions_updated.values()):    
                        assertions_updated[f"{resp_type}-PermuteArgs{permute}--{str(key).replace(" ", "")}"] = resp
                prev_assertions_updated = assertions_updated.copy()

    return assertions_updated

def main():
    assertions = ["assert some_name([2, 3, 4, 9, 10, 11, 12]) == (2,[3])", "assert some_name([3, 7, 6, 7, 8,9,7]) == (1,[2,3])"]
    op = all_type(assertions)
    op = all_permute(op)
    print(op)
    return


if __name__ == "__main__":
    main()