from __future__ import annotations

import os
import shutil
from abc import ABC, abstractmethod
from pathlib import Path
from typing import Dict, ItemsView, KeysView, List, Optional, Tuple, Union, ValuesView

import numpy as np
import pandas as pd
from pydantic import BaseModel

from hedal.block import Block
from hedal.context import Context
from hedal.core import config
from hedal.core.config import PackingType
from hedal.core.object import Object


class ObjectSequence(ABC):
    def __init__(self, context: Context, **kwargs):
        self.context = context
        self._path: Path = Path(kwargs["path"]) if kwargs.get("path") else config.temp_path()
        self.encrypted: bool = kwargs["encrypted"] if kwargs.get("encrypted") else False
        self.type: PackingType = kwargs["type"] if kwargs.get("type") else PackingType.FRAME
        self.description: str = kwargs["description"] if "description" in kwargs.keys() else ""

    @property
    def num_slots(self) -> int:
        return self.context.num_slots

    @property
    def block_shape(self) -> Tuple[int, int]:
        if self.type == PackingType.FRAME:
            return (self.context.num_slots, 1)
        elif self.type == PackingType.MATRIX:
            return self.context.shape
        else:
            raise TypeError("Invalid type")

    @property
    def path(self) -> Path:
        return Path(self._path)

    @path.setter
    def path(self, new_path: Path) -> None:
        self._path = Path(new_path)

    @property
    def metadata_path(self) -> Path:
        return config.metadata_path(self)

    @property
    def parent_path(self) -> Path:
        return config.parent_path(self)

    @property
    def name(self) -> str:
        return config.name(self)

    @name.setter
    def name(self, new_name: str) -> None:
        self.rename(new_name)

    def move(self, dst_parent_path: Path) -> ObjectSequence:
        dst_parent_path = Path(dst_parent_path)
        if not self.parent_path == dst_parent_path:
            dst_path = dst_parent_path / self.name
            if dst_path.exists():
                raise OSError("Already exists:", dst_path)
            shutil.move(self.path, dst_path)
            self.path = dst_path
            self.load()
        return self

    def rename(self, new_name: str) -> ObjectSequence:
        if not self.name == new_name:
            new_path = self.parent_path / new_name
            if new_path.exists():
                raise OSError("Already exists:", new_path)
            self.path.rename(new_path)
            self.path = new_path
            self.load()
        return self

    def remove(self) -> None:
        shutil.rmtree(self.path)

    def __enter__(self):
        return self

    def __exit__(self, exc_type, exc_value, exc_traceback) -> bool:
        return True

    def __delete__(self, instance) -> None:
        shutil.rmtree(self.path)

    class Metadata(BaseModel):
        ...

    def metadata(self):
        raise NotImplementedError

    @abstractmethod
    def load(self):
        pass

    @staticmethod
    @abstractmethod
    def from_path(context: Context, path: Path) -> ObjectSequence:
        pass

    @abstractmethod
    def copy(self, dst_parent_path: Optional[Path] = None, **kwargs) -> ObjectSequence:
        pass

    @abstractmethod
    def encrypt(self) -> None:
        pass

    @abstractmethod
    def decrypt(self) -> None:
        pass

    @abstractmethod
    def __len__(self) -> int:
        pass

    def save(self, dst_parent_path: Optional[Union[str, Path]] = None) -> None:
        raise NotImplementedError

    @staticmethod
    def from_ndarray(context: Context, array: np.ndarray, path: Optional[Path] = None):
        raise NotImplementedError

    def to_ndarray(self) -> np.ndarray:
        raise NotImplementedError

    @staticmethod
    def from_series(context: Context, series: pd.Series, path: Optional[Path] = None):
        raise NotImplementedError

    def to_series(self) -> pd.Series:
        raise NotImplementedError


class ObjectDict(ObjectSequence):
    def __init__(self, context: Context, **kwargs):
        super(ObjectDict, self).__init__(context, **kwargs)
        self.objects: Dict[str, Object] = {}

    def keys(self) -> KeysView[str]:
        return self.objects.keys()

    def values(self) -> ValuesView[Object]:
        return self.objects.values()

    def items(self) -> ItemsView[str, Object]:
        return self.objects.items()

    def copy_memory(self) -> ObjectList:
        raise NotImplementedError

    def level_down(self, target_level: int, inplace: bool = True) -> ObjectList:
        if not inplace:
            objects = self.copy_memory()
        else:
            objects = self
        for obj in objects.values():
            obj.level_down(target_level)
        return objects

    def encrypt(self, keys: Optional[List[str]] = None) -> None:
        if not keys:
            keys = list(self.keys())
        for key in keys:
            value = self[key]
            value.encrypt()
        self.encrypted = True

    def decrypt(self, keys: Optional[List[str]] = None) -> None:
        if not keys:
            keys = list(self.keys())
        for key in keys:
            value = self[key]
            value.decrypt()
        self.encrypted = False

    def save(self, dst_parent_path: Optional[Union[str, Path]] = None) -> None:
        """Save ObjectSequence to the destination parent path.

        Args:
            dst_parent_path (Optional[Union[str, Path]], optional): Destination parent path. Defaults to None.
        """
        if isinstance(dst_parent_path, str):
            dst_parent_path = Path(dst_parent_path)
        if dst_parent_path:
            metadata_path = dst_parent_path / self.name / config._metadata_file_name
        else:
            metadata_path = self.metadata_path
            dst_parent_path = self.parent_path
        dst_path = dst_parent_path / self.name
        if not dst_path.exists():
            os.makedirs(dst_path, mode=0o775, exist_ok=True)

        for value in self.values():
            value.save(dst_path)

        with open(metadata_path, "w") as m_file:
            m_file.write(self.metadata().json())

    def __len__(self) -> int:
        return len(self.objects)

    def __getitem__(self, key: str) -> Object:
        return self.objects[key]

    def __setitem__(self, key: str, obj: Object) -> None:
        if key in self.keys():
            self.objects.pop(key)

        obj.path = self.path / key
        self.objects[key] = obj


class ObjectList(ObjectSequence):
    def __init__(self, context: Context, **kwargs):
        super(ObjectList, self).__init__(context, **kwargs)
        self.objects: List[Object] = []

    @property
    def level(self) -> int:
        return self.objects[0].level

    def copy_memory(self) -> ObjectList:
        raise NotImplementedError

    def level_down(self, target_level: int, inplace: bool = True) -> ObjectList:
        if not inplace:
            obj_list = self.copy_memory()
        else:
            obj_list = self
        for obj in obj_list:
            obj.level_down(target_level)
        return obj_list

    def need_bootstrap(self, cost_per_iter: int = 2) -> bool:
        return self.objects[0].need_bootstrap(cost_per_iter)

    def save(self, dst_path: Optional[Union[str, Path]] = None) -> None:
        """Save ObjectSequence to the destination path.

        Directory structure:
            dst_path
                - metadata.json
                - 0 (HedalVector)
                    - ...
                - 1 (HedalVector)
                    - ...
                - ...

        Args:
            dst_path (Optional[Path], optional): Destination path. Defaults to None.
        """
        if isinstance(dst_path, str):
            dst_path = Path(dst_path)
        if dst_path is None:
            dst_path = self.path
        metadata_path = dst_path / config._metadata_file_name

        if not dst_path.exists():
            os.makedirs(dst_path, mode=0o775, exist_ok=True)

        for idx, value in enumerate(self.objects):
            value.save(dst_path / str(idx))

        with open(metadata_path, "w") as m_file:
            m_file.write(self.metadata().json())

    def encrypt(self) -> None:
        if self.encrypted:
            raise Exception("Already encrypted")
        for value in self:
            value.encrypt()
        self.encrypted = True

    def decrypt(self) -> None:
        if not self.encrypted:
            raise Exception("Already decrypted")
        for value in self:
            value.decrypt()
        self.encrypted = False

    def bootstrap(self, one_slot: bool = False, complex: bool = False) -> None:
        if complex:
            for value in self:
                value.bootstrap(one_slot=one_slot, complex=complex)
        else:
            # bootstrap two blocks at once
            tuple_list = [(i, j) for i in range(len(self.objects)) for j in range(len(self.objects[i].block_list))]
            for k in range(len(tuple_list) // 2):
                i1, j1 = tuple_list[2 * k]
                i2, j2 = tuple_list[2 * k + 1]
                Block.bootstrap_two_ctxts(self.objects[i1].block_list[j1], self.objects[i2].block_list[j2])
            if len(tuple_list) % 2 == 1:
                i, j = tuple_list[-1]
                self.objects[i].block_list[j].bootstrap()

    def __getitem__(self, idx: int) -> Object:
        if (idx >= len(self)) or (idx < 0):
            raise IndexError("Out of range")
        return self.objects[idx]

    def __setitem__(self, idx: int, obj: Object) -> None:
        if (idx >= len(self)) or (idx < 0):
            raise IndexError("Out of range")
        obj.path = self.path / str(idx)
        self.objects[idx] = obj

    def __iter__(self) -> ObjectList:
        self._idx = 0
        return self

    def __next__(self) -> Object:
        if self._idx >= len(self):
            raise StopIteration
        index = self._idx
        self._idx += 1
        return self[index]
