import math
from typing import Optional, Union

from hedal.stats.preprocessing import normalize
import numpy as np

from hedal import Block, NumColumn
from hedal.frame.mask_column import MaskColumn


def __safe_inverse(block: Block, num_rows: int) -> Block:
    count_large = 100000
    res_block = block.copy()
    if block.encrypted:
        if num_rows > count_large:
            res_block *= 50000 / num_rows
            res_block = res_block.inverse(one_slot=True)
            res_block *= 50000 / num_rows
        else:
            res_block = res_block.inverse(one_slot=True)
    else:
        res_block = Block(block.context, data=np.full(block.context.num_slots, 1 / block.data[0].real), encrypted=False)
    return res_block


def __safe_sqrt_inverse_large(block: Block, num_rows: int) -> Block:
    res_block = block.copy()
    if block.encrypted:
        res_block.log_slots = 0
        res_block *= 1.0 / num_rows
        res_block = res_block.sqrt_inv()
        res_block *= 1.0 / math.sqrt(num_rows)
        if res_block.need_bootstrap(5):
            res_block.bootstrap()
        res_block.log_slots = res_block.context.log_slots
    else:
        res_block = Block(
            block.context, data=np.full(block.context.num_slots, 1 / math.sqrt(block.data[0].real)), encrypted=False
        )
    return res_block


def __small_inverse(block: Block, one_slot: bool = False) -> Block:
    block.context.public_key
    res_block = Block.zeros(block.context, encrypted=block.encrypted)
    compare_block = Block.zeros(block.context, encrypted=block.encrypted)
    if one_slot:
        block.log_slots = 0
        res_block.log_slots = 0
    compare_block = block.compare(res_block)
    compare_block -= 0.5
    compare_block *= 2
    res_block = block * compare_block
    res_block = res_block.sqrt_inv()
    res_block *= res_block
    res_block *= compare_block
    if one_slot:
        block.log_slots = block.context.log_slots
        res_block.log_slots = block.context.log_slots
    return res_block


def check_valid(block: Block, column: NumColumn) -> Block:
    tmp_block = Block.zeros(column.context, encrypted=False)
    check_num = column.num_rows % column.num_slots
    if check_num == 0 and column.num_rows > 0:
        check_num = column.num_slots
    for i in range(check_num):
        tmp_block[i] = 1
    result = block * tmp_block
    return result


def filtering(num_column: NumColumn, mask_column: MaskColumn) -> NumColumn:
    if isinstance(mask_column, MaskColumn):
        if num_column.num_rows != mask_column.num_rows:
            raise Exception("num_rows of two columns mult be equal")

        res_column = num_column.copy(num_column.path.parent)
        res_column = NumColumn.from_path(res_column.context, res_column.path)

        if res_column.has_vbit():
            for res_vbit_block, mask_block in zip(res_column.vbit, mask_column):
                res_vbit_block *= mask_block
        else:
            res_column.vbit = mask_column
            res_column.vbit.path = res_column.path / "vbit"

        res_column.vbit_encrypted |= mask_column.encrypted
        res_column.save()
        return res_column
    else:
        raise TypeError("[Error] Mask should have type 'MaskColumn'")


def count(column: NumColumn) -> Block:
    if column.has_vbit():
        res_block = Block.zeros(column.context, encrypted=column.encrypted)
        for vbit_block in column.vbit:
            res_block += vbit_block
        res_block = res_block.rotate_sum()
    else:
        res_block = Block(column.context, data=np.full(column.context.num_slots, column.num_rows), encrypted=False)
    return res_block


def sum(column: NumColumn) -> Block:
    res_block = Block.zeros(column.context, encrypted=column.encrypted)
    if column.has_vbit():
        for data_block, vbit_block in zip(column.data, column.vbit):
            tmp = data_block * vbit_block
            res_block += tmp
    else:
        for idx, data_block in enumerate(column.data):
            if idx == len(column.data) - 1:
                data_block = check_valid(data_block, column)
            res_block += data_block
    res_block = res_block.rotate_sum()
    return res_block


def average(column: NumColumn, count_block: Optional[Block] = None) -> Block:
    block_sum = sum(column)
    if count_block:
        tmp_block = Block.zeros(column.context, encrypted=False)
        tmp_block[column.num_slots - 1] = 1
        block_count_inv = count_block * tmp_block
        block_count_inv = block_count_inv.rotate_sum()
        block_avg = block_sum * block_count_inv
        block_avg *= block_count_inv
    else:
        block_count = count(column)
        block_count_inv = __safe_inverse(block_count, column.num_rows)
        block_avg = block_sum * block_count_inv
    return block_avg


def variance(column: NumColumn, count_block: Optional[Block] = None) -> Block:
    block_sum = sum(column)
    if count_block:
        tmp_block = Block.zeros(column.context, encrypted=False)
        tmp_block[column.num_slots - 1] = 1
        block_count_inv = count_block * tmp_block
        block_count_inv = block_count_inv.rotate_sum()
        block_avg = block_sum * block_count_inv
        block_avg *= block_count_inv
        tmp_block[column.num_slots - 1] = 0
        tmp_block[column.num_slots - 2] = 1
        block_count_minus_inv = count_block * tmp_block
        block_count_minus_inv = block_count_minus_inv.rotate_sum()
    else:
        block_count = count(column)
        block_count_minus = block_count - 1
        block_count_inv = __safe_inverse(block_count, column.num_rows)
        block_avg = block_sum * block_count_inv
        block_count_minus_inv = __safe_inverse(block_count_minus, column.num_rows)
    if block_avg.need_bootstrap(3):
        normalize.mult_abs_max_inv(block_avg, column)
        block_avg.bootstrap(one_slot=True)
        normalize.mult_abs_max(block_avg, column)

    block_var = Block.zeros(column.context, encrypted=column.encrypted)
    if column.has_vbit():
        for data_block, vbit_block in zip(column.data, column.vbit):
            tmp = data_block - block_avg
            tmp *= vbit_block
            tmp *= tmp
            block_var += tmp
    else:
        for idx, data_block in enumerate(column.data):
            tmp = data_block - block_avg
            if idx == len(column.data) - 1:
                tmp = check_valid(tmp, column)
            tmp *= tmp
            block_var += tmp

    block_var = block_var.rotate_sum()

    if count_block:
        block_var *= block_count_minus_inv
        block_var *= block_count_minus_inv
    else:
        block_var *= block_count_minus_inv
    return block_var


def standarddev(column: NumColumn, count_block: Optional[Block] = None) -> Block:
    norm_column = normalize.normalize(column)
    block_sum = sum(norm_column)

    if count_block:
        tmp_block = Block.zeros(column.context, encrypted=False)
        tmp_block[column.num_slots - 1] = 1
        block_count_inv = count_block * tmp_block
        block_count_inv = block_count_inv.rotate_sum()
        block_avg = block_sum * block_count_inv
        block_avg *= block_count_inv
        tmp_block[column.num_slots - 1] = 0
        tmp_block[column.num_slots - 2] = 1
        block_count_minus_inv = count_block * tmp_block
        block_count_minus_inv = block_count_minus_inv.rotate_sum()
    else:
        block_count = count(norm_column)
        block_count_minus = block_count - 1
        block_count_inv = __safe_inverse(block_count, norm_column.num_rows)
        block_avg = block_sum * block_count_inv
        block_count_minus_inv = __safe_inverse(block_count_minus, norm_column.num_rows)

    if block_avg.need_bootstrap(3):
        block_avg.bootstrap(one_slot=True)

    block_std = Block.zeros(norm_column.context, encrypted=norm_column.encrypted)
    if norm_column.has_vbit():
        for data_block, vbit_block in zip(norm_column.data, norm_column.vbit):
            tmp = data_block - block_avg
            tmp *= vbit_block
            tmp *= tmp
            block_std += tmp
    else:
        for idx, data_block in enumerate(norm_column.data):
            tmp = data_block - block_avg
            if idx == len(norm_column.data) - 1:
                tmp = check_valid(tmp, norm_column)
            tmp *= tmp
            block_std += tmp

    block_std = block_std.rotate_sum()

    if count_block:
        block_std *= block_count_minus_inv
        block_std *= block_count_minus_inv
    else:
        block_std *= block_count_minus_inv
    # block_std = block_std.sqrt(one_slot=True)
    block_std.to_device()
    block_std.context.heaan.math.approx.sqrt(block_std.homevaluator, block_std.data, block_std.data)
    block_std.to_host()
    normalize.mult_abs_max(block_std, column)
    return block_std


def standarderr(column: NumColumn, count_block: Optional[Block] = None) -> Block:
    norm_column = normalize.normalize(column)
    block_sum = sum(norm_column)
    if count_block:
        tmp_block = Block.zeros(column.context, encrypted=False)
        tmp_block[column.num_slots - 1] = 1
        block_count_inv = count_block * tmp_block
        block_count_sqrt_inv = block_count_inv.rotate_sum()
        block_avg = block_sum * block_count_sqrt_inv
        block_avg *= block_count_sqrt_inv
        tmp_block[column.num_slots - 1] = 0
        tmp_block[column.num_slots - 2] = 1
        block_count_minus_inv = count_block * tmp_block
        block_count_minus_inv = block_count_minus_inv.rotate_sum()
    else:
        block_count = count(norm_column)
        block_count_sqrt_inv = __safe_sqrt_inverse_large(block_count, norm_column.num_rows)
        block_count_minus = block_count - 1
        block_count_minus_inv = __safe_inverse(block_count_minus, norm_column.num_rows)
        block_avg = block_sum * block_count_sqrt_inv
        block_avg *= block_count_sqrt_inv

    if block_avg.need_bootstrap(3):
        block_avg.bootstrap(one_slot=True)

    block_stderr = Block.zeros(norm_column.context, encrypted=column.encrypted)
    if norm_column.has_vbit():
        for data_block, vbit_block in zip(norm_column.data, norm_column.vbit):
            tmp = data_block - block_avg
            tmp *= vbit_block
            tmp *= tmp
            block_stderr += tmp
    else:
        for idx, data_block in enumerate(norm_column.data):
            tmp = data_block - block_avg
            if idx == len(norm_column.data) - 1:
                tmp = check_valid(tmp, norm_column)
            tmp *= tmp
            block_stderr += tmp

    block_stderr = block_stderr.rotate_sum()

    if count_block:
        block_stderr *= block_count_minus_inv
        block_stderr *= block_count_minus_inv
    else:
        block_stderr *= block_count_minus_inv

    # block_stderr = block_stderr.sqrt(one_slot=True)
    block_stderr.to_device()
    block_stderr.context.heaan.math.approx.sqrt(block_stderr.homevaluator, block_stderr.data, block_stderr.data)
    block_stderr.to_host()
    normalize.mult_abs_max(block_stderr, column)
    block_stderr *= block_count_sqrt_inv
    return block_stderr


def coeffvar(column: NumColumn, count_block: Optional[Block] = None) -> Block:
    norm_column = normalize.normalize(column)
    block_sum = sum(norm_column)
    if count_block:
        tmp_block = Block.zeros(column.context, encrypted=False)
        tmp_block[column.num_slots - 1] = 1
        block_count_inv = count_block * tmp_block
        block_count_inv = block_count_inv.rotate_sum()
        block_avg = block_sum * block_count_inv
        block_avg *= block_count_inv
        tmp_block[column.num_slots - 1] = 0
        tmp_block[column.num_slots - 2] = 1
        block_count_minus_inv = count_block * tmp_block
        block_count_minus_inv = block_count_minus_inv.rotate_sum()
    else:
        block_count = count(norm_column)
        block_count_minus = block_count - 1
        block_count_inv = __safe_inverse(block_count, norm_column.num_rows)
        block_avg = block_sum * block_count_inv
        block_count_minus_inv = __safe_inverse(block_count_minus, norm_column.num_rows)
    if block_avg.need_bootstrap(3):
        block_avg.bootstrap(one_slot=True)

    block_coeffvar = Block.zeros(norm_column.context, encrypted=column.encrypted)
    if norm_column.has_vbit():
        for data_block, vbit_block in zip(norm_column.data, norm_column.vbit):
            tmp = data_block - block_avg
            tmp *= vbit_block
            tmp *= tmp
            block_coeffvar += tmp
    else:
        for idx, data_block in enumerate(norm_column.data):
            tmp = data_block - block_avg
            if idx == len(norm_column.data) - 1:
                tmp = check_valid(tmp, norm_column)
            tmp *= tmp
            block_coeffvar += tmp

    block_coeffvar = block_coeffvar.rotate_sum()

    if count_block:
        block_coeffvar *= block_count_minus_inv
        block_coeffvar *= block_count_minus_inv
    else:
        block_coeffvar *= block_count_minus_inv

    # block_coeffvar = block_coeffvar.sqrt(one_slot=True)
    block_coeffvar.to_device()
    block_coeffvar.context.heaan.math.approx.sqrt(block_coeffvar.homevaluator, block_coeffvar.data, block_coeffvar.data)
    block_coeffvar.to_host()
    block_avg = __small_inverse(block_avg, one_slot=True)
    block_coeffvar *= block_avg
    return block_coeffvar


def skewness(column: NumColumn, count_block: Optional[Block] = None) -> Block:  # NORMALIZED
    norm_column = normalize.normalize(column)
    block_sum = sum(norm_column)
    if count_block:
        tmp_block = Block.zeros(column.context, encrypted=False)
        tmp_block[column.num_slots - 1] = 1
        block_count_inv = count_block * tmp_block
        block_count_inv = block_count_inv.rotate_sum()
        block_avg = block_sum * block_count_inv
        block_avg *= block_count_inv
        tmp_block[column.num_slots - 1] = 0
        tmp_block[column.num_slots - 2] = 1
        block_count_minus_inv = count_block * tmp_block
        block_count_minus_inv = block_count_minus_inv.rotate_sum()
    else:
        block_count = count(norm_column)
        block_count_minus = block_count - 1
        block_count_inv = __safe_inverse(block_count, norm_column.num_rows)
        block_avg = block_sum * block_count_inv
        block_count_minus_inv = __safe_inverse(block_count_minus, norm_column.num_rows)
    if block_avg.need_bootstrap(4):
        block_avg.bootstrap(one_slot=True)

    block_diff2 = Block.zeros(norm_column.context, encrypted=column.encrypted)
    block_diff3 = Block.zeros(norm_column.context, encrypted=column.encrypted)
    if norm_column.has_vbit():
        for data_block, vbit_block in zip(norm_column.data, norm_column.vbit):
            tmp = data_block - block_avg
            tmp *= vbit_block
            tmp_pow = tmp * tmp
            block_diff2 += tmp_pow  # (x-mean)^2
            tmp_pow *= tmp
            block_diff3 += tmp_pow  # (x-mean)^3
    else:
        for idx, data_block in enumerate(norm_column.data):
            tmp = data_block - block_avg
            if idx == len(norm_column.data) - 1:
                tmp = check_valid(tmp, norm_column)
            tmp_pow = tmp * tmp
            block_diff2 += tmp_pow  # (x-mean)^2
            tmp_pow *= tmp
            block_diff3 += tmp_pow  # (x-mean)^3

    block_diff2 = block_diff2.rotate_sum()  # sum (x-mean)^2
    block_diff3 = block_diff3.rotate_sum()  # sum (x-mean)^3

    if count_block:
        block_diff2 *= block_count_minus_inv
        block_diff2 *= block_count_minus_inv
    else:
        block_diff2 *= block_count_minus_inv  # diff2 = variance

    block_stdinv3 = block_diff2.sqrt_inv(one_slot=True)  # 1/stdev

    tmp = block_stdinv3.copy()
    block_stdinv3 *= tmp
    block_stdinv3 *= tmp  # stdinv3 = 1/std^3

    block_skew = block_diff3 * block_stdinv3
    if count_block:
        block_skew *= block_count_inv
        block_skew *= block_count_inv
    else:
        block_skew *= block_count_inv
    return block_skew


def kurtosis(column: NumColumn, count_block: Optional[Block] = None) -> Block:  # NORMALIZED
    norm_column = normalize.normalize(column)
    block_sum = sum(norm_column)
    if count_block:
        tmp_block = Block.zeros(column.context, encrypted=False)
        tmp_block[column.num_slots - 1] = 1
        block_count_inv = count_block * tmp_block
        block_count_inv = block_count_inv.rotate_sum()
        block_avg = block_sum * block_count_inv
        block_avg *= block_count_inv
        tmp_block[column.num_slots - 1] = 0
        tmp_block[column.num_slots - 2] = 1
        block_count_minus_inv = count_block * tmp_block
        block_count_minus_inv = block_count_minus_inv.rotate_sum()
    else:
        block_count = count(norm_column)
        block_count_minus = block_count - 1
        block_count_inv = __safe_inverse(block_count, norm_column.num_rows)
        block_avg = block_sum * block_count_inv
        block_count_minus_inv = __safe_inverse(block_count_minus, norm_column.num_rows)
    if block_avg.need_bootstrap(4):
        block_avg.bootstrap(one_slot=True)

    block_diff2 = Block.zeros(norm_column.context, encrypted=True)
    block_diff4 = Block.zeros(norm_column.context, encrypted=True)

    if norm_column.has_vbit():
        for data_block, vbit_block in zip(norm_column.data, norm_column.vbit):
            tmp = data_block - block_avg
            tmp *= vbit_block
            tmp *= tmp
            block_diff2 += tmp  # (x-mean)^2
            tmp *= tmp
            block_diff4 += tmp  # (x-mean)^4
    else:
        for idx, data_block in enumerate(norm_column.data):
            tmp = data_block - block_avg
            if idx == len(norm_column.data) - 1:
                tmp = check_valid(tmp, norm_column)
            tmp *= tmp
            block_diff2 += tmp  # (x-mean)^2
            tmp *= tmp
            block_diff4 += tmp  # (x-mean)^4

    block_diff4 = block_diff4.rotate_sum()
    block_diff2 = block_diff2.rotate_sum()

    if count_block:
        block_diff2 *= block_count_minus_inv
        block_diff2 *= block_count_minus_inv
    else:
        block_diff2 *= block_count_minus_inv  # diff2 = variance

    block_stdinv4 = block_diff2.sqrt_inv(one_slot=True)  # 1/stdev.
    block_stdinv4 *= block_stdinv4
    block_stdinv4 *= block_stdinv4

    if count_block:
        block_4cm = block_diff4 * block_count_inv
        block_4cm *= block_count_inv
    else:
        block_4cm = block_diff4 * block_count_inv
    block_4cm *= block_stdinv4
    block_4cm -= 3
    return block_4cm


def absVal(column: NumColumn) -> NumColumn:
    norm_column = normalize.normalize(column)
    res_column = NumColumn.from_path(column.context, column.path)
    compare_zero = Block.zeros(column.context, encrypted=True)
    for normed_block, res_block in zip(norm_column.data, res_column.data):
        compare_block = normed_block.compare(compare_zero)
        compare_block -= 0.5
        compare_block *= 2
        res_block *= compare_block
    return res_column


def minmaxVal(
    column: Union[NumColumn, Block], minmaxType: str, num_rows: Optional[int] = None, rotate_index: Optional[int] = None
) -> Block:
    column.context.public_key
    if not num_rows:
        if type(column) == NumColumn:
            num_rows = column.num_rows
        else:
            num_rows = column.num_slots
            print("Warning:using block min&max without declaring num_rows")
    else:
        if type(column) == Block and num_rows > column.num_slots:
            raise Exception("[Error] num_rows is bigger than num_slots of a block")
    max_block = Block.zeros(column.context, encrypted=column.encrypted)
    min_block = Block.zeros(column.context, encrypted=column.encrypted)
    block_dummy = Block.zeros(column.context, encrypted=column.encrypted)

    if type(column) == NumColumn:
        norm_column = normalize.normalize(column)
        for norm_block in norm_column.data:
            norm_block *= 0.5
        if column.has_vbit():
            for idx, (data_block, vbit_block) in enumerate(zip(norm_column.data, norm_column.vbit)):
                if minmaxType == "max" or minmaxType == "range":
                    rectify_block = data_block * vbit_block
                    vbit_block -= 1
                    rectify_block += vbit_block * 0.5

                    if idx == 0:
                        max_block.data = rectify_block.data
                    else:
                        block_dummy = max_block.compare(rectify_block)
                        if block_dummy.need_bootstrap(1):
                            block_dummy.bootstrap()
                        max_block *= block_dummy
                        block_dummy -= 1
                        block_dummy = -block_dummy
                        block_dummy *= rectify_block
                        max_block += block_dummy
                        if max_block.need_bootstrap(1):
                            max_block.bootstrap()

                if minmaxType == "min" or minmaxType == "range":
                    rectify_block = data_block * vbit_block
                    vbit_block -= 1
                    vbit_block = -vbit_block
                    rectify_block += vbit_block * 0.5

                    if idx == 0:
                        min_block.data = rectify_block.data
                    else:
                        block_dummy = rectify_block.compare(min_block)
                        if block_dummy.need_bootstrap(1):
                            block_dummy.bootstrap()
                        min_block *= block_dummy
                        block_dummy -= 1
                        block_dummy = -block_dummy
                        block_dummy *= rectify_block
                        min_block += block_dummy
                        if min_block.need_bootstrap(1):
                            min_block.bootstrap()
        else:
            for idx, data_block in enumerate(norm_column.data):
                rectify_block = Block.zeros(column.context, encrypted=column.encrypted)
                if minmaxType == "max" or minmaxType == "range":
                    if idx == len(norm_column.data) - 1:
                        rectify_block = check_valid(data_block, column)
                        tmp_block = Block(column.context, data=np.full(column.context.num_slots, -0.5), encrypted=False)
                        check_num = column.num_rows % column.num_slots
                        if check_num == 0 and column.num_rows > 0:
                            check_num = column.num_slots
                        for i in range(check_num):
                            tmp_block[i] = 0
                        rectify_block += tmp_block
                    else:
                        rectify_block.data = data_block.data

                    if idx == 0:
                        max_block.data = rectify_block.data
                    else:
                        block_dummy = max_block.compare(rectify_block)
                        if block_dummy.need_bootstrap(1):
                            block_dummy.bootstrap()
                        max_block *= block_dummy
                        block_dummy -= 1
                        block_dummy = -block_dummy
                        block_dummy *= rectify_block
                        max_block += block_dummy
                        if max_block.need_bootstrap(1):
                            max_block.bootstrap()

                if minmaxType == "min" or minmaxType == "range":
                    if idx == len(norm_column.data) - 1:
                        rectify_block = check_valid(data_block, column)
                        tmp_block = Block(column.context, data=np.full(column.context.num_slots, 0.5), encrypted=False)
                        check_num = column.num_rows % column.num_slots
                        if check_num == 0 and column.num_rows > 0:
                            check_num = column.num_slots
                        for i in range(check_num):
                            tmp_block[i] = 0
                        rectify_block += tmp_block
                    else:
                        rectify_block.data = data_block.data

                    if idx == 0:
                        min_block.data = rectify_block.data
                    else:
                        block_dummy = rectify_block.compare(min_block)
                        if block_dummy.need_bootstrap(1):
                            block_dummy.bootstrap()
                        min_block *= block_dummy
                        block_dummy -= 1
                        block_dummy = -block_dummy
                        block_dummy *= rectify_block
                        min_block += block_dummy
                        if min_block.need_bootstrap(1):
                            min_block.bootstrap()
    # mode위해 만든 경우
    else:
        if minmaxType == "max" or minmaxType == "range":
            max_block.data = column.data
            max_block *= 1.0 / num_rows
        if minmaxType == "min" or minmaxType == "range":
            min_block.data = column.data
            min_block *= 1.0 / num_rows
            tmp_block = Block.ones(column.context, encrypted=False)
            check_num = column.num_rows % column.num_slots
            if check_num == 0 and column.num_rows > 0:
                check_num = column.num_slots
            for i in range(check_num):
                tmp_block[i] = 0
            min_block += tmp_block

    if rotate_index:
        log2N = math.ceil(math.log2(rotate_index))
    else:
        log2N = math.ceil(math.log2(min(num_rows, column.num_slots)))

    for idx in range(log2N):
        rot_idx = 1 << idx
        if minmaxType == "max" or minmaxType == "range":
            tmp_block = max_block << rot_idx
            block_dummy = max_block.compare(tmp_block)
            if block_dummy.need_bootstrap(1):
                block_dummy.bootstrap()
            max_block *= block_dummy
            block_dummy -= 1
            block_dummy = -block_dummy
            block_dummy *= tmp_block
            max_block += block_dummy
            if max_block.need_bootstrap(1):
                max_block.bootstrap()
        if minmaxType == "min" or minmaxType == "range":
            tmp_block = min_block << rot_idx
            block_dummy = tmp_block.compare(min_block)
            if block_dummy.need_bootstrap(1):
                block_dummy.bootstrap()
            min_block *= block_dummy
            block_dummy -= 1
            block_dummy = -block_dummy
            block_dummy *= tmp_block
            min_block += block_dummy
            if min_block.need_bootstrap(1):
                min_block.bootstrap()

    if minmaxType == "max" or minmaxType == "range":
        if type(column) == NumColumn:
            normalize.mult_abs_max(max_block, column)
            max_block *= 2
        else:
            max_block *= num_rows
    if minmaxType == "min" or minmaxType == "range":
        if type(column) == NumColumn:
            normalize.mult_abs_max(min_block, column)
            min_block *= 2
        else:
            min_block *= num_rows

    # 첫번째 슬롯만 max/min
    if minmaxType == "max":
        return max_block
    if minmaxType == "min":
        return min_block
    if minmaxType == "range":
        block_index = Block.zeros(column.context, encrypted=False)
        block_index[0] = 1
        max_block *= block_index
        min_block *= block_index
        max_block += min_block >> 1
        return max_block


def mode(column: NumColumn, cell_columns) -> NumColumn:
    cd = len(cell_columns)
    res_column = NumColumn(
        column.context, num_rows=column.num_rows, encrypted=column.encrypted, name=f"{column.name}.mode"
    )
    res_column.data = DataColumn.zeros(column.context, path="tmp", num_rows=column.num_slots * (cd + 2))
    res_column.vbit = DataColumn.ones(column.context, path="tmp", num_rows=column.num_slots * (cd + 2))
    eval = column.context.homevaluator
    keypack = column.context.public_key
    cell_name_index = list()
    each_cell_bin = list()
    num_group = 1
    for cell_column in cell_columns:
        cell_name_index.append(column.name + ".mode_" + cell_column.name + "_index")
        each_cell_bin.append(len(cell_column.cell_values))
        num_group *= len(cell_column.cell_values)
    bin_table = [[0 for _ in range(cd)] for _ in range(num_group)]
    count = [0 for _ in range(cd)]
    row_index = 0

    block_multbinsum = Block.zeros(column.context, encrypted=column.encrypted)
    while count[0] < each_cell_bin[0]:
        for i in range(cd):
            bin_table[row_index][i] = int(cell_columns[i].cell_values[count[i]])

        block_multbin = Block.zeros(column.context, encrypted=column.encrypted)
        msg_multbin = Block.zeros(column.context, encrypted=False)
        msg_multbin[row_index] = 1
        tmp_block = Block.zeros(column.context, encrypted=column.encrypted)
        if column.has_vbit():
            for j in range(len(column.data)):
                tmp_block = column.vbit.block(j).copy()
                for i in range(cd):
                    mask_col = MaskColumn.from_path(
                        cell_columns[i].context,
                        cell_columns[i].path / cell_columns[i].cell_values[count[i]],
                        cell_columns[i].num_rows,
                        cell_columns[i].encrypted,
                    )
                    cell_block = Block.zeros(column.context, encrypted=column.encrypted)
                    cell_block = mask_col.block(j)
                    tmp_block *= cell_block
                block_multbin += tmp_block
        else:
            for j in range(len(column.data)):
                tmp_block = Block.ones(column.context, encrypted=column.encrypted)
                for i in range(cd):
                    cell_block = Block.zeros(column.context, encrypted=column.encrypted)
                    cell_block_path = (
                        str(cell_columns[i].path / cell_columns[i].cell_values[count[i]]) + "/block_" + str(j) + ".bin"
                    )
                    cell_block.copy(path=cell_block_path)
                    tmp_block *= cell_block
                block_multbin += tmp_block
        block_multbin = block_multbin.rotate_sum()
        block_multbin *= msg_multbin
        block_multbinsum += block_multbin

        count[cd - 1] += 1
        row_index += 1
        for i in range(cd - 1, 0, -1):
            if count[i] == each_cell_bin[i]:
                count[i] = 0
                count[i - 1] += 1

    for j in range(cd):
        block_bintable = Block.zeros(column.context, encrypted=False)
        block_vbit = Block.zeros(column.context, encrypted=False)
        for i in range(row_index):
            block_bintable[i] = bin_table[i][j]
            block_vbit[i] = 1
        block_bintable.encrypt()
        block_vbit.encrypt()
        block_bintable.path = res_column.data.block_path(j)
        block_bintable.save()
        block_vbit.path = res_column.vbit.block_path(j)
        block_vbit.save()

    block_vbit = Block.zeros(column.context, encrypted=False)
    for i in range(row_index):
        block_vbit[i] = 1
    block_vbit.encrypt()
    block_multbinsum.path = res_column.data.block_path(cd)
    block_multbinsum.save()
    block_vbit.path = res_column.vbit.block_path(cd)
    block_vbit.save()
    # 여기까지 frequencyTable함수

    block_multbinsum_copy = block_multbinsum.copy()
    block_max = minmaxVal(block_multbinsum_copy, "max", column.num_rows, num_group)
    tmp_block = Block.zeros(column.context, encrypted=False)
    tmp_block[0] = 1
    block_max *= tmp_block
    block_max = block_max.rotate_sum()
    for i in range(num_group):
        tmp_block[i] = 1
    log2N = math.ceil(math.log2(num_group))
    block_multbinsum_copy = block_multbinsum.copy()
    round_correction = 0.5 + (log2N + 1) * column.num_rows * 0.000002
    block_multbinsum_copy += round_correction
    block_multbinsum_copy *= 1.0 / column.num_rows
    block_multbinsum_copy *= tmp_block
    block_max *= 1.0 / column.num_rows
    tmp_block.encrypt()
    column.context.heaan.math.approx.compare(eval, keypack, block_multbinsum_copy.data, block_max.data, tmp_block.data)
    vbit_block = Block.ones(column.context, encrypted=True)
    vbit_block *= tmp_block
    tmp_block *= block_multbinsum
    tmp_block.path = res_column.data.block_path(cd + 1)
    tmp_block.save()
    vbit_block.path = res_column.vbit.block_path(cd + 1)
    vbit_block.save()
    return res_column


def percentile(column: NumColumn) -> Block:

    # -------------sort----------------
    # res_column =normalize.normalize(column, type='MINUS_PLUS')
    # res_column =normalize.normalize(column, type='MINUS_PLUS', reverse=True, target=res_column)
    res_column = column
    # ---------sort-finished-----------

    block_percentile = Block.zeros(column.context, encrypted=column.encrypted)
    for n in range(101):
        perc = float((column.num_rows - 1) * n) / 100
        perc_int = int(perc)
        perc_remained = perc - perc_int
        block_idx1 = int(perc_int // column.num_slots)
        idx = perc_int % column.num_slots
        block_idx2 = 0
        if n == 100:
            block_idx2 = int(perc_int // column.num_slots)
        else:
            block_idx2 = int((perc_int + 1) // column.num_slots)

        res_block1 = res_column.data.block(block_idx1).copy()
        res_block2 = res_column.data.block(block_idx2).copy()
        block_mask1 = Block.zeros(column.context, encrypted=False)
        block_mask2 = Block.zeros(column.context, encrypted=False)
        if block_idx1 == block_idx2:
            if n != 100:
                res_block2 = res_block2 << 1
            block_mask1[idx] = 1 - perc_remained
            block_mask2[idx] = perc_remained
            res_block1 *= block_mask1
            res_block2 *= block_mask2
            res_block1 += res_block2

            rot_idx = idx - n
            if rot_idx < 0:
                rot_idx += column.num_slots
            while rot_idx >= 1:
                int_tmp = 2 ** (int(math.log2(rot_idx)))
                res_block1 = res_block1 << int_tmp
                rot_idx -= int_tmp
        else:
            block_mask1[idx] = 1 - perc_remained
            block_mask2[0] = perc_remained
            res_block1 *= block_mask1
            res_block2 *= block_mask2

            rot_idx1 = column.num_slots - 1 - n
            rot_idx2 = column.num_slots - n
            while rot_idx1 >= 1:
                int_tmp = 2 ** (int(math.log2(rot_idx1)))
                res_block1 = res_block1 << int_tmp
                rot_idx1 -= int_tmp
            while rot_idx2 >= 1:
                int_tmp = 2 ** (int(math.log2(rot_idx2)))
                res_block2 = res_block2 << int_tmp
                rot_idx2 -= int_tmp
            res_block1 += res_block2
        block_percentile += res_block1

        # log
        # if n==10:
        #     print("10%")
        # if n==25:
        #     print("25%")
        # if n==50:
        #     print("50%")
        # if n==80:
        #     print("80%")
        # if n==100:
        #     print("100%")
    return block_percentile


def covariance(column1: NumColumn, column2: NumColumn) -> Block:
    tmp_column = NumColumn.from_path(column1.context, column1.path)
    tmp_column *= column2
    block_count = count(tmp_column)
    block_count_inv = __safe_inverse(block_count, column1.num_rows)
    block_count_minus = block_count - 1
    block_count_minus_inv = __safe_inverse(block_count_minus, tmp_column.num_rows)

    block_avg1 = Block.zeros(column1.context, encrypted=column1.encrypted)
    block_avg2 = Block.zeros(column1.context, encrypted=column2.encrypted)

    if tmp_column.has_vbit():
        for data_block1, data_block2, vbit_block in zip(column1.data, column2.data, tmp_column.vbit):
            tmp1 = data_block1 * vbit_block
            tmp2 = data_block2 * vbit_block
            block_avg1 += tmp1
            block_avg2 += tmp2
    else:
        for idx, (data_block1, data_block2) in enumerate(zip(column1.data, column2.data)):
            if idx == len(column1.data) - 1:
                data_block1 = check_valid(data_block1, column1)
            block_avg1 += data_block1
            if idx == len(column2.data) - 1:
                data_block2 = check_valid(data_block2, column2)
            block_avg2 += data_block2

    block_avg1 = block_avg1.rotate_sum()
    block_avg2 = block_avg2.rotate_sum()
    block_avg1 *= block_count_inv
    block_avg2 *= block_count_inv

    block_cov = Block.zeros(column1.context, encrypted=True)
    if tmp_column.has_vbit():
        for data_block1, data_block2, vbit_block in zip(column1.data, column2.data, tmp_column.vbit):
            tmp1 = data_block1 - block_avg1
            tmp1 *= vbit_block
            tmp2 = data_block2 - block_avg2
            tmp2 *= vbit_block
            block_cov += tmp1 * tmp2
    else:
        for idx, (data_block1, data_block2) in enumerate(zip(column1.data, column2.data)):
            tmp1 = data_block1 - block_avg1
            if idx == len(column1.data) - 1:
                tmp1 = check_valid(tmp1, column1)
            tmp2 = data_block2 - block_avg2
            if idx == len(column2.data) - 1:
                tmp2 = check_valid(tmp2, column2)
            block_cov += tmp1 * tmp2
    block_cov = block_cov.rotate_sum()
    block_cov *= block_count_minus_inv
    return block_cov


# Pearson Correlation Coefficient
def correlation(column1: NumColumn, column2: NumColumn) -> Block:
    norm_column1 = normalize.normalize(column1)
    norm_column2 = normalize.normalize(column2)
    norm_column1 *= 0.5
    norm_column2 *= 0.5

    # cov
    tmp_column = normalize.normalize(column1)
    tmp_column *= column2
    block_count = count(tmp_column)
    block_count_inv = __safe_inverse(block_count, norm_column1.num_rows)
    block_count_minus = block_count - 1
    block_count_minus_inv = __safe_inverse(block_count_minus, tmp_column.num_rows)

    block_avg1 = Block.zeros(norm_column1.context, encrypted=norm_column1.encrypted)
    block_avg2 = Block.zeros(norm_column1.context, encrypted=norm_column1.encrypted)

    if tmp_column.has_vbit():
        for data_block1, data_block2, vbit_block in zip(norm_column1.data, norm_column2.data, tmp_column.vbit):
            tmp1 = data_block1 * vbit_block
            tmp2 = data_block2 * vbit_block
            block_avg1 += tmp1
            block_avg2 += tmp2
    else:
        for idx, (data_block1, data_block2) in enumerate(zip(norm_column1.data, norm_column2.data)):
            if idx == len(norm_column1.data) - 1:
                data_block1 = check_valid(data_block1, norm_column1)
            block_avg1 += data_block1
            if idx == len(norm_column2.data) - 1:
                data_block2 = check_valid(data_block2, norm_column2)
            block_avg2 += data_block2

    block_avg1 = block_avg1.rotate_sum()
    block_avg2 = block_avg2.rotate_sum()
    block_avg1 *= block_count_inv
    block_avg2 *= block_count_inv
    if block_avg1.need_bootstrap(3):
        block_avg1.bootstrap(one_slot=True)
        block_avg2.bootstrap(one_slot=True)

    block_cov = Block.zeros(norm_column1.context, encrypted=True)
    if tmp_column.has_vbit():
        for data_block1, data_block2, vbit_block in zip(norm_column1.data, norm_column2.data, tmp_column.vbit):
            tmp1 = data_block1 - block_avg1
            tmp1 *= vbit_block
            tmp2 = data_block2 - block_avg2
            tmp2 *= vbit_block
            block_cov += tmp1 * tmp2
    else:
        for idx, (data_block1, data_block2) in enumerate(zip(norm_column1.data, norm_column2.data)):
            tmp1 = data_block1 - block_avg1
            if (idx == len(norm_column1.data)-1):
                tmp1 = check_valid(tmp1, norm_column1)
            tmp2 = data_block2 - block_avg2
            if (idx == len(norm_column2.data)-1):
                tmp2 = check_valid(tmp2, norm_column2)
            block_cov += tmp1 * tmp2

    block_cov = block_cov.rotate_sum()
    block_cov *= block_count_minus_inv

    # std1
    block_sum = sum(norm_column1)
    block_count = count(norm_column1)
    block_count_minus = block_count - 1
    block_count_minus_inv = __safe_inverse(block_count_minus, norm_column1.num_rows)
    block_count_inv = __safe_inverse(block_count, norm_column1.num_rows)
    block_avg = block_sum * block_count_inv
    if block_avg.need_bootstrap(3):
        block_avg.bootstrap(one_slot=True)

    block_std1 = Block.zeros(norm_column1.context, encrypted=norm_column1.encrypted)
    if norm_column1.has_vbit():
        for data_block, vbit_block in zip(norm_column1.data, norm_column1.vbit):
            tmp = data_block - block_avg
            tmp *= vbit_block
            tmp *= tmp
            block_std1 += tmp
    else:
        for idx, data_block in enumerate(norm_column1.data):
            tmp = data_block - block_avg
            if idx == len(norm_column1.data) - 1:
                tmp = check_valid(tmp, norm_column1)
            tmp *= tmp
            block_std1 += tmp

    block_std1 = block_std1.rotate_sum()
    block_std1 *= block_count_minus_inv
    block_std1 = block_std1.sqrt_inv(one_slot=True)

    # std2
    block_sum = sum(norm_column2)
    block_count = count(norm_column2)
    block_count_minus = block_count - 1
    block_count_minus_inv = __safe_inverse(block_count_minus, norm_column2.num_rows)
    block_count_inv = __safe_inverse(block_count, norm_column2.num_rows)
    block_avg = block_sum * block_count_inv
    if block_avg.need_bootstrap(3):
        block_avg.bootstrap(one_slot=True)

    block_std2 = Block.zeros(norm_column2.context, encrypted=norm_column2.encrypted)
    if norm_column2.has_vbit():
        for data_block, vbit_block in zip(norm_column2.data, norm_column2.vbit):
            tmp = data_block - block_avg
            tmp *= vbit_block
            tmp *= tmp
            block_std2 += tmp
    else:
        for idx, data_block in enumerate(norm_column2.data):
            tmp = data_block - block_avg
            if (idx == len(norm_column2.data)-1):
                tmp = check_valid(tmp, norm_column2)
            tmp *= tmp
            block_std2 += tmp

    block_std2 = block_std2.rotate_sum()
    block_std2 *= block_count_minus_inv
    block_std2 = block_std2.sqrt_inv(one_slot=True)

    # correlation
    # block_cov's level could be 3 (FGb)
    block_cov *= block_std1
    block_cov *= block_std2
    return block_cov
