from __future__ import annotations

import json
import math
import os
import shutil
from abc import ABC, abstractclassmethod
from pathlib import Path
from typing import Dict, ItemsView, KeysView, List, Optional, Tuple, Union, ValuesView

import numpy as np
from pydantic import BaseModel

from hedal.context import Context
from hedal.core import config
from hedal.core.config import PackingType, ScaleType
from hedal.matrix.matrix import HedalMatrix
from hedal.ml.scaler import Scaler


class Dataset(ABC):
    def __init__(
        self,
        context: Context,
        path: Union[str, Path],
        encrypted: bool = False,
        shape: Tuple[int, int] = (0, 0),
        description: str = "",
    ):
        """ 

        Args:
            context (hedal.Context)
            path (Optional[Path], optional): path of the dataset.
            shape (Tuple[int,int]): shape of the dataset.
            encrypted (bool): status of the dataset.
            description (str): description of the dataset.

        Notes:
            For Dataset.shape = (num_data, num_features),
            Dataset[class_idx].shape would be (num_data, size=num_features+1)
            because of the space of bias_col.
        """
        self.context = context
        self.path = Path(path)
        self.encrypted = encrypted
        self.type = PackingType.MATRIX
        self.shape = shape
        self.scaler: Scaler = None
        self.description = description
        self.objects: Dict[str, HedalMatrix] = {}
        self.classes = []

        if encrypted:
            self.encrypt()

    def keys(self) -> KeysView[str]:
        return self.objects.keys()

    def values(self) -> ValuesView[HedalMatrix]:
        return self.objects.values()

    def items(self) -> ItemsView[str, HedalMatrix]:
        return self.objects.items()

    @property
    def block_shape(self) -> Tuple[int, int]:
        return self.context.shape

    @property
    def num_data(self) -> int:
        return self.shape[0]

    @property
    def num_features(self) -> int:
        return self.shape[1]

    @property
    def num_classes(self) -> int:
        return len(self.classes)

    @property
    def size(self) -> int:
        return self.num_features + 1

    @property
    def scale_type(self) -> ScaleType:
        return self.scaler.scale_type

    @property
    def scaler_path(self) -> Path:
        return config.scaler_path(self)

    @property
    def parent_path(self) -> Path:
        return config.parent_path(self)

    @property
    def name(self) -> str:
        return config.name(self)

    @property
    def metadata_path(self) -> Path:
        return config.metadata_path(self)

    class Metadata(BaseModel):
        encrypted: bool
        classes: List[str]
        num_data: int
        num_features: int
        description: str

    def metadata(self):
        return self.Metadata(
            encrypted=self.encrypted,
            description=self.description,
            num_data=self.num_data,
            num_features=self.num_features,
            classes=self.classes,
        )

    def save(self, dst_path: Optional[Union[str, Path]] = None, target_level: Optional[int] = None) -> None:
        """Save Dataset to the destination path.

        Directory structure:

        for ovr:
            dst_path
            - metadata.json
            - class_0 (HedalMatrix)
                - ...
            - class_1 (HedalMatrix)
                - ...
            - ...
            - scale (Scaler)
                - ...

        for multinomial:
            dst_path
            - metadata.json
            - input (HedalMatrix)
                - ...
            - target (HedalMatrix)
                - ...
            - scale (Scaler)
                - ...

        Args:
            dst_path (Optional[Union[str, Path]], optional): Destination path. Defaults to None.
            target_level (Optional[int], optional): Target level. If it is not None, then the dataset matrices' level
                will be decreased to the given level before save. Defaults to None.
        """
        if dst_path is None:
            dst_path = self.path
        metadata_path = dst_path / config._metadata_file_name

        if not dst_path.exists():
            os.makedirs(dst_path, mode=0o775, exist_ok=True)

        with open(metadata_path, "w") as m_file:
            m_file.write(self.metadata().json())

        for key in self.objects.keys():
            if target_level is not None:
                self.objects[key].level_down(target_level)
            self.objects[key].save(dst_path / key)

        if target_level is not None:
            self.scaler.level_down(target_level)
        self.scaler.save(dst_path / "scale")

    def load(self):
        with open(self.metadata_path, "r") as m_file:
            m_info = json.load(m_file)
        self.encrypted = m_info["encrypted"]
        self.description = m_info["description"]
        self.shape = (m_info["num_data"], m_info["num_features"])
        self.scaler = Scaler.from_path(self.context, self.scaler_path)
        self.classes = m_info["classes"]

    def bootstrap(self):
        for key in self.keys():
            self[key].bootstrap()
        for key in self.scaler.keys():
            self.scaler[key].bootstrap()

    @staticmethod
    def from_path(context: Context, path: Path) -> Dataset:
        dataset = Dataset(context, path=path)
        dataset.load()
        return dataset

    def copy(self, dst_parent_path: Optional[Path] = None, **kwargs) -> Dataset:
        if dst_parent_path:
            recursive = kwargs["recursive"] if "recursive" in kwargs.keys() else 0
            dst_path = Path(dst_parent_path) / f"{self.name}({recursive})"
            if dst_path.exists():
                return self.copy(dst_parent_path, recursive=recursive + 1)
        else:
            dst_path = config.temp_path()
        shutil.copytree(self.path, dst_path)
        return Dataset.from_path(self.context, dst_path)

    @staticmethod
    @abstractclassmethod
    def from_ndarray(
        context: Context,
        X: np.ndarray,
        y: np.ndarray,
        multiclass: bool,
        scale_type: ScaleType = ScaleType.NORMAL,
        path: Optional[Path] = None,
    ) -> Dataset:
        pass

    def encrypt(self) -> None:
        for key in self.keys():
            self.objects[key].encrypt()
        self.scaler.encrypt()
        self.encrypted = True

    def decrypt(self) -> None:
        for key in self.keys():
            self.objects[key].decrypt()
        self.scaler.decrypt()
        self.encrypted = False

    def __len__(self) -> int:
        return int(np.ceil(self.num_data / self.block_shape[0]))

    def __getitem__(self, key: str) -> HedalMatrix:
        return self.objects[key]

    def __setitem__(self, key: str, value: HedalMatrix) -> None:
        if key in self.objects.keys():
            self.objects.pop(key)
        self.objects[key] = value


class DatasetOVR(Dataset):
    def __init__(
        self,
        context: Context,
        path: Union[set, Path],
        encrypted: bool = False,
        shape: Tuple[int, int] = (0, 0),
        description: str = "",
    ):
        """ 

        Args:
            context (hedal.Context)
            path (Optional[Path], optional): path of the dataset.
            shape (Tuple[int,int]): shape of the dataset.
            encrypted (bool): status of the dataset.
            description (str): description of the dataset.

        Notes:
            For Dataset.shape = (num_data, num_features),
            Dataset[class_idx].shape would be (num_data, size=num_features+1)
            because of the space of bias_col.
        """
        super().__init__(
            context, path=path, shape=shape, encrypted=encrypted, description=description,
        )

    def class_path(self, key: str) -> Path:
        return self.path / key

    def load(self):
        super().load()
        for key in self.classes:
            self.objects[key] = HedalMatrix.from_path(self.context, self.class_path(key))

    @staticmethod
    def from_path(context: Context, path: Path) -> DatasetOVR:
        dataset = DatasetOVR(context, path=path)
        dataset.load()
        return dataset

    def copy(self, dst_parent_path: Optional[Path] = None, **kwargs) -> DatasetOVR:
        if dst_parent_path:
            recursive = kwargs["recursive"] if "recursive" in kwargs.keys() else 0
            dst_path = Path(dst_parent_path) / f"{self.name}({recursive})"
            if dst_path.exists():
                return self.copy(dst_parent_path, recursive=recursive + 1)
        else:
            dst_path = config.temp_path()
        shutil.copytree(self.path, dst_path)
        return DatasetOVR.from_path(self.context, dst_path)

    @staticmethod
    def from_ndarray(
        context: Context,
        X: np.ndarray,
        y: np.ndarray,
        multiclass: bool,
        scale_type: ScaleType = ScaleType.NORMAL,
        path: Optional[Path] = None,
    ) -> DatasetOVR:
        """Return a dataset from ndarray.
        
        Args:
            context (Context): Context of the dataset.
            X (np.ndarray): ndarray of features.
            y (np.ndarray): ndarray of labels.
            multiclass (bool): True if the dataset is multiclass. False if the dataset is binary.
            scale_type (ScaleType, optional): type of scale. Defaults to ScaleType.NORMAL.
            path (Optional[Path], optional): path of the dataset. Defaults to None.
        """

        if y.dtype != int:
            raise TypeError("Unsupported dtype of y", y.dtype)

        dataset = DatasetOVR(context, path=path, shape=X.shape)
        X_scale, scaler = Scaler.fit(context, dataset.scaler_path, scale_type, X)
        dataset.scaler = scaler

        if multiclass:
            classes = np.unique(y)
            for key in classes:
                y_bin = 2 * np.equal(np.array([key] * len(y)), y).astype(int) - 1
                dataset._set_matrix(X_scale, y_bin, str(key))
            dataset.classes = classes.tolist()
        else:
            if np.unique(y).tolist() != [-1, 1]:
                raise TypeError("y must be a ndarray containing -1 and 1")

            dataset._set_matrix(X_scale, y, "0")
            dataset.classes = ["0"]
        return dataset

    def _set_matrix(self, X_scale: np.ndarray, y: np.ndarray, key: str) -> None:
        bias_col = np.expand_dims(np.ones(self.num_data), axis=1)
        Z_scale = np.concatenate((X_scale, bias_col), axis=1)
        Z_scale *= np.expand_dims(y, axis=1)

        self.objects[key] = HedalMatrix.from_ndarray(self.context, Z_scale, self.class_path(key))


class DatasetMultinomial(Dataset):
    def __init__(
        self,
        context: Context,
        classes: List,
        path: Union[str, Path],
        encrypted: bool = False,
        shape: Tuple[int, int] = (0, 0),
        description: str = "",
    ):
        """

        Args:
            context (hedal.Context)
            path (Optional[Path], optional): path of the dataset.
            shape (Tuple[int,int]): shape of the dataset.
            encrypted (bool): status of the dataset.
            description (str): description of the dataset.

        Notes:
            For Dataset.shape = (num_data, num_features),
            Dataset[class_idx].shape would be (num_data, size=num_features+1)
            because of the space of bias_col.
        """
        super().__init__(
            context, path=path, shape=shape, encrypted=encrypted, description=description,
        )
        self.classes = classes
        self.objects: Dict[str, HedalMatrix] = {"input": None, "target": None}

    @property
    def input_path(self) -> Path:
        return self.path / "input"

    @property
    def target_path(self) -> Path:
        return self.path / "target"

    def load(self):
        super().load()
        self.objects["input"] = HedalMatrix.from_path(self.context, self.input_path)
        self.objects["target"] = HedalMatrix.from_path(self.context, self.target_path)

    @staticmethod
    def from_path(context: Context, classes: List, path: Path) -> DatasetMultinomial:
        dataset = DatasetMultinomial(context, classes=classes, path=path)
        dataset.load()
        return dataset

    def copy(self, dst_parent_path: Optional[Path] = None, **kwargs) -> DatasetMultinomial:
        if dst_parent_path:
            recursive = kwargs["recursive"] if "recursive" in kwargs.keys() else 0
            dst_path = Path(dst_parent_path) / f"{self.name}({recursive})"
            if dst_path.exists():
                return self.copy(dst_parent_path, recursive=recursive + 1)
        else:
            dst_path = config.temp_path()
        shutil.copytree(self.path, dst_path)
        return DatasetMultinomial.from_path(self.context, dst_path)

    @staticmethod
    def from_ndarray(
        context: Context,
        X: np.ndarray,
        y: np.ndarray,
        scale_type: ScaleType = ScaleType.NORMAL,
        target_tiled: bool = True,
        path: Optional[Path] = None,
    ) -> DatasetMultinomial:
        """Return a dataset from ndarray.
        
        Args:
            context (Context): Context of the dataset.
            X (np.ndarray): ndarray of features.
            y (np.ndarray): ndarray of labels.
            scale_type (ScaleType, optional): type of scale. Defaults to ScaleType.NORMAL.
            path (Optional[Path], optional): path of the dataset. Defaults to None.
        """
        classes = np.unique(y).tolist()
        dataset = DatasetMultinomial(context, path=path, shape=X.shape, classes=classes)
        X_scale, scaler = Scaler.fit(context, dataset.scaler_path, scale_type, X)
        dataset.scaler = scaler
        dataset._set_matrix(X_scale, y, target_tiled)
        return dataset

    def _set_matrix(self, X_scale: np.ndarray, y: np.ndarray, target_tiled: bool) -> None:
        """Make matrices for scaled input and target and set as attributes.
        Note that bias column won't be added to the scaled input matrix.
        Target array will be converted to one-hot vector. Shape will be (num_data, num_classes).

        Args:
            X_scale (np.ndarray): Scaled input array of shape (num_data, num_features).
            y (np.ndarray): Target (label) array.
        """
        self.objects["input"] = HedalMatrix.from_ndarray(self.context, X_scale, self.input_path)

        target = []
        for k in self.classes:
            target.append((y == k).astype(int))
        target = np.stack(target, axis=0).T
        padded_num_cols = 2 ** int(math.ceil(math.log2(self.num_classes)))
        if target_tiled:
            target = np.concatenate((target, np.zeros((self.num_data, padded_num_cols - self.num_classes))), axis=1)
            target = np.tile(target, (1, self.context.shape[1] // padded_num_cols))
        self.objects["target"] = HedalMatrix.from_ndarray(self.context, target, self.target_path)
        self.objects["target"].num_cols = padded_num_cols
