"""ARG-related operators."""

# %% answer

from processing.position_utils import LinkedPosition

def check_unique_answer(answer_dict, cell_matrix):
    """Check if answer meet the 1-number/length requirement."""
    answer_position, answer_cell = None, None

    answer_cm_coords = list(answer_dict.keys())[0]
    answer_position = LinkedPosition.from_cm_coords(answer_cm_coords)
    arow, acol = answer_cm_coords
    try:
        answer_cell = cell_matrix[arow][acol]
    except: 
        answer_cell = None

    return (answer_position, answer_cell)


# %% orientation

def find_orientation(apos, linked_cells):
    """Find the orientation of the argmax/argmin. 
    Either 'top' or 'left'.
    """
    arow, acol = apos.cm_coords
    top, left = linked_cells['top'], linked_cells['left']
    if (arow, acol) in top: return 'horizontal'    # single dim row
    if (arow, acol) in left: return 'vertical'     # dim column

    return None


def get_keys(orientation):
    """Get keys in linked cells for classes and dimension."""

    if orientation == 'horizontal':
        class_key = 'top'
        dim_key = 'left'
    elif orientation == 'vertical':
        class_key = 'left'
        dim_key = 'top'
    else:
        raise ValueError(f'[get-keys] reads unexpected ori: {orientation}')

    return class_key, dim_key



# %% argument header/classes

def get_ascendents(cell):
    """Return all of the ascendent of the input cell, as a list."""
    ascendents = {}

    pnodes = cell.parent_cell
    pnode = pnodes[0]
    while pnode != 'virtual_root':
        # if len(pnodes) != 1: print('oops.. got multiple parents..')
        airow = pnode.span['start_row']
        aicol = pnode.span['start_col']
        ascendents[(airow, aicol)] = pnode

        pnodes = pnode.parent_cell
        pnode = pnodes[0]
    return ascendents
    

def find_common_parent(cell_list):
    """Hope to find a common parent among all of the 'Cell' items."""
    joint_asc = None   # set of 'Cell' items

    for cell in cell_list:
        cell_asc = get_ascendents(cell)
        if joint_asc is None: joint_asc = set(cell_asc)
        else: joint_asc = joint_asc.intersection( set(cell_asc) )
    
    # if len(joint_asc) != 1: 
    #     print(f'find multiple candidate parent: {joint_asc}..') 
    return list(joint_asc)


def find_class_from_answer_child(header_dict, answer_cell):
    """Find the class-name and classes from the answer child among classes."""
    assert len(header_dict) == 1

    answer_child = list(header_dict.values())[0]
    assert answer_child == answer_cell, 'Unmatched answer and header-dict..'
    assert len(answer_child.parent_cell) == 1, "Answer Child has Non-Unique Parent Class Name.."

    class_name = answer_child.parent_cell[0]
    if class_name == 'virtual_root': 
        return None
    
    classes = class_name.child_cells
    assert answer_child in classes

    cls_dict = {
        'class_name': [class_name], 
        'classes': classes, 
        'answer': [answer_child], 
    }
    return cls_dict


def heuristic_class_split(header_dict, cell_matrix, answer_cell, sort_row_first):
    """Heurstically treat the sorted first cell as parent/class-name."""
    assert len(header_dict) > 1   # at least 1 child except for the single parent

    coord_list = list(header_dict.keys())
    if sort_row_first:
        coord_list = sorted(coord_list, key=lambda x: (x[0], x[1]))
    else:
        coord_list = sorted(coord_list, key=lambda x: (x[1], x[0]))
    
    # 'Cell' items
    cand_class_name = [ header_dict[coord_list[0]] ]   # find using (irow, icol)
    cand_classes = [header_dict[cm_coords] for cm_coords in coord_list[1:]]
    joint_parents = find_common_parent(cand_classes)
    # if len(joint_parents) != 1 or coord_list[0] != joint_parents:
    if coord_list[0] not in joint_parents:
        cand_classes = [header_dict[cm_coords] for cm_coords in coord_list]
        cand_class_name_coords = find_common_parent(cand_classes) 
        cand_class_name = [cell_matrix[r][c] for r,c in cand_class_name_coords]
    # else: print('good class organization')       
    
    # class_strs = [c.cell_string for c in cand_classes]
    # if answer_cell.cell_string not in class_strs: 
    #     print('[heuristic-class-split] cannot find answer in classes..')

    cls_dict = {
        'class_name': cand_class_name, 
        'classes': cand_classes, 
    }
    return cls_dict


def trim_classes(classes, sort_row_first):
    """Remove children not at the same row (top)/column(left) against others.
    if is top-header-dict, then sort-row-first; otherwise (left) not sort-row-first.
    """
    if sort_row_first:   # within top header
        indices = [c.span['start_row'] for c in classes]
        index = max(set(indices), key=indices.count)
        trimmed_classes = [
            c for c in classes
            if c.span['start_row'] == index
        ]
    else:                # within left header
        indices = [c.span['start_col'] for c in classes]
        index = max(set(indices), key=indices.count)
        trimmed_classes = [
            c for c in classes
            if c.span['start_col'] == index
        ]
    return trimmed_classes


def parse_tree_snippet(header_dict, cell_matrix, answer_cell, sort_row_first):
    """Heuristically split a dictionary of header cells into 'parent' and 'children'.
    1. len(header_dict) == 0, raise error
    2. len(header_dict) == 1, treat as the answer child
    2. len(header_dict) >  1
    2.1 multi-children
    2.2 one-parent + (one/multi-)children

    Args:
        header_dict: {(irow, icol): cell}
        orientation: 'top' or 'left'
    Returns:
        parent_dict: len(*) == 1, subdict of header_dict
        children_dict: subdict of header_dict, complement of parent_dict
    """
    n = len(header_dict)
    assert n > 0

    if n == 1: 
        return find_class_from_answer_child(header_dict, answer_cell)
    
    cls_dict = heuristic_class_split(
        header_dict, cell_matrix, answer_cell, sort_row_first)
    return cls_dict


def organize_class(header_dict, cell_matrix, answer_cell, sort_row_first):   
    """Find out the single 'parent' and multiple 'children' that constitute a class.
    if is top-header-dict, then sort-row-first; otherwise (left) not sort-row-first.
    """
    cls_dict = parse_tree_snippet(header_dict, cell_matrix, answer_cell, sort_row_first)
    if (cls_dict is None): 
        return cls_dict

    trimmed_classes = trim_classes(cls_dict['classes'], sort_row_first)
    if len(trimmed_classes) != len(cls_dict['classes']):
        cls_dict['classes'] = trimmed_classes

    return cls_dict



# %% value dimension

def get_dimension(header_dict, orientation):
    """Get unique dimension cell and/or error message."""

    cm_coords_list = sorted(list(header_dict.keys()))
    row_indices = [row_index for (row_index, _) in cm_coords_list]
    col_indices = [col_index for (_, col_index) in cm_coords_list]

    if orientation == 'vertical':         
        col_set = set(col_indices)
        if len(col_set) == 0:
            # print(f'find NO DIMension COLUMNs..')
            return None
        
        if len(col_set) > 1:
            # print(f"find Multiple DIMension COLUMNs {col_set}")
            col_indices = sorted(list(col_set))
        dim_col_index = col_indices[-1]
        cand_dim_coords = sorted([(r,c) for (r,c) in cm_coords_list if (c == dim_col_index)])
        dimension_cell = header_dict[cand_dim_coords[-1]]

    elif orientation == 'horizontal':
        row_set = set(row_indices)
        if len(row_set) == 0:
            return None
        
        if len(row_set) > 1:
            # print(f"find Multiple DIMension ROWs {row_set}")
            row_indices = sorted(list(row_set))
        dim_row_index = row_indices[-1]
        cand_dim_coords = sorted([(r,c) for (r,c) in cm_coords_list if (r == dim_row_index)])
        dimension_cell = header_dict[cand_dim_coords[-1]]
    
    dimension = [dimension_cell]
    return dimension



# %% data

def cross_source_data(class_coords, dim_coords, orientation):
    """Select the data using class coordinates and a unique dimension.
    'orientation': 'horizontal'(dim in 'left') OR 'vertical'(dim in 'top')
    """
    if orientation == 'vertical': 
        row_indices = set([lirow for (lirow, licol) in class_coords])
        col_indices = set([ticol for (tirow, ticol) in dim_coords])
    elif orientation == 'horizontal':
        col_indices = set([licol for (lirow, licol) in class_coords])
        row_indices = set([tirow for (tirow, ticol) in dim_coords])
    else:
        row_indices, col_indices = [], []
        # print(f"orientation [{orientation}] not supported..")
    
    data_coords = [(r, c) for r in row_indices for c in col_indices]
    return data_coords


def select_data_region(cell_matrix, classes, dimension, orientation):
    """Select the target input region for argmax.
    'orientation': 'horizontal'(dim in 'left') OR 'vertical'(dim in 'top')
    """
    data_coords = cross_source_data(
        class_coords=[(c.span['start_row'], c.span['start_col']) for p,c in classes], 
        dim_coords=[(d.span['start_row'], d.span['start_col']) for p,d in dimension], 
        orientation=orientation
    )
    data_cells = [
        cell_matrix[dirow][dicol]
        for (dirow, dicol) in data_coords
    ]
    data_cells = [dc for dc in data_cells if (dc.cell_string != '')]

    if len(data_cells) == 0:
        data_cells = None
    elif len(data_cells) != len(classes):
        data_cells = None
    return data_cells




# %% formulaic herustic linking

from processing.position_utils import get_formula_argument


def formulaic_new_cells(structure, subdict, linked_cells):
    """Complete linked-cells from annotation, using the formula heuristics."""

    cell_matrix = structure['cell_matrix'] 

    # check unique answer
    ans_pos, ans_cell = check_unique_answer(
        answer_dict=subdict['answer_cells'], 
        cell_matrix=cell_matrix
    )
    if (ans_pos is None) or (ans_cell is None): 
        print(f'[args>>complete_cells] NOT any answer..')
        return {}

    # find orientation
    ori = find_orientation(ans_pos, linked_cells)
    if (ori is None): 
        print(f'[args>>complete_cells] NOT viable orientation..')
        return {}
    class_key, dim_key = get_keys(ori)

    # then get classes and dimension
    sort_row_first = (class_key == 'top')
    cls_dict = organize_class(
        header_dict=linked_cells[class_key], 
        cell_matrix=cell_matrix, 
        answer_cell=ans_cell, 
        sort_row_first=sort_row_first
    )
    if (cls_dict is None): 
        print(f'[args>>complete_cells] NOT viable classes..')
        return {}
    class_name = get_formula_argument(cls_dict['class_name'])
    class_args = get_formula_argument(cls_dict['classes'])

    dimension = get_dimension(
        header_dict=linked_cells[dim_key], 
        orientation=ori
    )
    if (dimension is None): 
        print(f'[args>>complete_cells] NOT viable dimension..')
        return {}
    dim_args = get_formula_argument(dimension)

    # find cross-indexed data cells
    data_cells = select_data_region(
        cell_matrix=cell_matrix, 
        classes=class_args, 
        dimension=dim_args, 
        orientation=ori
    )
    
    new_cells = dimension
    if data_cells is not None:
        new_cells = new_cells + data_cells
    if class_name is not None:
        new_cells = new_cells + [pair[1] for pair in class_name]
    if class_args is not None:
        new_cells = new_cells + [pair[1] for pair in class_args]
    new_cells_dict = {}
    for nc in new_cells:
        r = nc.span['start_row']
        c = nc.span['start_col']
        new_cells_dict[(r,c)] = nc
    return new_cells_dict
