import math
from typing import Union

import numpy as np

from hedal.block import Block
from hedal.core.config import PackingType
from hedal.matrix.matrix import HedalMatrix
from hedal.matrix.ops import block_ops as bop
from hedal.matrix.ops import vec_ops as vop
from hedal.matrix.vector import HedalVector


def mat_vec_mul_col_tiled(mat: HedalMatrix, vec: HedalVector, complex: bool = False) -> HedalMatrix:
    """Matrix-vector multiplication, where the vector (second operand) is assumed to be tiled along columns.

    Args:
        mat (HedalMatrix): Input matrix of shape (num_rows_1, num_cols).
        vec (HedalVector): Input vector of shape (num_rows_2, num_cols).
        complex (bool, optional): Whether inputs are complex matrix/vectors or not. If False,
            faster implementation will be used. Defaults to False.

    Returns:
        HedalMatrix: Output matrix of shape (num_rows_1, num_rows_2).
    """
    res_mat = HedalMatrix(mat.context, shape=(mat.num_rows, vec.num_rows), encrypted=mat.encrypted or vec.encrypted)
    for mat_row in mat:
        mul_row = vop.vec_mul_col_tiled(mat_row, vec, complex=complex)
        res_mat.objects.append(mul_row)
    return res_mat


def mat_mul_row_tiled(
    mat1: HedalMatrix, mat2: HedalMatrix, tile_col: bool = False, complex: bool = False
) -> HedalVector:
    """Matrix-matrix multiplication, where the mat1 is assumed to be tiled along rows.
    if tile_col is True, then the result will be tiled along columns.

    Args:
        mat1 (HedalMatrix): Input matrix of shape (num_rows, num_cols_1).
        mat2 (HedalMatrix): Input matrix of shape (num_rows, num_cols_2).
        tile_col (bool, optional): If True, the result will be tiled along columns. Defaults to False.
        complex (bool, optional): Whether inputs are complex matrix/vectors or not. If False,
            faster implementation will be used. Defaults to False.

    Returns:
        HedalVector: Output vector of shape (num_cols_1, num_cols_2).
    """
    if mat1.num_cols > mat1.context.shape[0]:
        raise ValueError(
            f"Number of columns of mat1 ({mat1.num_cols}) should be smaller than mat1.context.shape[0] ({mat1.context.shape[0]})"
        )
    res_vec = HedalVector.zeros(
        mat1.context, shape=(mat1.num_cols, mat2.num_cols), encrypted=mat1.encrypted or mat2.encrypted
    )
    for row1, row2 in zip(mat1, mat2):
        res_row = vop.vec_mul_row_tiled(row1, row2, tile_col=tile_col, complex=complex)
        res_vec += res_row
    return res_vec


def mvmul(mat: HedalMatrix, vec: HedalVector) -> HedalVector:
    if mat.num_cols != vec.num_cols:
        raise TypeError("Invalid dimension of matrix and vector")
    if (mat.type != PackingType.MATRIX) or (vec.type != PackingType.MATRIX):
        raise TypeError("Invalid vector type")

    res_vector = HedalVector(mat.context, shape=(1, mat.num_rows), encrypted=mat.encrypted or vec.encrypted)

    mask_diag = Block.identity(mat.context, type=mat.type)
    mask_cols = Block.mask(mat.context, index=0, axis=1, type=mat.type)

    for idx, mat_row in enumerate(mat):
        block = Block.zeros(mat.context, type=mat.type)
        for m, v in zip(mat_row, vec):
            tmp = m * v
            tmp = bop.sum(tmp, axis=1, direction=0)
            tmp *= mask_cols
            tmp = bop.sum(tmp, axis=1, direction=1)
            tmp *= mask_diag
            tmp = bop.sum(tmp, axis=0, direction=0)
            block += tmp
        res_vector[idx] = block

    return res_vector


def matmul(A: HedalMatrix, B: HedalMatrix) -> HedalMatrix:
    if A.num_cols != B.num_rows:
        raise TypeError("Invalid dimension of operands")

    encrypted = A.encrypted or B.encrypted

    if not encrypted:
        res_mat = HedalMatrix.from_ndarray(A.context, np.matmul(A.to_ndarray(), B.to_ndarray()))

    else:
        diag_col_A = diagonal_to_col(A)
        diag_row_B = diagonal_to_row(B)
        res_mat = _rot_mul_sum(diag_col_A, diag_row_B)
        diag_col_A.remove()
        diag_row_B.remove()

    res_mat.save()
    return res_mat


def transpose(mat: HedalMatrix) -> HedalMatrix:
    mat_t = HedalMatrix(mat.context, shape=(mat.num_cols, mat.num_rows), encrypted=mat.encrypted)

    for row_idx, mat_row in enumerate(mat):
        for col_idx, block in enumerate(mat_row):
            mat_t[col_idx][row_idx] = bop.transpose(block)
    return mat_t


def diagonal_to_col(mat: HedalMatrix) -> HedalMatrix:
    mat_diag = HedalMatrix(mat.context, shape=mat.shape, encrypted=mat.encrypted)

    for row_idx, mat_row in enumerate(mat):
        for col_idx, block in enumerate(mat_row):
            mat_diag[row_idx][col_idx] = bop.diagonal_to_col(block)
    return mat_diag


def diagonal_to_row(mat: HedalMatrix) -> HedalMatrix:
    mat_diag = HedalMatrix(mat.context, shape=mat.shape, encrypted=mat.encrypted)

    for row_idx, mat_row in enumerate(mat):
        for col_idx, block in enumerate(mat_row):
            mat_diag[row_idx][col_idx] = bop.diagonal_to_row(block)
    return mat_diag


def horizontal_sum(mat: HedalMatrix, direction: int, fill: bool = False, num_fill_cols: int = 0) -> HedalMatrix:
    res_mat = HedalMatrix(mat.context, shape=(mat.num_rows, mat.context.shape[1]), encrypted=mat.encrypted)
    for row in mat:
        res_mat.objects.append(vop.horizontal_sum(row, direction=direction, fill=fill, num_fill_cols=num_fill_cols))
    return res_mat


def exp(mat: HedalMatrix, degree: int = 8) -> HedalMatrix:
    """Approximate exponential of a mattor.
    It is approximated as exp(x) = (g(x) + 1)^8, where g(x) = exp(x/8) - 1 is approximated using taylor expension of given degree

    Args:
        mat (HedalMatrix): HedalMatrix to be exponentiated.
        degree (int, optional): Degree of Taylor approximation polynomial. Defaults to 8

    Returns:
        HedalMatrix: Approximated exponentiation result mattor.
    """
    if mat.encrypted:
        scale_pow_param = 3
        mx = HedalMatrix.zeros(mat.context, shape=mat.shape)
        my = mat.copy_memory() * (1 / (1 << scale_pow_param))
        for i in range(2, degree + 1):
            mz = my + mat * (my - mx) * (1 / (i * (1 << scale_pow_param)))
            if mz.need_bootstrap(2):
                mz.bootstrap()
            mx, my = my, mz
        mz += 1
        mz.bootstrap()  # FGb: level = 12
        for _ in range(scale_pow_param):
            mz *= mz
        return mz  # FGb: level = 9
    else:
        res_arr = mat.to_ndarray()
        res_arr = np.exp(res_arr)
        res_mat = HedalMatrix.from_ndarray(mat.context, res_arr)
        return res_mat


def exp_wide(mat: HedalMatrix, degree: int = 8, n: int = 3) -> HedalMatrix:
    """Approximated exponential wide function, which is a composition of the original (approximated) exponential function with domain extension function.
    This is used for softmax_wide, which is essential because of the restricted domain of the inverse function.

    Args:
        mat (HedalMatrix): HedalMatrix to compute exponential wide function.
        degree (int, optional): Degree of Taylor approximation polynomial. Defaults to 8
        n (int, optional): Number of iterations for domain extension function. Defaults to 3.

    Returns:
        HedalMatrix: Approximated exponentiation result matrix.
    """
    if mat.encrypted:
        r = 10
        m = mat * (1 / r)
        bootstrap_m = int(np.ceil(2.25 ** n))
        for i in range(n, -1, -1):
            if m.need_bootstrap(4):
                m *= 1 / bootstrap_m
                m.bootstrap()
                m *= bootstrap_m
            m = m - (4 / (27 * (2.25 ** i))) * (m * m * m)
        m = exp(r * m, degree=degree)
        return m
    else:
        res_arr = mat.to_ndarray()
        res_arr = np.exp(res_arr)
        res_mat = HedalMatrix.from_ndarray(mat.context, res_arr)
        return res_mat


def inverse(
    mat: HedalMatrix, one_slot: bool = False, greater_than_one: bool = True, inverse_num_iter: int = 20
) -> HedalMatrix:
    result = HedalMatrix(mat.context, shape=mat.shape, encrypted=mat.encrypted)
    for row in mat:
        inv_row = vop.inverse(
            row, one_slot=one_slot, greater_than_one=greater_than_one, inverse_num_iter=inverse_num_iter
        )
        result.objects.append(inv_row)
    return result


def softmax(mat: HedalMatrix, output_tiled: bool, exp_degree: int = 8, inverse_num_iter: int = 20) -> HedalMatrix:
    """Approximate row-wise softmax function. Appropriate input range is [-10, 10]

    Args:
        mat (HedalVector): Matrix to be softmaxed.
        output_tiled (bool): If True, the result is padded and tiled along rows.
            For example, if the real result is [a, b, c], then the padded & tiled result is [a, b, c, 0, a, b, c, 0, ...].
        exp_depth (int, optional): Degree of Taylor approximation for exponential. Defaults to 8.
        inverse_num_iter (int, optional): Number of iterations for inverse. Defaults to 20.

    Returns:
        HedalMatrix: Softmaxed vector. If output_tiled is True, then the shape of the result is (mat.shape[0], padded_num_cols),
            where the actual data is padded and tiled along rows.
    """
    padded_num_cols = int(2 ** math.ceil(math.log2(mat.num_cols)))
    if mat.encrypted:
        # TODO: use first-slot subtraction only when the number of columns (number of classes) is 'small'
        # first_slot_mask = HedalVector.mask(vec.context, vec.shape, index=0, axis=1)
        # sub_block = (vec * first_slot_mask).block_list[0]
        # for rot_idx in range(min(math.ceil(math.log2(vec.shape[1])), int(math.log2(vec.context.shape[1])))):
        #     sub_block += sub_block >> (1 << rot_idx)

        # assumes that the result of exponential has enough level so that we don't need additional bootstrapping
        # this is ensured when using FGb parameter.

        exp_mat = exp(mat, degree=exp_degree)
        mask = HedalMatrix.from_ndarray(mat.context, array=np.ones(mat.shape))
        exp_mat *= mask

        exp_sum_mat = horizontal_sum(exp_mat, direction=0, fill=True)
        exp_sum_mat = inverse(exp_sum_mat, greater_than_one=True, inverse_num_iter=inverse_num_iter)
        res_mat = exp_mat * exp_sum_mat

        if output_tiled:
            rot_num = int(math.log2(mat.context.shape[1] // padded_num_cols))
            for row in res_mat:
                for i in range(rot_num):
                    row[0] += row[0] >> ((1 << i) * padded_num_cols)
            res_mat.num_cols = padded_num_cols
        return res_mat
    else:
        arr = mat.to_ndarray()
        arr = arr - arr.max(axis=1, keepdims=True)
        arr = np.exp(arr)
        arr = arr / arr.sum(axis=1, keepdims=True)

        if output_tiled:
            arr = np.concatenate((arr, np.zeros((arr.shape[0], padded_num_cols - mat.shape[1]))), axis=1)
            arr = np.tile(arr, (1, mat.context.shape[1] // arr.shape[1]))
        res_mat = HedalMatrix.from_ndarray(mat.context, arr)
        res_mat.num_cols = padded_num_cols
        return res_mat


# def softmax(mat: HedalMatrix, output_tiled: bool, exp_degree: int = 8, inverse_num_iter: int = 20) -> HedalVector:
#     """Approximate row-wise softmax function. Appropriate input range is [-10, 10]

#     Args:
#         mat (HedalVector): Matrix to be softmaxed.
#         output_tiled (bool): If True, the result is padded and tiled along rows.
#             For example, if the real result is [a, b, c], then the padded & tiled result is [a, b, c, 0, a, b, c, 0, ...].
#         exp_depth (int, optional): Degree of Taylor approximation for exponential. Defaults to 8.
#         inverse_num_iter (int, optional): Number of iterations for inverse. Defaults to 20.

#     Returns:
#         HedalMatrix: Softmaxed vector. If output_tiled is True, then the shape of the result is (mat.shape[0], padded_num_cols),
#             where the actual data is padded and tiled along rows.
#     """
#     mat_softmax = HedalMatrix(mat.context, shape=mat.shape, encrypted=mat.encrypted)
#     padded_num_cols = int(2 ** math.ceil(math.log2(mat.num_cols)))
#     for mat_row in mat:
#         assert mat_row.shape[1] == mat.shape[1], "Invalid shape of matrix"
#         mat_softmax_row = vop.softmax(
#             mat_row, output_tiled=output_tiled, exp_degree=exp_degree, inverse_num_iter=inverse_num_iter
#         )
#         mat_softmax.objects.append(mat_softmax_row)
#     if output_tiled:
#         mat_softmax.num_cols = padded_num_cols
#     return mat_softmax


def softmax_wide(
    mat: HedalMatrix, output_tiled: bool, exp_degree: int = 8, inverse_num_iter: int = 20, n: int = 3
) -> HedalMatrix:
    """Approximate row-wise softmax_wide function.
    The error of this function with the original softmax function is quite large, but there's no accuracy loss for logistic regression
    compared to the original softmax function.
    This has wider input range than the original softmax function, which is possible because of the domain extension.
    For a given depth (n), the appropriate input range is  [-1.5^n*10, 1.5^n*10].

    Args:
        mat (HedalMatrix): Matrix to be softmaxed.
        output_tiled (bool): If True, the result is padded and tiled along rows.
            For example, if the real result is [a, b, c], then the padded & tiled result is [a, b, c, 0, a, b, c, 0, ...].
        exp_depth (int, optional): Degree of Taylor approximation for exponential. Defaults to 8.
        inverse_num_iter (int, optional): Number of iterations for inverse. Defaults to 20.
        n (int, optional): Depth for domain extension function. Defaults to 3.

    Returns:
        HedalMatrix: Softmaxed vector. If output_tiled is True, then the shape of the result is (mat.shape[0], padded_num_cols),
            where the actual data is padded and tiled along rows.
    """
    padded_num_cols = int(2 ** math.ceil(math.log2(mat.num_cols)))
    if mat.encrypted:
        # TODO: use first-slot subtraction only when the number of columns (number of classes) is 'small'
        # first_slot_mask = HedalVector.mask(vec.context, vec.shape, index=0, axis=1)
        # sub_block = (vec * first_slot_mask).block_list[0]
        # for rot_idx in range(min(math.ceil(math.log2(vec.shape[1])), int(math.log2(vec.context.shape[1])))):
        #     sub_block += sub_block >> (1 << rot_idx)

        # assumes that the result of exponential has enough level so that we don't need additional bootstrapping
        # this is ensured when using FGb parameter.

        exp_mat = exp_wide(mat, degree=exp_degree, n=n)
        mask = HedalMatrix.from_ndarray(mat.context, array=np.ones(mat.shape))
        exp_mat *= mask

        exp_sum_mat = horizontal_sum(exp_mat, direction=0, fill=True)
        exp_sum_mat = inverse(exp_sum_mat, greater_than_one=True, inverse_num_iter=inverse_num_iter)
        res_mat = exp_mat * exp_sum_mat

        if output_tiled:
            rot_num = int(math.log2(mat.context.shape[1] // padded_num_cols))
            for row in res_mat:
                for i in range(rot_num):
                    row[0] += row[0] >> ((1 << i) * padded_num_cols)
            res_mat.num_cols = padded_num_cols
        return res_mat
    else:
        arr = mat.to_ndarray()
        arr = arr - arr.max(axis=1, keepdims=True)
        arr = np.exp(arr)
        arr = arr / arr.sum(axis=1, keepdims=True)

        if output_tiled:
            arr = np.concatenate((arr, np.zeros((arr.shape[0], padded_num_cols - mat.shape[1]))), axis=1)
            arr = np.tile(arr, (1, mat.context.shape[1] // arr.shape[1]))
        res_mat = HedalMatrix.from_ndarray(mat.context, arr)
        res_mat.num_cols = padded_num_cols
        return res_mat


# def softmax_wide(
#     mat: HedalMatrix, output_tiled: bool, exp_degree: int = 8, inverse_num_iter: int = 20, n: int = 3
# ) -> HedalMatrix:
#     """Approximate row-wise softmax_wide function.
#     The error of this function with the original softmax function is quite large, but there's no accuracy loss for logistic regression
#     compared to the original softmax function.
#     This has wider input range than the original softmax function, which is possible because of the domain extension.
#     For a given depth (n), the appropriate input range is  [-1.5^n*10, 1.5^n*10].

#     Args:
#         mat (HedalMatrix): Matrix to be softmaxed.
#         output_tiled (bool): If True, the result is padded and tiled along rows.
#             For example, if the real result is [a, b, c], then the padded & tiled result is [a, b, c, 0, a, b, c, 0, ...].
#         exp_depth (int, optional): Degree of Taylor approximation for exponential. Defaults to 8.
#         inverse_num_iter (int, optional): Number of iterations for inverse. Defaults to 20.
#         n (int, optional): Depth for domain extension function. Defaults to 3.

#     Returns:
#         HedalMatrix: Softmaxed vector. If output_tiled is True, then the shape of the result is (mat.shape[0], padded_num_cols),
#             where the actual data is padded and tiled along rows.
#     """
#     mat_softmax = HedalMatrix(mat.context, shape=mat.shape, encrypted=mat.encrypted)
#     padded_num_cols = int(2 ** math.ceil(math.log2(mat.num_cols)))
#     for mat_row in mat:
#         assert mat_row.shape[1] == mat.shape[1], "Invalid shape of matrix"
#         mat_softmax_row = vop.softmax_wide(
#             mat_row, output_tiled=output_tiled, exp_degree=exp_degree, inverse_num_iter=inverse_num_iter, n=n
#         )
#         mat_softmax.objects.append(mat_softmax_row)
#     if output_tiled:
#         mat_softmax.num_cols = padded_num_cols
#     return mat_softmax


def dot(m1: HedalMatrix, m2: Union[HedalMatrix, HedalVector], fill: bool = False) -> HedalMatrix:
    """Dot product of two matrices or matrix and vector.

    Args:
        m1 (HedalMatrix): First matrix.
        m2 (Union[HedalMatrix, HedalVector]): Second matrix or vector.
        fill (bool, optional): Whether to fill the matrix with result or not. Defaults to False.

    Raises:
        ValueError: If the number of rows of the two matrices are not matched.
        TypeError: If the type of m2 is not HedalMatrix or HedalVector.
    Returns:
        HedalMatrix: Result matrix.
    """

    res = HedalMatrix(m1.context, shape=m1.shape, encrypted=m1.encrypted or m2.encrypted)
    if isinstance(m2, HedalMatrix):
        if m1.num_rows != m2.num_rows:
            raise ValueError(
                f"Invalid dimension of operands: number of rows of m1({m1.num_rows}) and m2({m2.num_rows}) should be the same."
            )
        for idx, (m1_row, m2_row) in enumerate(zip(m1, m2)):
            res_row = vop.dot(m1_row, m2_row, fill=fill)
            res_row.path = res.path / str(idx)
            res.objects.append(res_row)
    elif isinstance(m2, HedalVector):
        for idx, m1_row in enumerate(m1):
            res_row = vop.dot(m1_row, m2, fill=fill)
            res_row.path = res.path / str(idx)
            res.objects.append(res_row)
    else:
        raise TypeError(f"Invalid type of operands m2: {type(m2)}")
    return res


def sigmoid(mat: HedalMatrix, depth: int = 10) -> HedalMatrix:
    res = HedalMatrix(mat.context, shape=mat.shape, encrypted=mat.encrypted)
    for idx, mat_row in enumerate(mat):
        res_row = vop.sigmoid(mat_row, depth=depth)
        res_row.path = res.path / str(idx)
        res.objects.append(res_row)
    return res


def _rot_mul_sum(A: HedalMatrix, B: HedalMatrix) -> HedalMatrix:
    raise NotImplementedError
