from functools import partial
from typing import Dict, Optional, List, Union, Any, Tuple, Callable, Sequence

import flax
import flax.linen as nn
import jax
import jax.numpy as jnp
import numpy as np
import optax
import torch as th
from tqdm import tqdm
import transformers
from algos.mm_student import (
    jax_update_imitation as _imitation_update,
    jax_update_dqn as _dqn_update,
    jax_update_irl as _irl_update,
    soft_update
)
from models.base import BasePolicy
from models.common.type_aliases import MultiModalEncoderOutput
from models.common.utils import Model, get_basic_rngs, create_mlp
from models.multimodal_encoders import LlavaMultiModalEncoder

from models.third_party.gpt_modules import FlaxGPT2ModuleWoTimePosEmb
from common.vh_invariants import id2skill, available, init_pos, n_skills


PRNGKey = Any
Shape = Sequence[int]
Dtype = Any
Array = Any
Params = flax.core.FrozenDict[str, Any]
TensorDict = Dict[Union[str, int], jnp.ndarray]
Activation = Callable[[jnp.ndarray], jnp.ndarray]
ModuleLike = Union[flax.linen.Module, Activation]


class FlaxTransformerSkillDecoder(nn.Module):
    gpt2_config: Dict
    multimodal_embed_dim: int
    n_skills: int

    transformer = None
    skill_pred = None

    def setup(self):
        gpt2_config = transformers.GPT2Config(**self.gpt2_config, n_embd=self.multimodal_embed_dim)
        self.transformer = FlaxGPT2ModuleWoTimePosEmb(gpt2_config, dtype=jnp.float32)
        self.skill_pred = nn.Dense(self.n_skills)

    def __call__(self, *args, **kwargs) -> jnp.ndarray:
        return self.forward(*args, **kwargs)

    def forward(
        self,
        multimodal_embeddings: jnp.ndarray,  # [b, l, d]
        embedding_masks: jnp.ndarray,  # [b, l]
        deterministic: bool = True,
    ) -> jnp.ndarray:  # [b, n_skills]

        transformer_outputs = self.transformer(
            hidden_states=multimodal_embeddings,
            attention_mask=embedding_masks,
            deterministic=deterministic
        )
        x = transformer_outputs["last_hidden_state"]
        skill_logits = self.skill_pred(x)  # [b, l, n_skills]
        return skill_logits[:, -1]  # [b, n_skills]


class FlaxMlpSkillDecoder(nn.Module):
    n_skills: int
    net_arch: List
    activation_fn: nn.Module = nn.relu
    dropout: float = 0.0
    kernel_init: Callable[[PRNGKey, Shape, Dtype], Array] = nn.initializers.xavier_normal()
    bias_init: Callable[[PRNGKey, Shape, Dtype], Array] = nn.initializers.zeros
    layers = None
    last_activation = None

    def setup(self) -> None:
        self.layers = create_mlp(
            output_dim=self.n_skills,
            net_arch=self.net_arch,
            activation_fn=self.activation_fn,
            dropout=self.dropout,
            kernel_init=self.kernel_init,
            bias_init=self.bias_init,
            layer_norm=True
        )
        self.last_activation = nn.softmax

    def __call__(self, *args, **kwargs):
        return self.forward(*args, **kwargs)

    def forward(
        self,
        multimodal_embeddings: jnp.ndarray,  # [b, d]
        deterministic: bool = False,
        **kwargs
    ):
        # skill_logits = self.last_activation(self.layers(multimodal_embeddings, deterministic=deterministic))
        skill_logits = self.layers(multimodal_embeddings, deterministic=deterministic)
        return skill_logits



@partial(jax.jit, static_argnames=("deterministic",))
def forward(
    rng: jnp.ndarray,
    model: Model,
    multimodal_embeddings: jnp.ndarray,  # [b, l, d]
    embedding_masks: jnp.ndarray,  # [b, l]
    deterministic: bool = True
) -> Tuple[jnp.ndarray, jnp.ndarray]:  # (rng, [b, n_skills])
    rng, dropout_rng = jax.random.split(rng)

    decoder_output = model.apply_fn(
        {"params": model.params},
        multimodal_embeddings=multimodal_embeddings,
        embedding_masks=embedding_masks,
        deterministic=deterministic,
        rngs={"dropout": dropout_rng}
    )

    return rng, decoder_output


class SkillDecoder(BasePolicy):
    def __init__(
        self,
        seed: int,
        cfg: Dict,
        init_build_model: bool = True
    ):
        super().__init__(seed=seed, cfg=cfg)
        self.model = None  # Regarded as: (1) Behavior cloning policy, or (2) Q-network for DQN
        self.q_net_target = None  # This is not None only when update in DQN style

        self.multimodal_encoder = None  # type: LlavaMultiModalEncoder
        self.mode = cfg["mode"]
        
        self.excluded_save_params.extend(["multimodal_processor", "multimodal_encoder"])
        if init_build_model:
            self.build_model()
    

    def get_model_dummy_input(self) -> Dict[str, jnp.ndarray]:
        batch_size = 7
        subseq_len = 3
        input_dim = self.cfg["multimodal_embed_dim"]
        multimodal_embeddings = jax.random.normal(shape=(batch_size, 2, 768), key=self.rng)
        embedding_masks = jnp.ones((batch_size, 2))

        return {"multimodal_embeddings": multimodal_embeddings, "embedding_masks": embedding_masks}

    def build_model(self):

        arch = self.cfg.get("arch", "transformer")
        if arch == "mlp":
            model_def = FlaxMlpSkillDecoder(
                n_skills=self.cfg["n_skills"],
                # net_arch=self.cfg["net_arch"],
                net_arch=[512,512],

                activation_fn=nn.leaky_relu,
                kernel_init=nn.initializers.he_normal(),
            )

        elif arch == "transformer":
            model_def = FlaxTransformerSkillDecoder(
                gpt2_config=self.cfg["gpt2_config"],
                multimodal_embed_dim=self.cfg["multimodal_embed_dim"],
                n_skills=self.cfg["n_skills"]
            )

        else:
            raise NotImplementedError(f"Undefined model architecture: {arch}")

        self.rng, rngs = get_basic_rngs(self.rng)
        self.rng, _ = jax.random.split(self.rng)

        dummy_input = self.get_model_dummy_input()

        tx = optax.chain(optax.adam(learning_rate=self.cfg["lr"]))

        inputs = [rngs]
        inputs.extend(dummy_input.values())

        self.model = Model.create(model_def=model_def, inputs=inputs, tx=tx)
        
        if self.mode in ["rl", "irl"]:
            self.q_net_target = Model.create(model_def=model_def, inputs=inputs)
            self.q_net_target = soft_update(q_net=self.model, q_net_target=self.q_net_target, tau=1.0)
            self.param_components.append("q_net_target")

    def predict(
        self,
        prompts: Optional[Union[str, List[str]]] = None,
        captions: Optional[Union[str, List[str]]] = None,
        deterministic: bool = True,
        # greedy_selection: bool = True
    ) -> jnp.ndarray:
        mm_encoded = self.get_multimodal_observations(prompts=prompts, captions=captions)

        multimodal_embeddings = mm_encoded.multimodal_embeddings
        embedding_masks = mm_encoded.embedding_masks

        self.rng, output = forward(
            rng=self.rng,
            model=self.model,
            multimodal_embeddings=multimodal_embeddings,
            embedding_masks=embedding_masks,
            deterministic=deterministic
        )
        self.rng, _ = jax.random.split(self.rng)
        output_action = jnp.argmax(output, axis=-1)  # [b, ]

        return output_action


    def predict_action(
            self,
            prompts: Optional[Union[str, List[str]]] = None,
            captions: Optional[List[Union[jnp.ndarray, np.ndarray]]] = None,
            deterministic: bool = True,
        ) -> jnp.ndarray:

            mm_encoded = self.get_multimodal_observations(prompts=prompts, captions=captions)

            multimodal_embeddings = mm_encoded.multimodal_embeddings
            embedding_masks = mm_encoded.embedding_masks

            self.rng, output = forward(
                rng=self.rng,
                model=self.model,
                multimodal_embeddings=multimodal_embeddings,
                embedding_masks=embedding_masks,
                deterministic=deterministic
            )
            self.rng, _ = jax.random.split(self.rng)
            return output

    def _after_update(self):
        self.rng, _ = jax.random.split(self.rng)
        self.n_update += 1

    @th.no_grad()
    def get_multimodal_observations(
        self,
        prompts: Optional[Union[str, List[str]]] = None,
        captions: Optional[Union[str, List[str]]] = None,
    ) -> MultiModalEncoderOutput:
        return self.multimodal_encoder(
            prompts=prompts,
            captions=captions
        )

    def rl_update(
        self,
        prompts: Optional[Union[str, List[str]]] = None,  # List of prompts
        next_prompts: Optional[Union[str, List[str]]] = None,  # List of prompts
        captions: Optional[Union[str, List[str]]] = None,  # Batch of image observations
        actions: Optional[Union[jnp.ndarray, np.ndarray]] = None,  # [b,]  # Skill indices
        rewards: Optional[Union[jnp.ndarray, np.ndarray]] = None,  # [b,]
        # List of next image observations
        next_captions: Optional[Union[str, List[str]]] = None,  # Batch of image observations
        dones: Optional[Union[jnp.ndarray, np.ndarray]] = None,
        deterministic: Optional[bool] = True
    ) -> Dict[str, Any]:
        
        observations = self.get_multimodal_observations(
            prompts=prompts,
            captions=captions
        )
        next_observations = self.get_multimodal_observations(
            prompts=next_prompts,
            captions=next_captions
        )

        new_model, info = _dqn_update(
            rng=self.rng,
            q_net=self.model,
            q_net_target=self.q_net_target,
            observations=observations.multimodal_embeddings,
            observation_masks=observations.embedding_masks,
            actions=actions,
            rewards=rewards,
            next_observations=next_observations.multimodal_embeddings,
            next_observation_masks=next_observations.embedding_masks,
            dones=dones,
            gamma=self.cfg["gamma"],
            deterministic=deterministic
        )

        self.model = new_model
        self._after_update()

        if (self.n_update % self.cfg["target_update_interval"]) == 0:
            self.q_net_target = soft_update(self.model, self.q_net_target, self.cfg["tau"])

        return info
    
    
    def imitation_update(
        self,
        prompts: Optional[Union[str, List[str]]] = None,  # List of prompts
        # List of image. Each image has a shape of [h, w, channel]
        captions: Optional[List[Union[jnp.ndarray, np.ndarray]]] = None,
        label: Optional[Union[jnp.ndarray, np.ndarray]] = None,  # [b,]: Target skill indices (=Expert actions)
        deterministic: Optional[bool] = True
    ) -> Dict[str, Any]:


        mm_encoded = self.get_multimodal_observations(prompts=prompts, captions=captions)

        multimodal_embeddings = mm_encoded.multimodal_embeddings
        embedding_masks = mm_encoded.embedding_masks

        new_model, info = _imitation_update(
            rng=self.rng,
            mm_student=self.model,
            multimodal_embeddings=multimodal_embeddings,
            embedding_masks=embedding_masks,
            label=label,
            deterministic=deterministic
        )
        self.model = new_model

        self._after_update()
        return info