from argparse import ArgumentParser

from qa.table_bert.hm_table import *
from preprocess.utils import *


def write_table2url(table_dict, record_dict, table_filtered_path=None, filtered=True):
    if filtered:
        valid_tables = os.listdir(table_filtered_path)
        valid_table_names = [t.split('.')[0] for t in valid_tables]

    wb = Workbook()
    ws = wb.active
    font = Font(bold=True)
    alignment = Alignment(horizontal='left')
    ws.cell(1, 1).value = 'Local .xlsx Path'
    ws.cell(1, 1).font = font
    ws.cell(1, 2).value = 'Table Id'
    ws.cell(1, 2).font = font
    ws.cell(1, 3).value = 'Table Id in Page'
    ws.cell(1, 3).font = font
    ws.cell(1, 4).value = 'URL'
    ws.cell(1, 4).font = font
    ws.cell(1, 5).value = 'Domain'
    ws.cell(1, 5).font = font

    row = 2
    for i in sorted(table_dict.keys()):
        table = table_dict[i]
        if filtered and str(table.table_id) not in valid_table_names:
            ws.cell(row, 1).value = '{}.xlsx'.format(table.table_id) + '(--)'
        else:
            ws.cell(row, 1).value = '=HYPERLINK("{}", "{}")'.format('spreadsheet/{}.xlsx'.format(table.table_id),
                                                                    '{}.xlsx'.format(table.table_id))
        ws.cell(row, 1).alignment = alignment
        ws.cell(row, 2).value = table.table_id
        ws.cell(row, 2).alignment = alignment
        ws.cell(row, 3).value = "Table " + str(table.table_id_in_page)
        ws.cell(row, 4).value = table.url
        ws.cell(row, 4).hyperlink = table.url
        ws.cell(row, 5).value = record_dict[table.table_id][0]['subject']
        row += 1
    auto_fit_column_width(ws)
    wb.save(os.path.join(args.root_dir, args.data_dir, 'table_url_pairs.xlsx'))


def find_max_line_idx(root: TreeNode):
    """ Wrapper for finding max line idx. """
    max_line_idx = [0]
    _find_max_line_idx(root, max_line_idx)
    return max_line_idx[0]


def _find_max_line_idx(root: TreeNode, max_line_idx):
    """ Find max line idx."""
    if root.line_idx is not None:
        max_line_idx[0] = max(max_line_idx[0], root.line_idx)
    for child in root.children:
        _find_max_line_idx(child, max_line_idx)


def valid_headers(root: TreeNode):
    """ Check if left_headers/top_headers are valid, considering length and encoding."""
    if not valid_header(root.name):
        return False
    for child in root.children:
        if not valid_headers(child):
            return False
    return True


def valid_header(header):
    """ Check if a single header is valid."""
    if not isinstance(header, str):
        header = str(header)
    if not valid_ascii_number(header, 2):
        return False
    num_words = len(header.split())
    if num_words > 20:
        return False
    return True


def valid_ascii_number(s, max_ascii):
    cnt = 0
    for c in s:
        if ord(c) >= 128:
            cnt += 1
            if cnt > max_ascii:
                return False
    return True


def main():
    src_dir = os.path.join(args.root_dir, args.data_dir, args.table_dir)
    dst_dir = os.path.join(args.root_dir, args.data_dir, args.table_filtered_dir)

    # filter valid tables
    table_files = os.listdir(src_dir)
    cnt_init_table = 0
    valid_table_names = []
    for table_file in table_files:
        name = table_file.split('.')[0]
        if '_' not in name:  # TODO
            continue
        print(f"processing {name}")
        with open(os.path.join(src_dir, table_file)) as f:
            table_info = json.load(f)
            table_info_copy = copy.deepcopy(table_info)
        try:
            # filter 1) left_header/top_header level > 3;
            #        2) too complex to extract hierarchy
            hmt = HMTable.from_dict(table_info_copy)
            # filter 3) left_header/top_header is empty
            if len(hmt.left_root.children) == 0 or len(hmt.top_root.children) == 0:
                raise ValueError("Left/Top header is empty.")
            # filter 4) hierarchy line_idx not match data region size
            if find_max_line_idx(hmt.top_root) != len(hmt.data_region[0]) - 1 \
                    or find_max_line_idx(hmt.left_root) != len(hmt.data_region) - 1:
                raise ValueError("Hierarchy line idx does not match data region.")
            # filter 5) rows/cols > 64 or < 2
            if len(hmt.data_region) < 2 or len(hmt.data_region) > 64 \
                    or len(hmt.data_region) < 2 or len(hmt.data_region) > 64:
                raise ValueError("Rows/Cols < 2 or > 64.")
            # filter 6) header tokens > 20 or non-ASCII header
            if not (valid_headers(hmt.left_root) and valid_headers(hmt.top_root)):
                raise ValueError("Header invalid.")

            cnt_init_table += 1
            valid_table_names.append(name)
            with open(os.path.join(dst_dir, table_file), 'w') as f:
                f.write(json.dumps(table_info, indent=2))
        except Exception as e:
            print(e)
    print(f"Successfully init table: {cnt_init_table}")
    # # write table_url_pairs.xlsx
    # table_dict_raw, record_dict = \
    #     read_annotated(os.path.join(args.root_dir, args.data_dir, args.anno_file_name), args.num_descriptions)
    # write_table2url(table_dict_raw, record_dict, dst_dir, filtered=True)


if __name__ == '__main__':
    parser = ArgumentParser()    
    parser.add_argument('--root_dir', type=str, default='/data/home/hdd3000/USER/HMT/')
    parser.add_argument('--data_dir', type=str, default='qa/data/')
    parser.add_argument('--table_dir', type=str, default='raw_input/table')
    parser.add_argument('--table_filtered_dir', type=str, default='raw_input/table_filtered')
    # parser.add_argument('--anno_file_name', type=str, required=True)
    # parser.add_argument('--num_descriptions', type=int, required=True)
    args = parser.parse_args()

    main()
