"""
 Copyright (c) 2023, salesforce.com, inc.
 All rights reserved.
 SPDX-License-Identifier: BSD-3-Clause
 For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
"""
import logging
from typing import Mapping, Any

import math

import torch
import torch.nn as nn
import torch.nn.functional as F

from lavis.models.base_model import BaseModel
from lavis.models.blip2_models.blip2_opt import Blip2OPT
from lavis.models.blip_models.blip_outputs import BlipOutputFeatures


class Blip2PromptTuning(BaseModel):

    def __init__(
            self,
            prompt_cfg: dict,
            model: Blip2OPT,
    ):
        """
        apply_lemmatizer: when set to True, postprocess predict_answers() result with lemmas.
        """
        super().__init__()

        # Initialize model
        self.model = model
        self._apply_lemmatizer = model._apply_lemmatizer

        if prompt_cfg["q_former"]["enable"] \
                or prompt_cfg["language"]["enable"]:
            if not (prompt_cfg["train_model_param"]):
                self.freeze()

        # Prompt Initialize
        # Prompt to Q-Former
        if prompt_cfg["q_former"]["enable"]:
            prompt_cfg["q_former"].pop("enable")
            prompt_cfg["q_former"].update(
                {"hidden_size": model.Qformer.config.hidden_size}
            )
            self.q_former_prompt = PromptEmbedding(**prompt_cfg["q_former"])

        # Prompt to Language Model
        if prompt_cfg["language"]["enable"]:
            prompt_cfg["language"].pop("enable")
            prompt_cfg["language"].update(
                {"hidden_size": model.opt_model.config.hidden_size}
            )
            self.language_prompt = PromptEmbedding(**prompt_cfg["language"])

    def freeze(self):
        print("Freeze Model Parameters")
        for p in self.model.parameters():
            p.requires_grad = False

    def forward(self, samples, output_features: bool = False):
        image = samples["image"]
        with self.model.maybe_autocast():
            image_embeds = self.model.ln_vision(self.model.visual_encoder(image))
        image_atts = torch.ones(
            image_embeds.size()[:-1],
            dtype=torch.long
        ).to(image.device)

        query_tokens = self.model.query_tokens.expand(image_embeds.shape[0], -1, -1)
        # ------- q-former prompt ------- #
        if hasattr(self, "q_former_prompt"):
            query_tokens = self.q_former_prompt(query_tokens)
        # ------- ------- ------- ------- #

        query_output = self.model.Qformer.bert(
            query_embeds=query_tokens,
            encoder_hidden_states=image_embeds,
            encoder_attention_mask=image_atts,
            return_dict=True,
        )

        inputs_opt = self.model.opt_proj(query_output.last_hidden_state)
        # ------- language prompt ------- #
        if hasattr(self, "language_prompt"):
            inputs_opt = self.language_prompt(inputs_opt)
        # ------- ------- ------- ------- #
        atts_opt = torch.ones(inputs_opt.size()[:-1], dtype=torch.long).to(image.device)

        self.model.opt_tokenizer.padding_side = "right"

        text = [t + "\n" for t in samples["text_input"]]

        opt_tokens = self.model.opt_tokenizer(
            text,
            return_tensors="pt",
            padding="longest",
            truncation=True,
            max_length=self.model.max_txt_len,
        ).to(image.device)

        targets = opt_tokens.input_ids.masked_fill(
            opt_tokens.input_ids == self.model.opt_tokenizer.pad_token_id, -100
        )
        if self.model.prompt:
            targets[:, : self.model.prompt_length] = -100  # do not apply loss to the prompt

        empty_targets = (
            torch.ones(atts_opt.size(), dtype=torch.long).to(image.device).fill_(-100)
        )
        targets = torch.cat([empty_targets, targets], dim=1)

        inputs_embeds = self.model.opt_model.model.decoder.embed_tokens(opt_tokens.input_ids)
        inputs_embeds = torch.cat([inputs_opt, inputs_embeds], dim=1)
        attention_mask = torch.cat([atts_opt, opt_tokens.attention_mask], dim=1)

        with self.model.maybe_autocast():
            outputs = self.model.opt_model(
                inputs_embeds=inputs_embeds,
                attention_mask=attention_mask,
                return_dict=True,
                labels=targets,
                output_hidden_states=True
            )
        loss = outputs.loss

        return {
            "loss": loss,
        }

    @torch.no_grad()
    def generate(
            self,
            samples,
            use_nucleus_sampling=False,
            num_beams=5,
            max_length=30,
            min_length=1,
            top_p=0.9,
            repetition_penalty=1.0,
            length_penalty=1.0,
            num_captions=1,
            temperature=1,
    ):
        """
        Args:
            samples (dict): A dictionary containing the following keys:
                - image (torch.Tensor): A tensor of shape (batch_size, 3, H, W)
            use_nucleus_sampling (bool): Whether to use nucleus sampling. If False, use top-k sampling.
            num_beams (int): Number of beams for beam search. 1 means no beam search.
            max_length (int): The maximum length of the sequence to be generated.
            min_length (int): The minimum length of the sequence to be generated.
            top_p (float): The cumulative probability for nucleus sampling.
            repetition_penalty (float): The parameter for repetition penalty. 1.0 means no penalty.
            num_captions (int): Number of captions to be generated for each image.
        Returns:
            captions (list): A list of strings of length batch_size * num_captions.
        """
        image = samples["image"]
        with self.model.maybe_autocast():
            image_embeds = self.model.ln_vision(self.model.visual_encoder(image))
            image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(
                image.device
            )

            query_tokens = self.model.query_tokens.expand(image_embeds.shape[0], -1, -1)
            # ------- q-former prompt ------- #
            if hasattr(self, "q_former_prompt"):
                query_tokens = self.q_former_prompt(query_tokens)
            # ------- ------- ------- ------- #
            query_output = self.model.Qformer.bert(
                query_embeds=query_tokens,
                encoder_hidden_states=image_embeds,
                encoder_attention_mask=image_atts,
                return_dict=True,
            )

            inputs_opt = self.model.opt_proj(query_output.last_hidden_state)
            # ------- language prompt ------- #
            if hasattr(self, "language_prompt"):
                inputs_opt = self.language_prompt(inputs_opt)
            # ------- ------- ------- ------- #
            atts_opt = torch.ones(inputs_opt.size()[:-1], dtype=torch.long).to(
                image.device
            )

            if "prompt" in samples.keys():
                prompt = samples["prompt"]
            else:
                prompt = self.model.prompt

            prompt = [prompt] * image.size(0)

            opt_tokens = self.model.opt_tokenizer(
                prompt,
                return_tensors="pt",
                padding="longest",
                truncation=True,
                max_length=self.model.max_txt_len,
            ).to(image.device)
            attention_mask = torch.cat([atts_opt, opt_tokens.attention_mask], dim=1)

            # new version for transformers>=4.27
            inputs_embeds = self.model.opt_model.get_input_embeddings()(opt_tokens.input_ids)
            inputs_embeds = torch.cat([inputs_opt, inputs_embeds], dim=1)

            outputs = self.model.opt_model.generate(
                inputs_embeds=inputs_embeds,
                attention_mask=attention_mask,
                do_sample=use_nucleus_sampling,
                top_p=top_p,
                temperature=temperature,
                num_beams=num_beams,
                max_length=max_length,
                min_length=min_length,
                eos_token_id=self.model.eos_token_id,
                repetition_penalty=repetition_penalty,
                length_penalty=length_penalty,
                num_return_sequences=num_captions,
            )
            output_text = self.model.opt_tokenizer.batch_decode(
                outputs, skip_special_tokens=True
            )

            output_text = [text.strip() for text in output_text]
            return output_text

    def predict_answers(
            self,
            samples,
            num_beams=5,
            inference_method="generate",
            max_len=10,
            min_len=1,
            num_ans_candidates=128,
            answer_list=None,
            prompt="",
            length_penalty=0,
            **kwargs
    ):
        image = samples["image"]
        with self.model.maybe_autocast():
            image_embeds = self.model.ln_vision(self.model.visual_encoder(image))
            image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(
                image.device
            )

            query_tokens = self.model.query_tokens.expand(image_embeds.shape[0], -1, -1)
            # ------- q-former prompt ------- #
            if hasattr(self, "q_former_prompt"):
                query_tokens = self.q_former_prompt(query_tokens)
            # ------- ------- ------- ------- #
            query_output = self.model.Qformer.bert(
                query_embeds=query_tokens,
                encoder_hidden_states=image_embeds,
                encoder_attention_mask=image_atts,
                return_dict=True,
            )

            inputs_opt = self.model.opt_proj(query_output.last_hidden_state)
            # ------- language prompt ------- #
            if hasattr(self, "language_prompt"):
                inputs_opt = self.language_prompt(inputs_opt)
            # ------- ------- ------- ------- #
            atts_opt = torch.ones(inputs_opt.size()[:-1], dtype=torch.long).to(
                image.device
            )

            if isinstance(samples["text_input"], str):
                samples["text_input"] = [samples["text_input"]]
            if prompt:
                text_input = [prompt.format(question) for question in samples["text_input"]]
            else:
                text_input = samples["text_input"]

            self.model.opt_tokenizer.padding_side = "left"
            opt_tokens = self.model.opt_tokenizer(
                text_input,
                return_tensors="pt",
                padding="longest",
                truncation=True,
                max_length=self.model.max_txt_len,
            ).to(image.device)

            attention_mask = torch.cat([atts_opt, opt_tokens.attention_mask], dim=1)

            # require transformers>=4.27
            inputs_embeds = self.model.opt_model.get_input_embeddings()(opt_tokens.input_ids)
            inputs_embeds = torch.cat([inputs_opt, inputs_embeds], dim=1)

            outputs = self.model.opt_model.generate(
                inputs_embeds=inputs_embeds,
                attention_mask=attention_mask,
                do_sample=False,
                num_beams=num_beams,
                max_new_tokens=max_len,
                min_length=min_len,
                eos_token_id=self.model.eos_token_id,
                length_penalty=length_penalty,
            )
            output_text = self.model.opt_tokenizer.batch_decode(
                outputs, skip_special_tokens=True
            )
            output_text = [text.strip() for text in output_text]
        if self._apply_lemmatizer or ("apply_lemmatizer" in samples.keys() and samples["apply_lemmatizer"]):
            output_text = self._lemmatize(output_text)

        return output_text


class PromptEmbedding(nn.Module):
    def __init__(
            self,
            token_num,
            hidden_size,
            enable_reparameterize: bool = True,
            enable_type_embedding: bool = True,
    ):

        super().__init__()

        self.prompt = nn.Parameter(
            torch.FloatTensor(token_num, hidden_size)
        )
        nn.init.normal_(self.prompt, mean=0, std=math.sqrt(1 / hidden_size))

        # Re-parameterization of Q-Former Prompt
        self.reparameterize = None

        if enable_reparameterize:
            self.reparameterize = nn.Sequential(
                nn.Linear(hidden_size, int(hidden_size / 2)),
                nn.GELU(),
                nn.Linear(int(hidden_size / 2), hidden_size)
            )

        # Type Embedding for Q-Former Prompt
        self.type_embedding = None

        if enable_type_embedding:
            self.type_embedding = nn.Parameter(
                torch.FloatTensor(1, hidden_size)
            )
            nn.init.normal_(self.type_embedding, mean=0, std=math.sqrt(1 / hidden_size))

    def forward(self, input_embeddings):
        prompts = self.prompt

        if self.reparameterize is not None:
            prompts = self.reparameterize(prompts)

        if self.type_embedding is not None:
            prompts += self.type_embedding

        prompts = prompts.unsqueeze(0).expand(input_embeddings.size(0), -1, -1)

        return torch.concat([prompts, input_embeddings], dim=1)
