from __future__ import annotations

import json
import math
import os
import random
import shutil
from abc import ABC, abstractclassmethod
from pathlib import Path
from typing import Dict, ItemsView, KeysView, List, Optional, Tuple, Union, ValuesView

import numpy as np
from pydantic import BaseModel

from hedal.block import Block
from hedal.context import Context
from hedal.core import config
from hedal.core.config import PackingType, ScaleType
from hedal.core.sequence import ObjectDict
from hedal.matrix.matrix import HedalMatrix
from hedal.matrix.ops import block_ops as bop
from hedal.matrix.ops import mat_ops as mop
from hedal.matrix.ops import vec_ops as vop
from hedal.matrix.vector import HedalVector
from hedal.ml.dataset import Dataset, DatasetMultinomial, DatasetOVR
from hedal.ml.reports import ReportSet, ReportSetMultinomial, ReportSetOVR
from hedal.ml.scaler import Scaler


class WeightParam(ObjectDict):
    def __init__(self, context: Context, size: Union[int, Tuple[int, int]], encrypted: bool = False, **kwargs):
        """ 

        Args:
            context (hedal.Context): Context object.
            path (Optional[Path]): path of the weights.
            size (int): Size of the weights, which is the number of features + 1.
            encrypted (bool, optional): Whether the weights are encrypted. Defaults to False.
        """
        super(WeightParam, self).__init__(context, **kwargs)
        self.type: PackingType = PackingType.MATRIX

        self._init_size(size)
        self._init_params()

        self.encrypted = encrypted
        if encrypted:
            self.encrypt()

    @abstractclassmethod
    def _init_size(self, size: Union[int, Tuple[int, int]]) -> None:
        pass

    @abstractclassmethod
    def _init_params(self):
        pass

    @abstractclassmethod
    def _evaluate_gradient(self, Z_class: HedalMatrix, batch_list: List[int], depth: int) -> HedalVector:
        pass

    def _get_batch_size(self, input: HedalMatrix, batch_list: List[int]) -> int:
        bsz = sum(input[i].num_rows for i in batch_list)
        return bsz

    def _update(self, gradient: HedalVector, learning_rate: float, eta: float, batch_size: int) -> None:
        theta_old = self["theta"]
        theta_new = self["auxil"] - (learning_rate / batch_size) * gradient

        self["auxil"] = ((1 - eta) * theta_new) + (eta * theta_old)
        self["theta"] = theta_new

    def learn(self, Z_class: HedalMatrix, batch_list: List[int], depth: int, learning_rate: float, eta: float) -> None:
        gradient = self._evaluate_gradient(Z_class, batch_list, depth)
        bsz = self._get_batch_size(Z_class, batch_list)
        self._update(gradient, learning_rate, eta, bsz)

    @abstractclassmethod
    def predict(self, mat: HedalMatrix, path: Optional[Path] = None, last_activation: bool = False) -> HedalMatrix:
        pass

    class Metadata(BaseModel):
        encrypted: bool
        description: str
        params: List[str]
        size: Union[int, Tuple[int, int]]

    def metadata(self):
        return self.Metadata(
            encrypted=self.encrypted, description=self.description, params=list(self.keys()), size=self.size
        )

    def save(self, dst_path: Optional[Union[str, Path]] = None) -> None:
        """Save WeightParam to the destination path.

        Directory structure:
            dst_path
                - metadata.json
                - theta (HedalVector)
                    - block_0.bin
                    - block_1.bin
                    - ...
                - auxil (HedalVector)
                    - block_0.bin
                    - block_1.bin
                    - ...

        Args:
            dst_path (Optional[Union[str, Path]], optional): Destination path. Defaults to None.
        """
        if dst_path is None:
            dst_path = self.path
        metadata_path = dst_path / config._metadata_file_name
        if not dst_path.exists():
            dst_path.mkdir(mode=0o775, exist_ok=True)

        for key, value in self.items():
            value.save(dst_path / key)

        with open(metadata_path, "w") as m_file:
            m_file.write(self.metadata().json())

    def load(self):
        raise NotImplementedError

    def copy(self, dst_parent_path: Path) -> WeightParam:
        dst_path = Path(dst_parent_path) / self.name
        if dst_path.exists():
            raise OSError("Already exists:", dst_path)
        shutil.copytree(self.path, dst_path)
        return WeightParam.from_path(self.context, dst_path)

    def encrypt(self) -> None:
        super(WeightParam, self).encrypt()
        self.encrypted = True

    def decrypt(self) -> None:
        super(WeightParam, self).decrypt()
        self.encrypted = False

    @staticmethod
    def from_path(context: Context, path: Path, size: int, encrypted: bool) -> WeightParam:
        raise NotImplementedError


class WeightParamOVR(WeightParam):
    """WeightParam class for One-vs-Rest classification (use sigmoid)"""

    def __init__(self, context: Context, size: int, encrypted: bool = False, **kwargs):
        """ 

        Args:
            context (hedal.Context): Context object.
            path (Optional[Path]): path of the weights.
            size (int): Size of the weights, which is the number of features + 1.
            encrypted (bool, optional): Whether the weights are encrypted. Defaults to False.
        """
        super().__init__(context, size, encrypted, **kwargs)

    def _init_size(self, size: int):
        if size <= 1:
            raise ValueError("Size (number of features + 1) must be greater than 1.")
        self.size = size

    def _init_params(self):
        theta = np.random.uniform(-np.sqrt(1 / (self.size - 1)), np.sqrt(1 / (self.size - 1)), self.size)
        auxil = theta[:]
        self.objects: Dict[str, HedalVector] = {
            "theta": HedalVector.from_ndarray(self.context, theta, self.path / "theta"),
            "auxil": HedalVector.from_ndarray(self.context, auxil, self.path / "auxil"),
        }

    def _evaluate_gradient(self, Z_class: HedalMatrix, batch_list: List[int], depth: int) -> HedalVector:
        gradient = HedalVector.zeros(self.context, encrypted=self.encrypted, shape=(1, self.size))

        if self["auxil"].need_bootstrap(4):
            self["auxil"].bootstrap()

        for idx in batch_list:
            mat_row = Z_class[idx]
            block_dot = vop.dot(mat_row, self["auxil"], fill=True)[0]
            block_sigmoid = bop.sigmoid(-block_dot, depth)
            if block_sigmoid.need_bootstrap(1):
                block_sigmoid.bootstrap()
            vector_tmp = mat_row * block_sigmoid
            gradient -= vop.sum(vector_tmp, axis=0, direction=0, fill=True)

        if gradient.need_bootstrap(2):
            gradient.bootstrap()
        return gradient

    def predict(self, mat: HedalMatrix, path: Optional[Path] = None, last_activation: bool = False) -> HedalMatrix:
        """Model inference.

        Args:
            mat (HedalMatrix): Input matrix.
            path (Optional[Path], optional): Path to save the results. Defaults to None.
            last_activation (bool, optional): Whether to apply sigmoid at last or not. Defaults to False.

        Returns:
            HedalMatrix: Output matrix.
        """
        num_features = mat.shape[1]
        mask = HedalMatrix.mask(
            self.context, shape=(mat.shape[0], mat.shape[1] + 1), index=num_features, axis=1, encrypted=mat.encrypted,
        )
        if num_features % self.context.shape[1] == 0:
            for row in mat:
                row.block_list.append(Block.zeros(self.context, encrypted=mat.encrypted, type=PackingType.MATRIX))
        mat_pred = mop.dot(mat + mask, self["theta"], fill=False)
        mat_pred.shape = (mat.shape[0], 1)
        if last_activation:
            mat_pred = mop.sigmoid(mat_pred)
        if path is not None:
            mat_pred.path = path
        return mat_pred

    @staticmethod
    def from_path(context: Context, path: Path, size: int, encrypted: bool) -> WeightParamOVR:
        weight = WeightParamOVR(context, path=path, size=size, encrypted=encrypted)
        weight.load()
        return weight

    def load(self):
        with open(self.metadata_path, "r") as m_file:
            m_info = json.load(m_file)
        self.encrypted = m_info["encrypted"]
        self.description = m_info["description"]
        self.size = m_info["size"]
        for name in m_info["params"]:
            self[name] = HedalVector.from_path(
                self.context, self.path / name, shape=(1, self.size), encrypted=self.encrypted
            )

    def level_down(self, target_level: int) -> None:
        for name in self.objects:
            self.objects[name].level_down(target_level)


class WeightParamMultinomial(WeightParam):
    """WeightParam class for multinomial classification (uses softmax)"""

    def __init__(self, context: Context, size: Tuple[int, int], encrypted: bool = False, **kwargs):
        """ 

        Args:
            context (hedal.Context): Context object.
            path (Optional[Path]): path of the weights.
            size (Tuple[int, int]): Size of the weights, which equals (num_classes, num_features + 1).
                For now, the number of classes must not exceed `context.shape[0]`, which equals 256 for FVa parmeter.
            encrypted (bool, optional): Whether the weights are encrypted. Defaults to False.
        """
        super().__init__(context, size, encrypted, **kwargs)

    def _init_size(self, size: Tuple[int, int]):
        if size[0] <= 1:
            raise ValueError("Size[0] (number of classes) must be greater than 1.")
        if size[0] > self.context.shape[0]:
            raise ValueError("Number of classes must not exceed `context.shape[0]`.")
        if size[1] <= 1:
            raise ValueError("Size[1] (number of features + 1) must be greater than 1.")
        self.size = size

    def _init_params(self):
        """Initialize the weights.
        The weights are initialized as a matrix of size (self.context.shape[0], num_features + 1), whose data
        are actually a tiled matrix of original matrix of size (self.size[0], num_features + 1).
        """
        row_pad = int(2 ** math.ceil(math.log2(self.size[0]))) - self.size[0]
        theta = np.random.uniform(-np.sqrt(1 / (self.size[1] - 1)), np.sqrt(1 / (self.size[1] - 1)), self.size)
        theta = np.concatenate([theta, np.zeros((row_pad, self.size[1]))], axis=0)
        theta = np.tile(theta, (self.context.shape[0] // theta.shape[0], 1))
        auxil = theta[:]
        self.objects: Dict[str, HedalVector] = {
            "theta": HedalVector.from_ndarray(self.context, theta, self.path / "theta"),
            "auxil": HedalVector.from_ndarray(self.context, auxil, self.path / "auxil"),
        }
        self.objects["theta"].shape = (self.size[0] + row_pad, self.size[1])
        self.objects["auxil"].shape = (self.size[0] + row_pad, self.size[1])

    def _add_bias_col(self, mat: HedalMatrix) -> HedalMatrix:
        """Add bias column with all ones as a rightmost column of the input matrix.

        Args:
            mat (HedalMatrix): Input matrix.

        Returns:
            HedalMatrix: Output matrix with bias column.
        """
        num_features = mat.shape[1]
        bias_col = HedalMatrix.mask(
            self.context, shape=(mat.shape[0], mat.shape[1] + 1), index=num_features, axis=1, encrypted=mat.encrypted,
        )
        if num_features % self.context.shape[1] == 0:
            copy_mat = mat.copy_memory()
            copy_mat.num_cols = mat.num_cols + 1
            for row in copy_mat:
                row.block_list.append(Block.zeros(self.context, encrypted=mat.encrypted, type=PackingType.MATRIX))
            new_mat = copy_mat + bias_col
        else:
            new_mat = mat + bias_col
        new_mat.num_cols = mat.shape[1] + 1
        return new_mat

    def _evaluate_gradient(self, input: HedalMatrix, target: HedalMatrix, batch_list: List[int]) -> HedalVector:
        num_batch = sum([input[idx].shape[0] for idx in batch_list])
        batch_input_matrix = HedalMatrix(self.context, shape=(num_batch, input.shape[1]), encrypted=input.encrypted)
        batch_target_matrix = HedalMatrix(self.context, shape=(num_batch, target.shape[1]), encrypted=target.encrypted)
        for idx in batch_list:
            batch_input_matrix.objects.append(input[idx])
            batch_target_matrix.objects.append(target[idx])

        prob = self.predict(batch_input_matrix, last_activation=True, softmax_output_tiled=True, param_key="auxil")
        prob.num_cols = target.shape[1]
        if prob.need_bootstrap(3):
            prob.bootstrap()
        gradient = mop.mat_mul_row_tiled(
            prob - batch_target_matrix, self._add_bias_col(batch_input_matrix), tile_col=True
        )
        if gradient.need_bootstrap(2):
            gradient.bootstrap()
        return gradient

    def learn(self, dataset: DatasetMultinomial, batch_list: List[int], learning_rate: float, eta: float) -> None:
        gradient = self._evaluate_gradient(dataset["input"], dataset["target"], batch_list)
        bsz = self._get_batch_size(dataset["input"], batch_list)
        self._update(gradient, learning_rate, eta, bsz)

    def predict(
        self,
        mat: HedalMatrix,
        path: Optional[Path] = None,
        last_activation: bool = False,
        softmax_output_tiled: bool = False,
        param_key: str = "theta",
    ) -> HedalMatrix:
        """Model inference.

        Args:
            mat (HedalMatrix): Input matrix (scaled / normalized) of shape (num_samples, num_features), without bias column.
            path (Optional[Path], optional): Path to save the results. Defaults to None.
            last_activation (bool, optional): Whether to apply softmax at last or not. Defaults to False.
            softmax_output_tiled (bool, optional): Whether to tile the output of the softmax.
            param_key (str, optional): Key of the parameter to use. Defaults to "theta".

        Returns:
            HedalMatrix: Output matrix. Dot-product of the input matrix and the weights (not softmax-ed).
        """
        if self[param_key].need_bootstrap(3):
            self[param_key].bootstrap()
        mat_pred = mop.mat_vec_mul_col_tiled(self._add_bias_col(mat), self[param_key])
        mat_pred.num_cols = self.size[0]
        if last_activation:
            if mat_pred.need_bootstrap(3):
                mat_pred.bootstrap()
            mat_pred = mop.softmax_wide(mat_pred, output_tiled=softmax_output_tiled)
        if path is not None:
            mat_pred.path = path
        return mat_pred

    @staticmethod
    def from_path(context: Context, path: Path, size: int, encrypted: bool) -> WeightParamMultinomial:
        weight = WeightParamMultinomial(context, path=path, size=size, encrypted=encrypted)
        weight.load()
        return weight

    def load(self):
        with open(self.metadata_path, "r") as m_file:
            m_info = json.load(m_file)
        self.encrypted = m_info["encrypted"]
        self.description = m_info["description"]
        self.size = m_info["size"]
        padded_row = 2 ** math.ceil(math.log2(self.size[0]))
        padded_size = (padded_row, self.size[1])
        for name in m_info["params"]:
            self[name] = HedalVector.from_path(
                self.context, self.path / name, shape=padded_size, encrypted=self.encrypted
            )

    def level_down(self, target_level: int) -> None:
        for name in self.objects:
            self.objects[name].level_down(target_level)


class WeightSet(ABC):
    def __init__(self, context: Context, path: Path, scale_type: ScaleType = ScaleType.NORMAL):
        """ 

        Args:
            context (hedal.Context): Context object.
            classes (List): List of classes.
            path (Path): path of the weights.
            description (str): description of the weights.
            epoch_state (int): number of epoch of the weights.
            size (int): length of vectors in the weights.
        """
        self.context = context
        self.size = 0
        self._path: Path = Path(path)
        self.type: PackingType = PackingType.MATRIX
        self.description = ""
        self.encrypted = False
        self.epoch_state: int = 0
        self.scaler: Scaler = Scaler(context, scale_type=scale_type)
        self.objects = None

    @property
    def path(self) -> Path:
        return self._path

    @property
    def name(self) -> str:
        return config.name(self)

    @property
    def classes(self) -> List:
        raise NotImplementedError

    @property
    def num_classes(self) -> int:
        return len(self.classes)

    @property
    def scaler_path(self) -> Path:
        return config.scaler_path(self)

    def __getitem__(self, key: str) -> WeightParam:
        return self.objects[key]

    @abstractclassmethod
    def fit(self, dataset: Dataset) -> None:
        """Initialize scaler and weights

        Args:
            dataset (Dataset): [description]
        """

    @abstractclassmethod
    def learn(self, Z_train: Dataset, num_batch: int, learning_rate: float, eta: float) -> None:
        """Model train.

        Args:
            Z_train (Dataset): [description]
            num_batch (int): [description]
            learning_rate (float): [description]
            eta (float): [description]
        """

    @abstractclassmethod
    def predict(self, X: HedalMatrix, path: Path, normalize: bool = True, last_activation: bool = False) -> ReportSet:
        """Model inference.

        Args:
            X (HedalMatrix): [description]
            path (Path): [description]

        Returns:
            ReportSet: [description]
        """

    class Metadata(BaseModel):
        description: str
        epoch_state: int
        size: Union[int, Tuple[int, int]]
        classes: List[str]
        encrypted: bool

    def metadata(self):
        return self.Metadata(
            encrypted=self.encrypted,
            description=self.description,
            epoch_state=self.epoch_state,
            size=self.size,
            classes=self.classes,
        )

    @property
    def metadata_path(self) -> Path:
        return config.metadata_path(self)

    def load(self):
        raise NotImplementedError

    def save(self, dst_parent_path: Optional[Path] = None) -> None:
        raise NotImplementedError

    def copy(self, dst_parent_path: Path) -> WeightSet:
        raise NotImplementedError

    def remove(self) -> None:
        shutil.rmtree(self.path)

    @staticmethod
    def from_path(context: Context, path: Path) -> WeightSet:
        raise NotImplementedError

    def level_down(self, target_level: int) -> None:
        raise NotImplementedError


class WeightSetOVR(WeightSet):
    def __init__(self, context: Context, path: Path, scale_type: ScaleType = ScaleType.NORMAL):
        """ 

        Args:
            context (hedal.Context)
            path (Optional[Path], optional): path of the weights.
            description (str): description of the weights.
            epoch_state (int): number of epoch of the weights.
            size (int): length of vectors in the weights.
        """
        super().__init__(context, path=path, scale_type=scale_type)
        self.objects: Dict[str, WeightParamOVR] = {}

    @property
    def classes(self) -> List:
        return list(self.objects.keys())

    def keys(self) -> KeysView[str]:
        return self.objects.keys()

    def values(self) -> ValuesView[WeightParam]:
        return self.objects.values()

    def items(self) -> ItemsView[str, WeightParam]:
        return self.objects.items()

    def class_path(self, name: str) -> Path:
        return self.path / name

    def fit(self, dataset: DatasetOVR) -> None:
        self.size = dataset.size
        self.scaler = dataset.scaler.copy(self.path)
        for name in dataset.keys():
            self.objects[name] = WeightParamOVR(
                self.context, path=self.class_path(name), size=self.size, encrypted=dataset.encrypted
            )

    def learn(self, Z_train: DatasetOVR, num_batch: int, learning_rate: float, eta: float) -> None:
        depth = (self.epoch_state + 1) * (learning_rate ** 2 * self.size ** 2 + 0.6 * learning_rate * self.size)
        depth = min(int(np.ceil(np.log(np.sqrt(depth)) / np.log(2.45))), 10)
        depth = max(0, depth)

        total_list = list(range(len(Z_train)))
        random.shuffle(total_list)
        batch_set = [total_list[idx : idx + num_batch] for idx in range(0, len(Z_train), num_batch)]

        for class_idx, param in self.items():
            for batch_list in batch_set:
                param.learn(Z_train[class_idx], batch_list, depth, learning_rate, eta)
        self.epoch_state += 1

    def predict(self, X: HedalMatrix, path: Path, normalize: bool = True, last_activation: bool = False) -> ReportSet:
        path = Path(path)
        report_set = ReportSetOVR(
            self.context, path=path, classes=self.classes, num_data=X.num_rows, encrypted=X.encrypted or self.encrypted
        )
        if normalize:
            Z_test = self.scaler.scale(X)
        else:
            Z_test = HedalMatrix(self.context, shape=X.shape, encrypted=X.encrypted)
            Z_test.objects = X.objects

        for class_idx, param in self.items():
            report_set[class_idx] = param.predict(Z_test, path / class_idx, last_activation=last_activation)
        return report_set

    def encrypt(self, keys: Optional[List[str]] = None) -> None:
        if not keys:
            keys = list(self.keys())
        for key in keys:
            value = self[key]
            value.encrypt()
        self.encrypted = True

    def decrypt(self, keys: Optional[List[str]] = None) -> None:
        if not keys:
            keys = list(self.keys())
        for key in keys:
            value = self[key]
            value.decrypt()
        self.encrypted = False

    def load(self):
        with open(self.metadata_path, "r") as m_file:
            m_info = json.load(m_file)
        self.description = m_info["description"]
        self.epoch_state = m_info["epoch_state"]
        self.size = m_info["size"]
        self.encrypted = m_info["encrypted"]
        for key in m_info["classes"]:
            self.objects[key] = WeightParamOVR.from_path(
                self.context, self.class_path(key), size=self.size, encrypted=self.encrypted
            )
        self.scaler = Scaler.from_path(self.context, self.scaler_path)

    def save(self, dst_path: Optional[Path] = None) -> None:
        """Save weight set to the destination path.

        Directory structure:
            dst_path
            - metadata.json
            - class_0 (WeightParamOVR)
                - metadata.json
                - auxil (HedalVector)
                    ...
                - theta (HedalVector)
                    ...
            - class_1 (WeightParamOVR)
                - metadata.json
                - auxil (HedalVector)
                    ...
                - theta (HedalVector)
                    ...
            ...
            - scale (Scaler)
                ...

        Args:
            dst_path (Optional[Path], optional): Destination parent path. Defaults to None.
        """
        if dst_path is None:
            dst_path = self.path
        if isinstance(dst_path, str):
            dst_path = Path(dst_path)
        metadata_path = dst_path / config._metadata_file_name

        if not dst_path.exists():
            os.makedirs(dst_path, mode=0o775, exist_ok=True)

        for key in self.classes:
            self.objects[key].save(dst_path / key)
        self.scaler.save(dst_path / "scale")

        with open(metadata_path, "w") as m_file:
            m_file.write(self.metadata().json())

    def copy(self, dst_parent_path: Path) -> WeightSetOVR:
        dst_path = Path(dst_parent_path) / self.name
        if dst_path.exists():
            raise OSError("Already exists:", dst_path)
        shutil.copytree(self.path, dst_path)
        return WeightSetOVR.from_path(self.context, dst_path)

    @staticmethod
    def from_path(context: Context, path: Path) -> WeightSetOVR:
        weight = WeightSetOVR(context, path=path)
        weight.load()
        return weight

    def level_down(self, target_level: int) -> None:
        for val in self.values():
            val.level_down(target_level)
        self.scaler.level_down(target_level)


class WeightSetMultinomial(WeightSet):
    def __init__(self, context: Context, classes: List, path: Path, scale_type: ScaleType = ScaleType.NORMAL):
        """ 

        Args:
            context (hedal.Context)
            path (Optional[Path], optional): path of the weights.
            description (str): description of the weights.
            epoch_state (int): number of epoch of the weights.
            size (Tuple[int, int]): length of vectors in the weights.
        """
        super().__init__(context, path=path, scale_type=scale_type)
        self._classes = classes
        self.objects: WeightParamMultinomial = None

    @property
    def classes(self) -> List:
        return self._classes

    @property
    def param_path(self):
        return self.path / "param"

    def fit(self, dataset: DatasetMultinomial) -> None:
        self.size = (dataset.num_classes, dataset.num_features + 1)
        self.scaler = dataset.scaler.copy(self.path)
        self.objects = WeightParamMultinomial(
            self.context, path=self.param_path, size=self.size, encrypted=dataset.encrypted
        )

    def learn(self, Z_train: DatasetMultinomial, num_batch: int, learning_rate: float, eta: float) -> None:
        total_list = list(range(len(Z_train)))
        random.shuffle(total_list)
        batch_set = [total_list[idx : idx + num_batch] for idx in range(0, len(Z_train), num_batch)]

        for batch_list in batch_set:
            self.objects.learn(Z_train, batch_list, learning_rate, eta)
        self.epoch_state += 1

    def predict(
        self, X: HedalMatrix, path: Union[str, Path], normalize: bool = True, last_activation: bool = False
    ) -> ReportSet:
        if isinstance(path, str):
            path = Path(path)
        report_set = ReportSetMultinomial(
            self.context, self.classes, path=path, num_data=X.num_rows, encrypted=X.encrypted or self.encrypted
        )

        if normalize:
            X_test = self.scaler.scale(X)
        else:
            X_test = HedalMatrix(self.context, shape=X.shape, encrypted=X.encrypted)
            X_test.objects = X.objects

        report_set.results = self.objects.predict(X_test, path / "results", last_activation=last_activation)
        return report_set

    def encrypt(self) -> None:
        self.objects.encrypt()
        self.encrypted = True

    def decrypt(self) -> None:
        self.objects.decrypt()
        self.encrypted = False

    def load(self):
        with open(self.metadata_path, "r") as m_file:
            m_info = json.load(m_file)
        self.description = m_info["description"]
        self.epoch_state = m_info["epoch_state"]
        self.size = m_info["size"]
        self.encrypted = m_info["encrypted"]
        self.objects = WeightParamMultinomial.from_path(
            self.context, self.param_path, size=self.size, encrypted=self.encrypted
        )
        self.scaler = Scaler.from_path(self.context, self.scaler_path)

    def save(self, dst_path: Optional[Path] = None) -> None:
        """Save weight set to the destination path.

        Directory structure:
            dst_path
            - metadata.json
            - param
                - auxil (HedalVector)
                    ...
                - theta (HedalVector)
                    ...
            - scale (Scaler)
                ...

        Args:
            dst_path (Optional[Path], optional): Destination parent path. Defaults to None.
        """
        if dst_path is None:
            dst_path = self.path
        if isinstance(dst_path, str):
            dst_path = Path(dst_path)
        metadata_path = dst_path / config._metadata_file_name

        if not dst_path.exists():
            os.makedirs(dst_path, mode=0o775, exist_ok=True)

        self.objects.save(dst_path / "param")
        self.scaler.save(dst_path / "scale")

        with open(metadata_path, "w") as m_file:
            m_file.write(self.metadata().json())

    def copy(self, dst_parent_path: Path) -> WeightSetMultinomial:
        dst_path = Path(dst_parent_path) / self.name
        if dst_path.exists():
            raise OSError("Already exists:", dst_path)
        shutil.copytree(self.path, dst_path)
        return WeightSetMultinomial.from_path(self.context, dst_path)

    @staticmethod
    def from_path(context: Context, path: Path) -> WeightSetMultinomial:
        weight = WeightSetMultinomial(context, path=path)
        weight.load()
        return weight

    def level_down(self, target_level: int) -> None:
        self.objects.level_down(target_level)
        self.scaler.level_down(target_level)
