"""
 Copyright (c) 2023, salesforce.com, inc.
 All rights reserved.
 SPDX-License-Identifier: BSD-3-Clause
 For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
"""
import torch

from lavis.models.base_model import BaseModel
from lavis.models.blip2_models.blip2_opt import Blip2OPT


class Blip2RankExtraction(BaseModel):

    def __init__(
            self,
            model: Blip2OPT,
    ):
        """
        apply_lemmatizer: when set to True, postprocess predict_answers() result with lemmas.
        """
        super().__init__()

        # Initialize model
        self.model = model

    def forward(self, samples, output_features: bool = False):
        image = samples["image"]
        with self.model.maybe_autocast():
            image_embeds = self.model.ln_vision(self.model.visual_encoder(image))
        image_atts = torch.ones(
            image_embeds.size()[:-1],
            dtype=torch.long
        ).to(image.device)

        query_tokens = self.model.query_tokens.expand(image_embeds.shape[0], -1, -1)
        # ------- q-former prompt ------- #
        if hasattr(self, "q_former_prompt"):
            query_tokens = self.q_former_prompt(query_tokens)
        # ------- ------- ------- ------- #

        query_output = self.model.Qformer.bert(
            query_embeds=query_tokens,
            encoder_hidden_states=image_embeds,
            encoder_attention_mask=image_atts,
            return_dict=True,
        )

        inputs_opt = self.model.opt_proj(query_output.last_hidden_state)
        # ------- language prompt ------- #
        if hasattr(self, "language_prompt"):
            inputs_opt = self.language_prompt(inputs_opt)
        # ------- ------- ------- ------- #
        atts_opt = torch.ones(inputs_opt.size()[:-1], dtype=torch.long).to(image.device)

        self.model.opt_tokenizer.padding_side = "right"

        text = [t + "\n" for t in samples["text_input"]]

        opt_tokens = self.model.opt_tokenizer(
            text,
            return_tensors="pt",
            padding="longest",
            truncation=True,
            max_length=self.model.max_txt_len,
        ).to(image.device)

        targets = opt_tokens.input_ids.masked_fill(
            opt_tokens.input_ids == self.model.opt_tokenizer.pad_token_id, -100
        )
        if self.model.prompt:
            targets[:, : self.model.prompt_length] = -100  # do not apply loss to the prompt

        empty_targets = (
            torch.ones(atts_opt.size(), dtype=torch.long).to(image.device).fill_(-100)
        )
        targets = torch.cat([empty_targets, targets], dim=1)

        inputs_embeds = self.model.opt_model.model.decoder.embed_tokens(opt_tokens.input_ids)
        inputs_embeds = torch.cat([inputs_opt, inputs_embeds], dim=1)
        attention_mask = torch.cat([atts_opt, opt_tokens.attention_mask], dim=1)

        with self.model.maybe_autocast():
            outputs = self.model.opt_model(
                inputs_embeds=inputs_embeds,
                attention_mask=attention_mask,
                return_dict=True,
                labels=targets,
                output_hidden_states=True
            )
        loss = outputs.loss

        return {
            "loss": loss,
            "image_embeds_input": image_embeds if output_features else None,
            "image_embeds_proj": query_output.last_hidden_state[:, :self.model.query_tokens.shape[1], :] if output_features else None,
            "text_embeds_input": inputs_embeds[:, atts_opt.shape[1]:, :] if output_features else None,
            "text_embeds_proj": outputs.hidden_states[-1][:, atts_opt.shape[1]:, :] if output_features else None
        }