import os
import json
import numpy as np
import pandas as pd
from typing import List, Dict
from argparse import ArgumentParser
import copy
from openpyxl.utils import get_column_letter, column_index_from_string

from processing.base.crawler import Table
from processing.base.structure import cellmatrix2structure, Cell
from processing.base.utils import *


MAX_ROW_THRESHOLD = 150
MAX_COL_THRESHOLD = 100

def enhance_borders_on_both_sides(ws):
    """ Ensure borders exist on both sides."""
    border_type = Side(style='thin')
    for row in range(1, MAX_ROW_THRESHOLD+1):
        for col in range(1, MAX_COL_THRESHOLD+1):
            cell = ws.cell(row, col)
            if cell.border.bottom.style is not None and row + 1 <= MAX_ROW_THRESHOLD:
                neighbour_cell = ws.cell(row+1, col)
                neighbour_cell.border = Border(left=copy.deepcopy(neighbour_cell.border.left),
                                                    right=copy.deepcopy(neighbour_cell.border.right),
                                                    top=border_type,
                                                    bottom=copy.deepcopy(neighbour_cell.border.bottom))
            if cell.border.top.style is not None and row - 1 >= 1:
                neighbour_cell = ws.cell(row-1, col)
                neighbour_cell.border = Border(left=copy.deepcopy(neighbour_cell.border.left),
                                                    right=copy.deepcopy(neighbour_cell.border.right),
                                                    top=copy.deepcopy(neighbour_cell.border.top),
                                                    bottom=border_type)
            if cell.border.left.style is not None and col - 1 >= 1:
                neighbour_cell = ws.cell(row, col-1)
                neighbour_cell.border = Border(left=copy.deepcopy(neighbour_cell.border.left),
                                                    right=border_type,
                                                    top=copy.deepcopy(neighbour_cell.border.top),
                                                    bottom=copy.deepcopy(neighbour_cell.border.bottom))
            if cell.border.right.style is not None:
                neighbour_cell = ws.cell(row, col+1)
                neighbour_cell.border = Border(left=border_type,
                                                    right=copy.deepcopy(neighbour_cell.border.right),
                                                    top=copy.deepcopy(neighbour_cell.border.top),
                                                    bottom=copy.deepcopy(neighbour_cell.border.bottom))


def find_title_and_footer(ws):
    """ Find title and footer, and extract them."""
    # find title
    row = 1
    title_list = []
    while True:
        if ws.cell(row, 1).border.top.style is not None:
            break
        title_list.append(ws.cell(row, 1).value)
        row += 1
    title_end = row - 1
    title = ' '.join(title_list)

    # find footer
    while True:
        if ws.cell(row, 1).value == 'table descriptive sentence id:' or row > MAX_ROW_THRESHOLD:
            break
        row += 1
    desc_start = row
    if row > MAX_ROW_THRESHOLD:  # no annotation
        return None
    row -= 1
    while True:
        if ws.cell(row, 1).border.bottom.style is not None:
            break
        row -= 1
    footer_start = row + 1
    return (title, title_end), ('', footer_start), desc_start


def find_table_max_columns(ws, table_start_row, table_end_row):
    """ Find max columns of tables."""
    # col = 1
    # while True:
    #     ic(col)
    #     ic(ws.cell(table_end_row, col).border.bottom)
    #     if ws.cell(table_end_row, col).border.bottom.style is None:
    #         break
    #     col += 1
    end_col_list = []
    for row in range(max(table_start_row, table_end_row - 2), table_end_row+1):
        col = MAX_COL_THRESHOLD
        while col >= 1:
            if ws.cell(row, col).value is not None and str(ws.cell(row,col).value).strip() != '':
                end_col_list.append(col)
                break
            col -= 1
    counts = np.bincount(end_col_list)
    return np.argmax(counts)


def is_merged_cell(ws, ws_cell):
    """ Check if cell is a merged cell."""
    for merged_cell in ws.merged_cell_ranges:
        if ws_cell.coordinate in merged_cell:
            return True
    return False


def line_index_after_removed_empty_lines(line_idx, empty_lines):
    """ After removing empty lines, row/column index in table should be changed."""
    # assert line_idx not in empty_lines
    while line_idx in empty_lines:
        line_idx -= 1
    new_line_idx = line_idx
    for empty_line in empty_lines:
        if empty_line < line_idx:
            new_line_idx -= 1
    return new_line_idx


def parse_merged_region_to_table_index(ws, ws_cell, empty_rows, empty_cols, table_start_row):
    """ Parse start_row/end_row, start_col/end_col in merged region into row/col in table."""
    merged_cell = None
    for mc in ws.merged_cell_ranges:
        if ws_cell.coordinate in mc:
            merged_cell = mc
            break
    assert merged_cell is not None
    min_row, max_row, min_col, max_col = \
        merged_cell.min_row, merged_cell.max_row, merged_cell.min_col, merged_cell.max_col
    corecell_string = ws.cell(min_row, min_col).value
    # coord_corecell_in_ws = get_column_letter(min_col) + str(min_row)

    # line index after removing empty lines
    start_row, end_row = line_index_after_removed_empty_lines(min_row, empty_rows), \
                         line_index_after_removed_empty_lines(max_row, empty_rows)
    start_col, end_col = line_index_after_removed_empty_lines(min_col, empty_cols), \
                         line_index_after_removed_empty_lines(max_col, empty_cols)

    # to table index
    start_row, end_row = start_row - table_start_row, end_row - table_start_row
    start_col, end_col = start_col - 1, end_col - 1
    return (start_row, end_row, start_col, end_col), corecell_string


def find_empty_lines(ws, table_start_row, table_end_row, table_max_columns):
    """ Remove blank lines in tables."""
    # find empty rows/columns
    empty_rows, empty_cols = set(), set()
    for row in range(table_start_row, table_end_row+1):
        empty_row_flag = True
        for col in range(1, table_max_columns+1):
            if (ws.cell(row, col).value is not None and str(ws.cell(row, col).value).strip() not in ['', 'i']):
                    # or is_merged_cell(ws, ws.cell(row, col)):
                empty_row_flag = False
                break
        if empty_row_flag:
            empty_rows.add(row)
    for col in range(1, table_max_columns+1):
        empty_col_flag = True
        for row in range(table_start_row, table_end_row+1):
            if (ws.cell(row, col).value is not None and str(ws.cell(row, col).value).strip() not in ['', 'i']):
                    # or is_merged_cell(ws, ws.cell(row, col)):
                empty_col_flag = False
                break
        if empty_col_flag:
            empty_cols.add(col)
    empty_rows = sorted(list(empty_rows))
    empty_cols = sorted(list(empty_cols))

    # move borders of removed empty lines to their neighbour lines
    border_type = Side(style='thin')
    for row in empty_rows:
        for col in range(1, table_max_columns+1):
            cell = ws.cell(row, col)
            if cell.border.bottom.style is not None and row - 1 >= table_start_row:
                neighbour_cell = ws.cell(row-1, col)
                neighbour_cell.border = Border(left=copy.deepcopy(neighbour_cell.border.left),
                                                    right=copy.deepcopy(neighbour_cell.border.right),
                                                    top=copy.deepcopy(neighbour_cell.border.top),
                                                    bottom=border_type)
            if cell.border.top.style is not None and row + 1 <= table_end_row:
                neighbour_cell = ws.cell(row+1, col)
                neighbour_cell.border = Border(left=copy.deepcopy(neighbour_cell.border.left),
                                                    right=copy.deepcopy(neighbour_cell.border.right),
                                                    top=border_type,
                                                    bottom=copy.deepcopy(neighbour_cell.border.bottom))
    return empty_rows, empty_cols


def build_cell_matrix(ws, table_start_row, table_end_row, table_max_columns, empty_rows, empty_cols):
    """ Build cell matrix based on table content and remove empty lines."""
    cell_matrix = []
    coord_to_cell_map = {}
    for row in range(table_start_row, table_end_row+1):
        if row in empty_rows:
            continue
        row_of_cells = []
        for col in range(1, table_max_columns+1):
            if col in empty_cols:
                continue
            ws_cell = ws.cell(row, col)
            # row/column in spreadsheet after removing empty lines, one-indexed
            row_wo_empty, col_wo_empty = \
                line_index_after_removed_empty_lines(row, empty_rows), \
                line_index_after_removed_empty_lines(col, empty_cols)
            # row/column in table, zero-indexed
            row_in_table, col_in_table = row_wo_empty - table_start_row, col_wo_empty - 1
            if is_merged_cell(ws, ws_cell):
                (start_row, end_row, start_col, end_col), corecell_string = \
                    parse_merged_region_to_table_index(ws, ws_cell, empty_rows, empty_cols, table_start_row)
                iscorecell = (row_in_table == start_row and col_in_table == start_col)
            else:
                start_row, end_row, start_col, end_col = \
                    row_in_table, row_in_table, col_in_table, col_in_table
                iscorecell = True
                corecell_string = ws_cell.value

            if corecell_string is None:
                corecell_string = ''
            cell = Cell(
                text=str(corecell_string).strip(),  # only core cell has cell value
                alignment=ws_cell.alignment,
                font=ws_cell.font,
                fill=ws_cell.fill,
                border=ws_cell.border,
                number_format=ws_cell.number_format,
                indent_length=ws_cell.alignment.indent,
                bold=ws_cell.font.bold,
                color=ws_cell.fill.fgColor,
                patternType=ws_cell.fill.patternType,
                cell_span=[start_row, start_col, end_row, end_col, row_in_table, col_in_table],
                iscorecell=iscorecell,
                coord_in_ws = ws_cell.coordinate
            )
            row_of_cells.append(cell)
            coord_to_cell_map[ws_cell.coordinate] = (row_in_table, col_in_table)
        cell_matrix.append(row_of_cells)

    return cell_matrix, coord_to_cell_map


def print_cell_matrix(cell_matrix, rows_to_print):
    ''' Pretty(maybe not) print a cell matrix for DEBUG use.'''
    for i in range(min(rows_to_print, len(cell_matrix))):
        for j in range(len(cell_matrix[0])):
            cell = cell_matrix[i][j]
            print(cell.cell_string)
            print(cell.span)
        print('-------------------------------------')


def detect_header_region(cell_matrix):
    """ Detect number of rows/columns of headers in cell_matrix."""
    num_rows, num_cols = len(cell_matrix), len(cell_matrix[0])
    # top header
    num_topheaderrows = None
    max_topheaderrows = 8
    for row in range(1, min(max_topheaderrows, num_rows)):
        top_line_flag = True
        for col in range(num_cols):
            if cell_matrix[row][col].border.top.style is None:
                top_line_flag = False
                break
        num_chars, num_digit_chars = 0, 0
        for col in range(1, num_cols):
            for ch in cell_matrix[row][col].cell_string:
                if ch == ',' or ch == '.':
                    continue
                num_chars += 1
                if ch.isdigit():
                    num_digit_chars += 1
        row_digit_proportion = num_digit_chars / num_chars if num_chars != 0 else None
        # rule#1 and rule#2
        if top_line_flag and row_digit_proportion is None:  # parent header with empty data row
            if cell_matrix[row][0].cell_string.strip() != '':
                num_topheaderrows = row
                break
        elif top_line_flag and row_digit_proportion > 0.8:  # typical first data row
            if cell_matrix[row][0].span['row_span'] == 1 and cell_matrix[row][0].span['col_span'] == 1:
                num_topheaderrows = row
                break
    # rule#3 (backup)
    if num_topheaderrows is None:
        for row in range(min(max_topheaderrows, num_rows)):
            bottom_line_flag = True
            for col in range(num_cols):
                if cell_matrix[row][col].border.bottom.style is None:
                    bottom_line_flag = False
                    break
            if bottom_line_flag:
                num_topheaderrows = row + 1
                break
    assert num_topheaderrows <= min(max_topheaderrows, num_rows)

    # left header
    # horizontal merge
    num_leftheadercols = 0
    for j in range(num_cols):
        ifhorizontalmerge = False
        ifverticalmerge = False
        for i in range(num_rows):
            if cell_matrix[i][j] != None and cell_matrix[i][j].span["col_span"] > 1 and cell_matrix[i][j].span[
                "row_span"] == 1:
                ifhorizontalmerge = True
            if i > num_topheaderrows and cell_matrix[i][j] != None and cell_matrix[i][j].span["row_span"] > 1:
                ifverticalmerge = True
        if ifhorizontalmerge:
            break
        if not ifverticalmerge:
            num_leftheadercols = j
            break
        num_leftheadercols = j
    # reset to 0 if there's no horizontal merge
    if num_leftheadercols == num_cols - 1:
        num_leftheadercols = 0
    num_leftheadercols += 1

    return num_topheaderrows, num_leftheadercols


def merge_empty_header_cells(cell_matrix, num_top_header_rows, num_left_header_cols):
    """ Merge empty header(top) cells with header cells beneath it."""
    for row in range(num_top_header_rows):
        for col in range(len(cell_matrix[0])):
            cell = cell_matrix[row][col]
            if cell.cell_string != '':
                continue
            merge_start_row = i = row
            merge_end_row = None
            core_cell = None  # core_cell is the cell with non-empty cell string, not left-top cell
            while i < num_top_header_rows:
                if cell_matrix[i][col].cell_string != '' or i == num_top_header_rows - 1:
                    core_cell = cell_matrix[i][col]
                    merge_end_row = i
                    break
                i += 1
            merge_start_col = core_cell.span['start_col']
            merge_end_col = core_cell.span['end_col']

            # if core_cell itself is a merged cell, modify span of cells in the merged cell
            for i in range(core_cell.span['start_row'], core_cell.span['end_row'] + 1):
                for j in range(core_cell.span['start_col'], core_cell.span['end_col'] + 1):
                    cell = cell_matrix[i][j]
                    cell.span['start_row'] = merge_start_row
                    cell.span['row_span'] = cell.span['end_row'] - cell.span['start_row'] + 1

            # merge empty cells with core_cell
            for i in range(merge_start_row, merge_end_row+1):
                # ic(i, col)
                iscorecell = (i == merge_start_row and col == merge_start_col)  # if left-top cell
                cell_matrix[i][col] = Cell(
                                        text=core_cell.cell_string,  # only core cell has cell value
                                        alignment=cell_matrix[i][col].alignment,
                                        font=cell_matrix[i][col].font,
                                        fill=cell_matrix[i][col].fill,
                                        border=cell_matrix[i][col].border,
                                        number_format=cell_matrix[i][col].number_format,
                                        indent_length=cell_matrix[i][col].indent_level,
                                        bold=cell_matrix[i][col].font_bold,
                                        color=cell_matrix[i][col].fill_rgb,
                                        patternType=cell_matrix[i][col].fill_patternType,
                                        cell_span=[merge_start_row, merge_start_col,
                                                   merge_end_row, merge_end_col,
                                                   cell_matrix[i][col].span['cur_row'], cell_matrix[i][col].span['cur_col']],
                                        iscorecell=iscorecell,
                                        coord_in_ws=cell_matrix[i][col],
                                        row_merge_offset=merge_start_row-i,
                                        col_merge_offset=merge_start_col-col
                                        )


def rewrite_coord(coord, cell_matrix, coord_to_cell_map, title_end, empty_rows, empty_cols):
    """ Rewrite coord into new coord. e.g. B15->A14 in new coords.
    Consider:
    1. merged cell offset
    2. title offset
    3. empty line offset
    """
    row_in_ws, col_in_ws = int(find_row(coord)[0]), int(column_index_from_string(find_column(coord)[0]))
    if row_in_ws <= title_end and col_in_ws == 1 or coord not in coord_to_cell_map:  # link to title
        return 'A1'
    row_in_table, col_in_table = coord_to_cell_map[coord]
    cell = cell_matrix[row_in_table][col_in_table]
    new_row_in_ws = line_index_after_removed_empty_lines(row_in_ws, empty_rows) - title_end + 2  # 2 for title
    if cell.row_merge_offset is not None:
        new_row_in_ws += cell.row_merge_offset
    new_col_in_ws = line_index_after_removed_empty_lines(col_in_ws, empty_cols)
    return get_column_letter(new_col_in_ws) + str(new_row_in_ws)


def rewrite_formula_cell(formula, cell_matrix, coord_to_cell_map, title_end, empty_rows, empty_cols):
    """ Rewrite formula in a cell."""
    coord_pattern = re.compile('[a-zA-Z]+[0-9]+')
    coords = split_formula_items(formula, coord_pattern)
    for coord in coords:
        formula_list = list(formula)
        new_coord = rewrite_coord(coord, cell_matrix, coord_to_cell_map, title_end, empty_rows, empty_cols)
        pos = formula.find(coord)  # must find in string and assign in list to ensure accurate position
        formula_list[pos: pos+len(coord)] = new_coord
        formula = ''.join(formula_list)
    return formula


def split_formula_items(formula, pattern):  # deprecated
    """ Split '=SUM(A15:B16)' into ['A15', 'B16']. """
    valid_items = [
        item for item in re.findall(pattern, formula)
    ]
    return valid_items


def rewrite_formulas(ws, cell_matrix, coord_to_cell_map, title_end, desc_start, empty_rows, empty_cols):
    """ Rewrite formulas to meet new coordinates in spreadsheet.  """
    row = desc_start
    num_cont_empty_line = 0
    while True:
        if ws.cell(row, 1).value is None or str(ws.cell(row, 1).value).strip() == '':
            num_cont_empty_line += 1
            if num_cont_empty_line > 2:  # EOF
                break
        else:
            num_cont_empty_line = 0
            col = 1
            while True:
                if ws.cell(row, col).value is None or str(ws.cell(row, col).value).strip() == '':
                    break
                if isinstance(ws.cell(row, col).value, str) and ws.cell(row, col).value.startswith('='):
                    formula = ws.cell(row, col).value
                    new_formula = rewrite_formula_cell(formula,
                                                       cell_matrix, coord_to_cell_map, title_end, empty_rows, empty_cols)
                    ws.cell(row, col).value = new_formula
                col += 1
        row += 1


def save_file_in_uniform_format(ws, cell_matrix, title, desc_start,
                                table_id, page_id, file_name, new_nsf_dir):
    """ Save annotation file in format that 'read_annotation' can directly parse."""
    wb_write = Workbook()
    ws_write = wb_write.active
    ws_write.title = 'labeling'
    ws_write.cell(1, 1).value = title
    max_rows, max_cols = len(cell_matrix), len(cell_matrix[0])
    for row in range(max_rows):
        for col in range(max_cols):
            cell = cell_matrix[row][col]
            if cell is None or not cell.iscorecell:
                continue
            ws_write.cell(row + 3, col + 1).value = cell.cell_string
            ws_write.cell(row + 3, col + 1).font = copy.copy(cell.font)
            ws_write.cell(row + 3, col + 1).fill = copy.copy(cell.fill)
            ws_write.cell(row + 3, col + 1).border = copy.copy(cell.border)
            ws_write.cell(row + 3, col + 1).alignment = copy.copy(cell.alignment)
            ws_write.cell(row + 3, col + 1).number_format = copy.copy(cell.number_format)
            colspan = cell.span["col_span"]
            rowspan = cell.span["row_span"]
            if colspan > 1 and rowspan > 1:
                ws_write.merge_cells(start_row=row + 3, start_column=col + 1, end_row=row + 2 + rowspan,
                               end_column=col + colspan)
            elif colspan > 1:
                ws_write.merge_cells(start_row=row + 3, start_column=col + 1, end_row=row + 3, end_column=col + colspan)
            elif rowspan > 1:
                ws_write.merge_cells(start_row=row + 3, start_column=col + 1, end_row=row + 2 + rowspan, end_column=col + 1)

    row_read = desc_start
    row_write = 2 + max_rows + 4  # 2 for title, 4 for empty lines between table and desc
    num_cont_empty_line = 0
    while True:
        if ws.cell(row_read, 1).value is None or str(ws.cell(row_read, 1).value).strip() == '':
            num_cont_empty_line += 1
            if num_cont_empty_line > 2:  # EOF
                break
        else:
            num_cont_empty_line = 0
            col_read = 1
            while True:
                if ws.cell(row_read, col_read).value is None or str(ws.cell(row_read, col_read).value).strip() == '':
                    break
                col_write = col_read
                ws_write.cell(row_write, col_write).value = ws.cell(row_read, col_read).value
                ws_write.cell(row_write, col_write).font = copy.copy(ws.cell(row_read, col_read).font)
                ws_write.cell(row_write, col_write).fill = copy.copy(ws.cell(row_read, col_read).fill)
                ws_write.cell(row_write, col_write).border = copy.copy(ws.cell(row_read, col_read).border)
                ws_write.cell(row_write, col_write).alignment = copy.copy(ws.cell(row_read, col_read).alignment)
                ws_write.cell(row_write, col_write).number_format = copy.copy(ws.cell(row_read, col_read).number_format)
                col_read += 1
        row_read += 1
        row_write += 1

    nsf_table_id = f"{table_id}_{page_id}_{file_name.split('.')[0]}.xlsx"
    final_filename = os.path.join(new_nsf_dir, nsf_table_id)
    wb_write.save(final_filename)
    return final_filename


def infer_metric_by_div_num(question, num):
    """ Infer metric by the number to divide. e.g. '1000' means 'billion'->'million'. """
    old_metric = None  # no explicit metric, e.g. 500 dollars
    if 'thousand' in question:
        old_metric = 'thousand'
    elif 'million' in question:
        old_metric = 'million'
    elif 'billion' in question:
        old_metric = 'billion'
    if old_metric is None:
        return question
    simple_metric_conversion_map = {
        '1000_billion': 'million',
        '1000_million': 'thousand',
        '1000_thousand': '',
        '1000000_billion': 'thousand',
        '1000000_million': '',
        '1000000000_billion': ''
    }
    new_metric = simple_metric_conversion_map[f"{num}_{old_metric}"]
    return new_metric


def rewrite_qa_with_metric_conversion(formula, question, aggr_type, title):
    """ Rewrite formula, question and aggregation type after metric conversion."""
    conversion_pattern = re.compile('/10+')
    div_and_num_list = re.findall(conversion_pattern, formula)
    if len(div_and_num_list) == 0:
        return formula, question, aggr_type
    div_and_num = div_and_num_list[0]
    start_positions = [i for i in range(len(formula)) if formula.startswith(div_and_num, i)]
    positions = [p for sp in start_positions for p in range(sp, sp+len(div_and_num))]
    formula_list = list(formula)
    new_formula_list = []
    for i, ch in enumerate(formula_list):
        if i not in positions:
            new_formula_list.append(ch)
    new_formula = ''.join(new_formula_list)

    new_metric = None
    title = normalize(title)
    if 'thousand' in title:
        new_metric = 'thousand'
    elif 'million' in title:
        new_metric = 'million'
    elif 'billion' in title:
        new_metric = 'billion'
    if new_metric is None:
        new_metric = infer_metric_by_div_num(question, int(div_and_num[1:]))
    new_question = question\
        .replace('billion', new_metric)\
        .replace('million', new_metric)\
        .replace('thousand', new_metric)

    new_aggr_type = aggr_type
    if aggr_type == 'div' and '/' not in new_formula:
            new_aggr_type = 'none'
    return new_formula, new_question, new_aggr_type


def rewrite_qas(ws, title, desc_start):
    """ Convert answer 'B15/1000' to 'B15', and modify question/aggregation type correspondingly."""
    row = desc_start
    num_cont_empty_line = 0
    while True:
        if ws.cell(row, 1).value is None or str(ws.cell(row, 1).value).strip() == '':
            num_cont_empty_line += 1
            if num_cont_empty_line > 2:  # EOF
                break
        else:
            num_cont_empty_line = 0
            if ws.cell(row, 1).value == 'answer (formula):':
                col = 2
                while True:
                    if ws.cell(row, col).value is None or str(ws.cell(row, col).value).strip() == '':
                        break
                    if isinstance(ws.cell(row, col).value, str) and ws.cell(row, col).value.startswith('='):
                        formula, question, aggr_type = \
                            ws.cell(row, col).value, ws.cell(row-1, 2).value, ws.cell(row+1, 2).value
                        new_formula, new_question, new_aggr_type = \
                            rewrite_qa_with_metric_conversion(formula, question, aggr_type, title)
                        ws.cell(row, col).value = new_formula
                        ws.cell(row-1, 2).value = new_question
                        ws.cell(row+1, 2).value = new_aggr_type
                    col += 1
        row += 1



def process(file_path, file_name, page_id, table_id, new_nsf_dir):
    """
    Process nsf annotation file, including
    1. remove empty lines (done), and change formula correspondingly (done).
    2. construct cell matrix (done)
    3. merge top headers (done)
    4. extract hierarchy (done)
    5. write a new nsf annotation file, to be parsed by 'read_annotation()' (done)

    args:
        file_path: '/mnt/USER/Apr13/data_nsf/annotations_nsf/128/tab1.xlsx'    # used to load the original workbook file
        filename: 'tab1.xlsx'                    # used to generate new workbook filename
        page_id: '128'                           # used to generate new workbook filename
        table_id: (int) global_table_id          # used to generate new workbook filename
    """
    wb = load_workbook(file_path)
    sheet_names = wb.get_sheet_names()
    ws = wb.get_sheet_by_name(sheet_names[0])
    enhance_borders_on_both_sides(ws)
    (title, title_end), (footer, footer_start), desc_start \
        = find_title_and_footer(ws)

    table_start_row, table_end_row = title_end + 1, footer_start - 1
    table_max_columns = find_table_max_columns(ws, table_start_row, table_end_row)
    # ic(table_start_row, table_end_row, table_max_columns)

    empty_rows, empty_cols = find_empty_lines(ws, table_start_row, table_end_row, table_max_columns)
    # ic(empty_rows, empty_cols)

    cell_matrix, coord_to_cell_map = build_cell_matrix(ws, table_start_row, table_end_row, table_max_columns, empty_rows, empty_cols)
    num_top_header_rows, num_left_header_cols = detect_header_region(cell_matrix)
    # ic(num_top_header_rows, num_left_header_cols)

    merge_empty_header_cells(cell_matrix, num_top_header_rows, num_left_header_cols)
    # print_cell_matrix(cell_matrix, rows_to_print=6)
    table_structure = cellmatrix2structure(cell_matrix, num_top_header_rows, num_left_header_cols, len(cell_matrix), len(cell_matrix[0]),
                                           os.path.join(new_nsf_dir, f"testhier_{table_id}_{page_id}_{file_name}"))
    table_structure['num_top_header_rows'] = num_top_header_rows
    table_structure['num_left_header_cols'] = num_left_header_cols
    table_structure['caption'] = title

    rewrite_formulas(ws, cell_matrix, coord_to_cell_map, title_end, desc_start, empty_rows, empty_cols)
    rewrite_qas(ws, title, desc_start)
    # save_file_in_uniform_format(ws, cell_matrix, title, desc_start, table_id, page_id, file_name, new_nsf_dir)
    nsf_table_id = f"{table_id}_{page_id}_{file_name.split('.')[0]}.xlsx"
    final_anno_filename = os.path.join(new_nsf_dir, nsf_table_id)
    return table_structure, final_anno_filename


def main():
    global_table_id = 0
    error_case = 0
    for dir in sorted(os.listdir(os.path.join(args.root_dir, args.data_dir, args.nsf_annotation_dir))):
        if dir.endswith('N'):    # e.g. '12N'
            continue
        page_dir = os.path.join(args.root_dir, args.data_dir, args.nsf_annotation_dir, dir)
        for file in sorted(os.listdir(page_dir)):
            if file.split('.')[0].endswith('N') or file.startswith('~') or 'fig' in file:
                continue
            print(f"Processing {file} on page {dir}.")
            try:
                table_structure, new_filename = process(
                    file_path=os.path.join(page_dir, file), 
                    file_name=file, page_id=dir, table_id=global_table_id
                )
                global_table_id += 1
            except Exception as e:
                print(f"Error: {e}")
                error_case += 1
                time.sleep(3)
    ic(error_case)
    ic(global_table_id)


if __name__ == '__main__':
    parser = ArgumentParser()
    parser.add_argument('--root_dir', type=str, default='/data/home/hdd3000/USER/HMT/')
    parser.add_argument('--data_dir', type=str, default='qa/data/')
    parser.add_argument('--nsf_annotation_dir', type=str, default='annotations_nsf/')
    parser.add_argument('--nsf_save_dir', type=str, default='annotations_nsf_clean/')
    parser.add_argument('--structure_path', type=str, default='spreadsheet_stru/')
    parser.add_argument('--table_dir', type=str, default='raw_input/table/')
    args = parser.parse_args()

    main()
