from dataclasses import dataclass
from typing import Sequence, Any, Dict, Union, Callable, Optional

import flax
from jax import numpy as jnp
import numpy as np


@dataclass
class MultiModalEncoderOutput:
    multimodal_embeddings: Optional[Union[jnp.ndarray, np.ndarray]] = None
    embedding_masks: Optional[Union[jnp.ndarray, np.ndarray]] = None


PRNGKey = Any
Shape = Sequence[int]
Dtype = Any
Array = Any
Params = flax.core.FrozenDict[str, Any]
TensorDict = Dict[Union[str, int], jnp.ndarray]
Activation = Callable[[jnp.ndarray], jnp.ndarray]
ModuleLike = Union[flax.linen.Module, Activation]
