from __future__ import annotations

import json
import os
from abc import ABC, abstractclassmethod
from pathlib import Path
from typing import Dict, ItemsView, KeysView, List, Optional, ValuesView

import numpy as np
import pandas as pd
from pydantic import BaseModel

from hedal.context import Context
from hedal.core import config
from hedal.matrix.matrix import HedalMatrix


class ReportSet(ABC):
    def __init__(
        self,
        context: Context,
        classes: List,
        path: Optional[Path] = None,
        encrypted: bool = False,
        num_data: int = 0,
        description: str = "",
    ):
        self.context = context
        self.classes = classes
        if path is None:
            self.path = config.temp_path()
        else:
            self.path = Path(path)
        self.encrypted = encrypted
        self.num_data = num_data
        self.description = description

    @property
    def name(self) -> str:
        return config.name(self)

    @property
    def block_shape(self) -> tuple:
        return self.context.shape

    @property
    def num_classes(self) -> int:
        return len(self.classes)

    class Metadata(BaseModel):
        encrypted: bool
        description: str
        num_data: int
        classes: List[str]

    def metadata(self):
        return self.Metadata(
            encrypted=self.encrypted, description=self.description, num_data=self.num_data, classes=self.classes
        )

    @property
    def metadata_path(self) -> Path:
        return config.metadata_path(self)

    @abstractclassmethod
    def load(self) -> None:
        pass

    @abstractclassmethod
    def save(self, dst_parent_path: Optional[Path] = None) -> None:
        pass

    @staticmethod
    def from_path(context: Context, path: Path) -> ReportSet:
        raise NotImplementedError

    def copy(self, dst_parent_path: Optional[Path] = None, **kwargs) -> ReportSet:
        raise NotImplementedError

    @abstractclassmethod
    def to_dataframe(self, *args, **kwargs) -> pd.DataFrame:
        pass

    @abstractclassmethod
    def encrypt(self) -> None:
        pass

    @abstractclassmethod
    def decrypt(self) -> None:
        pass


class ReportSetOVR(ReportSet):
    def __init__(
        self,
        context: Context,
        classes: List,
        path: Optional[Path] = None,
        encrypted: bool = False,
        num_data: int = 0,
        description: str = "",
    ):
        super().__init__(
            context, classes=classes, path=path, encrypted=encrypted, num_data=num_data, description=description
        )
        self.objects: Dict[str, HedalMatrix] = {}

    def keys(self) -> KeysView[str]:
        return self.objects.keys()

    def values(self) -> ValuesView[HedalMatrix]:
        return self.objects.values()

    def items(self) -> ItemsView[str, HedalMatrix]:
        return self.objects.items()

    def load(self) -> None:
        raise NotImplementedError

    @staticmethod
    def from_path(context: Context, path: Path) -> ReportSetOVR:
        raise NotImplementedError

    def copy(self, dst_parent_path: Optional[Path] = None, **kwargs) -> ReportSetOVR:
        raise NotImplementedError

    def _sigmoid(self, x: np.ndarray) -> np.ndarray:
        """Element-wise sigmoid."""
        return 1 / (1 + np.exp(-x))

    def to_dataframe(self, apply_sigmoid: bool = True, normalize: bool = False) -> pd.DataFrame:
        if self.encrypted:
            raise Exception("Do after decrypt")

        df = pd.DataFrame()
        for key, value in self.items():
            val_arr = value.to_ndarray()[: self.num_data, 0]
            if apply_sigmoid:
                val_arr = self._sigmoid(val_arr)
            df[key] = val_arr

        return df.divide(df.sum(axis=1), axis=0) if normalize else df

    def load(self) -> None:
        with open(self.metadata_path, "r") as m_file:
            m_info = json.load(m_file)
        self.encrypted = m_info["encrypted"]
        self.description = m_info["description"]
        self.num_data = m_info["num_data"]
        self.classes = m_info["classes"]
        for key in m_info["classes"]:
            self.objects[key] = HedalMatrix.from_path(self.context, self.path / key)

    def save(self, dst_path: Optional[Path] = None) -> None:
        """Save report.

        Directory structure:
            dst_path
            - metadata.json
            - class_0 (HedalMatrix)
                - metadata.json
                - 0 (HedalVector)
                    - block_0.bin
                    - block_1.bin
                    ...
                - 1 (HedalVector)
                    - block_0.bin
                    - block_1.bin
                    ...
                ...
            - class_1 (HedalMatrix)
                ...

        Args:
            dst_path (Optional[Path], optional): Path to save reports. If dst_path is None,
                then the report will be saved on `self.path`. Defaults to None.
        """
        if dst_path is None:
            dst_path = self.path
        if isinstance(dst_path, str):
            dst_path = Path(dst_path)
        metadata_path = dst_path / config._metadata_file_name

        if not dst_path.exists():
            os.makedirs(dst_path, mode=0o775, exist_ok=True)

        for key in self.classes:
            self.objects[key].save(dst_path / key)

        with open(metadata_path, "w") as m_file:
            m_file.write(self.metadata().json())

    @staticmethod
    def from_path(context: Context, classes: List, path: Path) -> ReportSetOVR:
        reports = ReportSetOVR(context, classes=classes, path=path)
        reports.load()
        return reports

    def copy(self, dst_parent_path: Optional[Path] = None, **kwargs) -> ReportSetOVR:
        raise NotImplementedError

    def encrypt(self) -> None:
        if self.encrypted:
            raise Exception("Already encrypted")
        for value in self.values():
            value.encrypt()
        self.encrypted = True

    def decrypt(self) -> None:
        if not self.encrypted:
            raise Exception("Already decrypted")

        for value in self.values():
            value.decrypt()
        self.encrypted = False

    def __len__(self) -> int:
        return int(np.ceil(self.num_data / self.context.shape[0]))

    def __getitem__(self, key: str) -> HedalMatrix:
        return self.objects[key]

    def __setitem__(self, key: str, value: HedalMatrix) -> None:
        value.path = self.path / key
        self.objects[key] = value


class ReportSetMultinomial(ReportSet):
    def __init__(
        self,
        context: Context,
        classes: List,
        path: Optional[Path] = None,
        encrypted: bool = False,
        num_data: int = 0,
        description: str = "",
    ):
        super().__init__(
            context, classes=classes, path=path, encrypted=encrypted, num_data=num_data, description=description
        )
        self.results: HedalMatrix = None

    def copy(self, dst_parent_path: Optional[Path] = None, **kwargs) -> ReportSet:
        raise NotImplementedError

    def _softmax(self, x: np.ndarray) -> np.ndarray:
        """Row-wise softmax."""
        x = x - x.max(axis=1, keepdims=True)
        x = np.exp(x)
        x = x / np.sum(x, axis=1, keepdims=True)
        return x

    def to_dataframe(self, apply_softmax: bool = True) -> pd.DataFrame:
        if self.encrypted:
            raise Exception("Do after decrypt")

        arr = self.results.to_ndarray()[: self.num_data, : self.num_classes]
        if apply_softmax:
            arr = self._softmax(arr)
        df = pd.DataFrame()
        for key in self.classes:
            df[key] = arr[:, self.classes.index(key)]
        return df

    def load(self) -> None:
        with open(self.metadata_path, "r") as m_file:
            m_info = json.load(m_file)
        self.encrypted = m_info["encrypted"]
        self.description = m_info["description"]
        self.num_data = m_info["num_data"]
        self.classes = m_info["classes"]
        self.results = HedalMatrix.from_path(self.context, self.path / "results")

    def save(self, dst_path: Optional[Path] = None) -> None:
        """Save report.

        Directory structure:
            dst_path
            - metadata.json
            - results (HedalMatrix)
                - metadata.json
                - 0 (HedalVector)
                    - block_0.bin
                    - block_1.bin
                    - ...
                - 1 (HedalVector)
                    - block_0.bin
                    - block_1.bin
                    - ...
                - ...

        Args:
            dst_path (Optional[Path], optional): Path to save reports. If dst_path is None,
                then the report will be saved on `self.path`. Defaults to None.
        """
        if dst_path is None:
            dst_path = self.path
        if isinstance(dst_path, str):
            dst_path = Path(dst_path)
        metadata_path = dst_path / config._metadata_file_name

        if not dst_path.exists():
            os.makedirs(dst_path, mode=0o775, exist_ok=True)

        self.results.save(dst_path / "results")
        with open(metadata_path, "w") as m_file:
            m_file.write(self.metadata().json())

    def encrypt(self) -> None:
        if self.encrypted:
            raise Exception("Already encrypted")
        self.results.encrypt()
        self.encrypted = True

    def decrypt(self) -> None:
        if not self.encrypted:
            raise Exception("Already decrypted")
        self.results.decrypt()
        self.encrypted = False

    @staticmethod
    def from_path(context: Context, classes: List, path: Path) -> ReportSetMultinomial:
        reports = ReportSetMultinomial(context, classes=classes, path=path)
        reports.load()
        return reports
