from __future__ import annotations

import os
import shutil
import warnings
from numbers import Number
from pathlib import Path
from typing import Iterable, Optional, Tuple, Union

import numpy as np
import pandas as pd

from hedal.context import Context
from hedal.core.config import PackingType, temp_path


class Block:
    def __init__(
        self,
        context: Context,
        path: Optional[str] = None,
        encrypted: bool = False,
        data: Optional[Iterable[Number]] = None,
        type: PackingType = PackingType.FRAME,
    ):
        """

        Args:
            context (Context): Context of the block.
            path (Optional[str], Optional): Path to save the block. Defaults to None.
            encrypted (bool, Optional): Whether the block is encrypted or not.
            data (Optional[Iterable[Number]], Optional): Data to initialize block. Defaults to None.
            type (PackingType, Optional): Type of the block. Defaults to PackingType.FRAME.

        Raises:
            Exception: When path and data are both given.
        """
        self.context = context
        if path is None:
            self._path = str(temp_path() / "block_0.bin")
        else:
            self._path: str = str(path)

        if encrypted:
            self.data = context.heaan.Ciphertext(context.context)
        else:
            self.data = context.heaan.Message(context.log_slots)

        if not isinstance(type, PackingType):
            raise Exception("type must be a PackingType")
        self.type = type

        if data is not None:
            self.set_data(data)

    @property
    def encrypted(self) -> bool:
        return isinstance(self.data, self.context.heaan.Ciphertext)

    @property
    def level(self) -> int:
        return self.data.level

    @property
    def num_slots(self) -> int:
        return self.context.num_slots

    @property
    def log_slots(self) -> int:
        if self.encrypted:
            return self.data.log_slots
        else:
            return self.context.log_slots

    @log_slots.setter
    def log_slots(self, num: int) -> None:
        if self.encrypted:
            self.data.log_slots = num

    @property
    def shape(self) -> Tuple[int, int]:
        if self.type == PackingType.FRAME:
            return (self.context.num_slots,)
        elif self.type == PackingType.MATRIX:
            return self.context.shape
        else:
            raise TypeError("Invalid type")

    @property
    def path(self) -> str:
        return str(self._path)

    @path.setter
    def path(self, new_path: str) -> None:
        self._path = new_path[:]

    @property
    def encryptor(self):
        return self.context.encryptor

    @property
    def decryptor(self):
        return self.context.decryptor

    @property
    def homevaluator(self):
        return self.context.homevaluator

    def __getitem__(self, idx: int) -> Number:
        if self.encrypted:
            raise TypeError("cannot access data inside ciphertext")
        return self.data[idx]

    def __setitem__(self, idx: int, num: float) -> None:
        if self.encrypted:
            raise TypeError("cannot access data inside ciphertext")
        self.data[idx] = num

    def set_data(self, data: Iterable[Number] = []) -> None:
        if len(data) > self.context.num_slots:
            raise Warning("Some data would be lost")
        for idx, num in enumerate(data):
            self.data[idx] = num

    def level_down(self, target_level: int, inplace: bool = True) -> Block:
        if not self.encrypted:
            raise TypeError("Cannot level down messages.")
        if target_level > self.level:
            raise ValueError("Target level is higher than current level.")
        if inplace:
            self.homevaluator.level_down(self.data, target_level, self.data)
        else:
            new_block = self.copy()
            new_block.level_down(target_level, inplace=True)
            return new_block

    def save(self, dst_path: Optional[Union[str, Path]] = None) -> None:
        """Save a single block.

        Args:
            dst_path (Optional[Union[str, Path]], optional): Path to save the block.
                If it is None, then `self.path` will be used. Defaults to None.

        Raises:
            Exception: When path is not given.
        """
        if dst_path:
            if isinstance(dst_path, Path):
                dst_path = str(dst_path)
            self.data.save(dst_path)
        elif self.path:
            self.data.save(self.path)
        else:
            raise TypeError("Uninitialized path")

    def load(self) -> None:
        if not self.path:
            raise TypeError("Uninitialized path")
        self.data.load(self.path)

    def copy(self, path: Optional[str] = None) -> Block:
        block = Block(self.context, encrypted=self.encrypted, path=path)
        block.data = (
            self.context.heaan.Ciphertext(self.data) if self.encrypted else self.context.heaan.Message(self.data)
        )
        block.type = self.type
        if path:
            block.save()
        return block

    def move(self, path: str) -> Block:
        if not self.path:
            raise TypeError("Invalid function")
        shutil.move(self.path, path)
        self.path = path
        return self

    def remove(self) -> None:
        if self.path:
            os.remove(self.path)

    def encrypt(self, inplace: bool = True) -> Optional[Block]:
        if self.encrypted:
            raise TypeError("already encrypted")
        msg = self.data
        public_key = self.context.public_key

        if inplace:
            self.data = self.context.heaan.Ciphertext(self.context.context)
            self.encryptor.encrypt(msg, public_key, self.data)
            return None
        else:
            res = Block(self.context, encrypted=True, type=self.type)
            self.encryptor.encrypt(msg, public_key, res.data)
            return res

    def decrypt(self, inplace: bool = True) -> Optional[Block]:
        if not self.encrypted:
            raise TypeError("already decrypted")
        ctxt = self.data
        secret_key = self.context.secret_key

        if inplace:
            self.data = self.context.heaan.Message()
            self.decryptor.decrypt(ctxt, secret_key, self.data)
            return None
        else:
            res = Block(self.context, encrypted=False, type=self.type)
            self.decryptor.decrypt(ctxt, secret_key, res.data)
            return res

    @staticmethod
    def from_ndarray(context: Context, array: np.ndarray, type: str = "matrix", path: str = "") -> Block:
        """Create a block from a numpy array.

        Args:
            context (Context): Context of the block.
            array (np.ndarray): Array to create the block from.
            type (str, optional): Packing type of block. Defaults to "matrix".
            path (str, optional): Path to the block. Defaults to "".

        Raises:
            TypeError: When type is not a valid packing type.

        Returns:
            Block: Block created from the array.
        """
        if not type:
            type = PackingType.FRAME
        if isinstance(type, str):
            if type == "matrix":
                type = PackingType.MATRIX
            elif type == "frame":
                warnings.warn("Block & numpy array conversion is for HedalMatrices and HedalVectors.")
                type = PackingType.FRAME
            else:
                raise TypeError("Invalid type")

        block = Block(context, path, type=type, data=array.flatten())
        return block

    def to_ndarray(self, complex: bool = False) -> np.ndarray:
        """Convert a block to a 2D numpy array of shape `self.context.shape`.

        Args:
            complex (bool, optional): Result will be complex array. Defaults to False.

        Raises:
            Exception: When the block is encrypted.

        Returns:
            np.ndarray: 2D numpy array.
        """
        if self.encrypted:
            raise Exception("Do after decrypt")

        array = np.array(self.data) if complex else np.array(self.data).real
        return array.reshape(self.shape)

    @staticmethod
    def from_series(context: Context, series: pd.Series, type: str = "frame", path: str = "") -> Block:
        """Create a block from a pandas Series.

        Args:
            context (Context): Context of the block.
            series (pd.Series): Series to create the block from.
            type (str, optional): Packing type of the block. Defaults to "frame".
            path (str, optional): Path to the block. Defaults to "".

        Raises:
            TypeError: When type is not a valid packing type.

        Returns:
            Block: Block created from the series.
        """
        if not type:
            type = PackingType.FRAME
        if isinstance(type, str):
            if type == "matrix":
                warnings.warn("Block & pandas.Series conversion is for HedalFrames and Columns.")
                type = PackingType.MATRIX
            elif type == "frame":
                type = PackingType.FRAME
            else:
                raise TypeError("Invalid type")

        block = Block(context, path, type=type, data=series.values)
        return block

    def to_series(self, complex: bool = False) -> pd.Series:
        """Convert a block to a pandas Series.

        Args:
            complex (bool, optional): Result will be complex series. Defaults to False.

        Raises:
            Exception: When the block is encrypted.

        Returns:
            pd.Series: Series created from the block.
        """
        if self.encrypted:
            raise Exception("Do after decrypt")

        array = np.array(self.data) if complex else np.array(self.data).real
        return pd.Series(array)

    def __hash__(self) -> int:
        return hash(self.path)

    def __enter__(self) -> Block:
        return self

    def __exit__(self, exc_type, exc_value, exc_traceback) -> bool:
        return True

    @staticmethod
    def zeros(context: Context, encrypted: bool = False, type: PackingType = PackingType.FRAME) -> Block:
        block = Block(context, data=np.zeros(context.num_slots), type=type)
        if encrypted:
            block.encrypt()
        return block

    @staticmethod
    def ones(context: Context, encrypted: bool = False, type: PackingType = PackingType.FRAME) -> Block:
        block = Block(context, data=np.ones(context.num_slots), type=type)
        if encrypted:
            block.encrypt()
        return block

    @staticmethod
    def identity(context: Context, encrypted: bool = False) -> Block:
        size = min(context.shape[0], context.shape[1])
        array = np.zeros(context.shape)
        array[:size, :size] = np.identity(size)
        block = Block(context, data=array.reshape(-1), type=PackingType.MATRIX)
        if encrypted:
            block.encrypt()
        return block

    @staticmethod
    def mask(context: Context, index: int, axis: int, encrypted: bool = False, type: Optional[str] = None,) -> Block:
        array = np.zeros(context.shape)
        if axis == 0:
            array[index] = 1
        elif axis == 1:
            array[:, index] = 1
        else:
            raise TypeError
        block = Block.from_ndarray(context, array.flatten(), type=type)
        if block.type != PackingType.MATRIX:
            raise TypeError("Support only for MATRIX type")
        if encrypted:
            block.encrypt()
        return block

    def need_bootstrap(self, cost_per_iter: int = 2) -> bool:
        if self.encrypted:
            if self.context.context.is_bootstrappable_parameter:
                return self.level - cost_per_iter < self.context.min_level_for_bootstrap
            else:
                return False
        else:
            return False

    def __add__(self, other: Union[Block, Number]) -> Block:
        if isinstance(other, Block):
            if self.type != other.type:
                raise TypeError("Type does not match")
            res = Block(self.context, encrypted=self.encrypted or other.encrypted, type=self.type)

            if res.encrypted and not self.encrypted:
                self.homevaluator.add(other.data, self.data, res.data)
            else:
                self.homevaluator.add(self.data, other.data, res.data)

        elif isinstance(other, Number):
            res = Block(self.context, encrypted=self.encrypted, type=self.type)
            self.homevaluator.add(self.data, other, res.data)

        else:
            raise TypeError("Unsupported type:", type(other))
        return res

    def __iadd__(self, other: Union[Block, Number]) -> Block:
        if isinstance(other, Block):
            if self.type != other.type:
                raise TypeError("Type does not match")
            res_encrypted = self.encrypted or other.encrypted
            if res_encrypted:
                res = self.context.heaan.Ciphertext(self.context.context)
            else:
                res = self.context.heaan.Message(self.context.log_slots)

            if not self.encrypted:
                self.homevaluator.add(other.data, self.data, res)
            else:
                self.homevaluator.add(self.data, other.data, res)

            self.data = res

        elif isinstance(other, Number):
            self.homevaluator.add(self.data, other, self.data)

        else:
            raise TypeError("Unsupported type:", type(other))
        return self

    def __radd__(self, other: Union[Block, Number]) -> Block:
        return self + other

    def __sub__(self, other: Union[Block, Number]) -> Block:
        if isinstance(other, Block):
            if self.type != other.type:
                raise TypeError("Type does not match")
            res = Block(self.context, encrypted=self.encrypted or other.encrypted, type=self.type)

            if res.encrypted and not self.encrypted:
                self.homevaluator.sub(other.data, self.data, res.data)
                self.homevaluator.negate(res.data, res.data)
            else:
                self.homevaluator.sub(self.data, other.data, res.data)

        elif isinstance(other, Number):
            res = Block(self.context, encrypted=self.encrypted, type=self.type)
            self.homevaluator.sub(self.data, other, res.data)

        else:
            raise TypeError("Unsupported type:", type(other))
        return res

    def __isub__(self, other: Union[Block, Number]) -> Block:
        if isinstance(other, Block):
            if self.type != other.type:
                raise TypeError("Type does not match")
            res_encrypted = self.encrypted or other.encrypted
            if res_encrypted:
                res = self.context.heaan.Ciphertext(self.context.context)
            else:
                res = self.context.heaan.Message(self.context.log_slots)

            if not self.encrypted:
                self.homevaluator.sub(other.data, self.data, res)
                self.homevaluator.negate(res, res)
            else:
                self.homevaluator.sub(self.data, other.data, res)

            self.data = res

        elif isinstance(other, Number):
            self.homevaluator.sub(self.data, other, self.data)
        else:
            raise TypeError("Unsupported type:", type(other))
        return self

    def __rsub__(self, other: Union[Block, Number]) -> Block:
        return -self + other

    def __mul__(self, other: Union[Block, Number]) -> Block:
        if isinstance(other, Block):
            if self.type != other.type:
                raise TypeError("Type does not match")
            res = Block(self.context, encrypted=self.encrypted or other.encrypted, type=self.type)

            self.to_device()
            other.to_device()

            if res.encrypted and not self.encrypted:
                self.homevaluator.mult(other.data, self.data, res.data)
            else:
                self.homevaluator.mult(self.data, other.data, res.data)

            self.to_host()
            other.to_host()
            res.to_host()

        elif isinstance(other, Number):
            res = Block(self.context, encrypted=self.encrypted, type=self.type)

            self.to_device()
            self.homevaluator.mult(self.data, other, res.data)
            self.to_host()
            res.to_host()

        else:
            raise TypeError("Unsupported type:", type(other))
        return res

    def __imul__(self, other: Union[Block, Number]) -> Block:
        if isinstance(other, Block):
            if self.type != other.type:
                raise TypeError("Type does not match")
            res_encrypted = self.encrypted or other.encrypted
            if res_encrypted:
                res = self.context.heaan.Ciphertext(self.context.context)
            else:
                res = self.context.heaan.Message(self.context.log_slots)

            self.to_device()
            other.to_device()

            if not self.encrypted:
                self.homevaluator.mult(other.data, self.data, res)
            else:
                self.homevaluator.mult(self.data, other.data, res)

            self.to_host()
            other.to_host()
            if res_encrypted:
                res.to_host()
            self.data = res

        elif isinstance(other, Number):
            self.to_device()
            self.homevaluator.mult(self.data, other, self.data)
            self.to_host()
        else:
            raise TypeError("Unsupported type:", type(other))
        return self

    def __rmul__(self, other: Union[Block, Number]) -> Block:
        return self * other

    def __lshift__(self, rot_idx: int) -> Block:
        result = Block(self.context, encrypted=self.encrypted, type=self.type)

        self.to_device()
        self.homevaluator.left_rotate(self.data, rot_idx, result.data)
        self.to_host()
        result.to_host()

        return result

    def __ilshift__(self, rot_idx: int) -> Block:
        if self.encrypted:
            res = self.context.heaan.Ciphertext(self.context.context)
            if self.context.with_gpu:
                res.to_device()
                self.to_device()
        else:
            res = self.context.heaan.Message(self.context.log_slots)

        self.homevaluator.left_rotate(self.data, rot_idx, res)
        if self.encrypted and self.context.with_gpu:
            self.to_host()
            res.to_host()
        self.data = res
        return self

    def __rshift__(self, rot_idx: int) -> Block:
        result = Block(self.context, encrypted=self.encrypted, type=self.type)

        self.to_device()
        self.homevaluator.right_rotate(self.data, rot_idx, result.data)
        self.to_host()
        result.to_host()
        return result

    def __irshift__(self, rot_idx: int) -> Block:
        if self.encrypted:
            res = self.context.heaan.Ciphertext(self.context.context)
            if self.context.with_gpu:
                res.to_device()
                self.to_device()
        else:
            res = self.context.heaan.Message(self.context.log_slots)

        self.homevaluator.right_rotate(self.data, rot_idx, res)
        if self.encrypted and self.context.with_gpu:
            self.to_host()
            res.to_host()
        self.data = res
        return self

    def __neg__(self) -> Block:
        result = Block(self.context, encrypted=self.encrypted, type=self.type)

        self.to_device()
        self.homevaluator.negate(self.data, result.data)
        self.to_host()
        result.to_host()
        return result

    def i_mult(self) -> Block:
        """Multiply \sqrt{-1} to a block."""
        result = Block(self.context, encrypted=self.encrypted, type=self.type)
        self.homevaluator.i_mult(self.data, result.data)
        return result

    def conjugate(self) -> Block:
        """Complex conjugation of a block."""
        result = Block(self.context, encrypted=self.encrypted, type=self.type)
        self.homevaluator.conjugate(self.data, result.data)
        return result

    def rotate_sum(self) -> Block:
        result = self.copy()
        for idx in range(self.log_slots):
            rot_idx = 1 << idx
            tmp = result << rot_idx
            result += tmp
        return result

    def bootstrap(self, one_slot: bool = False, complex: bool = False) -> None:
        if self.encrypted:
            if one_slot:
                self.log_slots = 0
            self.to_device()
            self.homevaluator.bootstrap(self.data, self.data, complex)
            self.to_host()
            if one_slot:
                self.log_slots = self.context.log_slots

    @staticmethod
    def bootstrap_two_ctxts(block1, block2):
        """Bootstrap two REAL ciphertexts at once."""
        if not (block1.encrypted and block2.encrypted):
            raise TypeError(f"Both blocks must be encrypted. block1: {block1.encrypted}, block2: {block2.encrypted}")
        block1.to_device()
        block2.to_device()
        block1.homevaluator.bootstrap(block1.data, block2.data, block1.data, block2.data)
        block1.to_host()
        block2.to_host()

    def to_device(self) -> None:
        if self.encrypted and self.context.with_gpu:
            if self.data.device_type == self.context.heaan.DeviceType.CPU:
                self.data.to_device()

    def to_host(self) -> None:
        if self.encrypted and self.context.with_gpu:
            if self.data.device_type == self.context.heaan.DeviceType.GPU:
                self.data.to_host()

    def inverse(self, one_slot: bool = False, greater_than_one: bool = True, num_iter: Optional[int] = None) -> Block:
        if one_slot:
            self.log_slots = 0
        result = Block(self.context, encrypted=self.encrypted, type=self.type)

        self.to_device()

        if num_iter is None:
            self.context.heaan.math.approx.inverse(self.homevaluator, self.data, result.data, greater_than_one)
        else:
            self.context.heaan.math.approx.inverse(
                self.homevaluator, self.data, result.data, 2 ** (-18), num_iter,
            )

        self.to_host()
        if greater_than_one and result.need_bootstrap(5):
            result.bootstrap()
        if one_slot:
            self.log_slots = self.context.log_slots
            result.log_slots = self.context.log_slots
        return result

    def sqrt(self, one_slot: bool = False) -> Block:
        if one_slot:
            self.log_slots = 0
        result = Block(self.context, encrypted=self.encrypted, type=self.type)

        self.to_device()
        self.context.heaan.math.approx.sqrt(self.homevaluator, self.data, result.data)
        self.to_host()
        result.to_host()

        if one_slot:
            self.log_slots = self.context.log_slots
            result.log_slots = self.context.log_slots
        return result

    def sqrt_inv(self, one_slot: bool = False, greater_than_one: bool = False) -> Block:
        if one_slot:
            self.log_slots = 0
        result = Block(self.context, encrypted=self.encrypted, type=self.type)

        if self.need_bootstrap():
            self.bootstrap()

        self.to_device()
        self.context.heaan.math.approx.sqrt_inverse(self.homevaluator, self.data, result.data, greater_than_one)
        self.to_host()
        result.to_host()

        if result.need_bootstrap():
            result.bootstrap()

        if one_slot:
            self.log_slots = self.context.log_slots
            result.log_slots = self.context.log_slots
        return result

    def compare(self, other: Block) -> Block:
        result = Block(self.context, encrypted=self.encrypted, type=self.type)

        self.to_device()
        other.to_device()
        self.context.heaan.math.approx.compare(self.homevaluator, self.data, other.data, result.data, 6, 3)
        self.to_host()
        other.to_host()
        result.to_host()
        return result

    def sort(self, N: int, ascent_new: bool):
        result = Block(self.context, encrypted=self.encrypted, type=self.type)

        self.to_device()
        self.context.heaan.math.sort.sort(self.homevaluator, self.data, result.data, N, ascent_new, False)
        self.to_host()
        result.to_host()
