"""Different linking strategy to construct the inputs.
#1. RAW: get the schema-linking annotated cells.
#2. FORMUAL: get both the schema-linking and formula parsed/heuristic cells.
"""

# %% general

_LINK_KEYS = {'top', 'left', 'corner', 'data'}   # type == set


def find_corecell(cell_matrix, irow, icol):
    """Find the top-left core-cell of the current cell span."""
    curr_cell = cell_matrix[irow][icol]

    srow = curr_cell.span['start_row']
    scol = curr_cell.span['start_col']
    core_cell = cell_matrix[srow][scol]
    assert core_cell.iscorecell

    return core_cell

def update_linked_dict(input_dict, linked_cells, cell_matrix, hdr_nrows, hdr_ncols):
    """Update 'linked_cells' with the 'input_dict'."""

    nrows = len(cell_matrix)
    ncols = len(cell_matrix[0])

    for (irow,icol), item in input_dict.items():
        if item is None: continue
        if (irow >= nrows) or (icol >= ncols): continue

        if (irow < hdr_nrows) and (icol < hdr_ncols):  # corner
            corner_cell = find_corecell(cell_matrix, irow, icol)
            # if corner_cell.cell_string != str(item):
            #     print(f"got unmatching CORNER [{corner_cell.cell_string}] and ITEM [{item}]")
            crow = corner_cell.span['start_row']
            ccol = corner_cell.span['start_col']
            if (crow, ccol) not in linked_cells['corner']:
                linked_cells['corner'][(crow, ccol)] = corner_cell
        elif (irow < hdr_nrows) and (icol >= hdr_ncols):  # top
            top_cell = find_corecell(cell_matrix, irow, icol)
            # if top_cell.cell_string != str(item):
            #     print(f"got unmatching TOP [{top_cell.cell_string}] and ITEM [{item}]")
            trow = top_cell.span['start_row']
            tcol = top_cell.span['start_col']
            if (trow, tcol) not in linked_cells['top']:
                linked_cells['top'][(trow, tcol)] = top_cell
        elif (irow >= hdr_nrows) and (icol < hdr_ncols):  # left
            left_cell = find_corecell(cell_matrix, irow, icol)
            # if left_cell.cell_string != str(item):
            #     print(f"got unmatching LEFT [{left_cell.cell_string}] and ITEM [{item}]")
            lrow = left_cell.span['start_row']
            lcol = left_cell.span['start_col']
            if (lrow, lcol) not in linked_cells['left']:
                linked_cells['left'][(lrow, lcol)] = left_cell
        elif (irow >= hdr_nrows) and (icol >= hdr_ncols):  # data
            data_cell = find_corecell(cell_matrix, irow, icol)
            # if data_cell.cell_string != str(item):
            #     print(f"got unmatching DATA [{data_cell.cell_string}] and ITEM [{item}]")
            drow = data_cell.span['start_row']
            dcol = data_cell.span['start_col']
            if (drow, dcol) not in linked_cells['data']:
                linked_cells['data'][(drow, dcol)] = data_cell
    return linked_cells

def link_schema_cells(cell_matrix, schema_dict, linked_cells, hdr_nrows, hdr_ncols):
    """Read schema linking from subdict, locate corresponding cells in structure. """
    for phrase, phrase_dict in schema_dict.items():
        linked_cells = update_linked_dict(
            phrase_dict, linked_cells, cell_matrix, hdr_nrows, hdr_ncols)
    return linked_cells

def link_answer_cells(cell_matrix, answer_cells, linked_cells, hdr_nrows, hdr_ncols):
    """Parse formula from subdict, extract out data regions mentioned.
    """
    linked_cells = update_linked_dict(
        input_dict=answer_cells, 
        linked_cells=linked_cells, 
        cell_matrix=cell_matrix, 
        hdr_nrows=hdr_nrows, 
        hdr_ncols=hdr_ncols
    )
    return linked_cells



def expand_path(header_dict, hdr_nrows, hdr_ncols):
    asc = []

    for (r, c), node in header_dict.items():
        # if (r > hdr_nrows) or (c > hdr_ncols):
        #     continue
        parent = node.parent_cell[0]
        while parent != 'virtual_root':
            pparent = parent.parent_cell[0]
            asc.append(parent)
            parent = pparent
    
    for parent in asc:
        pr = parent.span['start_row']
        pc = parent.span['start_col']
        if (pr, pc) in header_dict: continue
        header_dict[(pr, pc)] = parent
    return header_dict

def match_node_to_cell(node, cell):
    if (node.span == cell.span) and node.iscorecell: 
        return node

    for j, child in enumerate(node.child_cells):
        child_candidate = match_node_to_cell(child, cell)
        if child_candidate is not None: return child_candidate

    return None

def map_cell_to_leaf(root, cell):
    """Traverse from 'root' of the tree, for a map of 'node'."""
    for node in root['virtual_root']:
        candidate = match_node_to_cell(node, cell)
        if candidate is not None: return candidate
    return None

def augment_data_cells(linked_cells, structure, hdr_nrows, hdr_ncols):
    top_headers, left_headers = {}, {}
    cell_matrix = structure['cell_matrix']

    linked_data_cells = linked_cells['data']
    for (irow, icol), data_cell in linked_data_cells.items():
        # find top leaf
        top_cell = find_corecell(cell_matrix, hdr_nrows-1, icol)
        top_leaf = map_cell_to_leaf(structure['top_header'], top_cell)
        if (top_leaf is not None) and (top_leaf != 'virtual_root'): 
            top_irow = top_leaf.span['start_row']
            top_icol = top_leaf.span['start_col']
            core_top_leaf = find_corecell(cell_matrix, top_irow, top_icol)
            top_headers[(top_irow, top_icol)] = core_top_leaf
        # find left leaf
        left_cell = find_corecell(cell_matrix, irow, hdr_ncols-1)
        left_leaf = map_cell_to_leaf(structure['left_header'], left_cell)
        if (left_leaf is not None) and (left_leaf != 'virtual_root') and left_leaf.iscorecell: 
            left_irow = left_leaf.span['start_row']
            left_icol = left_leaf.span['start_col']
            core_left_leaf = find_corecell(cell_matrix, left_irow, left_icol)
            left_headers[(left_irow, left_icol)] = core_left_leaf
    
    top_headers = expand_path(top_headers, hdr_nrows, hdr_ncols)
    linked_cells['top'].update(top_headers)
    left_headers = expand_path(left_headers, hdr_nrows, hdr_ncols)
    linked_cells['left'].update(left_headers)
    return linked_cells



def find_ascendants(node):
    asc = []

    parent = node.parent_cell[0]
    while parent != 'virtual_root':
        pparent = parent.parent_cell[0]
        asc.append(parent)
        parent = pparent

    return asc

def augment_data_paths(linked_cells, structure, hdr_nrows, hdr_ncols):
    path_dicts = []

    top_headers, left_headers = {}, {}
    cell_matrix = structure['cell_matrix']

    linked_data_cells = linked_cells['data']
    for (irow, icol), data_cell in linked_data_cells.items():
        data_path_dict = {'data': data_cell, 'top': [], 'left': []}

        # find top leaf
        top_cell = find_corecell(cell_matrix, hdr_nrows-1, icol)
        top_leaf = map_cell_to_leaf(structure['top_header'], top_cell)
        if (top_leaf is not None) and (top_leaf != 'virtual_root'): 
            top_irow = top_leaf.span['start_row']
            top_icol = top_leaf.span['start_col']
            core_top_leaf = find_corecell(cell_matrix, top_irow, top_icol)
            top_path = find_ascendants(core_top_leaf)
            data_path_dict['top'] = top_path
        # find left leaf
        left_cell = find_corecell(cell_matrix, irow, hdr_ncols-1)
        left_leaf = map_cell_to_leaf(structure['left_header'], left_cell)
        if (left_leaf is not None) and (left_leaf != 'virtual_root') and left_leaf.iscorecell: 
            left_irow = left_leaf.span['start_row']
            left_icol = left_leaf.span['start_col']
            core_left_leaf = find_corecell(cell_matrix, left_irow, left_icol)
            left_path = find_ascendants(core_left_leaf)
            data_path_dict['left'] = left_path 

        path_dicts.append(data_path_dict)    

    for (irow, icol), top_cell in linked_cells['top'].items():
        top_path_dict = {'top': [top_cell], 'left': [], 'data': None}   
        path_dicts.append(top_path_dict)    
    for (irow, icol), left_cell in linked_cells['left'].items():
        left_path_dict = {'left': [left_cell], 'top': [], 'data': None}   
        path_dicts.append(left_path_dict)    

    return path_dicts



# %% raw

def get_anno_linked_cells(structure, subdict):
    """Link the raw annotated cells from schema not formula."""

    hdr_nrows = structure['num_top_header_rows']
    hdr_ncols = structure['num_left_header_cols']

    raw_cells = {k: {} for k in _LINK_KEYS}

    raw_cells = link_schema_cells(
        cell_matrix=structure['cell_matrix'], 
        schema_dict=subdict['schema_link'], 
        linked_cells=raw_cells, 
        hdr_nrows=hdr_nrows, 
        hdr_ncols=hdr_ncols
    )
    return raw_cells



def get_raw_linked_cells(structure, subdict):
    """Link the raw annotated cells from schema not formula."""

    hdr_nrows = structure['num_top_header_rows']
    hdr_ncols = structure['num_left_header_cols']

    raw_cells = {k: {} for k in _LINK_KEYS}

    raw_cells = link_schema_cells(
        cell_matrix=structure['cell_matrix'], 
        schema_dict=subdict['schema_link'], 
        linked_cells=raw_cells, 
        hdr_nrows=hdr_nrows, 
        hdr_ncols=hdr_ncols
    )
    raw_cells_augmented = augment_data_cells(
        raw_cells, structure, hdr_nrows, hdr_ncols)
    raw_paths = augment_data_paths(
        raw_cells, structure, hdr_nrows, hdr_ncols)
    return raw_cells_augmented, raw_paths



# %% formulaic

from processing.operation.args import formulaic_new_cells


def get_formula_linked_cells(structure, subdict):
    hdr_nrows = structure['num_top_header_rows']
    hdr_ncols = structure['num_left_header_cols']

    formula_cells = {k: {} for k in _LINK_KEYS}
    formula_cells = link_schema_cells(
        cell_matrix=structure['cell_matrix'], 
        schema_dict=subdict['schema_link'], 
        linked_cells=formula_cells, 
        hdr_nrows=hdr_nrows, 
        hdr_ncols=hdr_ncols
    )
    formula_cells = link_answer_cells(
        cell_matrix=structure['cell_matrix'], 
        answer_cells=subdict['answer_cells'], 
        linked_cells=formula_cells, 
        hdr_nrows=hdr_nrows, 
        hdr_ncols=hdr_ncols
    )

    # formula heuristic
    if any([('arg' in op) for op in subdict['aggregation']]):
        fnew_cells_dict = formulaic_new_cells(structure, subdict, formula_cells)
        formula_cells = update_linked_dict(
            input_dict=fnew_cells_dict, 
            linked_cells=formula_cells, 
            cell_matrix=structure['cell_matrix'], 
            hdr_nrows=hdr_nrows, 
            hdr_ncols=hdr_ncols
        )

    formula_cells_augmented = augment_data_cells(
        formula_cells, structure, hdr_nrows, hdr_ncols)
    formula_paths = augment_data_paths(
        formula_cells, structure, hdr_nrows, hdr_ncols)
    return formula_cells_augmented, formula_paths
