from __future__ import annotations

import json
import shutil
from numbers import Number
from pathlib import Path
from typing import List, Optional, Tuple

import numpy as np
from pydantic import BaseModel

from hedal.block import Block
from hedal.context import Context
from hedal.core import config
from hedal.core.config import PackingType
from hedal.core.sequence import ObjectList
from hedal.matrix.vector import HedalVector


class HedalMatrix(ObjectList):
    def __init__(self, context: Context, **kwargs):
        super(HedalMatrix, self).__init__(context, **kwargs)
        self.shape: Tuple[int, int] = kwargs["shape"] if "shape" in kwargs.keys() else (0, 0)
        self.type: PackingType = PackingType.MATRIX

    def __getitem__(self, idx: int) -> HedalVector:
        return super(HedalMatrix, self).__getitem__(idx)

    @property
    def num_rows(self) -> int:
        return self.shape[0]

    @property
    def num_cols(self) -> int:
        return self.shape[1]

    @num_cols.setter
    def num_cols(self, value: int):
        self.shape = (self.num_rows, value)
        for vec in self:
            vec.num_cols = value

    def metadata(self):
        return self.Metadata(
            num_rows=self.num_rows,
            num_cols=self.num_cols,
            encrypted=self.encrypted,
            type=self.type,
            description=self.description,
        )

    class Metadata(BaseModel):
        num_rows: int
        num_cols: int
        encrypted: bool
        type: PackingType
        description: str

    @staticmethod
    def mask(context: Context, shape: Tuple[int, int], index: int, axis: int, encrypted: bool = False) -> HedalMatrix:
        array = np.zeros(shape=shape)
        if axis == 0:
            array[index, :] = 1
        elif axis == 1:
            array[:, index] = 1
        else:
            raise ValueError(f"Axis must be 0 or 1, not {axis}")
        matrix = HedalMatrix.from_ndarray(context, array)
        if encrypted:
            matrix.encrypt()
        return matrix

    def load(self):
        with open(self.metadata_path, "r") as m_file:
            m_info = json.load(m_file)
        self.shape = (m_info["num_rows"], m_info["num_cols"])
        self.encrypted = m_info["encrypted"]
        self.type = PackingType(m_info["type"])
        self.description = m_info["description"]
        self.objects = [
            HedalVector.from_path(
                self.context, self.path / str(idx), (self.block_shape[0], self.num_cols), self.encrypted
            )
            for idx in range(len(self))
        ]

    def copy_memory(self) -> HedalMatrix:
        """Copy the matrix to a new location on memory.

        Returns:
            HedalMatrix: Copied matrix.
        """
        new_matrix = HedalMatrix(self.context, shape=self.shape, encrypted=self.encrypted)
        for row in self:
            new_matrix.objects.append(row.copy_memory())
        return new_matrix

    def copy(self, dst_parent_path: Optional[Path] = None, **kwargs) -> HedalMatrix:
        if dst_parent_path:
            recursive = kwargs["recursive"] if "recursive" in kwargs.keys() else 0
            dst_path = Path(dst_parent_path) / f"{self.name}({recursive})"
            if dst_path.exists():
                return self.copy(dst_parent_path, recursive=recursive + 1)
        else:
            dst_path = config.temp_path()

        if self.path:
            shutil.copytree(self.path, dst_path)
        return HedalMatrix.from_path(self.context, dst_path)

    def to_ndarray(self, complex: bool = False) -> np.ndarray:
        if self.encrypted:
            raise Exception("Do after decrypt")

        matrix = np.array([]).reshape(0, self.num_cols)
        for array in self:
            matrix = np.concatenate((matrix, array.to_ndarray(complex=complex)[:, : self.num_cols]), axis=0)
        return matrix[: self.num_rows]

    @staticmethod
    def from_path(context: Context, path: Path, encrypted: bool = False) -> HedalMatrix:
        matrix = HedalMatrix(context, path=path, encrypted=encrypted)
        matrix.load()
        return matrix

    @staticmethod
    def from_ndarray(context: Context, array: np.ndarray, path: Optional[Path] = None) -> HedalMatrix:
        if path:
            path = Path(path)
            if path.exists() and next(path.iterdir(), None) is not None:
                raise Exception("Already exists:", path)

        matrix = HedalMatrix(context, path=path, shape=array.shape)
        matrix.objects = HedalMatrix.ndarray_to_vectors(array, matrix)
        return matrix

    @staticmethod
    def ndarray_to_vectors(array: np.ndarray, matrix: HedalMatrix) -> List[HedalVector]:
        vector_list = []
        unit_size = matrix.block_shape[0]
        for vector_idx in range(len(matrix)):
            start = vector_idx * unit_size
            end = (vector_idx + 1) * unit_size
            vector = HedalVector.from_ndarray(matrix.context, array[start:end, :], path=matrix.path / str(vector_idx))
            vector_list.append(vector)
        return vector_list

    def __len__(self) -> int:
        return int(np.ceil(self.num_rows / self.block_shape[0]))

    @staticmethod
    def zeros(context: Context, **kwargs) -> HedalMatrix:
        array = np.zeros(kwargs.get("shape"))
        matrix = HedalMatrix.from_ndarray(context, array, path=kwargs.get("path"))
        if kwargs.get("encrypted"):
            matrix.encrypt()
        return matrix

    @staticmethod
    def ones(context: Context, **kwargs) -> HedalMatrix:
        array = np.ones(kwargs.get("shape"))
        matrix = HedalMatrix.from_ndarray(context, array, path=kwargs.get("path"))
        if kwargs.get("encrypted"):
            matrix.encrypt()
        return matrix

    def __neg__(self) -> HedalMatrix:
        res_matrix = HedalMatrix(self.context, shape=self.shape, encrypted=self.encrypted)

        for vec in self:
            vec = -vec

        return res_matrix

    def __add__(self, other) -> HedalMatrix:
        if isinstance(other, HedalMatrix):
            res_matrix = HedalMatrix(self.context, shape=self.shape, encrypted=self.encrypted or other.encrypted)

            for idx, (self_vec, other_vec) in enumerate(zip(self, other)):
                vec = self_vec + other_vec
                vec.path = res_matrix.path / str(idx)
                res_matrix.objects.append(vec)

        elif isinstance(other, (HedalVector, Block, Number)):
            res_matrix = HedalMatrix(self.context, shape=self.shape, encrypted=self.encrypted)

            for idx, self_vec in enumerate(self):
                vec = self_vec + other
                vec.path = res_matrix.path / str(idx)
                res_matrix.objects.append(vec)

        else:
            raise TypeError("Unsupported type", type(other))

        return res_matrix

    def __iadd__(self, other) -> HedalMatrix:
        if isinstance(other, HedalMatrix):
            for self_vec, other_vec in zip(self, other):
                self_vec += other_vec

        elif isinstance(other, (HedalVector, Block, Number)):
            for self_vec in self:
                self_vec += other

        else:
            raise TypeError("Unsupported type", type(other))

        if isinstance(other, (HedalMatrix, HedalVector, Block)):
            self.encrypted = self.encrypted or other.encrypted

        return self

    def __radd__(self, other) -> HedalMatrix:
        return self + other

    def __sub__(self, other) -> HedalMatrix:
        if isinstance(other, HedalMatrix):
            res_matrix = HedalMatrix(self.context, shape=self.shape, encrypted=self.encrypted or other.encrypted)

            for idx, (self_vec, other_vec) in enumerate(zip(self, other)):
                vec = self_vec - other_vec
                vec.path = res_matrix.path / str(idx)
                res_matrix.objects.append(vec)

        elif isinstance(other, (HedalVector, Block, Number)):
            res_matrix = HedalMatrix(self.context, shape=self.shape, encrypted=self.encrypted)

            for idx, self_vec in enumerate(self):
                vec = self_vec - other
                res_matrix.objects.append(vec)

        else:
            raise TypeError("Unsupported type", type(other))

        return res_matrix

    def __isub__(self, other) -> HedalMatrix:
        if isinstance(other, HedalMatrix):
            for self_vec, other_vec in zip(self, other):
                self_vec -= other_vec

        elif isinstance(other, (HedalVector, Block, Number)):
            for self_vec in self:
                self_vec -= other

        else:
            raise TypeError("Unsupported type", type(other))

        if isinstance(other, (HedalMatrix, HedalVector, Block)):
            self.encrypted = self.encrypted or other.encrypted

        return self

    def __rsub__(self, other) -> HedalMatrix:
        return -self + other

    def __mul__(self, other) -> HedalMatrix:
        if isinstance(other, HedalMatrix):
            res_matrix = HedalMatrix(self.context, shape=self.shape, encrypted=self.encrypted or other.encrypted)

            if self is other:
                for idx, vec in enumerate(self):
                    tmp_vec = vec * vec
                    tmp_vec.path = res_matrix.path / str(idx)
                    res_matrix.objects.append(tmp_vec)
            else:
                for idx, (self_vec, other_vec) in enumerate(zip(self, other)):
                    vec = self_vec * other_vec
                    vec.path = res_matrix.path / str(idx)
                    res_matrix.objects.append(vec)

        elif isinstance(other, (HedalVector, Block, Number)):
            res_matrix = HedalMatrix(self.context, shape=self.shape, encrypted=self.encrypted)

            for idx, self_vec in enumerate(self):
                vec = self_vec * other
                vec.path = res_matrix.path / str(idx)
                res_matrix.objects.append(vec)

        else:
            raise TypeError("Unsupported type", type(other))

        return res_matrix

    def __imul__(self, other) -> HedalMatrix:
        return self * other

    def __rmul__(self, other) -> HedalMatrix:
        return self * other
