from functools import partial
from typing import Tuple, Dict, Union, Any

import jax
import numpy as np
import optax
from jax import numpy as jnp

from models.common.type_aliases import Params
from models.common.utils import Model


@jax.jit
def soft_update(q_net: Model, q_net_target: Model, tau: float) -> Model:
    # q_net_target = q_net_target.replace(params=optax.incremental_update(q_net.params, q_net_target.params, tau))
    # return q_net_target
    new_target_params = jax.tree_map(lambda p, tp: p * tau + tp * (1 - tau), q_net.params, q_net_target.params)
    return q_net_target.replace(params=new_target_params)


@partial(jax.jit, static_argnames=("deterministic",))
def jax_update_imitation(
    rng: jnp.ndarray,
    mm_student: Model,
    multimodal_embeddings: Union[np.ndarray, jnp.ndarray],  # [b, l, d]
    embedding_masks: Union[np.ndarray, jnp.ndarray],  # [b, l]
    label: Union[np.ndarray, jnp.ndarray],  # [b,]
    deterministic: bool = True
) -> Tuple[Model, Dict[str, Any]]:
    rng, dropout_key = jax.random.split(rng)
    label = label.reshape(-1, 1)

    def loss_fn(params: Params) -> Tuple[jnp.ndarray, Dict]:
        predictions = mm_student.apply_fn(
            {"params": params},
            multimodal_embeddings=multimodal_embeddings,
            embedding_masks=embedding_masks,
            deterministic=deterministic,
            rngs={"dropout": dropout_key}
        )  # [b, n_skills]

        likelihood = jnp.log(jax.nn.softmax(predictions, axis=-1))  # Rescale

        ce_loss = - jnp.mean(jnp.take_along_axis(likelihood, label, axis=-1))

        _info = {
            "__likelihood": likelihood,
            "ce_loss": ce_loss
        }

        return ce_loss, _info

    new_model, info = mm_student.apply_gradient(loss_fn=loss_fn)

    return new_model, info


@partial(jax.jit, static_argnames=("gamma", "deterministic"))
def jax_update_dqn(
    rng: jnp.ndarray,
    q_net: Model,
    q_net_target: Model,
    observations: jnp.ndarray,  # [b, l, d]: Multimodal embedding
    observation_masks: jnp.ndarray,  # [b, l]: Multimodal embedding masks
    actions: jnp.ndarray,  # [b, ]: Skill indices
    rewards: jnp.ndarray,  # [b, ]
    next_observations: jnp.ndarray,  # [b, l, d]: Multimodal embedding
    next_observation_masks: jnp.ndarray,  # [b, l]: Multimodal embedding masks
    dones: jnp.ndarray,  # [b, ]
    gamma: float,
    deterministic: bool = True
) -> Tuple[Model, Dict[str, Any]]:
    _, dropout_key = jax.random.split(rng)

    target_next_q_values = q_net_target(
        multimodal_embeddings=next_observations,
        embedding_masks=next_observation_masks,
        deterministic=deterministic,
        rngs={"dropout": dropout_key}
    )  # [b, n_skills]

    next_q_values = q_net(
        multimodal_embeddings=next_observations,
        embedding_masks=next_observation_masks,
        deterministic=deterministic,
        rngs={"dropout": dropout_key}
    )   # [b, n_skills]
    max_next_actions = jnp.argmax(next_q_values, axis=1, keepdims=True)
    next_q_values = jnp.take_along_axis(target_next_q_values, max_next_actions, axis=1)

    # next_q_values = jnp.max(next_q_values, axis=1, keepdims=False)  # [b, ]
    

    
    target_q_values = rewards + (1 - dones) * gamma * next_q_values  # [b, ]
    

    def loss_fn(params: Params) -> Tuple[jnp.ndarray, Dict]:
        # Get current Q-values estimates
        current_q_values = q_net.apply_fn(
            {"params": params},
            multimodal_embeddings=observations,
            embedding_masks=observation_masks,
            deterministic=deterministic,
            rngs={"dropout": dropout_key}
        )  # [b, n_skills]
        original_current_q_values = current_q_values

        # Retrieve the q-values for the actions from the replay buffer
        current_q_values = jnp.take_along_axis(
            current_q_values,
            actions.reshape(-1, 1),
            axis=1
        ).reshape(-1, )  # [b, ]
        
        
        
        loss = optax.huber_loss(current_q_values, target_q_values).mean()
        # loss = jnp.mean((current_q_values - target_q_values) ** 2)

        _info = {
            "dqn_loss": loss,
            "__current_q_values": current_q_values,
            "__next_q_values": next_q_values,
            "__target_q_values": target_q_values,
            "__original_current_q_values": original_current_q_values,
            "__rewards": rewards,
            "__dones": dones,
        }

        return loss, _info

    new_model, info = q_net.apply_gradient(loss_fn=loss_fn)
    return new_model, info


@partial(jax.jit, static_argnames=("gamma", "alpha", "deterministic"))
def jax_update_irl(
    rng: jnp.ndarray,
    q_net: Model,
    q_net_target: Model,
    observations: jnp.ndarray,  # [b, l, d]: Multimodal embedding
    observation_masks: jnp.ndarray,  # [b, l]: Multimodal embedding masks
    actions: jnp.ndarray,
    next_observations: jnp.ndarray,  # [b, l, d]: Multimodal embedding
    next_observation_masks: jnp.ndarray,  # [b, l]: Multimodal embedding masks
    dones: jnp.ndarray,
    gamma: float,
    alpha: float,
    deterministic: bool = True
) -> Tuple[Model, Dict[str, Any]]:
    _, dropout_key = jax.random.split(rng)

    tar_next_q_value = q_net_target(
        multimodal_embeddings=next_observations,
        embedding_masks=next_observation_masks,
        deterministic=deterministic,
        rngs={"dropout": dropout_key}
    )
    tar_next_v_star = alpha * jax.nn.logsumexp(tar_next_q_value, axis=-1)

    def loss_fn(params: Params) -> Tuple[jnp.ndarray, Dict]:
        q_value = q_net.apply_fn(
            {"params": params},
            multimodal_embeddings=observations,
            embedding_masks=observation_masks,
            deterministic=deterministic,
            rngs={"dropout": dropout_key}
        )
        q_value = jnp.take_along_axis(q_value, actions.reshape(-1, 1), axis=-1)
        v_star = alpha * jax.nn.logsumexp(q_value / alpha, axis=-1)
        y = (1 - dones) * gamma * tar_next_v_star
        reward = q_value - y

        phi_reward = reward - (reward ** 2) / 4
        value_loss = (v_star - y)
        loss = - jnp.mean(reward + value_loss)

        _info = {
            "__q_value": q_value,
            "__v_star": v_star,
            "__y": y,
            "reward": reward,
            "phi_reward": phi_reward,
            "value_loss": value_loss,
            "loss": loss
        }

        return loss, _info

    new_q_net, info = q_net.apply_gradient(loss_fn)

    return new_q_net, info