import math

import numpy as np

from hedal.block import Block
from hedal.context import Context
from hedal.core.config import PackingType
from hedal.matrix.matrix import HedalMatrix


def add_bias_col(context: Context, mat: HedalMatrix) -> HedalMatrix:
    """Add bias column with all ones as a rightmost column of the input matrix.

    Args:
        mat (HedalMatrix): Input matrix.

    Returns:
        HedalMatrix: Output matrix with bias column.
    """
    num_features = mat.shape[1]
    bias_col = HedalMatrix.mask(
        context, shape=(mat.shape[0], mat.shape[1] + 1), index=num_features, axis=1, encrypted=mat.encrypted,
    )
    if num_features % context.shape[1] == 0:
        copy_mat = mat.copy_memory()
        copy_mat.num_cols = mat.num_cols + 1
        for row in copy_mat:
            row.block_list.append(Block.zeros(context, encrypted=mat.encrypted, type=PackingType.MATRIX))
        new_mat = copy_mat + bias_col
    else:
        new_mat = mat + bias_col
    new_mat.num_cols = mat.shape[1] + 1
    return new_mat


def get_Z_from_X_y(context: Context, input: HedalMatrix, target: HedalMatrix, class_idx: int):
    # add bias to input
    Z = add_bias_col(context, input)

    # get y
    class_mask = np.zeros(shape=(input.num_rows, context.shape[1]))
    class_mask[:, class_idx] = 1
    class_mask = HedalMatrix.from_ndarray(context, class_mask)
    class_target = target * class_mask
    class_target = class_target.rot_left(class_idx)
    for i in range(int(math.log2(context.shape[1]))):
        for vec in class_target:
            tmp = vec[0] >> (1 << i)
            vec[0] += tmp

    # multiplication
    for idx, vec in enumerate(Z):
        for block in vec:
            block *= class_target[idx][0]
    return Z
