from typing import Dict, List, Union

import jax
from jax import numpy as jnp
import numpy as np
from models.common.interfaces import IJaxSavable
from models.common.type_aliases import Params
from models.common.utils import Model, set_seed


class BasePolicy(IJaxSavable):
    def __init__(
        self,
        seed: int,
        cfg: Dict
    ):
        set_seed(seed)
        self.seed = seed
        self.cfg = cfg
        self.rng = jax.random.PRNGKey(seed)
        self.n_update = 0
        self.excluded_save_params=["model"]
        self.param_components = ["model"]

    def get_param_components(self) -> List:
        return self.param_components

    def predict(self, *args, **kwargs) -> Union[np.ndarray, jnp.ndarray]:
        pass

    def _excluded_save_params(self) -> List:
        return self.excluded_save_params

    def _get_save_params(self) -> Dict[str, Params]:
        params_dict = {}
        param_components = self.get_param_components()
        assert len(param_components) > 0
        for component_str in param_components:
            component = getattr(self, component_str)    # type: Model
            params_dict[component_str] = component.params
        return params_dict

    def _get_load_params(self) -> List[str]:
        return self.get_param_components()
