"""
 Copyright (c) 2023, salesforce.com, inc.
 All rights reserved.
 SPDX-License-Identifier: BSD-3-Clause
 For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
"""
import logging
from typing import Mapping, Any

import math

import torch
import torch.nn as nn

from lavis.models.blip_models.blip_vqa import BlipVQA
from lavis.models.blip_models.blip import BlipBase
from lavis.models.med import BertEmbeddings


class BlipPromptTuning(BlipBase):

    def __init__(
            self,
            prompt_cfg: dict,
            model: BlipVQA,
    ):
        """
        apply_lemmatizer: when set to True, postprocess predict_answers() result with lemmas.
        """
        super().__init__()

        # Initialize model
        self.model = model

        if prompt_cfg["text_encoder"]["enable"] \
                or prompt_cfg["text_decoder"]["enable"]:
            if not(prompt_cfg["train_model_param"]):
                self.freeze()

        # Prompt Initialize
        # Prompt to text encoder
        if prompt_cfg["text_encoder"]["enable"]:
            prompt_cfg["text_encoder"].pop("enable")
            prompt_cfg["text_encoder"].update(
                {"hidden_size": model.text_encoder.config.hidden_size}
            )
            self.model.text_encoder.embeddings = BlipLanguageEmbedding(
                self.model.text_encoder.embeddings,
                prompt_cfg["text_encoder"]
            )

        # Prompt to text decoder
        if prompt_cfg["text_decoder"]["enable"]:
            prompt_cfg["text_decoder"].pop("enable")
            prompt_cfg["text_decoder"].update(
                {"hidden_size": model.text_decoder.config.hidden_size}
            )
            self.model.text_decoder.base_model.embeddings = BlipLanguageEmbedding(
                self.model.text_decoder.base_model.embeddings,
                prompt_cfg["text_decoder"]
            )

    def freeze(self):
        print("Freeze Model Parameters")
        for p in self.model.parameters():
            p.requires_grad = False

    def forward(self, samples):

        return self.model.forward(samples)

    def predict_answers(
            self,
            samples,
            num_beams=3,
            inference_method="rank",
            max_len=10,
            min_len=1,
            num_ans_candidates=128,
            answer_list=None,
            **kwargs
    ):

        return self.model.predict_answers(
            samples,
            num_beams,
            inference_method,
            max_len,
            min_len,
            num_ans_candidates,
            answer_list,
            **kwargs
    )

class BlipLanguageEmbedding(nn.Module):

    def __init__(self, embedding_layer, prompt_cfg):
        super().__init__()
        self.embedding = embedding_layer
        self.prompt_embedding = PromptEmbedding(**prompt_cfg)

    def forward(
            self,
            input_ids=None,
            token_type_ids=None,
            position_ids=None,
            inputs_embeds=None,
            past_key_values_length=0,
    ):

        if input_ids is not None:
            input_shape = input_ids.size()
        else:
            input_shape = inputs_embeds.size()[:-1]

        seq_length = input_shape[1]

        if position_ids is None:
            position_ids = self.embedding.position_ids[
                :, past_key_values_length : seq_length + past_key_values_length
            ]

        if inputs_embeds is None:
            inputs_embeds = self.embedding.word_embeddings(input_ids)

        if token_type_ids is not None:
            token_type_embeddings = self.embedding.token_type_embeddings(token_type_ids)

            embeddings = inputs_embeds + token_type_embeddings
        else:
            embeddings = inputs_embeds

        if self.embedding.position_embedding_type == "absolute":
            position_embeddings = self.embedding.position_embeddings(position_ids)
            embeddings += position_embeddings

        embeddings = self.prompt_embedding(embeddings)
        embeddings = self.embedding.LayerNorm(embeddings)
        embeddings = self.embedding.dropout(embeddings)

        return embeddings

class PromptEmbedding(nn.Module):
    def __init__(
            self,
            token_num,
            hidden_size,
            enable_reparameterize: bool = True,
            enable_type_embedding: bool = True,
    ):

        super().__init__()

        self.prompt = nn.Parameter(
            torch.FloatTensor(token_num, hidden_size)
        )
        nn.init.normal_(self.prompt, mean=0, std=math.sqrt(1 / hidden_size))

        # Re-parameterization of Prompt
        self.reparameterize = None

        if enable_reparameterize:
            self.reparameterize = nn.Sequential(
                nn.Linear(hidden_size, int(hidden_size / 2)),
                nn.GELU(),
                nn.Linear(int(hidden_size / 2), hidden_size)
            )

        # Type Embedding for Prompt
        self.type_embedding = None

        if enable_type_embedding:
            self.type_embedding = nn.Parameter(
                torch.FloatTensor(1, hidden_size)
            )
            nn.init.normal_(self.type_embedding, mean=0, std=math.sqrt(1 / hidden_size))

    def forward(self, input_embeddings):
        prompts = self.prompt

        if self.reparameterize is not None:
            prompts = self.reparameterize(prompts)

        if self.type_embedding is not None:
            prompts += self.type_embedding

        prompts = prompts.unsqueeze(0).expand(input_embeddings.size(0), -1, -1)

        return torch.concat([prompts, input_embeddings], dim=1)