from __future__ import annotations

import json
from pathlib import Path
from typing import Dict, Optional, Tuple, Union

import numpy as np
from pydantic import BaseModel

from hedal.context import Context
from hedal.core import config
from hedal.core.config import ScaleType
from hedal.core.sequence import ObjectDict
from hedal.matrix.matrix import HedalMatrix
from hedal.matrix.vector import HedalVector


class Scaler(ObjectDict):
    """
    Scaler class, which normalizes the data.

    Attributes:
        scale_size (int): The size of the scaler.
        scale_type (ScaleType): The type of the scaler.
        objects (Dict[str, HedalVector]): Scaling constants (sub, div).
    """

    def __init__(self, context: Context, scale_size: int = 0, scale_type: ScaleType = ScaleType.NORMAL, **kwargs):
        super(Scaler, self).__init__(context, **kwargs)
        self.scale_size = scale_size
        self.scale_type = scale_type
        self.objects: Dict[str, HedalVector] = {}

    class Metadata(BaseModel):
        size: int
        encrypted: bool
        scale_type: ScaleType

    def metadata(self):
        return self.Metadata(size=self.scale_size, encrypted=self.encrypted, scale_type=self.scale_type)

    def load(self):
        with open(self.metadata_path, "r") as m_file:
            m_info = json.load(m_file)
        self.scale_size = m_info["size"]
        self.encrypted = m_info["encrypted"]
        self.scale_type = ScaleType(m_info["scale_type"])
        if self.scale_type != ScaleType.NORMAL:
            self.objects["sub"] = HedalVector.from_path(
                self.context, self.path / "sub", shape=(1, self.scale_size), encrypted=self.encrypted
            )
            self.objects["div"] = HedalVector.from_path(
                self.context, self.path / "div", shape=(1, self.scale_size), encrypted=self.encrypted
            )

    def save(self, dst_path: Optional[Union[str, Path]] = None) -> None:
        """Save Scaler.

        Directory structure:
            dst_path
            - metadata.json
            - sub (HedalVector)
                - block_0.bin
            - div (HedalVector)
                - block_0.bin

        Args:
            dst_path (Optional[Union[str, Path]], optional): Path to save the scaler. Defaults to None.
        """
        if dst_path is None:
            dst_path = self.path
        if isinstance(dst_path, str):
            dst_path = Path(dst_path)
        metadata_path = dst_path / config._metadata_file_name

        if not dst_path.exists():
            dst_path.mkdir(mode=0o775, exist_ok=True)

        for key, value in self.objects.items():
            value.save(dst_path / key)

        with open(metadata_path, "w") as m_file:
            m_file.write(self.metadata().json())

    @staticmethod
    def from_path(context: Context, path: Path) -> Scaler:
        scaler = Scaler(context, path=path)
        scaler.load()
        return scaler

    def copy(self, dst_parent_path: Path) -> Scaler:
        dst_path = Path(dst_parent_path) / self.name
        if dst_path.exists():
            raise OSError("Already exists:", dst_path)
        scaler = Scaler(
            self.context,
            path=dst_path,
            encrypted=self.encrypted,
            scale_type=self.scale_type,
            scale_size=self.scale_size,
        )
        if self.scale_type != ScaleType.NORMAL:
            scaler.objects = {"sub": self["sub"], "div": self["div"]}
        return scaler

    @staticmethod
    def fit(context: Context, path: Path, scale_type: ScaleType, X: np.ndarray) -> Tuple[np.ndarray, Scaler]:
        scale_size = X.shape[1]
        scaler = Scaler(context, path=path, scale_size=scale_size, scale_type=scale_type)
        if scale_type == ScaleType.NORMAL:
            return X, scaler
        else:
            if scale_type == ScaleType.STNDRD:
                sub_coeff = np.mean(X, axis=0)
                div_coeff = np.std(X, axis=0) * 3
            elif scale_type == ScaleType.MINMAX:
                sub_coeff = np.min(X, axis=0)
                div_coeff = np.max(X, axis=0) - np.min(X, axis=0)

            div_coeff = np.array([1 if x == 0 else 1 / x for x in div_coeff])
            scaler.objects["sub"] = HedalVector.from_ndarray(context, sub_coeff, path / "sub")
            scaler.objects["div"] = HedalVector.from_ndarray(context, div_coeff, path / "div")

            return (X - sub_coeff) * div_coeff, scaler

    def scale(self, X: HedalMatrix) -> HedalMatrix:
        X_scale = HedalMatrix(
            self.context, shape=(X.num_rows, self.scale_size), encrypted=self.encrypted or X.encrypted
        )
        X_scale.objects = X.objects

        if self.scale_type == ScaleType.NORMAL:
            return X_scale
        else:
            return (X_scale - self["sub"]) * self["div"]
