from typing import Optional, Union, Dict, List

import jax
import numpy as np
import optax
import torch as th
from jax import numpy as jnp

from models.common.type_aliases import MultiModalEncoderOutput
from models.common.utils import Model
from models.third_party.multimodal_encoder import LlavaMultiModalEncoderTh
from transformers import (
    AutoProcessor,
    FlaxBertModel,
    BertTokenizerFast,
    FlaxViTModel
)
class LlavaMultiModalEncoder:
    embedding_dim: int = 1024

    def __init__(self, cfg: Dict = None):
        self.encoder = LlavaMultiModalEncoderTh.from_pretrained(**cfg).to("cuda")
        self.processor = AutoProcessor.from_pretrained(**cfg)

    def __call__(self, *args, **kwargs):
        return self.forward(*args, **kwargs)

    @th.no_grad()
    def forward(
        self,
        prompts: Optional[Union[str, List[str]]] = None,
        images: Optional[List[Union[jnp.ndarray, np.ndarray]]] = None
    ) -> MultiModalEncoderOutput:
        observation_inputs = self.processor(
            text=prompts,
            images=images,
            return_tensors="pt",
            padding="max_length",
            max_length=1024,
            truncation=True
        ).to("cuda")
        mm_encoded = self.encoder.encode(**observation_inputs, as_np=True)

        return mm_encoded
    

class VitBertMultiModalEncoderForImage:
    embedding_dim: int = 768 * 2

    def __init__(self, cfg: Dict = None):
        
        self.vit_processor = AutoProcessor.from_pretrained("google/vit-base-patch16-224-in21k")
        self.vit_encoder = FlaxViTModel.from_pretrained("google/vit-base-patch16-224-in21k")  # type: FlaxViTModel
        
        self.bert_tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased")
        self.bert_encoder = FlaxBertModel.from_pretrained("bert-base-uncased")
        self.rng = jax.random.PRNGKey(0)

        text_input = self.bert_tokenizer("Hello, world", return_tensors="np")
        inputs = [self.rng, text_input["input_ids"], text_input["attention_mask"]]
        self.bert_encoder = Model.create(self.bert_encoder.module, inputs=inputs)
        
        image_input = self.vit_processor(jnp.zeros((2, 400, 400, 3)), return_tensors="np")  # [b, h, w, rgb]
        inputs = [self.rng, jnp.transpose(image_input["pixel_values"], (0, 2, 3, 1))]
        self.vit_encoder = Model.create(self.vit_encoder.module, inputs=inputs)
        

    def __call__(self, *args, **kwargs):
        return self.forward(*args, **kwargs)

    @th.no_grad()
    def forward(
        self,
        prompts: Optional[Union[str, List[str]]] = None,
        captions: Optional[Union[str, List[str]]] = None,
    ) -> MultiModalEncoderOutput:


        img_inputs = self.vit_processor(images=captions, return_tensors="np")
        img_features = VitBertMultiModalEncoderForImage.vit_forward(
            model=self.vit_encoder,
            **img_inputs
        )[1]

        text_inputs = self.bert_tokenizer(
            prompts,
            return_tensors="np",
            padding="max_length",
            max_length=100,
            truncation=True
        )

        text_features = VitBertMultiModalEncoderForImage.bert_forward(
            model=self.bert_encoder,
            **text_inputs
        )[1]
        
        batch_size = len(prompts)
        multimodal_embeddings = np.stack((img_features, text_features), axis=1)
        embedding_masks = np.ones((batch_size, 2))

        return MultiModalEncoderOutput(multimodal_embeddings=multimodal_embeddings, embedding_masks=embedding_masks)

    @staticmethod
    @jax.jit
    def bert_forward(
        model,
        input_ids,
        token_type_ids,
        attention_mask
    ):
        return model(input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask)

    @staticmethod
    @jax.jit
    def vit_forward(
        model,
        pixel_values,
    ):

        return model(pixel_values=jnp.transpose(pixel_values, (0, 2, 3, 1)))

    
class VitBertMultiModalEncoder:
    embedding_dim: int = 768 * 2

    def __init__(self, cfg: Dict = None):

        self.bert_tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased")
        self.bert_encoder = FlaxBertModel.from_pretrained("bert-base-uncased")
        self.rng = jax.random.PRNGKey(0)

        text_input = self.bert_tokenizer("Hello, world", return_tensors="np")
        inputs = [self.rng, text_input["input_ids"], text_input["attention_mask"]]
        self.bert_encoder = Model.create(self.bert_encoder.module, inputs=inputs)

    def __call__(self, *args, **kwargs):
        return self.forward(*args, **kwargs)

    @th.no_grad()
    def forward(
        self,
        prompts: Optional[Union[str, List[str]]] = None,
        captions: Optional[Union[str, List[str]]] = None,
    ) -> MultiModalEncoderOutput:

        text_inputs = self.bert_tokenizer(
            prompts,
            return_tensors="np",
            padding="max_length",
            max_length=180,
            truncation=True
        )

        text_features = VitBertMultiModalEncoder.bert_forward(
            model=self.bert_encoder,
            **text_inputs
        )[1]

        batch_size = len(prompts)
        multimodal_embeddings = text_features
        embedding_masks = np.ones((batch_size,))

        return MultiModalEncoderOutput(multimodal_embeddings=multimodal_embeddings, embedding_masks=embedding_masks)

    @staticmethod
    @jax.jit
    def bert_forward(
        model,
        input_ids,
        token_type_ids,
        attention_mask
    ):
        return model(input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask)

