""" Example logic form test."""

from qa.table_bert.hm_table import HMTable
from qa.nsm.execution.executor import HMTExecutor, AGGREGATION2ID


def main():
    from qa.examples.example_data import KG
    
    # test filter_header_str_contain
    print("----------------------\ntest filter_header_str_contain")
    hmt = HMTable.from_dict(KG)
    hmt_executor = HMTExecutor(hmt.build_kg_info())
    hmt_executor.init_op_region(hmt.get_init_op_region())
    print(hmt_executor.op_region)
    hmt_executor.filter_header_str_contain('area_a', '<LEFT>')
    print(hmt_executor.op_region)
    hmt_executor.filter_header_str_contain('male', '<TOP>')
    print(hmt_executor.op_region)
    hmt_executor.filter_header_str_contain('married', '<TOP>')
    print(hmt_executor.op_region)

    # test filter_header_str_not_contain
    print("----------------------\ntest filter_header_str_not_contain")
    hmt = HMTable.from_dict(KG)
    hmt_executor = HMTExecutor(hmt.build_kg_info())
    hmt_executor.init_op_region(hmt.get_init_op_region())
    print(hmt_executor.op_region)
    hmt_executor.filter_header_str_not_contain('area_b', '<LEFT>')
    print(hmt_executor.op_region)
    hmt_executor.filter_header_str_contain('west_city_a', '<LEFT>')
    print(hmt_executor.op_region)
    hmt_executor.filter_header_str_not_contain('married', '<TOP>')
    print(hmt_executor.op_region)

    # test argmax, e.g. question="which city has the least married male persons?"
    print("----------------------\ntest argmax")
    hmt = HMTable.from_dict(KG)
    hmt_executor = HMTExecutor(hmt.build_kg_info())
    hmt_executor.init_op_region(hmt.get_init_op_region())
    print(hmt_executor.op_region)
    hmt_executor.filter_header_str_contain('married', '<TOP>')
    print(hmt_executor.op_region)
    hmt_executor.filter_header_str_contain('male', '<TOP>')
    print(hmt_executor.op_region)
    result = hmt_executor.argmin(2)
    print(hmt_executor.op_region)
    print(result)

    # test aggregation, e.g. question="list number of single persons of each city"
    print("----------------------\ntest aggregation")
    hmt = HMTable.from_dict(KG)
    hmt_executor = HMTExecutor(hmt.build_kg_info())
    hmt_executor.init_op_region(hmt.get_init_op_region())
    print(hmt_executor.op_region)
    hmt_executor.filter_header_str_contain('single', '<TOP>')
    print(hmt_executor.op_region)
    hmt_executor.sum('<TOP>')
    print(hmt_executor.op_region)
    for data_node in hmt.top_aggr[AGGREGATION2ID['sum']]:
        if data_node is not None:
            print(f"{data_node.value}{data_node.left_coord}", end=", ")
        else:
            print(None, end=", ")
    print()

    # test aggregation+argmax(common), e.g. question="which city has the most single persons?"
    print("----------------------\ntest aggregation and argmax")
    hmt = HMTable.from_dict(KG)
    hmt_executor = HMTExecutor(hmt.build_kg_info())
    hmt_executor.init_op_region(hmt.get_init_op_region())
    print(hmt_executor.op_region)
    hmt_executor.filter_header_str_contain('single', '<TOP>')
    print(hmt_executor.op_region)
    hmt_executor.sum('<TOP>')
    print(hmt_executor.op_region)
    for data_node in hmt.top_aggr[AGGREGATION2ID['sum']]:
        if data_node is not None:
            print(f"{data_node.value}{data_node.left_coord}", end=", ")
        else:
            print(None, end=", ")
    print()
    result = hmt_executor.argmax(2)
    print(hmt_executor.op_region)
    print("result: ", result)

    # test hop
    print("----------------------\ntest hop")
    hmt = HMTable.from_dict(KG)
    hmt_executor = HMTExecutor(hmt.build_kg_info())
    hmt_executor.init_op_region(hmt.get_init_op_region())
    print(hmt_executor.op_region)
    hmt_executor.filter_header_str_contain('area_a', '<LEFT>')
    print(hmt_executor.op_region)
    hmt_executor.sum('<TOP>')
    print(hmt_executor.op_region)
    for data_node in hmt.top_aggr[AGGREGATION2ID['sum']]:
        if data_node is not None:
            print(f"{data_node.value}{data_node.left_coord}", end=", ")
        else:
            print(None, end=", ")
    print()
    result = hmt_executor.hop('<LEFT>')
    print(hmt_executor.op_region)
    print(result)

    # test union(subtotal), e.g. question="difference between single persons of west_city_a and mid_city_b"
    print("----------------------\ntest union")
    hmt = HMTable.from_dict(KG)
    hmt_executor = HMTExecutor(hmt.build_kg_info())
    hmt_executor.init_op_region(hmt.get_init_op_region())
    print(hmt_executor.op_region)
    v0 = hmt_executor.filter_header_str_contain('west_city_a', '<LEFT>')
    print(hmt_executor.op_region)
    print(f"v0: {v0}")
    v1 = hmt_executor.filter_header_str_contain('mid_city_b', '<LEFT>')
    print(hmt_executor.op_region)
    print(f"v1: {v1}")
    hmt_executor.union(v0, v1)
    print(hmt_executor.op_region)
    hmt_executor.filter_header_str_contain('single', '<TOP>')
    print(hmt_executor.op_region)
    hmt_executor.sum('<TOP>')
    print(hmt_executor.op_region)
    for data_node in hmt.top_aggr[AGGREGATION2ID['sum']]:
        if data_node is not None:
            print(f"{data_node.value}{data_node.left_coord}", end=", ")
        else:
            print(None, end=", ")
    print()
    result = hmt_executor.difference()
    print(hmt_executor.op_region)
    print("result: ", result)


if __name__ == '__main__':
    main()
