from functools import partial
from typing import Tuple, Dict, Union, Any

import jax
import numpy as np
from jax import numpy as jnp
from jax import lax
from common.policies import Model
from common.type_aliases import Params


@partial(jax.jit, static_argnames=("deterministic", "max_epi_len"))
def jax_update(
    rng: jnp.ndarray,
    model: Model,
    input_tokens: Union[np.ndarray, jnp.ndarray],
    rewards: Union[np.ndarray, jnp.ndarray] ,
    dummys: Union[np.ndarray, jnp.ndarray],
    succs: Union[np.ndarray, jnp.ndarray],
    reward_masks: Union[np.ndarray, jnp.ndarray] = None,
    max_epi_len: int=7,
    deterministic: bool = True
) -> Tuple[Model, Dict[str, Any]]:
    
    """
    rng: jnp.ndarray,
    model: Model,
    input_tokens: Union[np.ndarray, jnp.ndarray], (batch * max_epi_len, 3)  3 = n_rewards
    rewards: Union[np.ndarray, jnp.ndarray] , (batch, max_epi_len, 3)
    dummys: Union[np.ndarray, jnp.ndarray], (batch, max_epi_len)
    succs: Union[np.ndarray, jnp.ndarray], (batch, )
    batch_size: int,
    reward_masks: Union[np.ndarray, jnp.ndarray] = None, (batch, max_epi_len, 3)

    Returns:
        _type_: _description_
    """
    rng, dropout_key = jax.random.split(rng)
    batch_size = rewards.shape[0]
    rewards = rewards * (1 - jnp.expand_dims(dummys, axis=-1))  # [b, epi_len, 3]
    
    idx = max_epi_len - jnp.sum(dummys, axis=-1)

    def loss_fn(params: Params) -> Tuple[jnp.ndarray, Dict]:
        weights = model.apply_fn(
            {"params": params},
            input_emb=input_tokens,
            deterministic=deterministic,
            rngs={"dropout": dropout_key}
        )  # [b * max_epi_len, n_rewards]
    
        weights = weights.reshape(batch_size, max_epi_len, -1)  ## batch, max_epi_len, h

        normalizing_factors = jnp.sum(reward_masks * weights, axis=-1, keepdims=True)  # [b, max_epi_len, 1]
        normalizing_factors = normalizing_factors + jnp.expand_dims(dummys, axis=-1)
        normalized_weights = weights
        normalized_weights = weights / normalizing_factors  # [b, max_epi_len, 3] ### llava
               
        ensembled_reward = jnp.sum(rewards * normalized_weights, axis=-1)   # [b, max_epi_len]
        predicted_return = jnp.sum(ensembled_reward, axis=-1)/(2*idx)  # [b, ]

        loss = jnp.mean((predicted_return - succs) ** 2)

        _info = {
            "__likelihood": weights,
            # "mean_w": jnp.mean(weights, axis=(0,1)),
            "__normalizing_factors": normalizing_factors,
            "__predicted_return": predicted_return,
            "loss": loss
        }

        return loss, _info

    new_model, info = model.apply_gradient(loss_fn=loss_fn)
    return new_model, info
