import math

import numpy as np

from hedal.block import Block
from hedal.core.config import PackingType


def transpose(block: Block) -> Block:
    if block.type != PackingType.MATRIX:
        raise TypeError("Invalid block type")

    if block.encrypted:
        num_d = block.shape[1]

        gs = pow(2, int((np.log2(num_d)) // 2))
        bs = 2 * num_d // gs

        ### PREPARE
        blocks_rot = [block >> num_d]
        for j in range(1, bs):
            blocks_rot.append(blocks_rot[j - 1] >> (num_d - 1))

        ### START baby-step-giant-step
        block_res = Block.zeros(block.context, type=block.type)
        for i in range(gs):
            j_start = 1 if i == 0 else 0

            block_sum = Block.zeros(block.context, type=block.type)

            rot_num = (i * bs * (num_d - 1)) % block.num_slots

            for j in range(j_start, bs):
                # make mask for mask_index = num_d - (i*bs + j)
                mask = Block.zeros(block.context, type=block.type)
                mask_idx = num_d - i * bs - j

                if mask_idx < 0:
                    for k in range(num_d + mask_idx):
                        mask.data[k * (num_d + 1) - mask_idx * num_d] = 1
                else:
                    for k in range(num_d - mask_idx):
                        mask.data[k * (num_d + 1) + mask_idx] = 1
                if not rot_num == 0:
                    mask <<= rot_num

                # mult sum
                block_sum += blocks_rot[j] * mask

            block_res += block_sum >> rot_num
    else:
        array = block.to_ndarray()
        block_res = Block.from_ndarray(block.context, array.transpose(), type=block.type)

    return block_res


def diagonal_to_col(block: Block) -> Block:
    if block.type != PackingType.MATRIX:
        raise TypeError("Invalid block type")

    if block.encrypted:
        num_d = block.shape[0]
        gs = pow(2, int(np.log2(num_d) // 2))
        bs = 2 * num_d // gs

        ### PREPARE
        blocks_rot = [block << num_d]
        for j in range(1, bs):
            blocks_rot.append(blocks_rot[j - 1] >> 1)

        ### START baby-step-giant-step
        block_res = Block.zeros(block.context, type=block.type)

        for i in range(gs):
            j_start = 1 if i == 0 else 0
            block_sum = Block.zeros(block.context, type=block.type)
            rot_num = (i * bs) % block.num_slots

            for j in range(j_start, bs):
                # make mask for mask index = num_d -  (i*bs + j)
                mask = Block.zeros(block.context, type=block.type)
                mask_idx = num_d - i * bs - j

                if mask_idx < 0:
                    row_idx = (num_d + mask_idx) * num_d
                    for k in range(-mask_idx, num_d):
                        mask.data[row_idx + k] = 1
                else:
                    row_idx = mask_idx * num_d
                    for k in range(num_d - mask_idx):
                        mask.data[row_idx + k] = 1
                if not rot_num == 0:
                    mask <<= rot_num

                # mult sum
                block_sum += blocks_rot[j] * mask

            block_res += block_sum >> rot_num

    else:
        array = block.to_ndarray()
        res_array = np.zeros(block.shape)
        for idx in range(block.shape[0]):
            res_array[idx, :] = np.roll(array[idx, :], -idx)
        block_res = Block.from_ndarray(block.context, res_array, type=block.type)

    return block_res


def diagonal_to_row(block: Block) -> Block:
    if block.type != PackingType.MATRIX:
        raise TypeError("Invalid block type")

    if block.encrypted:
        num_d = block.shape[1]
        gs = pow(2, int(np.log2(num_d) // 2))
        bs = num_d // gs

        ### PREPARE
        blocks_rot = [block.copy()]
        for j in range(1, bs):
            blocks_rot.append(blocks_rot[j - 1] << num_d)

        ### START baby-step-giant-step
        mask = Block.zeros(block.context, type=block.type)
        for i in range(num_d):
            mask.data[i * num_d] = 1
        block_res = blocks_rot[0] * mask

        for i in range(gs):
            j_start = 1 if i == 0 else 0
            block_sum = Block.zeros(block.context, type=block.type)
            rot_num = (i * bs * num_d) % block.num_slots

            for j in range(j_start, bs):
                # make sum of mask for mask idices = i*bs + j,  i*bs + j - n
                mask = Block.zeros(block.context, type=block.type)
                mask_idx = i * bs + j

                col_idx = mask_idx
                for k in range(num_d - mask_idx):
                    mask.data[k * num_d + col_idx] = 1

                for k in range(num_d - mask_idx, num_d):
                    mask.data[k * num_d + col_idx] = 1

                mask >>= rot_num

                # mult sum
                block_sum += blocks_rot[j] * mask

            block_res += block_sum << rot_num

    else:
        array = block.to_ndarray()
        res_array = np.zeros(block.shape)
        for idx in range(block.shape[1]):
            res_array[:, idx] = np.roll(array[:, idx], -idx)
        block_res = Block.from_ndarray(block.context, res_array, type=block.type)

    return block_res


def sum(block: Block, axis: int, direction: int) -> Block:
    if block.type != PackingType.MATRIX:
        raise TypeError("Invalid block type")

    if (axis not in (0, 1)) or (direction not in (0, 1)):
        raise TypeError

    res_block = block.copy()
    num_rows, num_cols = block.shape
    log_rows, log_cols = int(np.log2(num_rows)), int(np.log2(num_cols))

    if axis == 0:
        # verti. up rotate sum
        if direction == 0:
            for idx in range(log_rows):
                res_block += res_block << ((1 << idx) * num_cols)

        # verti. down rotate sum
        elif direction == 1:
            for idx in range(log_rows):
                res_block += res_block >> ((1 << idx) * num_cols)

    elif axis == 1:
        # horiz. left rotate sum
        if direction == 0:
            for idx in range(log_cols):
                res_block += res_block << (1 << idx)

        # horiz. right rotate sum
        elif direction == 1:
            for idx in range(log_cols):
                res_block += res_block >> (1 << idx)

    return res_block


def matmul(a: Block, b: Block) -> Block:
    if (a.type != PackingType.MATRIX) or (b.type != PackingType.MATRIX):
        raise TypeError("Invalid block type")

    num_rows, num_cols = a.shape
    block_a, block_b = a.copy(), b.copy()
    res_block = a * b
    mask = Block.ones(a.context, type=res_block.type)
    for i in range(1, num_cols):
        block_a <<= 1
        block_b <<= num_cols
        for j in range(num_rows):
            mask.data[j * num_cols + num_cols * i] = 0

        tmp = block_a * mask
        tmp += (block_a >> num_cols) * ((-mask) + 1)
        res_block += block_b + tmp

    return res_block


def sigmoid15(block: Block) -> Block:
    eps = 1 / 12
    # if block.encrypted:
    #     block.decrypt()
    # block_data = block.to_ndarray()
    # print(f"max: {block_data.max(): .4f}, min: {block_data.min(): .4f}")
    # if not block.encrypted:
    #     block.encrypt()
    if block.need_bootstrap(5):
        block *= eps
        block.bootstrap()
        block *= 12
    block1 = block * -1.09472135759754
    block2 = block * block * 0.01
    block3 = block2 - 1.6014381203
    block1 *= block3
    block4 = block2 * block2
    block3 = block2 * -2.65078846906
    block3 += block4
    block3 += 1.89130125130
    block1 *= block3
    block3 = block2 * -1.35875365695
    block3 += block4
    block3 += 0.70493730522
    block2 *= -0.179296867889
    block2 += block4
    block2 += 0.1019531383631
    block2 *= block3
    block1 *= block2
    res_block = block1 + 0.5
    return res_block


def sigmoid(block: Block, depth: int = 10) -> Block:
    result = Block(block.context, encrypted=block.encrypted, type=block.type)
    if block.encrypted:
        block.to_device()
        # block.context.heaan.math.approx.sigmoid_wide(block.homevaluator, block.data, result.data, depth)
        result = sigmoid15(block)
        result.to_host()
    else:
        for idx in range(block.num_slots):
            result[idx] = 1 / (1 + np.exp(-block[idx].real))
    return result


def exp(block: Block, degree: int = 8) -> Block:
    """Approximated exponential function.
    It is approximated as exp(x) = (g(x) + 1)^8, where g(x) = exp(x/8) - 1 is approximated using taylor expension of given degree

    Args:
        block (Block): Block to compute exponential function.
        degree (int, optional): Degree of Taylor approximation polynomial. Defaults to 8

    Returns:
        Block: Approximated exponentiation result block.
    """
    if block.encrypted:
        scale_pow_param = 3
        bx = Block.zeros(block.context, encrypted=block.encrypted, type=block.type)
        by = block.copy() * (1 / (1 << scale_pow_param))
        for i in range(2, degree + 1):
            bz = by + block * (by - bx) * (1 / (i * (1 << scale_pow_param)))
            if bz.need_bootstrap(2):
                bz.bootstrap()
            bx, by = by, bz
        bz += 1
        bz.bootstrap()  # FGb: level = 12
        for _ in range(scale_pow_param):
            bz *= bz
        return bz  # FGb: level = 9
    else:
        res_block = block.copy()
        for j in range(block.num_slots):
            res_block[j] = math.exp(block[j].real)
        return res_block


def exp_wide(block: Block, degree: int = 8, n: int = 3) -> Block:
    """Approximated exponential function.
    It is approximated as exp(x) = (g(x) + 1)^8, where g(x) = exp(x/8) - 1 is approximated using taylor expension of given degree

    Args:
        block (Block): Block to compute exponential function.
        degree (int, optional): Degree of Taylor approximation polynomial. Defaults to 8
        n (int, optional): Number of iterations for domain extension function. Defaults to 3.

    Returns:
        Block: Approximated exponentiation result block.
    """
    if block.encrypted:
        r = 10
        b = block * (1 / r)
        bootstrap_m = int(np.ceil(2.25 ** n))
        for i in range(n, -1, -1):
            if b.need_bootstrap(4):
                b *= 1 / bootstrap_m
                b.bootstrap()
                b *= bootstrap_m
            b = b - (4 / (27 * (2.25 ** i))) * (b * b * b)
        b = exp(r * b, degree=degree)
        return b
    else:
        res_block = block.copy()
        for j in range(block.num_slots):
            res_block[j] = math.exp(block[j].real)
        return res_block
