# 8월 버전입니다. 추후 사용할 때 수정할 예정입니다.
def chi_square_test(column_X: NumColumn, column_A: CellColumn, column_B: CellColumn):
    """
    column_X : data
    column_A : factor A
    column_B : factor B
    a : number of levels of factor A
    b : number of levels of factor B
    """
    if column_X.has_vbit():
        TypeError("chi square test applies only to data without missing values")

    block_sum = sum(column_X)
    block_sum_inv = block_sum.inverse(one_slot=True)

    block_sum_a_dict = {}
    for cell_value in column_A.cell_values:
        block_sum_a = Block.zeros(column_X.context)
        for data_block, vbit_block in zip(column_X.data, column_A[cell_value]):
            block_sum_a += data_block * vbit_block
        block_sum_a = __rotate_sum(block_sum_a)
        block_sum_a_dict[cell_value] = block_sum_a

    block_sum_b_dict = {}
    for cell_value in column_B.cell_values:
        block_sum_b = Block.zeros(column_X.context)
        for data_block, vbit_block in zip(column_X.data, column_B[cell_value]):
            block_sum_b += data_block * vbit_block
        block_sum_b = __rotate_sum(block_sum_b)
        block_sum_b_dict[cell_value] = block_sum_b

    res_block = Block.zeros(column_X.context)
    for cell_a in column_A.cell_values:
        for cell_b in column_B.cell_values:

            tmp = block_sum_a_dict[cell_a] * block_sum_b_dict[cell_b]
            tmp_inv = tmp.inverse(one_slot=True)

            block_E = tmp * block_sum_inv
            block_E_inv = tmp_inv * block_sum

            for data_block, vbit_block_a, vbit_block_b in zip(column_X.data, column_A[cell_a], column_B[cell_b]):
                tmp = data_block - block_E
                tmp *= tmp
                tmp *= block_E_inv
                tmp *= vbit_block_a
                tmp *= vbit_block_b
                res_block += tmp

    res_block = __rotate_sum(res_block)

    return res_block


def chi_sq_gof_seq(observed: CellColumn, expected: CellColumn = None):
    block_chi_sq = Block.zeros(observed.context, encrypted=True)

    if expected:
        for cat in observed.cell_values:
            block_obs_count = Block.zeros(observed.context, encrypted=True)
            for cmask_block in observed[cat]:
                block_obs_count += cmask_block
            block_obs_count = __rotate_sum(block_obs_count)

            block_exp_count = Block.zeros(observed.context, encrypted=True)
            for cmask_block in expected[cat]:
                block_exp_count += cmask_block
            block_exp_count = __rotate_sum(block_exp_count)
            block_exp_count_inv = __safe_inverse(block_exp_count, num_rows=expected.num_rows)

            block_sq = block_obs_count - block_exp_count
            block_sq *= block_sq
            block_chi_sq += block_sq * block_exp_count_inv
    else:
        block_obs_total = Block.zeros(observed.context, encrypted=True)
        L_blocks = []

        for cat in observed.cell_values:
            block_obs_count = Block.zeros(observed.context, encrypted=True)
            for cmask_block in observed[cat]:
                block_obs_count += cmask_block
            block_obs_count = __rotate_sum(block_obs_count)

            L_blocks.append(block_obs_count)
            block_obs_total += block_obs_count

        block_exp_count = block_obs_total * (1 / len(observed.cell_values))
        block_exp_count_inv = __safe_inverse(block_exp_count, num_rows=observed.num_rows)

        for block_obs_count in L_blocks:
            block_sq = block_obs_count - block_exp_count
            block_sq *= block_sq
            block_chi_sq += block_sq * block_exp_count_inv

    return block_chi_sq


def chi_sq_gof_par(observed: CellColumn, expected: CellColumn = None):
    num_cats = len(observed.cell_values)
    select_blocks = [Block.oneone(observed.context, i) for i in range(num_cats)]

    block_obs = Block.zeros(observed.context, encrypted=True)
    for i in range(num_cats):
        cat = observed.cell_values[i]
        block_count = Block.zeros(observed.context, encrypted=True)
        for cmask_block in observed[cat]:
            block_count += cmask_block
        block_count = __rotate_sum(block_count)
        block_obs += block_count * select_blocks[i]
    if expected:
        block_exs = Block.zeros(observed.context, encrypted=True)
        for i in range(num_cats):
            cat = observed.cell_values[i]
            block_count = Block.zeros(observed.context, encrypted=True)
            for cmask_block in expected[cat]:
                block_count += cmask_block
            block_count = __rotate_sum(block_count)
            block_exs += block_count * select_blocks[i]
        block_exs_inv = __safe_inverse(block_exs, num_rows=expected.num_rows, one_slot=False)
    else:
        block_obs_total = __rotate_sum(block_obs)
        block_exs = block_obs_total * (1 / num_cats)
        block_exs_inv = __safe_inverse(block_exs, num_rows=observed.num_rows)

    block_diff_squares = block_obs - block_exs
    block_diff_squares *= block_diff_squares
    block_chi_squares = block_diff_squares * block_exs_inv

    block_chi_squares *= Block.prefixones(observed.context, num_cats, encrypted=True)
    block_chisq = __rotate_sum(block_chi_squares)

    return block_chisq
