from functools import partial
from typing import Dict, Optional, List, Union, Any, Tuple, Callable, Sequence, NamedTuple

import flax
import flax.linen as nn
import jax
import jax.numpy as jnp
import numpy as np
import optax

from common.update import jax_update
from common.policies import Model 
from transformers import FlaxBertForSequenceClassification, BertTokenizerFast, FlaxBertModel
from common.invariants import n_rewards
from model.base import BasePolicy

PRNGKey = Any
Shape = Sequence[int]
Dtype = Any
Array = Any
Params = flax.core.FrozenDict[str, Any]
TensorDict = Dict[Union[str, int], jnp.ndarray]
Activation = Callable[[jnp.ndarray], jnp.ndarray]
ModuleLike = Union[flax.linen.Module, Activation]


class DenseLayer(nn.Module):
    num_rewards: int
    activation_fn: Callable[[jnp.ndarray], jnp.ndarray] = nn.relu
    activation_fn_last: Callable[[jnp.ndarray], jnp.ndarray] = nn.softmax

    @nn.compact
    def __call__(self, x: jnp.ndarray) -> jnp.ndarray:
        # x = nn.Dense(64)(x)
        # x = self.activation_fn(x)
        x = nn.Dense(self.num_rewards)(x)
        logits = self.activation_fn_last(x)
        return logits
    
class TrpNetwork(nn.Module):
    n_rewards: int
    # config: Dict = None
    model: Model = None
        
    def __call__(self, *args, **kwargs) -> jnp.ndarray:
        
        return self.forward(*args, **kwargs)

    def forward(
        self,
        input_emb,  # [b, d],
        deterministic: bool = True,

    ):
        logits = self.model(input_emb)
        
        return logits


@partial(jax.jit, static_argnames=("deterministic",))
def forward(
    rng: jnp.ndarray,
    model: Model,
    input_emb: jnp.ndarray,  # [b, l, d]
    deterministic: bool = False
) -> Tuple[jnp.ndarray, jnp.ndarray]:  # (rng, [b, n_skills])
    
    rng, _ = jax.random.split(rng)
    
    decoder_output = model.apply_fn(
        {"params": model.params},
        input_emb=input_emb,
        deterministic=deterministic,
    )

    return rng, decoder_output


def get_basic_rngs(rng: jnp.ndarray) -> Tuple[jnp.ndarray, Dict[str, jnp.ndarray]]:
    rng, param_key, dropout_key, batch_key = jax.random.split(rng, 4)
    return rng, {"params": param_key, "dropout": dropout_key, "batch_stats": batch_key}

class Classifier(BasePolicy):
    def __init__(
        self,
        seed: int,
        cfg: Dict,
        init_build_model: bool = True
    ):
        super().__init__(seed=seed, cfg=cfg)

        self.model = None
        self.tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased")
        self.pretrained_backbone = FlaxBertModel.from_pretrained("bert-base-uncased", num_labels=3)
        self.excluded_save_params.extend(["tokenzier"])
        self.excluded_save_params.extend(["pretrained_backbone"])

        if init_build_model:
            self.build_model()

    def get_model_dummy_input(self):
        
        dummy_text = "Hello. "
        dummy_input = self.tokenizer(dummy_text, return_tensors = "jax", padding="max_length", truncation=True, max_length=100,)
    
        return dummy_input

    def build_model(self):
        network = DenseLayer(num_rewards=n_rewards)

        model_def = TrpNetwork(
            n_rewards=self.cfg["n_rewards"],
            model=network
        )

        dummy_input = self.get_model_dummy_input()
        # tx = optax.chain(optax.clip_by_global_norm(self.cfg["grad_clip"]), optax.adam(learning_rate=self.cfg["lr"]))
        tx = optax.chain(optax.adam(learning_rate=self.cfg["lr"]))
# 
        rng = jax.random.PRNGKey(0)
        
        dummy_emb = self.pretrained_backbone(**dummy_input)[1]

        self.model = Model.create(model_def=model_def, inputs=[rng, dummy_emb], tx=tx)
        # self.param_components.append("mode")

    def predict(
        self,
        prompts: Optional[Union[str]] ,
        rewards: Optional[Union[float]], 
        deterministic: Optional[bool] = True
    ) -> jnp.ndarray:
        
        inputs_token = self.tokenizer(prompts, return_tensors = "np",  max_length=100, padding="max_length", truncation=True).data
        
        dummy_emb = self.pretrained_backbone(**inputs_token)[1]

        self.rng, output = forward(
            rng=self.rng,
            model=self.model,
            input_emb = dummy_emb,
            deterministic=deterministic
        )   # batch, 3
            
        self.rng, _ = jax.random.split(self.rng)
        predictions = jnp.sum(rewards * output, -1) #(batch, )

        return predictions, output


    def _after_update(self):
        self.rng, _ = jax.random.split(self.rng)
        self.n_update += 1

    def update(
        self,
        prompts: Optional[Union[str]] ,
        rewards: Optional[Union[float]], 
        succs: Optional[Union[float]], 
        dummys: Optional[Union[float]] = None,
        reward_masks: Optional[Union[float]] = None,
        deterministic: Optional[bool] = False
    ) -> Dict[str, Any]:
        
        
        inputs_token = self.tokenizer(prompts, return_tensors = "np",  max_length=100, padding="max_length", truncation=True).data
        
        prompts = self.pretrained_backbone(**inputs_token)[1]
        
        
        new_model, info = jax_update(
            rng=self.rng,
            model=self.model,
            input_tokens=prompts,
            rewards=rewards,
            dummys=dummys,
            succs=succs,
            reward_masks=reward_masks,
            deterministic=deterministic
        )

        self.model = new_model

        self._after_update()
        return info

