"""Checking functions for the GROUPING type of operators.

1.  If has non-empty and non-unique answer/data
    cuz it requires at least two cells for the op to make sense
2.  Require numeric values for linked data cells
"""

from preprocess.align import link_table_cells
from preprocess.number_utils import parse_number, NormalNumber
from preprocess.position_utils import LinkedPosition
from preprocess.operation.formula_utils import GroupFormula, BiTermFormula, _CALC_FUNCS



# for answer cells

def check_answer(answer_dict):
    """@USER: WHAT requirements do we have on the answer?"""
    return True



# process and/or filter data cells

def remove_nonum_data(data_dict):
    """Remove the non-numeric data cells from the addition."""
    linked_positions, data_nums = [], []
    for cm_coords, data_cell in data_dict.items():
        datastr = data_cell.cell_string
        datastr = ''.join(datastr.split(','))
        datanum = parse_number(datastr)
        if datanum is None: continue

        pos = LinkedPosition.from_cm_coords(cm_coords)
        linked_positions.append(pos)
        data_nums.append(datanum)
    return linked_positions, data_nums



# %% calculation

# %% checking functions
def check_group(structure, subdict, operator):
    """Check and Refine if a sample of workable struc-subsent pair.
    Args:
        subdict: {'start_row', 'sub_sent', 'key_part', 'key_index', 
            'schema_link', 'answer_cells', 'answer', 'aggregation'}
        operator: one of ['sum', 'average', 'max', 'min', 'counta']
    Returns:
        formula: =SUM(A1, B2) or =SUM(C1:C5)
        # error_message: str
    """

    # if operator not in _CALC_FUNCS: raise ValueError(f'OP {operator} invalid.')

    # initialize an empty formula with worksheet location and default error message
    formula = GroupFormula(
        start_row=subdict['start_row'], name=operator.upper())  

    # @USER [TODO]: should check 'answer' or 'data', or stat the diff between them??
    if len(subdict['answer_cells']) < 2: return formula

    linked_cells = link_table_cells(structure, subdict)
    data_pos, data_nums = remove_nonum_data(linked_cells['data'])
    if len(data_nums) < 2:    # still keep it valid for now
        print(f'[warning] {operator.upper()} less than two cells.')
    
    # check if 1)key part and 2)sum of data match, return the calc result if so
    calc_result = _CALC_FUNCS[operator](data_nums)
    numbers = [subdict['key_part'], calc_result]
    normal_num = NormalNumber.create_number(sources=numbers)
    if normal_num is None: return formula
    
    # replace the orignal sentence with the unified number
    s = subdict['sub_sent']
    skey = subdict['key_index']
    ekey = skey + len(subdict['key_part'])
    subdict.update({
        'replaced_sub_sent': 
        s[:skey] + normal_num.string + s[ekey:] 
    })
    
    # generate formula with data positions
    formula.get_formula(data_pos, data_nums)

    return formula



# binary inputs

def check_binary(structure, subdict, operator):
    formula = BiTermFormula(
        start_row=subdict['start_row'], name=operator.upper())  

    linked_cells = link_table_cells(structure, subdict)
    data_pos, data_nums = remove_nonum_data(linked_cells['data'])
    if len(data_nums) != 2: 
        print('warning: DIFF not-two cells. Try using subdict-answer.')
        data_pos = [LinkedPosition.from_cm_coords(cm) 
            for cm in subdict['answer_cells'].keys()]
        data_nums = [subdict['answer_cells'][pos.cm_coords] for pos in data_pos]
        if len(data_nums) != 2:
            return formula
    num_a, num_b = data_nums
    
    # @USER TODO: should add the 'answer-num' ??
    numbers_a = [subdict['key_part'], _CALC_FUNCS[operator](num_a, num_b)]
    normal_num_a = NormalNumber.create_number(sources=numbers_a)
    numbers_b = [subdict['key_part'], _CALC_FUNCS[operator](num_b, num_a)]
    normal_num_b = NormalNumber.create_number(sources=numbers_b)
    if (normal_num_a is not None) and (normal_num_b is None):
        numstr = normal_num_a.string
    elif (normal_num_a is None) and (normal_num_b is not None):
        numstr = normal_num_b.string
        data_pos = [data_pos[-i-1] for i in range(2)]
        data_nums = [data_nums[-i-1] for i in range(2)]
    else: return formula
    
    s = subdict['sub_sent']
    skey = subdict['key_index']
    ekey = skey + len(subdict['key_part'])
    subdict.update({
        'replaced_sub_sent': 
        s[:skey] + numstr + s[ekey:] 
    })
    
    formula.get_formula(data_pos, data_nums)
    return formula
