import ast
import json
from typing import Any

class ListToTuple(ast.NodeTransformer):
    def __init__(self, op_resp: bool) -> None:
        self.op_resp = op_resp

    def visit_Assert(self, node: ast.Assert) -> Any:
        if self.op_resp:
            value = node.test.comparators
        else:
            value = node.test.left.args
        if len(value)==1:
            value = value[0]
            if isinstance(value, ast.List):
                value = ast.Tuple(elts=value.elts, ctx=value.ctx)
            if self.op_resp:
                node.test.comparators[0] = value
            else:
                node.test.left.args[0] = value
        return node

class TupleToList(ast.NodeTransformer):
    def __init__(self, op_resp: bool) -> None:
        self.op_resp = op_resp

    def visit_Assert(self, node: ast.Assert) -> Any:
        if self.op_resp:
            value = node.test.comparators
        else:
            value = node.test.left.args
        if len(value)==1:
            value = value[0]
            if isinstance(value, ast.Tuple):
                value = ast.List(elts=value.elts, ctx=value.ctx)
            if self.op_resp:
                node.test.comparators[0] = value
            else:
                node.test.left.args[0] = value
        return node

class StrToNum(ast.NodeTransformer):
    def __init__(self, op_resp: bool) -> None:
        self.op_resp = op_resp

    def visit_Assert(self, node: ast.Assert) -> Any:
        if self.op_resp:
            values = node.test.comparators
        else:
            values = node.test.left.args
        if len(values) == 1:
            value = values[0]
            if isinstance(value, ast.Constant) and isinstance(value.value, str):
                try:
                    value = ast.Constant(value=int(value.value))
                except:
                    pass
            if self.op_resp:
                node.test.comparators[0] = value
            else:
                node.test.left.args[0] = value
        else: # multiple arguments but all should be constant
            if not isinstance(values[0], ast.Constant):
                return node
            type_val = type(values[0].value)
            for value in values:
                if not isinstance(value, ast.Constant) or type(value.value) != type_val:
                    return node
            for indx,value in enumerate(values):
                if isinstance(value, ast.Constant) and isinstance(value.value, str):
                    try:
                        value = ast.Constant(value=int(value.value))
                    except:
                        pass
                if self.op_resp:
                    node.test.comparators[indx] = value
                else:
                    node.test.left.args[indx] = value
        return node

class NumToStr(ast.NodeTransformer):
    def __init__(self, op_resp: bool) -> None:
        self.op_resp = op_resp

    def visit_Assert(self, node: ast.Assert) -> Any:
        if self.op_resp:
            values = node.test.comparators
        else:
            values = node.test.left.args
        if len(values) ==1:
            value = values[0]
            if isinstance(value, ast.Constant) and isinstance(value.value, int) and not isinstance(value.value, bool):
                try:
                    value = ast.Constant(value=str(value.value))
                except:
                    pass
            if self.op_resp:
                node.test.comparators[0] = value
            else:
                node.test.left.args[0] = value
        else: # multiple arguments but all should be constant
            if not isinstance(values[0], ast.Constant):
                return node
            type_val = type(values[0].value)
            for value in values:
                if not isinstance(value, ast.Constant) or type(value.value) != type_val:
                    return node
            for indx,value in enumerate(values):
                if isinstance(value, ast.Constant) and isinstance(value.value, int):
                    try:
                        value = ast.Constant(value=str(value.value))
                    except:
                        pass
                if self.op_resp:
                    node.test.comparators[indx] = value
                else:
                    node.test.left.args[indx] = value
        return node

class StrToList(ast.NodeTransformer):
    def __init__(self, op_resp: bool) -> None:
        self.op_resp = op_resp

    def visit_Assert(self, node: ast.Assert) -> Any:
        if self.op_resp:
            values = node.test.comparators
        else:
            values = node.test.left.args
        if len(values)==1 and isinstance(values[0], ast.Constant) and isinstance(values[0].value, str):
            value = values[0]
            try: 
                if len(value.value) and value.value[0] == "[":
                    value = ast.List(elts=[ast.Constant(value=i) for i in json.loads(value.value)])
            except:
                print("ERROR STR2LIST:",value.value, "-", type(value.value))
                br
            if self.op_resp:
                node.test.comparators[0] = value
            else:
                node.test.left.args[0] = value
        return node

class SingleStrToList(ast.NodeTransformer):
    def __init__(self, op_resp: bool) -> None:
        self.op_resp = op_resp

    def visit_Assert(self, node: ast.Assert) -> Any:
        if self.op_resp:
            values = node.test.comparators
        else:
            values = node.test.left.args
        if len(values)==1 and isinstance(values[0], ast.Constant) and not isinstance(values[0].value, bool) and (isinstance(values[0].value, str) or isinstance(values[0].value, int)):
            value = values[0]
            try:
                if isinstance(value.value, int) or len(value.value) and value.value[0] != "[":
                    value = ast.List(elts=[ast.Constant(value=value.value)])
            except:
                print("ERROR singleSTR2LIST:",value.value, "-", type(value.value))
                br
            if self.op_resp:
                node.test.comparators[0] = value
            else:
                node.test.left.args[0] = value
        return node


def all_type(assertions):
    seen =[a.replace(' ', '') for a in assertions]
    seen.append([])
    assertions_updated={"original": assertions}
    prev_assertions_updated = assertions_updated.copy()

    for formatting, resp_type in zip([True, False], ["op_resp", "ip_resp"]):
        for option in [[ListToTuple, TupleToList], [StrToNum, NumToStr], [StrToList], [SingleStrToList]]:
            for translate in option:
                for key,combination in prev_assertions_updated.items():
                    resp=[]
                    for assertion in combination:
                        temp = ast.unparse(translate(formatting).visit(ast.parse(assertion)))
                        if temp.replace(" ", "") not in seen:
                            seen.append(temp)
                            resp.append(temp)
                    if len(resp)==len(assertions) and resp not in list(assertions_updated.values()):    
                        assertions_updated[f"{resp_type}-{translate.__name__}--{key}"] = resp
                prev_assertions_updated = assertions_updated.copy()


    return assertions_updated

def main():
    # assertions = ["assert some_name([2, 3,0]) == [2, 3]", "assert some_name([2, 3,9,8]) == [1,7]"]
    # assertions = ["assert some_name(3, [2]) == (2,3)", "assert some_name(1,[3]) == (1,2,3)"]
    assertions = ["assert some_name('7') == True", "assert some_name('0') == False"]
    op = all_type(assertions)
    print(op)
    return


if __name__ == "__main__":
    main()