from __future__ import annotations

import math
from math import ceil
from pathlib import Path
from typing import List, Optional

import numpy as np
import pandas as pd

from hedal.context import Context
from hedal.core.config import ScaleType
from hedal.matrix.matrix import HedalMatrix
from hedal.matrix.vector import HedalVector
from hedal.ml.dataset import Dataset
from hedal.ml.reports import ReportSet
from hedal.ml.weights import WeightParamMultinomial, WeightParamOVR, WeightSet, WeightSetMultinomial, WeightSetOVR


class LogisticRegression:
    def __init__(self, context: Context, path: Path, scale_type: ScaleType = ScaleType.NORMAL) -> None:
        self.context: Context = context
        self.weight_set: WeightSet = WeightSet(context, path, scale_type=scale_type)

    def fit(self, Z: Dataset, learning_rate: float = 0.01, num_epoch: int = 10, batch_size: int = 0) -> None:
        """Fit the model according to the given training data

        Args:
            learning_rate (float): learning_rate of the weights. Default is 0.01.
            num_epoch (int): epoch state of the weights. Default is 10.
            batch_size (int): size of batch.
        """
        if batch_size == 0:
            batch_size = len(Z)
        num_batch = min(int(ceil(batch_size / self.context.shape[0])), len(Z))

        if not self.weight_set.objects:
            self.weight_set.fit(Z)

        # lambda_1 = 1
        eta = 0.9
        print("learning rate:", learning_rate)
        print("eta:", eta)

        for idx in range(num_epoch):
            print(f"{idx}-th epoch")

            # lambda_0 = lambda_1
            # lambda_1 = (1 + sqrt(1 + 4 * lambda_0 * lambda_0)) / 2.0
            # eta = (1 - lambda_0) / lambda_1
            # eta = 0.1

            self.weight_set.learn(Z, num_batch, learning_rate, eta=eta)

    def predict(self, X: HedalMatrix, path: Path, normalize: bool = True, last_activation: bool = False) -> ReportSet:
        return self.weight_set.predict(X, path, normalize, last_activation)

    def encrypt(self) -> None:
        self.weight_set.encrypt()

    def decrypt(self) -> None:
        self.weight_set.decrypt()

    def save(self) -> None:
        self.weight_set.save()

    def load(self) -> None:
        self.weight_set.load()

    def remove(self) -> None:
        self.weight_set.remove()

    def level_down(self, target_level: int) -> None:
        self.weight_set.level_down(target_level)


class LogisticRegressionOVR(LogisticRegression):
    def __init__(self, context: Context, path: Path, scale_type: ScaleType = ScaleType.NORMAL) -> None:
        self.context = context
        self.weight_set = WeightSetOVR(context, path, scale_type=scale_type)

    @staticmethod
    def from_path(context: Context, path: Path) -> LogisticRegressionOVR:
        model = LogisticRegressionOVR(context, path)
        model.load()
        return model

    @staticmethod
    def from_ndarray(context: Context, array: np.ndarray, path: Optional[Path] = None) -> LogisticRegressionOVR:
        """Load model from ndarray of weights and returns model.

        Args:
            context (Context): Context of the model.
            array (np.ndarray): ndarray of weights of shape (num_classes, num_features + 1), where the rightmost column is a bias column.
            path (Optional[Path], optional): Path of the model. Defaults to None.

        Returns:
            LogisticRegressionMultinomial: Multinomial logistic regression model with loaded weights.
        """
        num_class, size = array.shape

        model = LogisticRegressionOVR(context, path)
        model.weight_set.size = size
        for class_idx in range(num_class):
            class_name = str(class_idx)
            param = WeightParamOVR(context, path=model.weight_set.class_path(class_name), size=size)
            param["theta"] = HedalVector.from_ndarray(context, array[class_idx], param.path / "theta")
            model.weight_set.objects[class_name] = param

        return model

    def to_ndarray(self) -> np.ndarray:
        res = []
        for value in self.weight_set.values():
            res.append(value["theta"].to_ndarray()[0])
        return np.array(res)

    def to_dataframe(self) -> pd.DataFrame:
        return pd.DataFrame(self.to_ndarray())


class LogisticRegressionMultinomial(LogisticRegression):
    def __init__(self, context: Context, path: Path, classes: List, scale_type: ScaleType = ScaleType.NORMAL) -> None:
        self.context = context
        self.weight_set = WeightSetMultinomial(context, path=path, classes=classes, scale_type=scale_type)

    @staticmethod
    def from_path(
        context: Context, path: Path, classes: List, scale_type: ScaleType = ScaleType.NORMAL
    ) -> LogisticRegressionMultinomial:
        model = LogisticRegressionMultinomial(context, path=path, classes=classes, scale_type=scale_type)
        model.load()
        return model

    @staticmethod
    def from_ndarray(
        context: Context, array: np.ndarray, classes: List, path: Optional[Path] = None
    ) -> LogisticRegressionMultinomial:
        """Load model from ndarray of weights and returns model.

        Args:
            context (Context): Context of the model.
            array (np.ndarray): ndarray of weights of shape (num_classes, num_features + 1), where the rightmost column is a bias column.
            classes (List): List of classes.
            path (Optional[Path], optional): Path of the model. Defaults to None.

        Returns:
            LogisticRegressionMultinomial: Multinomial logistic regression model with loaded weights.
        """
        path = Path(path)
        size = array.shape
        row_pad = int(2 ** math.ceil(math.log2(size[0])))
        model = LogisticRegressionMultinomial(context, path, classes=classes)
        model.weight_set.size = (row_pad, size[1])
        model.weight_set.objects = WeightParamMultinomial(context, path=path / "param", size=size, encrypted=False)
        array = np.concatenate([array, np.zeros((row_pad - size[0], size[1]))], axis=0)
        array = np.tile(array, (context.shape[0] // row_pad, 1))
        model.weight_set.objects.objects["theta"] = HedalVector.from_ndarray(
            context, array, path=model.weight_set.path / "theta"
        )
        model.weight_set.objects.objects["theta"].shape = (row_pad, size[1])
        return model

    def to_ndarray(self) -> np.ndarray:
        res = self.weight_set.objects["theta"].to_ndarray()
        return np.array(res)[: self.weight_set.num_classes]

    def to_dataframe(self) -> pd.DataFrame:
        return pd.DataFrame(self.to_ndarray())
