"""Checking procedures for the ARG-based functions.
e.g. ARGMAX, ARGMIN; PAIR-ARGMAX, PAIR-ARGMIN; 
     TOPK-ARGMAX, TOPK-ARGMIN; KTH-ARGMAX, KTH-ARGMIN

group the linked cells into top, left, data, corner
find the comparison classes and value dimension
match a list of headers with a list of datas
format the argmax formula
"""

from preprocess.align import link_table_cells
from preprocess.position_utils import LinkedPosition, get_formula_argument
from preprocess.operation.checkop_utils import ( CheckOpMsg, 
    _OP_ANSWER, _OP_CLASS, _OP_DIMENSION, _OP_DATA, 
    _OP_ERROR_NONE, _OP_ERROR_MULTI, _OP_ERROR_PLACE, _OP_ERROR_NUM, _OP_ERROR_GENERAL
)
from preprocess.operation.formula_utils import ArgumentFormula



# preliminary

def check_unique_answer(answer_dict, cell_matrix):
    """Check if answer meet the 1-number/length requirement."""
    answer_position, answer_cell = None, None
    ans_msg = None

    if len(answer_dict) == 1:
        answer_cm_coords = list(answer_dict.keys())[0]
        answer_position = LinkedPosition.from_cm_coords(answer_cm_coords)
        arow, acol = answer_cm_coords
        try:
            answer_cell = cell_matrix[arow][acol]
        except: 
            ans_msg = CheckOpMsg(_OP_ANSWER, _OP_ERROR_GENERAL)
    elif len(answer_dict) == 0:
        # print('no available ANSWER links..')
        ans_msg = CheckOpMsg(_OP_ANSWER, _OP_ERROR_NONE)
    else:   # len(answer_dict) > 1
        # print('multiple/not unique ANSWER links..')
        ans_msg = CheckOpMsg(_OP_ANSWER, _OP_ERROR_MULTI)
    return (answer_position, answer_cell), ans_msg



def find_orientation(apos, linked_cells):
    """Find the orientation of the argmax/argmin. 
    Either 'top' or 'left'.
    """
    arow, acol = apos.cm_coords
    top, left = linked_cells['top'], linked_cells['left']
    if (arow, acol) in top: return 'horizontal', None    # single dim row
    if (arow, acol) in left: return 'vertical', None     # dim column

    corner, data = linked_cells['corner'], linked_cells['data']
    # if (arow, acol) in corner: print('oops..answer falls into the CORNER..')
    # if (arow, acol) in data: print('oops..answer falls into the DATA..')
    ori_msg = CheckOpMsg(_OP_ANSWER, _OP_ERROR_PLACE)
    return None, ori_msg
    
def get_keys(orientation):
    """Get keys in linked cells for classes and dimension."""
    if orientation == 'horizontal':
        class_key = 'top'
        dim_key = 'left'
    elif orientation == 'vertical':
        class_key = 'left'
        dim_key = 'top'
    else:
        raise ValueError(f'[get-keys] reads unexpected ori: {orientation}')
    return class_key, dim_key

# organize classes to compare

def get_ascendents(cell):
    """Return all of the ascendent of the input cell, as a list."""
    ascendents = {}

    pnodes = cell.parent_cell
    pnode = pnodes[0]
    while pnode != 'virtual_root':
        # if len(pnodes) != 1: print('oops.. got multiple parents..')
        airow = pnode.span['start_row']
        aicol = pnode.span['start_col']
        ascendents[(airow, aicol)] = pnode

        pnodes = pnode.parent_cell
        pnode = pnodes[0]
    return ascendents
    

def find_class_from_answer_child(header_dict, answer_cell):
    """Find the class-name and classes from the answer child among classes."""
    assert len(header_dict) == 1

    answer_child = list(header_dict.values())[0]
    assert answer_child == answer_cell, 'Unmatched answer and header-dict..'
    assert len(answer_child.parent_cell) == 1, "Answer Child has Non-Unique Parent Class Name.."

    class_name = answer_child.parent_cell[0]
    if class_name == 'virtual_root': 
        cls_msg = CheckOpMsg(_OP_CLASS, _OP_ERROR_GENERAL)
        return None, cls_msg
    
    classes = class_name.child_cells
    assert answer_child in classes

    cls_dict = {
        'class_name': [class_name], 
        'classes': classes, 
        'answer': [answer_child], 
    }
    return cls_dict, None


def find_common_parent(cell_list):
    """Hope to find a common parent among all of the 'Cell' items."""
    joint_asc = None   # set of 'Cell' items

    for cell in cell_list:
        cell_asc = get_ascendents(cell)
        if joint_asc is None: joint_asc = set(cell_asc)
        else: joint_asc = joint_asc.intersection( set(cell_asc) )
    
    # if len(joint_asc) != 1: 
    #     print(f'find multiple candidate parent: {joint_asc}..') 
    return list(joint_asc)


def heuristic_class_split(header_dict, cell_matrix, answer_cell, sort_row_first):
    """Heurstically treat the sorted first cell as parent/class-name."""
    assert len(header_dict) > 1   # at least 1 child except for the single parent

    coord_list = list(header_dict.keys())
    if sort_row_first:
        coord_list = sorted(coord_list, key=lambda x: (x[0], x[1]))
    else:
        coord_list = sorted(coord_list, key=lambda x: (x[1], x[0]))
    
    # 'Cell' items
    cand_class_name = [ header_dict[coord_list[0]] ]   # find using (irow, icol)
    cand_classes = [header_dict[cm_coords] for cm_coords in coord_list[1:]]
    joint_parents = find_common_parent(cand_classes)
    # if len(joint_parents) != 1 or coord_list[0] != joint_parents:
    if coord_list[0] not in joint_parents:
        cand_classes = [header_dict[cm_coords] for cm_coords in coord_list]
        cand_class_name_coords = find_common_parent(cand_classes) 
        cand_class_name = [cell_matrix[r][c] for r,c in cand_class_name_coords]
    # else: print('good class organization')       
    
    class_strs = [c.cell_string for c in cand_classes]
    # if answer_cell.cell_string not in class_strs: 
    #     print('[heuristic-class-split] cannot find answer in classes..')

    cls_dict = {
        'class_name': cand_class_name, 
        'classes': cand_classes, 
    }
    return cls_dict, None


def parse_tree_snippet(header_dict, cell_matrix, answer_cell, sort_row_first):
    """Heuristically split a dictionary of header cells into 'parent' and 'children'.
    1. len(header_dict) == 0, raise error
    2. len(header_dict) == 1, treat as the answer child
    2. len(header_dict) >  1
    2.1 multi-children
    2.2 one-parent + (one/multi-)children

    Args:
        header_dict: {(irow, icol): cell}
        orientation: 'top' or 'left'
    Returns:
        parent_dict: len(*) == 1, subdict of header_dict
        children_dict: subdict of header_dict, complement of parent_dict
    """
    n = len(header_dict)
    assert n > 0

    if n == 1: 
        return find_class_from_answer_child(header_dict, answer_cell)
    
    cls_dict, split_msg = heuristic_class_split(
        header_dict, cell_matrix, answer_cell, sort_row_first)
    return cls_dict, split_msg


def trim_classes(classes, sort_row_first):
    """Remove children not at the same row (top)/column(left) against others.
    if is top-header-dict, then sort-row-first; otherwise (left) not sort-row-first.
    """
    if sort_row_first:   # within top header
        indices = [c.span['start_row'] for c in classes]
        index = max(set(indices), key=indices.count)
        trimmed_classes = [
            c for c in classes
            if c.span['start_row'] == index
        ]
    else:                # within left header
        indices = [c.span['start_col'] for c in classes]
        index = max(set(indices), key=indices.count)
        trimmed_classes = [
            c for c in classes
            if c.span['start_col'] == index
        ]
    return trimmed_classes

def organize_class(header_dict, cell_matrix, answer_cell, sort_row_first):   
    """Find out the single 'parent' and multiple 'children' that constitute a class.
    if is top-header-dict, then sort-row-first; otherwise (left) not sort-row-first.
    """
    cls_dict, split_msg = parse_tree_snippet(header_dict, cell_matrix, answer_cell, sort_row_first)
    if (cls_dict is None) and (split_msg is not None): 
        return cls_dict, split_msg

    trimmed_classes = trim_classes(cls_dict['classes'], sort_row_first)
    if len(trimmed_classes) != len(cls_dict['classes']):
        cls_dict['classes'] = trimmed_classes

    return cls_dict, split_msg



# find unique dimension
def get_dimension(header_dict, orientation):
    """Get unique dimension cell and/or error message."""
    dim_msg = None

    cm_coords_list = sorted(list(header_dict.keys()))
    row_indices = [row_index for (row_index, _) in cm_coords_list]
    col_indices = [col_index for (_, col_index) in cm_coords_list]

    if orientation == 'vertical':         
        col_set = set(col_indices)
        if len(col_set) == 0:
            # print(f'find NO DIMension COLUMNs..')
            dim_msg = CheckOpMsg(_OP_DIMENSION, _OP_ERROR_NONE)
            return None, dim_msg
        
        if len(col_set) > 1:
            # print(f"find Multiple DIMension COLUMNs {col_set}")
            dim_msg = CheckOpMsg(_OP_DIMENSION, _OP_ERROR_MULTI)
            col_indices = sorted(list(col_set))
        dim_col_index = col_indices[-1]
        cand_dim_coords = sorted([(r,c) for (r,c) in cm_coords_list if (c == dim_col_index)])
        dimension_cell = header_dict[cand_dim_coords[-1]]

    elif orientation == 'horizontal':
        row_set = set(row_indices)
        if len(row_set) == 0:
            dim_msg = CheckOpMsg(_OP_DIMENSION, _OP_ERROR_NONE)
            return None, dim_msg
        
        if len(row_set) > 1:
            # print(f"find Multiple DIMension ROWs {row_set}")
            dim_msg = CheckOpMsg(_OP_DIMENSION, _OP_ERROR_MULTI)
            row_indices = sorted(list(row_set))
        dim_row_index = row_indices[-1]
        cand_dim_coords = sorted([(r,c) for (r,c) in cm_coords_list if (r == dim_row_index)])
        dimension_cell = header_dict[cand_dim_coords[-1]]
    
    dimension = [dimension_cell]
    return dimension, dim_msg


def cross_source_data(class_coords, dim_coords, orientation):
    """Select the data using class coordinates and a unique dimension.
    'orientation': 'horizontal'(dim in 'left') OR 'vertical'(dim in 'top')
    """
    if orientation == 'vertical': 
        row_indices = set([lirow for (lirow, licol) in class_coords])
        col_indices = set([ticol for (tirow, ticol) in dim_coords])
    elif orientation == 'horizontal':
        col_indices = set([licol for (lirow, licol) in class_coords])
        row_indices = set([tirow for (tirow, ticol) in dim_coords])
    else:
        row_indices, col_indices = [], []
        # print(f"orientation [{orientation}] not supported..")
    
    data_coords = [(r, c) for r in row_indices for c in col_indices]
    return data_coords


def select_data_region(cell_matrix, classes, dimension, orientation):
    """Select the target input region for argmax.
    'orientation': 'horizontal'(dim in 'left') OR 'vertical'(dim in 'top')
    """
    data_coords = cross_source_data(
        class_coords=[(c.span['start_row'], c.span['start_col']) for p,c in classes], 
        dim_coords=[(d.span['start_row'], d.span['start_col']) for p,d in dimension], 
        orientation=orientation
    )
    data_cells = [
        cell_matrix[dirow][dicol]
        for (dirow, dicol) in data_coords
    ]
    data_cells = [dc for dc in data_cells if (dc.cell_string != '')]

    data_msg = None
    if len(data_cells) == 0:
        data_cells = None
        data_msg = CheckOpMsg(_OP_DATA, _OP_ERROR_NONE)
    elif len(data_cells) != len(classes):
        data_msg = CheckOpMsg(_OP_DATA, _OP_ERROR_NUM)
        data_cells = None
    return data_cells, data_msg


_ARGMAX = 'ARGMAX'
_ARGMIN = 'ARGMIN'


def get_excel_string(cell_dict):
    """cell_dict = {(irow, icol): cellstr}"""
    cellstr_list = [cell.cell_string for cm_coords, cell in cell_dict.items()]
    return ','.join(cellstr_list)

def format_arg_formula(function_name, class_cells, data_cells):
    """Format an excel formula given the function name, class, and data cells."""
    class_formula = get_excel_string(class_cells)
    data_formula = get_excel_string(data_cells)
    return f'{function_name}({class_formula}; {data_formula})'



# %% ARG checking function

def check_argument(structure, subdict, operator):
    formula = ArgumentFormula(
        start_row=subdict['start_row'], 
        name=operator.upper()
    )

    cell_matrix = structure['cell_matrix']
    # check unique answer
    (ans_pos, ans_cell), _ = check_unique_answer(
        subdict['answer_cells'], cell_matrix)
    if (ans_pos is None) or (ans_cell is None): 
        return formula
    formula.answer = ans_cell.cell_string

    linked_cells = link_table_cells(structure, subdict)
    ori, _ = find_orientation(ans_pos, linked_cells)
    if (ori is None): return formula
    class_key, dim_key = get_keys(ori)

    cls_dict, _ = organize_class(
        header_dict=linked_cells[class_key], 
        cell_matrix=cell_matrix, 
        answer_cell=ans_cell, 
        sort_row_first=(class_key == 'top')
    )
    if (cls_dict is None): return formula
    class_name = get_formula_argument(cls_dict['class_name'])
    class_args = get_formula_argument(cls_dict['classes'])

    dimension, _ = get_dimension(
        header_dict=linked_cells[dim_key], 
        orientation=ori
    )
    if (dimension is None): return formula
    dim_args = get_formula_argument(dimension)

    data_cells, _ = select_data_region(
        cell_matrix=cell_matrix, 
        classes=class_args, 
        dimension=dim_args, 
        orientation=ori
    )
    if (data_cells is None): return formula
    data_args = get_formula_argument(data_cells)

    formula.get_formula(class_args, data_args)

    return formula


# %% Complete linked cells

def check_unique_answer_loose(answer_dict, cell_matrix):
    """Check if answer meet the 1-number/length requirement."""
    answer_position, answer_cell = None, None

    if len(answer_dict) > 0:
        answer_cm_coords = list(answer_dict.keys())[0]
        answer_position = LinkedPosition.from_cm_coords(answer_cm_coords)
        arow, acol = answer_cm_coords
        try:
            answer_cell = cell_matrix[arow][acol]
        except: 
            return None, None
    return (answer_position, answer_cell)


def complete_cells(structure, subdict, linked_cells):
    """Complete linked-cells from annotation, using the formula heuristics."""

    cell_matrix = structure['cell_matrix'] 

    # check unique answer
    (ans_pos, ans_cell), _ = check_unique_answer(
        answer_dict=subdict['answer_cells'], 
        cell_matrix=cell_matrix
    )
    if (ans_pos is None) or (ans_cell is None): 
        print(f'[args>>complete_cells] NOT any answer..')
        return linked_cells

    # find orientation
    ori, _ = find_orientation(ans_pos, linked_cells)
    if (ori is None): 
        print(f'[args>>complete_cells] NOT viable orientation..')
        return linked_cells
    class_key, dim_key = get_keys(ori)

    # then get classes and dimension
    sort_row_first = (class_key == 'top')
    cls_dict, _ = organize_class(
        header_dict=linked_cells[class_key], 
        cell_matrix=cell_matrix, 
        answer_cell=ans_cell, 
        sort_row_first=sort_row_first
    )
    if (cls_dict is None): 
        print(f'[args>>complete_cells] NOT viable classes..')
        return linked_cells
    class_args = get_formula_argument(cls_dict['classes'])

    dimension, _ = get_dimension(
        header_dict=linked_cells[dim_key], 
        orientation=ori
    )
    if (dimension is None): 
        print(f'[args>>complete_cells] NOT viable dimension..')
        return linked_cells
    dim_args = get_formula_argument(dimension)

    # find cross-indexed data cells
    data_cells, _ = select_data_region(
        cell_matrix=cell_matrix, 
        classes=class_args, 
        dimension=dim_args, 
        orientation=ori
    )
    if (data_cells is None): 
        print(f'[args>>complete_cells] NOT viable data..')
        return linked_cells
    
    # update data cells into the linked-cells dict
    if len(data_cells) > 0:
        print(f'[args>>complete_cells] got {len(data_cells)} new data cells!')
    for dcell in data_cells:
        drow = dcell.span['start_row']
        dcol = dcell.span['start_col']
        if (drow,dcol) not in linked_cells['data']:
            linked_cells['data'][(drow,dcol)] = dcell
            # print(f'[new] at {(drow,dcol)} reads {dcell.cell_string}')
        else:
            print(f'[old] at {(drow,dcol)} reads {dcell.cell_string}')
    return linked_cells

