from typing import Optional, Union, Dict, List

import jax
import numpy as np
import optax
import torch as th
from jax import numpy as jnp

from models.common.type_aliases import MultiModalEncoderOutput
from models.common.utils import Model
from models.third_party.multimodal_encoder import LlavaMultiModalEncoderTh
from transformers import (
    AutoProcessor,
    FlaxBertModel,
    BertTokenizerFast,
    FlaxViTModel
)

class LlavaMultiModalEncoder:
    embedding_dim: int = 1024

    def __init__(self, cfg: Dict = None):
        self.encoder = LlavaMultiModalEncoderTh.from_pretrained(**cfg).to("cuda")
        self.processor = AutoProcessor.from_pretrained(**cfg)

    def __call__(self, *args, **kwargs):
        return self.forward(*args, **kwargs)

    @th.no_grad()
    def forward(
        self,
        prompts: Optional[Union[str, List[str]]] = None,
        images: Optional[List[Union[jnp.ndarray, np.ndarray]]] = None
    ) -> MultiModalEncoderOutput:
        observation_inputs = self.processor(
            text=prompts,
            images=images,
            return_tensors="pt",
            padding="max_length",
            max_length=1024,
            truncation=True
        ).to("cuda")
        mm_encoded = self.encoder.encode(**observation_inputs, as_np=True)

        return mm_encoded
    
    
class VitBertMultiModalEncoderForCaption:
    embedding_dim: int = 768 * 2

    def __init__(self, cfg: Dict = None):

        self.bert_tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased")
        self.bert_encoder = FlaxBertModel.from_pretrained("bert-base-uncased")
        self.rng = jax.random.PRNGKey(0)

        text_input = self.bert_tokenizer("Hello, world", return_tensors="np")
        inputs = [self.rng, text_input["input_ids"], text_input["attention_mask"]]
        self.bert_encoder = Model.create(self.bert_encoder.module, inputs=inputs)

    def __call__(self, *args, **kwargs):
        return self.forward(*args, **kwargs)

    @th.no_grad()
    def forward(
        self,
        prompts: Optional[Union[str, List[str]]] = None,
        captions: Optional[Union[str, List[str]]] = None
    ) -> MultiModalEncoderOutput:

        text_inputs = self.bert_tokenizer(
            prompts,
            return_tensors="np",
            padding="max_length",
            max_length=150,
            truncation=True
        )
        
        
        text_features = VitBertMultiModalEncoderForCaption.bert_forward(
            model=self.bert_encoder,
            **text_inputs
        )[1]


        if captions != None:
            caption_inputs = self.bert_tokenizer(
                captions,
                return_tensors="np",
                padding="max_length",
                max_length=150,  # 120 (eval cross)
                truncation=True
            )

            caption_feature = VitBertMultiModalEncoderForCaption.bert_forward(
                model=self.bert_encoder,
                **caption_inputs
            )[1]
        
        batch_size = len(prompts)

        if captions != None:
            multimodal_embeddings = np.stack((caption_feature, text_features), axis=1)
            embedding_masks = np.ones((batch_size, 2))

        else:
            multimodal_embeddings = text_features
            embedding_masks = np.ones((batch_size, 1))

        return MultiModalEncoderOutput(multimodal_embeddings=multimodal_embeddings, embedding_masks=embedding_masks)


    @staticmethod
    @jax.jit
    def bert_forward(
        model,
        input_ids,
        token_type_ids,
        attention_mask
    ):
        return model(input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask)


class VitBertMultiModalEncoder:
    embedding_dim: int = 768 * 2

    def __init__(self, cfg: Dict = None):

        self.bert_tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased")
        self.bert_encoder = FlaxBertModel.from_pretrained("bert-base-uncased")
        self.rng = jax.random.PRNGKey(0)

        text_input = self.bert_tokenizer("Hello, world", return_tensors="np")
        inputs = [self.rng, text_input["input_ids"], text_input["attention_mask"]]
        self.bert_encoder = Model.create(self.bert_encoder.module, inputs=inputs)

    def __call__(self, *args, **kwargs):
        return self.forward(*args, **kwargs)

    @th.no_grad()
    def forward(
        self,
        prompts: Optional[Union[str, List[str]]] = None,
        captions: Optional[Union[str, List[str]]] = None,
    ) -> MultiModalEncoderOutput:

        text_inputs = self.bert_tokenizer(
            prompts,
            return_tensors="np",
            padding="max_length",
            max_length=180,
            truncation=True
        )



        text_features = VitBertMultiModalEncoder.bert_forward(
            model=self.bert_encoder,
            **text_inputs
        )[1]

        batch_size = len(prompts)
        multimodal_embeddings = text_features
        embedding_masks = np.ones((batch_size,))

        return MultiModalEncoderOutput(multimodal_embeddings=multimodal_embeddings, embedding_masks=embedding_masks)

    @staticmethod
    @jax.jit
    def bert_forward(
        model,
        input_ids,
        token_type_ids,
        attention_mask
    ):
        return model(input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask)

