from __future__ import annotations

import os
import shutil
from numbers import Number
from pathlib import Path
from typing import List, Optional, Tuple, Union

import numpy as np
import pandas as pd

from hedal.block import Block
from hedal.context import Context
from hedal.core import config
from hedal.core.config import PackingType
from hedal.core.object import Object


class HedalVector(Object):
    def __init__(self, context: Context, shape: Tuple[int, int] = (0, 0), block_list: List[Block] = [], **kwargs):
        """ 
        Args:
            context (hedal.Context)
            path (Optional[Path], optional): path of the vector.
            shape (Tuple[int,int]): shape of the vector.
            block_list (List[hedal.Block], optional): List of blocks to initialize vector. Defaults to [].
            encrypted (bool): status of the vector.
        """
        super(HedalVector, self).__init__(context, **kwargs)
        self.type: PackingType = PackingType.MATRIX
        self.shape = shape

        if block_list != []:
            self.set_block_list(block_list)

    @property
    def num_rows(self):
        return self.shape[0]

    @property
    def num_cols(self):
        return self.shape[1]

    @num_cols.setter
    def num_cols(self, value: int):
        self.shape = (self.num_rows, value)

    def copy_memory(self) -> HedalVector:
        """Copy vector to a new location on memory.

        Returns:
            HedalVector: Copy of vector.
        """
        new_vector = HedalVector(self.context, shape=self.shape, encrypted=self.encrypted)
        new_vector.set_block_list([block.copy() for block in self])
        return new_vector

    def copy(self, dst_parent_path: Optional[Path] = None, **kwargs) -> HedalVector:
        if dst_parent_path:
            recursive = kwargs["recursive"] if "recursive" in kwargs.keys() else 0
            dst_path = Path(dst_parent_path) / f"{self.name}({recursive})"
            if dst_path.exists():
                return self.copy(dst_parent_path, recursive=recursive + 1)
        else:
            dst_path = config.temp_path()

        if self.path.exists():
            shutil.copytree(self.path, dst_path)
        return HedalVector.from_path(self.context, dst_path, self.shape, self.encrypted)

    def to_series(self) -> pd.Series:
        if (self.num_rows > 1) and (self.num_cols > 1):
            raise TypeError("Unsupported shape of vector", self.shape)
        return pd.Series(self.to_ndarray().squeeze())

    def save(self, dst_path: Optional[Union[str, Path]] = None) -> None:
        """Save object to the destination path.

        Directory structure:
            dst_path
            - block_0.bin
            - block_1.bin
            - ...

        Args:
            dst_path (Optional[Union[str, Path]], optional): Destination path. Defaults to None.
        """
        if dst_path is None:
            dst_path = self.path
        if isinstance(dst_path, str):
            dst_path = Path(dst_path)
        if not dst_path.exists():
            os.makedirs(dst_path, mode=0o775, exist_ok=True)

        for idx, block in enumerate(self):
            block_path = dst_path / os.path.basename(self.block_path(idx))
            block.save(block_path)

    @staticmethod
    def from_path(context: Context, path: Path, shape: Tuple[int, int], encrypted: bool) -> HedalVector:
        vector = HedalVector(context, path=path, shape=shape, encrypted=encrypted)
        for idx in range(len(vector)):
            block = Block(context, config.block_path(vector, idx), type=PackingType.MATRIX, encrypted=encrypted)
            block.load()
            vector.block_list.append(block)
        return vector

    @staticmethod
    def from_ndarray(context: Context, array: np.ndarray, path: Optional[Path] = None) -> HedalVector:
        if array.ndim > 2 or array.ndim < 1:
            raise TypeError("Unsupported dimension of array", array.ndim)

        if array.ndim == 1:
            array = np.tile(array, (context.shape[0], 1))
            shape = (1, array.shape[1])
        else:
            shape = array.shape
            if shape[0] > context.shape[0]:
                raise TypeError("Unsupported shape of array", array.shape)

        vector = HedalVector(context, path=path, shape=shape)
        vector.block_list = HedalVector.ndarray_to_blocks(array, vector)
        return vector

    @staticmethod
    def ndarray_to_blocks(array: np.ndarray, vector: HedalVector) -> List[Block]:
        block_list = []
        unit_size = vector.block_shape[1]
        if array.shape[1] % unit_size != 0:
            array = np.pad(array, ((0, 0), (0, unit_size - array.shape[1] % unit_size)), "constant")
        for block_idx in range(len(vector)):
            start = block_idx * unit_size
            end = (block_idx + 1) * unit_size
            block = Block.from_ndarray(
                vector.context, array[:, start:end], type="matrix", path=vector.block_path(block_idx)
            )
            block_list.append(block)
        return block_list

    def to_ndarray(self, complex: bool = False) -> np.ndarray:
        if self.encrypted:
            raise Exception("Do after decrypt")

        array = None
        for block in self:
            block_array = block.to_ndarray(complex=complex).reshape(self.block_shape)
            if array is None:
                array = block_array
            else:
                array = np.concatenate((array, block_array), axis=1)
        array = array[: self.num_rows, : self.num_cols]
        return array

    @staticmethod
    def mask(context: Context, shape: Tuple[int, int], index: int, axis: int, encrypted: bool = False,) -> HedalVector:
        array = np.zeros(shape=shape)
        if axis == 0:
            array[index] = 1
        elif axis == 1:
            array[:, index] = 1
        else:
            raise ValueError("Unsupported axis", axis)
        vector = HedalVector.from_ndarray(context, array)
        if encrypted:
            vector.encrypt()
        return vector

    def __len__(self) -> int:
        return int(np.ceil(self.num_cols / self.block_shape[1]))

    @staticmethod
    def zeros(
        context, shape: Union[int, Tuple[int, int]], path: Optional[Path] = None, encrypted: bool = False
    ) -> HedalVector:
        """Return a new vector of given shape, filled with zeros.
        
        Args:
            shape (Union[int, Tuple[int,int]]): shape of the vector.
            path (Optional[Path], optional): path of the vector. Defaults to None.
            encrypted (bool, optional): status of the vector. Defaults to False.

        """
        array = np.zeros(shape)
        vector = HedalVector.from_ndarray(context, array, path=path)
        if encrypted:
            vector.encrypt()
        return vector

    @staticmethod
    def ones(
        context, shape: Union[int, Tuple[int, int]], path: Optional[Path] = None, encrypted: bool = False
    ) -> HedalVector:
        """Return a new vector of given shape, filled with ones.
        
        Args:
            shape (Union[int, Tuple[int, int]]): shape of the vector.
            path (Optional[Path], optional): path of the vector. Defaults to None.
            encrypted (bool, optional): status of the vector. Defaults to False.

        """
        array = np.ones(shape)
        vector = HedalVector.from_ndarray(context, array, path=path)
        if encrypted:
            vector.encrypt()
        return vector

    def rot_up(self, rot_idx: int, in_place: bool = False) -> HedalVector:
        """Rotate columns of input vector to up.

        Args:
            rot_idx (int): Rotation index
            in_place (bool): If True, rotate in place. Defaults to False.

        Returns:
            HedalVector: Rotated vector.

        Example:
            >>> v = HedalVector.from_ndarray(context, array=np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]]))
            >>> v.rot_up(1)
            HedalVector([[4, 5, 6], [7, 8, 9], [1, 2, 3]])
        """
        rot_idx = rot_idx % self.context.shape[0]
        if in_place:
            if rot_idx == 0:
                return self
            for block in self:
                block <<= rot_idx * self.context.shape[1]
            return self
        else:
            if rot_idx == 0:
                return self.copy_memory()
            res_block_list = []
            for block in self:
                block_rot = block << (rot_idx * self.context.shape[1])
                res_block_list.append(block_rot)
            res_vector = HedalVector(
                self.context, shape=self.shape, block_list=res_block_list, type=self.type, encrypted=self.encrypted
            )
            return res_vector

    def rot_down(self, rot_idx: int, in_place: bool = False) -> HedalVector:
        """Rotate columns of input vector to down.

        Args:
            rot_idx (int): Rotation index
            in_place (bool): If True, rotate in place. Defaults to False.

        Returns:
            HedalVector: Rotated vector.

        Example:
            >>> v = HedalVector.from_ndarray(context, array=np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]]))
            >>> v.rot_down(1)
            HedalVector([[4, 5, 6], [7, 8, 9], [1, 2, 3]])
        """
        rot_idx = rot_idx % self.context.shape[0]
        if in_place:
            if rot_idx == 0:
                return self
            for block in self:
                block >>= rot_idx * self.context.shape[1]
            return self
        else:
            if rot_idx == 0:
                return self.copy_memory()
            res_block_list = []
            for block in self:
                block_rot = block >> (rot_idx * self.context.shape[1])
                res_block_list.append(block_rot)
            res_vector = HedalVector(
                self.context, shape=self.shape, block_list=res_block_list, type=self.type, encrypted=self.encrypted
            )
            return res_vector

    def rot_left(self, rot_idx: int, in_place: bool = False) -> HedalVector:
        """Rotate columns of input vector to left.

        Args:
            rot_idx (int): Rotation index
            in_place (bool): If True, rotate in place. Defaults to False.

        Returns:
            HedalVector: Rotated vector.

        Example:
            >>> v = HedalVector.from_ndarray(context, array=np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]]))
            >>> v.rot_left(1)
            HedalVector([[2, 3, 1], [6, 4, 5], [9, 7, 8]])
        """
        if self.num_cols > self.context.shape[1]:
            raise NotImplementedError(
                "Rotate left is not implemented when number of columns is larger than self.context.shape[1]"
            )

        rot_idx = rot_idx % self.context.shape[1]
        if in_place:
            raise NotImplementedError
        else:
            if rot_idx == 0:
                return self.copy_memory()
            res_block = self.block(0) << rot_idx
            mask1_arr = np.ones(shape=self.context.shape)
            mask1_arr[:, -rot_idx:] = 0
            mask2_arr = np.ones(shape=self.context.shape)
            mask2_arr[:, :-rot_idx] = 0
            mask1 = Block.from_ndarray(self.context, mask1_arr)
            mask2 = Block.from_ndarray(self.context, mask2_arr)
            res_block = (res_block * mask1) + ((res_block * mask2) >> self.context.shape[1])
            res_vector = HedalVector(
                self.context, shape=self.shape, block_list=[res_block], type=self.type, encrypted=self.encrypted
            )
            return res_vector

    def __neg__(self) -> HedalVector:
        res_vector = HedalVector(self.context, shape=self.shape, encrypted=self.encrypted)

        for idx in range(len(res_vector)):
            block = -self.block(idx)
            block.path = res_vector.block_path(idx)
            res_vector.block_list.append(block)

        return res_vector

    def __add__(self, other) -> HedalVector:
        if isinstance(other, HedalVector):
            res_vector = HedalVector(self.context, shape=self.shape, encrypted=self.encrypted or other.encrypted)

            for idx in range(len(res_vector)):
                block = self.block(idx) + other.block(idx)
                block.path = res_vector.block_path(idx)
                res_vector.block_list.append(block)

        elif isinstance(other, (Block, Number)):
            res_vector = HedalVector(self.context, shape=self.shape, encrypted=self.encrypted)

            for idx, self_block in enumerate(self):
                block = self_block + other
                block.path = res_vector.block_path(idx)
                res_vector.block_list.append(block)

        else:
            raise TypeError("Unsupported type", type(other))

        return res_vector

    def __iadd__(self, other) -> HedalVector:
        if isinstance(other, HedalVector):
            for self_block, other_block in zip(self, other):
                self_block += other_block

        elif isinstance(other, (Block, Number)):
            for self_block in self:
                self_block += other

        else:
            raise TypeError("Unsupported type", type(other))

        if isinstance(other, (HedalVector, Block)):
            self.encrypted = self.encrypted or other.encrypted

        return self

    def __radd__(self, other) -> HedalVector:
        return self + other

    def __sub__(self, other) -> HedalVector:
        if isinstance(other, HedalVector):
            res_vector = HedalVector(self.context, shape=self.shape, encrypted=self.encrypted or other.encrypted)

            for idx in range(len(res_vector)):
                block = self.block(idx) - other.block(idx)
                block.path = res_vector.block_path(idx)
                res_vector.block_list.append(block)

        elif isinstance(other, (Block, Number)):
            res_vector = HedalVector(self.context, shape=self.shape, encrypted=self.encrypted)

            for idx, self_block in enumerate(self):
                block = self_block - other
                block.path = res_vector.block_path(idx)
                res_vector.block_list.append(block)

        else:
            raise TypeError("Unsupported type", type(other))

        return res_vector

    def __isub__(self, other) -> HedalVector:
        if isinstance(other, HedalVector):
            for self_block, other_block in zip(self, other):
                self_block -= other_block

        elif isinstance(other, (Block, Number)):
            for self_block in self:
                self_block -= other

        else:
            raise TypeError("Unsupported type", type(other))

        if isinstance(other, (HedalVector, Block)):
            self.encrypted = self.encrypted or other.encrypted

        return self

    def __rsub__(self, other) -> HedalVector:
        return -self + other

    def __mul__(self, other) -> HedalVector:
        if isinstance(other, HedalVector):
            res_vector = HedalVector(self.context, shape=self.shape, encrypted=self.encrypted or other.encrypted)

            if self is other:
                for idx, block in enumerate(self):
                    tmp_block = block * block
                    tmp_block.path = res_vector.block_path(idx)
                    res_vector.block_list.append(tmp_block)
            else:
                for idx, (self_block, other_block) in enumerate(zip(self, other)):
                    block = self_block * other_block
                    block.path = res_vector.block_path(idx)
                    res_vector.block_list.append(block)

        elif isinstance(other, (Block, Number)):
            res_vector = HedalVector(self.context, shape=self.shape, encrypted=self.encrypted)

            for idx, self_block in enumerate(self):
                block = self_block * other
                block.path = res_vector.block_path(idx)
                res_vector.block_list.append(block)

        else:
            raise TypeError("Unsupported type", type(other))

        return res_vector

    def __imul__(self, other) -> HedalVector:
        return self * other

    def __rmul__(self, other) -> HedalVector:
        return self * other

    def i_mult(self) -> HedalVector:
        vec = self.copy_memory()
        for idx, block in enumerate(self):
            vec[idx] = block.i_mult()
        return vec

    def conjugate(self) -> HedalVector:
        vec = self.copy_memory()
        for idx, block in enumerate(self):
            vec[idx] = block.conjugate()
        return vec

    def __enter__(self):
        return self

    def __exit__(self, exc_type, exc_value, exc_traceback) -> bool:
        self.save()
        return True
