import math
import os
from typing import Optional, Union
from hedal import Block, NumColumn
import numpy as np
from hedal.stats.preprocessing import norm_metadata


def make_norm_blocks(column: NumColumn) -> None:
    norm_col_idx = math.ceil((column.num_rows + norm_metadata.get_norm_size(column)) / column.num_slots)
    norm_block = Block.zeros(column.context, encrypted=column.encrypted)
    norm_block.path = column.data.block_path(norm_col_idx - 1)
    norm_block.load()

    if os.path.isfile(column.data.block_path(norm_col_idx)):
        os.remove(column.data.block_path(norm_col_idx))
    if os.path.isfile(column.data.block_path(norm_col_idx + 1)):
        os.remove(column.data.block_path(norm_col_idx + 1))
    if os.path.isfile(column.data.block_path(norm_col_idx + 2)):
        os.remove(column.data.block_path(norm_col_idx + 2))
    if os.path.isfile(column.data.block_path(norm_col_idx + 3)):
        os.remove(column.data.block_path(norm_col_idx + 3))

    msg_block = Block.zeros(column.context, encrypted=False)
    to_minus = False
    if norm_metadata.get_norm_plus(column):
        msg_block[column.num_slots - 2 - norm_metadata.get_norm_plus(column)[0]] = 1
    elif norm_metadata.get_norm_minus(column):
        to_minus = True
        msg_block[column.num_slots - 2 - norm_metadata.get_norm_minus(column)[0]] = 1
    else:
        raise Exception("[Error] No absmax data.")
    norm_block1 = norm_block * msg_block
    norm_block1 = norm_block1 * msg_block
    if to_minus:
        norm_block1 = -norm_block1
    norm_block1.path = column.data.block_path(norm_col_idx)
    norm_block1.save()

    if os.path.isfile(column.data.block_path(norm_col_idx + 1)):
        os.remove(column.data.block_path(norm_col_idx + 1))

    msg_block = Block.zeros(column.context, encrypted=False)
    msg_block[column.num_slots - 1] = 1
    norm_block2 = norm_block * msg_block
    norm_block2 = norm_block2 * msg_block
    norm_block2.path = column.data.block_path(norm_col_idx + 1)
    norm_block2.save()


def remove_norm_blocks(column: NumColumn) -> None:
    norm_col_idx = math.ceil((column.num_rows + norm_metadata.get_norm_size(column)) / column.num_slots)

    if os.path.isfile(str(column.path) + "/origin_" + str(norm_col_idx - 1) + ".bin"):
        os.remove(str(column.path) + "/origin_" + str(norm_col_idx - 1) + ".bin")
    if os.path.isfile(column.data.block_path(norm_col_idx)):
        os.remove(column.data.block_path(norm_col_idx))
    if os.path.isfile(column.data.block_path(norm_col_idx + 1)):
        os.remove(column.data.block_path(norm_col_idx + 1))
    if os.path.isfile(column.data.block_path(norm_col_idx + 2)):
        os.remove(column.data.block_path(norm_col_idx + 2))
    if os.path.isfile(column.data.block_path(norm_col_idx + 3)):
        os.remove(column.data.block_path(norm_col_idx + 3))


def add_norm(column: NumColumn, target_column: Optional[NumColumn] = None) -> None:
    norm_col_idx = math.ceil((column.num_rows + norm_metadata.get_norm_size(column)) / column.num_slots)
    real_col_idx = math.ceil(column.num_rows / column.num_slots)
    norm_path = column.data.block_path(norm_col_idx - 1)
    norm_block = Block.zeros(column.context, encrypted=column.encrypted)
    last_block = Block.zeros(column.context, encrypted=column.encrypted)
    norm_block.path = norm_path
    norm_block.load()
    if norm_col_idx != real_col_idx:
        last_block = norm_block
        if target_column:
            last_block.path = target_column.data.block_path(norm_col_idx - 1)
        else:
            last_block.path = column.data.block_path(norm_col_idx - 1)
        last_block.save()
    else:
        np_message = np.full(column.context.num_slots - norm_metadata.get_norm_size(column), 0)
        np_message = np.append(np_message, np.full(norm_metadata.get_norm_size(column), 1))
        msg_block = Block(column.context, data=np_message, encrypted=False)
        norm_block *= msg_block
        if target_column:
            last_block.path = target_column.data.block_path(norm_col_idx - 1)
        else:
            last_block.path = column.data.block_path(norm_col_idx - 1)
        last_block.load()
        last_block += norm_block
        last_block.save()

    if target_column:
        if os.path.isfile(target_column.data.block_path(norm_col_idx)):
            os.remove(target_column.data.block_path(norm_col_idx))
        if os.path.isfile(target_column.data.block_path(norm_col_idx + 1)):
            os.remove(target_column.data.block_path(norm_col_idx + 1))
        if os.path.isfile(target_column.data.block_path(norm_col_idx + 2)):
            os.remove(target_column.data.block_path(norm_col_idx + 2))
        if os.path.isfile(target_column.data.block_path(norm_col_idx + 3)):
            os.remove(target_column.data.block_path(norm_col_idx + 3))
    else:
        if os.path.isfile(column.data.block_path(norm_col_idx)):
            os.remove(column.data.block_path(norm_col_idx))
        if os.path.isfile(column.data.block_path(norm_col_idx + 1)):
            os.remove(column.data.block_path(norm_col_idx + 1))
        if os.path.isfile(column.data.block_path(norm_col_idx + 2)):
            os.remove(column.data.block_path(norm_col_idx + 2))
        if os.path.isfile(column.data.block_path(norm_col_idx + 3)):
            os.remove(column.data.block_path(norm_col_idx + 3))


def mult_abs_max_inv(target_column: Union[NumColumn, Block], basic_column: NumColumn) -> None:
    sqrt_abs_max_inv_block = norm_sqrt_abs_max_inv(basic_column)
    if type(target_column) == NumColumn:
        for data_block in target_column.data:
            if data_block.need_bootstrap(1):
                data_block.bootstrap()
            data_block *= sqrt_abs_max_inv_block
            if data_block.need_bootstrap(1):
                data_block.bootstrap()
            data_block *= sqrt_abs_max_inv_block
            if data_block.need_bootstrap(4):
                data_block.bootstrap()
    else:
        if target_column.need_bootstrap(1):
            target_column.bootstrap()
        target_column *= sqrt_abs_max_inv_block
        if target_column.need_bootstrap(1):
            target_column.bootstrap()
        target_column *= sqrt_abs_max_inv_block
        if target_column.need_bootstrap(4):
            target_column.bootstrap()


def mult_abs_max(target_column: Union[NumColumn, Block], basic_column: NumColumn) -> None:
    abs_max_block = norm_abs_max(basic_column)
    if type(target_column) == NumColumn:
        for data_block in target_column.data:
            if data_block.need_bootstrap(5):
                data_block.bootstrap()
            data_block *= abs_max_block
    else:
        if target_column.need_bootstrap(5):
            target_column.bootstrap()
        target_column *= abs_max_block


def norm_abs_max(column: NumColumn) -> Block:
    norm_block = Block.zeros(column.context, encrypted=column.encrypted)
    norm_col_idx = math.ceil((column.num_rows + norm_metadata.get_norm_size(column)) / column.num_slots)
    if os.path.isfile(column.data.block_path(norm_col_idx + 3)):
        norm_block.path = column.data.block_path(norm_col_idx + 3)
        norm_block.load()
    else:
        norm_block.path = column.data.block_path(norm_col_idx)
        norm_block.load()
        norm_block = norm_block.rotate_sum()
        norm_block.path = column.data.block_path(norm_col_idx + 3)
        norm_block.save()
        os.remove(column.data.block_path(norm_col_idx))
    return norm_block


def norm_sqrt_abs_max_inv(column: NumColumn) -> Block:
    norm_col_idx = math.ceil((column.num_rows + norm_metadata.get_norm_size(column)) / column.num_slots)
    if os.path.isfile(column.data.block_path(norm_col_idx + 2)):
        norm_block = Block.zeros(column.context, encrypted=column.encrypted)
        norm_block.path = column.data.block_path(norm_col_idx + 2)
        norm_block.load()
    elif norm_metadata.get_norm_type(column) != "basic":
        norm_block = norm_abs_max(column)
        norm_block *= 1 / norm_metadata.get_norm_factor(column)
        norm_block = norm_block.sqrt_inv(one_slot=True, greater_than_one=True)
        norm_block *= 1 / math.sqrt(norm_metadata.get_norm_factor(column))
        norm_block.bootstrap(one_slot=True)
        norm_block.path = column.data.block_path(norm_col_idx + 2)
        norm_block.save()
    else:
        norm_block = Block.zeros(column.context, encrypted=column.encrypted)
        norm_block.path = column.data.block_path(norm_col_idx + 1)
        norm_block.load()
        norm_block = norm_block.rotate_sum()
        norm_block.path = column.data.block_path(norm_col_idx + 2)
        norm_block.save()
        os.remove(column.data.block_path(norm_col_idx + 1))
    return norm_block


def normalize(column: NumColumn, save: bool = False) -> Optional[Union[NumColumn, Block]]:
    make_norm_blocks(column)
    if not save:
        res_column = NumColumn.from_path(column.context, column.path)
    else:
        res_column = column.copy(column.path.parent)
    mult_abs_max_inv(res_column, column)
    return res_column
