from pathlib import Path
from typing import Optional

import numpy as np

from hedal.context import Context
from hedal.core.config import ScaleType
from hedal.ml.dataset import DatasetMultinomial, DatasetOVR


def fit_train_set_ovr(
    context: Context,
    X: np.ndarray,
    y: np.ndarray,
    multiclass: bool,
    path: Optional[Path] = None,
    encrypted: bool = False,
    scale_type: str = "normal",
) -> DatasetOVR:
    """Return a scaled OVR dataset from ndarray.
        
    Args:
        context (Context)
        X (np.ndarray): The input features.
        y (np.ndarray): The input labels.
        multiclass (bool): if True, the dataset is multiclass. Otherwise, it is binary.
        path (Optional[Path], optional): path of the dataset.
        encrypted (bool, optional): if True, the dataset is encrypted. Defaults to False. 
        scale_type (str, optional): type of scale. Defaults to "normal".
    """
    if scale_type == "minmax":
        scale_type = ScaleType.MINMAX
    elif scale_type == "std":
        scale_type = ScaleType.STNDRD
    else:
        scale_type = ScaleType.NORMAL

    dataset = DatasetOVR.from_ndarray(context, X, y, multiclass, scale_type=scale_type, path=path)
    if encrypted:
        dataset.encrypt()
    return dataset


def fit_train_set_multinomial(
    context: Context,
    X: np.ndarray,
    y: np.ndarray,
    path: Optional[Path] = None,
    encrypted: bool = False,
    scale_type: str = "normal",
    target_tiled: bool = True,
) -> DatasetMultinomial:
    """Return a scaled dataset from ndarray.
        
    Args:
        context (Context)
        X (np.ndarray): The input features.
        y (np.ndarray): The input labels.
        path (Optional[Path], optional): Path of the dataset. Defaults to None.
        encrypted (bool, optional): If True, the dataset is encrypted. Defaults to False. 
        scale_type (str, optional): Type of scale. Defaults to "normal".
        target_tiled (bool, optional): If True, the target will be tiled. Defaults to True.
    """
    if scale_type == "minmax":
        scale_type = ScaleType.MINMAX
    elif scale_type == "std":
        scale_type = ScaleType.STNDRD
    else:
        scale_type = ScaleType.NORMAL

    dataset = DatasetMultinomial.from_ndarray(
        context, X, y, scale_type=scale_type, target_tiled=target_tiled, path=path
    )
    if encrypted:
        dataset.encrypt()
    return dataset
