"""Formula classes designed for each single operator."""


_ERROR_MSG = '<ERROR>'

# %% Calculation Functions

def average(numbers):
    return sum(numbers) / len(numbers)

def diff(number_a, number_b):
    return number_a - number_b

def div(number_a, number_b):
    return number_a / number_b

_CALC_FUNCS = {
    'sum': sum, 
    'average': average, 
    'max': max, 
    'min': min, 
    'counta': len, 
    'diff': diff, 
    'div': div, 
}




# %% Derived Formula Classes

class ExcelFormula(object):
    """Basic/General class for all kinds of operators."""

    def __init__(self, start_row, row_offset, col_index, name=None, excel_reads=_ERROR_MSG):
        self.row_index = start_row + row_offset
        self.col_index = col_index
        self.name = name
        self.excel_reads = excel_reads

    @staticmethod
    def get_region(linked_positions, sort=True, strict=False):
        """Check if the given list of cell positions a continuous table region.
        Args: linked_positions: List['LinkedPosition']
        Returns: region_repr: "A1,B2,C3" or "A1:B2"
        """
        if sort == True:
            linked_positions = sorted(linked_positions, key=lambda x:x.ws_coords)

        if strict == True:
            region_repr = ','.join([lp.ws_str for lp in linked_positions])
        else:
            region_repr = None

        ws_row_indices = set([pos.ws_coords[0] for pos in linked_positions])
        num_rows = len(ws_row_indices)
        ws_col_indices = set([pos.ws_coords[1] for pos in linked_positions])
        num_cols = len(ws_col_indices)
        if (num_rows != 1) and (num_cols != 1):
            return region_repr    # has to be a single (row ?) or column
        if len(linked_positions) != (num_rows * num_cols):
            return region_repr

        sregion = linked_positions[0].ws_str
        eregion = linked_positions[-1].ws_str
        region_repr = f"{sregion}:{eregion}"
        return region_repr



# Header-Data Argument Calculation

class ArgumentFormula(ExcelFormula):
    """Formula for ARGMAX, ARGMIN, etc. 
    Continuous Header or Data regions. Single row or column.
    """

    def __init__(self, start_row, row_offset=6, col_index=7, name='ARGMAX', data_args=None, answer=None):
        """name in ['ARGMAX', 'ARGMIN', 'PAIRARGMAX', 'PAIRARGMIN', ...]"""
        ExcelFormula.__init__(self, start_row, row_offset, col_index, name)
        self.data_args = data_args
        self.answer = answer
    
    def get_formula(self, class_args, data_args, indices=[]):
        nidx = len(indices)
        ncls, ndata = len(class_args), len(data_args)
        if ncls != ndata: return

        self.data_args = data_args
        if self.name.startswith('ARG') and nidx == 0 and ncls > 2 and ndata > 2:
            self.get_formula_multi(class_args, data_args)
        elif self.name.startswith('PAIR') and nidx == 0 and ncls == 2 and ndata == 2:
            self.get_formula_pair(class_args, data_args)
        elif self.name.startswith('TOPK') and nidx > 1 and ncls > nidx:  # and ndata > nidx
            self.get_formula_topk(class_args, data_args)
        elif self.name.startswith('KTH') and nidx == 1 and ncls > nidx:  # and ndata > nidx
            self.get_formula_kth(class_args, data_args)
    
    def get_formula_multi(self, class_args, data_args):
        creg_repr = self.get_region([lp for (lp, cl) in class_args], strict=True)
        dreg_repr = self.get_region([lp for (lp, cl) in data_args], strict=True)
        if (creg_repr is None) or (dreg_repr is None): return 
        self.excel_reads = f"=XLOOKUP({self.name[-3:]}({dreg_repr}), {dreg_repr}, {creg_repr})"
        self.human_reads = self.excel_reads
    
    def get_formula_pair(self, class_args, data_args):
        """PAIRARGMAX, PAIRARGMIN."""
        class_args = sorted(class_args, key=lambda x: x[0].ws_coords)
        class_wsstr = [pos.ws_str for (pos, _) in class_args]
        data_args = sorted(data_args, key=lambda x: x[0].ws_coords)
        data_wsstr = [pos.ws_str for (pos, _) in data_args]

        if self.name.endswith('MAX'):
            self.excel_reads = f"=IF({data_wsstr[0]} > {data_wsstr[1]}, {class_wsstr[0]}, {class_wsstr[1]})"
        elif self.name.endswith('MIN'):
            self.excel_reads = f"=IF({data_wsstr[0]} < {data_wsstr[1]}, {class_wsstr[0]}, {class_wsstr[1]})"
    
    def get_formula_topk(self, class_args, data_args):
        return
    
    def get_formula_kth(self, class_args, data_args):
        creg_repr = self.get_region([lp for (lp, cl) in class_args], strict=True)
        dreg_repr = self.get_region([lp for (lp, cl) in data_args], strict=True)
        if (creg_repr is None) or (dreg_repr is None): return 
        if self.name == 'KTHARGMAX': 
            func_key = 'LARGE'
        elif self.name == 'KTHARGMIN': 
            func_key = 'SMALL'
        self.excel_reads = f"=XLOOKUP({func_key}({dreg_repr}), {dreg_repr}, {creg_repr})"
        self.human_reads = self.excel_reads
    
    def get_answer(self):
        return self.answer


# Single Group Calculation

class GroupFormula(ExcelFormula):
    """Formula for multi-data grouping operator."""

    def __init__(self, start_row, row_offset=6, col_index=7, name='SUM', data_args=None, answer=None):
        """name in ['SUM', 'AVERAGE', 'COUNTA', 'MIN', 'MAX']"""
        ExcelFormula.__init__(self, start_row, row_offset, col_index, name)
        self.data_args = data_args
        self.answer = answer
    
    def get_formula(self, linked_positions, data_nums):
        """Get excel- (and human-) readable formula given a list of 'Cell' items."""
        data_repr = self.get_region(linked_positions)
        excel_reads = f"={self.name}({data_repr})"
        self.excel_reads = excel_reads

        fake_data_args = [(pos, None) for pos in linked_positions]
        self.data_args = fake_data_args

        result = _CALC_FUNCS[self.name.lower()](data_nums)
        self.answer = f'{result:.2f}'
    
    def get_answer(self):
        return self.answer
        

# Binary Formula Input

class BiTermFormula(ExcelFormula):
    """Formula for DIFF and DIV operators."""

    def __init__(self, start_row, row_offset=6, col_index=7, name='DIFF', data_args=None, answer=None):
        """name in ['DIFF', 'DIV']"""
        ExcelFormula.__init__(self, start_row, row_offset, col_index, name)
        self.data_args = data_args
        self.answer = answer
    
    def get_formula(self, linked_positions ,data_nums):
        """Get excel- and human- readable formula given a list of 'Cell' items."""
        if len(linked_positions) != 2: return

        # @USER TODO: how to figure out the order of two arguments
        pos_a, pos_b = linked_positions
        if self.name == 'DIFF': 
            self.excel_reads = f"={pos_a.ws_str}-{pos_b.ws_str}"
        elif self.name == 'DIV':
            self.excel_reads = f"={pos_a.ws_str}/{pos_b.ws_str}"

        if self.answer is None:
            assert len(data_nums) == 2
            num_a, num_b = data_nums
            result = _CALC_FUNCS[self.name.lower()](num_a, num_b)
            self.answer = f'{result:.2f}'
    
    def get_answer(self):
        return self.answer



# NULL operator
class NullFormula(ExcelFormula):
    """Formula for NONE and OPP operators."""

    def __init__(self, start_row, row_offset=6, col_index=7, name='NONE', answer=None):
        """name in ['NONE', 'OPP']"""
        ExcelFormula.__init__(self, start_row, row_offset, col_index, name)
        self.answer = answer
    
    def get_formula(self, linked_positions):
        """Get excel- and human- readable formula given a list of 'Cell' items."""
        if len(linked_positions) != 1: return

        if self.name == 'NONE': 
            self.excel_reads = '=' + ', '.join([lp.ws_str for lp in linked_positions])
        elif self.name == 'OPP':
            self.excel_reads = '=' + ', '.join([f'-{lp.ws_str}' for lp in linked_positions])
    
    def get_answer(self):
        return self.answer


# Compound

class CompoundFormula(ExcelFormula):
    """Formula for compound operators."""

    def __init__(self, start_row, row_offset=6, col_index=7, name='COMPOUND'):
        ExcelFormula.__init__(self, start_row, row_offset, col_index, name)

    def get_answer(self):
        return None
