import os
import json
import numpy as np
import pandas as pd
from typing import List, Dict
from argparse import ArgumentParser
import copy as syscopy

from preprocess.crawler import Table
from preprocess.utils import *


pd.set_option('display.max_columns', None)
pd.set_option('display.max_rows', None)

INVALID_DATA_TOKENS = {'x', 'X', 'E', 'F', '.', '..', '...', ''}


class Cell(object):
    def __init__(self, text, alignment, font, fill, border, number_format, indent_length, bold, color, patternType,
                 cell_span, iscorecell, coord_in_ws=None, row_merge_offset=None, col_merge_offset=None):
        self.cell_string = text
        self.alignment = alignment
        self.font = font
        self.fill = fill
        self.border = border
        self.number_format = number_format
        self.indent_level = indent_length
        self.font_bold = bold
        if hasattr(color, 'rgb'):
            self.fill_rgb = color.rgb
        else:
            self.fill_rgb = color
        self.fill_patternType = patternType
        self.span = {"start_row": cell_span[0], "start_col": cell_span[1], "end_row": cell_span[2],
                     "end_col": cell_span[3], "cur_row": cell_span[4], "cur_col": cell_span[5],
                     "row_span": cell_span[2] - cell_span[0] + 1, "col_span": cell_span[3] - cell_span[1] + 1}
        self.iscorecell = iscorecell
        self.coord_in_ws = coord_in_ws
        self.row_merge_offset = row_merge_offset
        self.col_merge_offset = col_merge_offset

        self.mark_topheader = False
        self.mark_leftheader = False
        self.mark_datacell = False
        self.top_level = 0
        self.left_level = 0

        self.parent_cell = []
        self.child_cells = []

        # if the cell is a top node, topnode_datacolumn stores the data column it directly indexes
        self.topnode_datacolumn = -1
        # if the cell is a left node, leftnode_datarow stores the data row it directly indexes
        self.leftnode_datarow = -1

        self.right_cell = None


def extract_header(cell_matrix):
    num_rows, num_cols = len(cell_matrix), len(cell_matrix[0])

    # number of top header rows
    # fill color
    num_topheaderrows = 0
    for i in range(num_rows):
        row_fillcolor = False
        for j in range(num_cols):
            if cell_matrix[i][j] != None and cell_matrix[i][j].fill_rgb == 'EEEEEEEE' and cell_matrix[i][
                j].fill_patternType == "solid":
                # print("index", i, j, cell_matrix[i][j].cell_string)
                row_fillcolor = True
        if not row_fillcolor:
            break
        num_topheaderrows = i

    # non-blank cells
    while True:
        num_nonblankcell = 0
        for j in range(num_cols):
            # print(cell_matrix[num_topheaderrows][j].cell_string)
            if cell_matrix[num_topheaderrows][j] is not None and cell_matrix[num_topheaderrows][
                j].cell_string.strip() != "" \
                    and cell_matrix[num_topheaderrows][j].iscorecell:
                num_nonblankcell += 1
        if num_nonblankcell > 1 or num_topheaderrows < 1:
            break
        else:
            num_topheaderrows -= 1
    num_topheaderrows += 1

    # left header
    # horizontal merge
    num_leftheadercols = 0
    for j in range(num_cols):
        ifhorizontalmerge = False
        ifverticalmerge = False
        for i in range(num_rows):
            # if cell_matrix[i][j] != None :
            #    print(i, j, cell_matrix[i][j].cell_string, cell_matrix[i][j].span["col_span"])
            if cell_matrix[i][j] != None and cell_matrix[i][j].span["col_span"] > 1 and cell_matrix[i][j].span[
                "row_span"] == 1:
                ifhorizontalmerge = True
            if i > num_topheaderrows and cell_matrix[i][j] != None and cell_matrix[i][j].span["row_span"] > 1:
                ifverticalmerge = True
        # print(j, ifhorizontalmerge)
        if ifhorizontalmerge:
            break
        if not ifverticalmerge:
            num_leftheadercols = j
            break
        num_leftheadercols = j
    # reset to 0 if there's no horizontal merge
    if num_leftheadercols == num_cols - 1:
        num_leftheadercols = 0
    num_leftheadercols += 1

    return num_topheaderrows, num_leftheadercols


def extract_tophierarchy(cell_matrix):
    cell_tree = {"virtual_root": []}
    if len(cell_matrix) == 0 or len(cell_matrix[0]) == 0:
        return cell_tree
    num_rows, num_cols = len(cell_matrix), len(cell_matrix[0])
    cell_list = [cell_matrix[i][j] for i in range(num_rows) for j in range(num_cols)]
    len_cell_list = num_rows * num_cols
    # link parent and child
    for i in range(len_cell_list):
        cell_1 = cell_list[i]
        for j in range(len_cell_list):
            cell_2 = cell_list[j]
            if cell_1 == None or cell_2 == None or not cell_1.iscorecell or not cell_2.iscorecell:
                continue
            if cell_1.span["start_row"] >= cell_2.span["start_row"]:
                continue
            if cell_1.span["start_col"] > cell_2.span["end_col"] or cell_1.span["end_col"] < cell_2.span["start_col"]:
                continue
            cell_1.child_cells.append(cell_2)
            if len(cell_2.parent_cell) == 0:
                cell_2.parent_cell = [cell_1]
            elif cell_2.parent_cell[0].span["start_row"] < cell_1.span["start_row"]:
                cell_2.parent_cell = [cell_1]
            else:
                pass
    # add virtual node
    for i in range(len_cell_list):
        cell = cell_list[i]
        if cell == None: continue
        cell.mark_topheader = True
        if not cell.iscorecell: continue
        if len(cell.parent_cell) == 0:
            cell_tree["virtual_root"].append(cell)
            cell.parent_cell.append("virtual_root")

    # align child merge cells
    for i in range(len_cell_list):
        cell = cell_list[i]
        if cell == None: continue
        for j in range(len(cell.child_cells)):
            child = cell.child_cells[j]
            while child != None and child.span["cur_col"] < cell.span["cur_col"] \
                    and child.span["start_col"] + child.span["col_span"] - 1 >= cell.span["cur_col"]:
                child = child.right_cell
            cell.child_cells[j] = child
            child.parent_cell = [cell]
    # update cell level
    for _ in range(5):
        for i in range(len_cell_list):
            cell = cell_list[i]
            if cell == None or not cell.iscorecell: continue
            for child in cell.child_cells:
                child.top_level = max(child.top_level, cell.top_level + 1)
    return cell_tree


def extract_lefthierarchy(cell_matrix):
    cell_tree = {"virtual_root": []}
    if len(cell_matrix) == 0 or len(cell_matrix[0]) == 0:
        return cell_tree
    num_rows, num_cols = len(cell_matrix), len(cell_matrix[0])
    if num_cols > 1:
        cell_tree = extract_multilefthierarchy(cell_matrix)
    else:
        cell_tree = extract_singlelefthierarchy(cell_matrix)

    return cell_tree


def extract_singlelefthierarchy(cell_matrix):
    cell_tree = {"virtual_root": []}
    cell_list = []
    if len(cell_matrix) == 0 or len(cell_matrix[0]) == 0:
        return cell_tree
    num_rows = len(cell_matrix)
    hierarchystack = []
    index = 0
    while index < num_rows:
        curcell = cell_matrix[index][0]
        if curcell == None:
            index += 1
            continue
        if curcell != None and (curcell.cell_string.strip() == "" or not curcell.iscorecell):
            cellright = curcell
            while cellright != None:
                if cellright.cell_string.strip() != "" and cellright.iscorecell:
                    break
                cellright = cellright.right_cell
            if cellright != None and cellright.span["col_span"] > 1:
                curcell = cellright
        if len(hierarchystack) == 0:
            hierarchystack.append(curcell)
            cell_list.append(curcell)
            index += 1
            continue
        else:
            lastcell = hierarchystack[-1]
            compair_result = compair_leftattributes(extract_leftattributes(lastcell), \
                                                    extract_leftattributes(curcell))
            if compair_result == 1:
                lastcell.child_cells.append(curcell)
                curcell.parent_cell.append(lastcell)
                hierarchystack.append(curcell)
                cell_list.append(curcell)
                index += 1
                continue
            else:
                hierarchystack.pop()
                continue

    len_cell_list = len(cell_list)
    # add virtual node
    for i in range(len_cell_list):
        cell = cell_list[i]
        if cell == None: continue
        cell.mark_leftheader = True
        if not cell.iscorecell: continue
        if len(cell.parent_cell) == 0:
            cell_tree["virtual_root"].append(cell)
            cell.parent_cell.append("virtual_root")
    # update cell level
    for _ in range(5):
        for i in range(len_cell_list):
            cell = cell_list[i]
            if cell == None or not cell.iscorecell: continue
            for child in cell.child_cells:
                child.left_level = max(child.left_level, cell.left_level + 1)

    return cell_tree


def extract_leftattributes(cell):
    fontbold = cell.font_bold
    indentlevel = cell.indent_level
    if cell.span["col_span"] > 1:
        horizontalmerge = True
    else:
        horizontalmerge = False
    return {"horizontalmerge": horizontalmerge, "indentlevel": indentlevel, "fontbold": fontbold}


def compair_leftattributes(cell1_att, cell2_att):
    if cell1_att["horizontalmerge"] and not cell2_att["horizontalmerge"]:
        return 1
    if not cell1_att["horizontalmerge"] and cell2_att["horizontalmerge"]:
        return -1
    if cell1_att["indentlevel"] < cell2_att["indentlevel"]:
        return 1
    if cell1_att["indentlevel"] > cell2_att["indentlevel"]:
        return -1
    if cell1_att["fontbold"] and not cell2_att["fontbold"]:
        return 1
    if not cell1_att["fontbold"] and cell2_att["fontbold"]:
        return -1
    return 0


def extract_multilefthierarchy(cell_matrix):
    cell_tree = {"virtual_root": []}
    if len(cell_matrix) == 0 or len(cell_matrix[0]) == 0:
        return cell_tree
    num_rows, num_cols = len(cell_matrix), len(cell_matrix[0])
    cell_list = [cell_matrix[i][j] for i in range(num_rows) for j in range(num_cols)]
    len_cell_list = num_rows * num_cols
    # link parent and child
    for i in range(len_cell_list):
        cell_1 = cell_list[i]
        for j in range(len_cell_list):
            cell_2 = cell_list[j]
            if cell_1 == None or cell_2 == None or not cell_1.iscorecell or not cell_2.iscorecell:
                continue
            if cell_1.span["start_col"] >= cell_2.span["start_col"]:
                continue
            if cell_1.span["start_row"] > cell_2.span["end_row"] or cell_1.span["end_row"] < cell_2.span["start_row"]:
                continue
            cell_1.child_cells.append(cell_2)
            if len(cell_2.parent_cell) == 0:
                cell_2.parent_cell = [cell_1]
            elif cell_2.parent_cell[0].span["start_col"] < cell_1.span["start_col"]:
                cell_2.parent_cell = [cell_1]
            else:
                pass
    # add virtual node
    for i in range(len_cell_list):
        cell = cell_list[i]
        if cell == None: continue
        cell.mark_leftheader = True
        if not cell.iscorecell: continue
        if len(cell.parent_cell) == 0:
            cell_tree["virtual_root"].append(cell)
            cell.parent_cell.append("virtual_root")
    # update cell level
    for _ in range(5):
        for i in range(len_cell_list):
            cell = cell_list[i]
            if cell == None or not cell.iscorecell: continue
            for child in cell.child_cells:
                child.left_level = max(child.left_level, cell.left_level + 1)
    return cell_tree


def convert_data_type(data_region: List[List]):
    """ Convert data cell from string to number by pandas."""
    tmp_file = os.path.join(args.root_dir, args.data_dir, args.table_dir, 'tmp.xlsx')
    # write
    wb = Workbook()
    ws = wb.active
    ws.title = 'tmp'
    max_rows, max_cols = len(data_region), len(data_region[0])
    for i in range(1, max_rows + 1):
        for j in range(1, max_cols + 1):
            ws.cell(i, j).value = data_region[i - 1][j - 1]['value']
    wb.save(tmp_file)
    # read
    df = pd.read_excel(tmp_file, engine='openpyxl', thousands=',', header=None)
    df.fillna('', inplace=True)
    data_region = []
    for i, row in df.iterrows():
        data_region.append([])
        for _, v in row.items():
            if isinstance(v, str):
                v = naive_str_to_float(v)
            data_region[i].append({"value": v})
        #     print(v, end='| ')
        # print()
    os.remove(tmp_file)
    return data_region


def find_offset(children_cells):
    """ Wrapper for finding row/col offset of line_idx."""
    offset = [10000]
    _find_offset(children_cells, offset, 0)
    return offset[0]


def _find_offset(children_cells, offset, depth):
    """ Find row/col offset of line_idx."""
    if depth > 10:
        raise ValueError("Find offset error. Seems there are recursive children.")
    for cell in children_cells:
        if cell.mark_topheader:
            line_idx = cell.topnode_datacolumn if cell.topnode_datacolumn != -1 else None
        else:
            line_idx = cell.leftnode_datarow if cell.leftnode_datarow != -1 else None
        # print(cell.cell_string, line_idx)
        if line_idx is not None:
            offset[0] = min(offset[0], line_idx)
        _find_offset(cell.child_cells, offset, depth+1)


def build_hmt_hierarchy(tree_node: Dict, children_cells: List[Cell], row_offset: int, col_offset: int,
                        cur_max_valid_idx: List):
    """ Recursively build hmt hierarchy dict."""
    for cell in children_cells:
        cell_string = normalize(cell.cell_string)
        if 'mathtype' in cell_string:
            cell_string = remove_math_type(cell_string)
        tree_child_node = {
            "name": cell_string,
            "value": cell_string,  # TODO: use infer_type() to real value, like datetime/number
            "type": "string",
            "children_dict": []
        }
        if cell.mark_topheader:
            line_idx = cell.topnode_datacolumn - col_offset if cell.topnode_datacolumn != -1 else None
        else:
            line_idx = cell.leftnode_datarow - row_offset if cell.leftnode_datarow != -1 else None

        if line_idx is not None and line_idx > cur_max_valid_idx[0]:  # avoid merge cell in data region
            # print(line_idx, cur_max_valid_idx)
            line_idx = cur_max_valid_idx[0]
        if line_idx == cur_max_valid_idx[0]:
            cur_max_valid_idx[0] += 1
        tree_child_node.update({"line_idx": line_idx})

        build_hmt_hierarchy(tree_child_node, cell.child_cells, row_offset, col_offset, cur_max_valid_idx)
        tree_node['children_dict'].append(tree_child_node)


def build_matrix_hierarchy(tree_node: Dict, children_cells: List[Cell]):
    """ Recursively build matrix hierarchy dict. Only corecell is considered."""
    for cell in children_cells:
        if not cell.iscorecell:
            continue
        row_idx, col_idx = cell.span['cur_row'], cell.span['cur_col']
        tree_child_node = {
            'RI': row_idx,
            'CI': col_idx,
            'Cd': []
        }

        build_matrix_hierarchy(tree_child_node, cell.child_cells)
        tree_node['Cd'].append(tree_child_node)


def save_as_json(table_id, table_structure):
    """ Save table structure as .json file, which is input dict of HMTable."""
    # hmt
    hmt = {
        "top_root": {
            "name": "<TOP>",
            "value": "<TOP>",
            "type": "string",
            "line_idx": None,
            "children_dict": []
        },
        "left_root": {
            "name": "<LEFT>",
            "value": "<LEFT>",
            "type": "string",
            "line_idx": None,
            "children_dict": []
        },
    }
    row_offset, col_offset = find_offset(table_structure['left_header']['virtual_root']), \
                             find_offset(table_structure['top_header']['virtual_root'])
    build_hmt_hierarchy(hmt['top_root'], table_structure['top_header']['virtual_root'], row_offset, col_offset, [0])
    build_hmt_hierarchy(hmt['left_root'], table_structure['left_header']['virtual_root'], row_offset, col_offset, [0])
    data_region = []
    cell_matrix = table_structure['cell_matrix']
    max_rows, max_cols = len(cell_matrix), len(cell_matrix[0])
    for row in range(max_rows):
        data_region.append([])
        for col in range(max_cols):
            cell = cell_matrix[row][col]
            if cell is None:
                continue
            if cell.mark_datacell:
                cell_data = cell.cell_string.strip()
                if cell_data in INVALID_DATA_TOKENS:
                    cell_data = '#'
                data_region[-1].append({"value": cell_data})
        if len(data_region[-1]) == 0:
            data_region.pop()
    data_region = convert_data_type(data_region)
    hmt['data'] = data_region
    hmt['title'] = normalize(table_structure['title'])

    # cell matrix (for tuta)
    matrix = {
        "TopTreeRoot": {
            'RI': -1,
            'CI': -1,
            'Cd': [],
        },
        "LeftTreeRoot": {
            'RI': -1,
            'CI': -1,
            'Cd': []
        }
    }
    build_matrix_hierarchy(matrix['TopTreeRoot'], table_structure['top_header']['virtual_root'])
    build_matrix_hierarchy(matrix['LeftTreeRoot'], table_structure['left_header']['virtual_root'])
    cells, merged_regions, merged_regions_set = [], [], set()
    for row in range(max_rows):
        cells.append([])
        for col in range(max_cols):
            # print(row, col)
            cell = cell_matrix[row][col]
            if cell is None or not cell.iscorecell:
                cells[-1].append('')
            else:
                cell_string = cell.cell_string
                if 'mathtype' in cell_string:
                    cell_string = remove_math_type(cell_string)
                cells[-1].append(normalize(cell_string))
                if cell.span['start_row'] != cell.span['end_row'] \
                        or cell.span['start_col'] != cell.span['end_col']:
                    merged_regions_set.add((  # use set to remove duplicates
                        cell.span['start_row'],
                        cell.span['end_row'],
                        cell.span['start_col'],
                        cell.span['end_col']
                    ))
    for merged_region in merged_regions_set:
        merged_regions.append({
            "FirstRow": merged_region[0],
            "LastRow": merged_region[1],
            "FirstColumn": merged_region[2],
            "LastColumn": merged_region[3]
        })
    matrix['Texts'] = cells
    matrix['MergedRegions'] = merged_regions
    matrix['Title'] = normalize(table_structure['title'])

    # save in one json
    table_info = {
        "hmt": hmt,
        "matrix": matrix
    }
    with open(os.path.join(args.root_dir, args.data_dir, args.table_dir, f'{table_id}.json'), 'w') as f:
        f.write(json.dumps(table_info, indent=2))


def html2structure(table_dict: Dict[int, Table]):
    table_stru_dict = {}
    f = open("log.txt", "w")
    for index, table in table_dict.items():
        wb = Workbook()
        ws = wb.active
        ws.title = 'labeling'
        max_rows, max_cols = find_ranges(table)
        print("Processing table {}, max_rows: {}; max_cols: {}".format(table.table_id, max_rows, max_cols))

        # init
        table_stru = {}
        table_stru["caption"] = prettify_caption(table.html.find('caption').text.strip().split('\n'))
        table_stru['title'] = prettify_title(table.html.find('caption').text.strip().split('\n'))
        table_stru["max_rows"] = max_rows
        table_stru["max_cols"] = max_cols
        table_stru["cell_matrix"] = [[None for j in range(max_cols)] for i in range(max_rows)]

        # cell matrix
        table_mask = np.zeros((max_rows, max_cols))  # row start at 3 (first 2 for caption)
        row, col = 0, 0
        for x in table.html.find_all('tr'):
            for y in x.find_all(['th', 'td']):
                if row >= max_rows:  # table-footer, currently not add footer
                    continue
                font, fill, border, alignment, number_format = get_cell_style(y)

                colspan = get_int(y, 'colspan')
                rowspan = get_int(y, 'rowspan')
                row, col = find_position_zeroindex(table_mask, row, col, max_rows, max_cols)
                text = clear_footer(y)
                for internal_row in range(row, min(row + rowspan, max_rows)):
                    for internal_col in range(col, min(col + colspan, max_cols)):
                        iscorecell = (row == internal_row) and (col == internal_col)
                        cell = Cell(text, alignment, font, fill, border, number_format, alignment.indent, font.bold,
                                    fill.fgColor, \
                                    fill.patternType,
                                    [row, col, row + rowspan - 1, col + colspan - 1, internal_row, internal_col],
                                    iscorecell)
                        table_stru["cell_matrix"][internal_row][internal_col] = cell
                table_mask[row: row + rowspan, col: col + colspan] = 1
            col = 0
            row += 1
        # detect header region
        num_topheaderrows, num_leftheadercols = extract_header(table_stru["cell_matrix"])
        # print(num_topheaderrows, num_leftheadercols)
        f.write(str(table.table_id, ) + " " + str(num_topheaderrows) + " " + str(num_leftheadercols) + "\n")
        f.flush()

        for row in range(max_rows):
            for col in range(max_cols):
                if col > 0 and cellleft != None: cellleft.right_cell = table_stru["cell_matrix"][row][col]
                cellleft = table_stru["cell_matrix"][row][col]

        top_header = [[table_stru["cell_matrix"][i][j] for j in range(num_leftheadercols, max_cols)] for i in
                      range(num_topheaderrows)]
        left_header = [[table_stru["cell_matrix"][i][j] for j in range(num_leftheadercols)] for i in
                       range(num_topheaderrows, max_rows)]

        # header hierarchy
        table_stru["top_header"] = extract_tophierarchy(top_header)
        table_stru["left_header"] = extract_lefthierarchy(left_header)

        # mark datacell 
        for row in range(max_rows):
            for col in range(max_cols):
                cell = table_stru["cell_matrix"][row][col]
                mark_hastopheader = False
                mark_hasleftheader = False
                for i in range(row):
                    celltop = table_stru["cell_matrix"][i][col]
                    if cell == None or celltop == None or celltop.span["start_row"] == cell.span["start_row"]: continue
                    if celltop.mark_topheader:
                        mark_hastopheader = True
                for j in range(col):
                    cellleft = table_stru["cell_matrix"][row][j]
                    if cell == None or cellleft == None or cellleft.span["start_col"] == cell.span[
                        "start_col"]: continue
                    if cellleft.mark_leftheader:
                        mark_hasleftheader = True
                if mark_hastopheader and mark_hasleftheader:
                    cell.mark_datacell = True

        # data index
        for row in range(max_rows):
            for col in range(max_cols):
                cell = table_stru["cell_matrix"][row][col]
                if cell == None: continue
                # if cell.mark_topheader and cell.span["col_span"] == 1:
                #    mark_childnotmerge = True
                #    for child in cell.child_cells:
                #        if child.span["col_span"] == 1:
                #            mark_childnotmerge = False
                #    if mark_childnotmerge:
                #        cell.topnode_datacolumn = col
                # top index
                if cell.mark_topheader and len(cell.child_cells) == 0 and len(cell.parent_cell) > 0:
                    cell.topnode_datacolumn = col
                # left index
                if cell.mark_leftheader and cell.span["row_span"] == 1:
                    mark_childsamerow = False
                    for child in cell.child_cells:
                        if child.span["start_row"] == cell.span["start_row"]:
                            mark_childsamerow = True
                    cellright = cell.right_cell
                    while cellright != None and not cellright.mark_datacell:
                        cellright = cellright.right_cell
                    if not mark_childsamerow and cellright != None:
                        cell.leftnode_datarow = row

        # save as json
        try:
            save_as_json(index, table_stru)
        except Exception as e:
            print(f"Error: {e}")

        # save file
        for row in range(max_rows):
            for col in range(max_cols):
                cell = table_stru["cell_matrix"][row][col]
                if cell == None or not cell.iscorecell: continue
                if cell.mark_topheader:
                    topprefix = str(cell.top_level) + " " + str(cell.topnode_datacolumn) + " "
                    ws.cell(row + 1, col + 1).value = topprefix + cell.cell_string
                elif cell.mark_leftheader:
                    leftprefix = str(cell.left_level) + " " + str(cell.leftnode_datarow) + " "
                    ws.cell(row + 1, col + 1).value = leftprefix + cell.cell_string
                elif cell.mark_datacell:
                    dataprefix = "data "
                    ws.cell(row + 1, col + 1).value = dataprefix + cell.cell_string
                else:
                    ws.cell(row + 1, col + 1).value = cell.cell_string
                ws.cell(row + 1, col + 1).font = cell.font
                ws.cell(row + 1, col + 1).fill = cell.fill
                ws.cell(row + 1, col + 1).border = cell.border
                ws.cell(row + 1, col + 1).alignment = cell.alignment
                ws.cell(row + 1, col + 1).number_format = cell.number_format
                colspan = cell.span["col_span"]
                rowspan = cell.span["row_span"]
                if colspan > 1 and rowspan > 1:
                    ws.merge_cells(start_row=row + 1, start_column=col + 1, end_row=row + rowspan,
                                   end_column=col + colspan)
                elif colspan > 1:
                    ws.merge_cells(start_row=row + 1, start_column=col + 1, end_row=row + 1, end_column=col + colspan)
                elif rowspan > 1:
                    ws.merge_cells(start_row=row + 1, start_column=col + 1, end_row=row + rowspan, end_column=col + 1)
        file_name = '{}.xlsx'.format(table.table_id)
        save_path = os.path.join(args.root_dir, args.data_dir, args.structure_path, file_name)
        source = wb.active
        target = wb.copy_worksheet(source)
        target.title = 'original'
        wb.save(save_path)

def cellmatrix2structure(cell_matrix, num_topheaderrows, num_leftheadercols, max_rows, max_cols, hier_filename):
    table_stru = {}
    table_stru["max_rows"] = max_rows
    table_stru["max_cols"] = max_cols
    table_stru["cell_matrix"] = cell_matrix
    
    wb = Workbook()
    ws = wb.active
    ws.title = 'labeling'
    
    
    for row in range(max_rows):
        for col in range(max_cols):
            if col > 0 and cellleft != None: cellleft.right_cell = table_stru["cell_matrix"][row][col]
            cellleft = table_stru["cell_matrix"][row][col]

    top_header = [[table_stru["cell_matrix"][i][j] for j in range(num_leftheadercols, max_cols)] for i in
                  range(num_topheaderrows)]
    left_header = [[table_stru["cell_matrix"][i][j] for j in range(num_leftheadercols)] for i in
                   range(num_topheaderrows, max_rows)]

    # header hierarchy
    table_stru["top_header"] = extract_tophierarchy(top_header)
    table_stru["left_header"] = extract_lefthierarchy(left_header)

    # mark datacell 
    for row in range(max_rows):
        for col in range(max_cols):
            cell = table_stru["cell_matrix"][row][col]
            mark_hastopheader = False
            mark_hasleftheader = False
            for i in range(row):
                celltop = table_stru["cell_matrix"][i][col]
                if cell == None or celltop == None or celltop.span["start_row"] == cell.span["start_row"]: continue
                if celltop.mark_topheader:
                    mark_hastopheader = True
            for j in range(col):
                cellleft = table_stru["cell_matrix"][row][j]
                if cell == None or cellleft == None or cellleft.span["start_col"] == cell.span[
                    "start_col"]: continue
                if cellleft.mark_leftheader:
                    mark_hasleftheader = True
            if mark_hastopheader and mark_hasleftheader:
                cell.mark_datacell = True

    # data index
    for row in range(max_rows):
        for col in range(max_cols):
            cell = table_stru["cell_matrix"][row][col]
            if cell == None: continue
            if cell.mark_topheader and len(cell.child_cells) == 0 and len(cell.parent_cell) > 0:
                cell.topnode_datacolumn = col
            # left index
            if cell.mark_leftheader and cell.span["row_span"] == 1:
                mark_childsamerow = False
                for child in cell.child_cells:
                    if child.span["start_row"] == cell.span["start_row"]:
                        mark_childsamerow = True
                cellright = cell.right_cell
                while cellright != None and not cellright.mark_datacell:
                    cellright = cellright.right_cell
                if not mark_childsamerow and cellright != None:
                    cell.leftnode_datarow = row

    # save file
    for row in range(max_rows):
        for col in range(max_cols):
            cell = table_stru["cell_matrix"][row][col]
            if cell == None or not cell.iscorecell: continue
            if cell.mark_topheader:
                topprefix = str(cell.top_level) + " " + str(cell.topnode_datacolumn) + " "
                ws.cell(row + 1, col + 1).value = topprefix + cell.cell_string
            elif cell.mark_leftheader:
                leftprefix = str(cell.left_level) + " " + str(cell.leftnode_datarow) + " "
                ws.cell(row + 1, col + 1).value = leftprefix + cell.cell_string
            elif cell.mark_datacell:
                dataprefix = "data "
                ws.cell(row + 1, col + 1).value = dataprefix + cell.cell_string
            else:
                ws.cell(row + 1, col + 1).value = cell.cell_string
                
            ws.cell(row + 1, col + 1).font = syscopy.copy(cell.font)
            ws.cell(row + 1, col + 1).fill = syscopy.copy(cell.fill)
            ws.cell(row + 1, col + 1).border = syscopy.copy(cell.border)
            ws.cell(row + 1, col + 1).alignment = syscopy.copy(cell.alignment)
            ws.cell(row + 1, col + 1).number_format = syscopy.copy(cell.number_format)
            colspan = cell.span["col_span"]
            rowspan = cell.span["row_span"]
            if colspan > 1 and rowspan > 1:
                ws.merge_cells(start_row=row + 1, start_column=col + 1, end_row=row + rowspan,
                               end_column=col + colspan)
            elif colspan > 1:
                ws.merge_cells(start_row=row + 1, start_column=col + 1, end_row=row + 1, end_column=col + colspan)
            elif rowspan > 1:
                ws.merge_cells(start_row=row + 1, start_column=col + 1, end_row=row + rowspan, end_column=col + 1)
        
    source = wb.active
    target = wb.copy_worksheet(source)
    target.title = 'original'
    # wb.save(hier_filename)
    return table_stru

def table_to_struc(table: 'Table'):
    max_rows, max_cols = find_ranges(table)
    # print("Processing table {}, max_rows: {}; max_cols: {}".format(table.table_id, max_rows, max_cols))

    # init
    table_stru = {}
    table_stru["caption"] = prettify_caption(table.html.find('caption').text.strip().split('\n'))
    table_stru['title'] = prettify_title(table.html.find('caption').text.strip().split('\n'))
    table_stru["max_rows"] = max_rows
    table_stru["max_cols"] = max_cols
    table_stru["cell_matrix"] = [[None for j in range(max_cols)] for i in range(max_rows)]

    # cell matrix
    table_mask = np.zeros((max_rows, max_cols))  # row start at 3 (first 2 for caption)
    row, col = 0, 0
    for x in table.html.find_all('tr'):
        for y in x.find_all(['th', 'td']):
            if row >= max_rows:  # table-footer, currently not add footer
                continue
            font, fill, border, alignment, number_format = get_cell_style(y)

            colspan = get_int(y, 'colspan')
            rowspan = get_int(y, 'rowspan')
            row, col = find_position_zeroindex(table_mask, row, col, max_rows, max_cols)
            text = clear_footer(y)
            for internal_row in range(row, min(row + rowspan, max_rows)):
                for internal_col in range(col, min(col + colspan, max_cols)):
                    iscorecell = (row == internal_row) and (col == internal_col)
                    cell = Cell(text, alignment, font, fill, border, number_format, alignment.indent, font.bold,
                                fill.fgColor, \
                                fill.patternType,
                                [row, col, row + rowspan - 1, col + colspan - 1, internal_row, internal_col],
                                iscorecell)
                    table_stru["cell_matrix"][internal_row][internal_col] = cell
            table_mask[row: row + rowspan, col: col + colspan] = 1
        col = 0
        row += 1
    # detect header region
    num_topheaderrows, num_leftheadercols = extract_header(table_stru["cell_matrix"])
    table_stru['num_top_header_rows'] = num_topheaderrows
    table_stru['num_left_header_cols'] = num_leftheadercols
    # print(num_topheaderrows, num_leftheadercols)

    for row in range(max_rows):
        for col in range(max_cols):
            if col > 0 and cellleft != None: cellleft.right_cell = table_stru["cell_matrix"][row][col]
            cellleft = table_stru["cell_matrix"][row][col]

    top_header = [[table_stru["cell_matrix"][i][j] for j in range(num_leftheadercols, max_cols)] for i in
                      range(num_topheaderrows)]
    left_header = [[table_stru["cell_matrix"][i][j] for j in range(num_leftheadercols)] for i in
                       range(num_topheaderrows, max_rows)]

    # header hierarchy
    table_stru["top_header"] = extract_tophierarchy(top_header)
    table_stru["left_header"] = extract_lefthierarchy(left_header)

    # mark datacell 
    for row in range(max_rows):
        for col in range(max_cols):
            cell = table_stru["cell_matrix"][row][col]
            mark_hastopheader = False
            mark_hasleftheader = False
            for i in range(row):
                celltop = table_stru["cell_matrix"][i][col]
                if cell == None or celltop == None or celltop.span["start_row"] == cell.span["start_row"]: continue
                if celltop.mark_topheader:
                    mark_hastopheader = True
            for j in range(col):
                cellleft = table_stru["cell_matrix"][row][j]
                if cell == None or cellleft == None or cellleft.span["start_col"] == cell.span[
                    "start_col"]: continue
                if cellleft.mark_leftheader:
                    mark_hasleftheader = True
            if mark_hastopheader and mark_hasleftheader:
                    cell.mark_datacell = True

    # data index
    for row in range(max_rows):
        for col in range(max_cols):
            cell = table_stru["cell_matrix"][row][col]
            if cell == None: continue
            # top index
            if cell.mark_topheader and len(cell.child_cells) == 0 and len(cell.parent_cell) > 0:
                cell.topnode_datacolumn = col
            # left index
            if cell.mark_leftheader and cell.span["row_span"] == 1:
                mark_childsamerow = False
                for child in cell.child_cells:
                    if child.span["start_row"] == cell.span["start_row"]:
                        mark_childsamerow = True
                cellright = cell.right_cell
                while cellright != None and not cellright.mark_datacell:
                    cellright = cellright.right_cell
                if not mark_childsamerow and cellright != None:
                    cell.leftnode_datarow = row

    return table_stru



def main():
    print("---------------------Loading-----------------")
    table_dict = load_tables(args.root_dir, os.path.join(args.data_dir, 'html/'))
    print("Done.")

    print("---------------------Extracting-----------------")
    html2structure(table_dict)
    print("Done.")


if __name__ == '__main__':
    parser = ArgumentParser()
    parser.add_argument('--root_dir', type=str, default='/data/home/hdd3000/USER/HMT/')
    parser.add_argument('--data_dir', type=str, default='qa/data/')
    parser.add_argument('--structure_path', type=str, default='spreadsheet_stru/')
    parser.add_argument('--table_dir', type=str, default='raw_input/table/')
    args = parser.parse_args()
    
    main()
