from dataclasses import dataclass
from typing import Optional

import torch
from jax import numpy as jnp

from transformers import LlavaForConditionalGeneration, LlavaConfig
from models.common.type_aliases import MultiModalEncoderOutput


class LlavaMultiModalEncoder(LlavaForConditionalGeneration):
    def __init__(self, config: LlavaConfig):
        super().__init__(config=config)

    def encode(
        self,
        input_ids: torch.LongTensor = None,
        pixel_values: torch.FloatTensor = None,
        attention_mask: Optional[torch.Tensor] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        vision_feature_layer: Optional[int] = None,
        vision_feature_select_strategy: Optional[str] = None,
        labels: Optional[torch.LongTensor] = None,
        as_np: Optional[bool] = False
    ) -> MultiModalEncoderOutput:

        vision_feature_layer = (
            vision_feature_layer if vision_feature_layer is not None else self.config.vision_feature_layer
        )
        vision_feature_select_strategy = (
            vision_feature_select_strategy
            if vision_feature_select_strategy is not None
            else self.config.vision_feature_select_strategy
        )

        if inputs_embeds is None:
            # 1. Extra the input embeddings
            inputs_embeds = self.get_input_embeddings()(input_ids)

            # 2. Merge text and images
            if pixel_values is not None and input_ids.shape[1] != 1:
                image_outputs = self.vision_tower(pixel_values, output_hidden_states=True)
                # this is not memory efficient at all (output_hidden_states=True) will save all the hidden stated.
                selected_image_feature = image_outputs.hidden_states[vision_feature_layer]

                if vision_feature_select_strategy == "default":
                    selected_image_feature = selected_image_feature[:, 1:]
                elif vision_feature_select_strategy == "full":
                    selected_image_feature = selected_image_feature
                else:
                    raise ValueError(
                        f"Unexpected select feature strategy: {self.config.vision_feature_select_strategy}"
                    )

                image_features = self.multi_modal_projector(selected_image_feature)
                inputs_embeds, attention_mask, *_ = self._merge_input_ids_with_image_features(
                    image_features, inputs_embeds, input_ids, attention_mask, labels
                )

        else:
            raise NotImplementedError("input_embeds should be not given")

        if as_np:
            inputs_embeds = inputs_embeds.detach().cpu().numpy()
            attention_mask = attention_mask.detach().cpu().numpy()
        return MultiModalEncoderOutput(multimodal_embeddings=inputs_embeds, embedding_masks=attention_mask)
