from datasets import Dataset
from torch.utils.data import DataLoader
from trl import DPOTrainer
from trl.trainer.utils import pad_to_length
from contextlib import contextmanager, nullcontext
from transformers import PreTrainedModel
import warnings
from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union
from accelerate.utils import tqdm
import torch
import torch.nn as nn
import torch.nn.functional as F

__all__ = ["BAPOTrainer"]


class BAPOTrainer(DPOTrainer):
    def __init__(self, bapo_lambda1, bapo_lambda2, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.bapo_lambda1 = bapo_lambda1
        self.bapo_lambda2 = bapo_lambda2
        print("\n================ BAPO Trainer =================")

    def get_train_dataloader(self) -> DataLoader:
        """
        Returns the training [`~torch.utils.data.DataLoader`].

        Subclass of transformers.src.transformers.trainer.get_train_dataloader to precompute `ref_log_probs`.
        """
        if self.precompute_ref_log_probs and not self._precomputed_train_ref_log_probs:
            dataloader_params = {
                "batch_size": self.args.per_device_train_batch_size,
                "collate_fn": self.data_collator,
                "num_workers": self.args.dataloader_num_workers,
                "pin_memory": self.args.dataloader_pin_memory,
                "shuffle": False,
            }
            # prepare dataloader
            data_loader = self.accelerator.prepare(
                DataLoader(self.train_dataset, **dataloader_params)
            )

            reference_chosen_logps = []
            reference_rejected_logps = []
            reference_base_logps = []
            for padded_batch in tqdm(
                iterable=data_loader, desc="Train dataset reference log probs"
            ):
                (
                    reference_chosen_logp,
                    reference_rejected_logp,
                    reference_base_logp,
                ) = self.compute_reference_log_probs(padded_batch)
                (
                    reference_chosen_logp,
                    reference_rejected_logp,
                    reference_base_logp,
                ) = self.accelerator.gather_for_metrics(
                    (
                        reference_chosen_logp,
                        reference_rejected_logp,
                        reference_base_logp,
                    )
                )
                reference_chosen_logps.append(reference_chosen_logp.cpu())
                reference_rejected_logps.append(reference_rejected_logp.cpu())
                reference_base_logps.append(reference_base_logp.cpu())

            all_reference_chosen_logps = (
                torch.cat(reference_chosen_logps).float().numpy()
            )
            all_reference_rejected_logps = (
                torch.cat(reference_rejected_logps).float().numpy()
            )
            all_reference_base_logps = torch.cat(reference_base_logps).float().numpy()

            self.train_dataset = self.train_dataset.add_column(
                name="reference_chosen_logps", column=all_reference_chosen_logps
            )
            self.train_dataset = self.train_dataset.add_column(
                name="reference_rejected_logps", column=all_reference_rejected_logps
            )
            self.train_dataset = self.train_dataset.add_column(
                name="reference_base_logps", column=all_reference_base_logps
            )

            self._precomputed_train_ref_log_probs = True

        return super().get_train_dataloader()

    def get_eval_dataloader(self, eval_dataset: Optional[Dataset] = None) -> DataLoader:
        """
        Returns the evaluation [`~torch.utils.data.DataLoader`].

        Subclass of transformers.src.transformers.trainer.get_eval_dataloader to precompute `ref_log_probs`.

        Args:
            eval_dataset (`torch.utils.data.Dataset`, *optional*):
                If provided, will override `self.eval_dataset`. If it is a [`~datasets.Dataset`], columns not accepted
                by the `model.forward()` method are automatically removed. It must implement `__len__`.
        """
        if eval_dataset is None and self.eval_dataset is None:
            raise ValueError("Trainer: evaluation requires an eval_dataset.")
        eval_dataset = eval_dataset if eval_dataset is not None else self.eval_dataset

        if self.precompute_ref_log_probs and not self._precomputed_eval_ref_log_probs:
            dataloader_params = {
                "batch_size": self.args.per_device_eval_batch_size,
                "collate_fn": self.data_collator,
                "num_workers": self.args.dataloader_num_workers,
                "pin_memory": self.args.dataloader_pin_memory,
                "shuffle": False,
            }
            # prepare dataloader
            data_loader = self.accelerator.prepare(
                DataLoader(eval_dataset, **dataloader_params)
            )

            reference_chosen_logps = []
            reference_rejected_logps = []
            reference_base_logps = []

            for padded_batch in tqdm(
                iterable=data_loader, desc="Eval dataset reference log probs"
            ):
                (
                    reference_chosen_logp,
                    reference_rejected_logp,
                    reference_base_logp,
                ) = self.compute_reference_log_probs(padded_batch)
                (
                    reference_chosen_logp,
                    reference_rejected_logp,
                    reference_base_logp,
                ) = self.accelerator.gather_for_metrics(
                    (
                        reference_chosen_logp,
                        reference_rejected_logp,
                        reference_base_logp,
                    )
                )
                reference_chosen_logps.append(reference_chosen_logp.cpu())
                reference_rejected_logps.append(reference_rejected_logp.cpu())
                reference_base_logps.append(reference_base_logp.cpu())

            all_reference_chosen_logps = (
                torch.cat(reference_chosen_logps).float().numpy()
            )
            all_reference_rejected_logps = (
                torch.cat(reference_rejected_logps).float().numpy()
            )
            all_reference_base_logps = torch.cat(reference_base_logps).float().numpy()

            eval_dataset = eval_dataset.add_column(
                name="reference_chosen_logps", column=all_reference_chosen_logps
            )
            eval_dataset = eval_dataset.add_column(
                name="reference_rejected_logps", column=all_reference_rejected_logps
            )
            eval_dataset = eval_dataset.add_column(
                name="reference_base_logps", column=all_reference_base_logps
            )

            # Save calculated reference_chosen_logps and reference_rejected_logps to the eval_dataset for subsequent runs
            if self.eval_dataset is not None:
                self.eval_dataset = eval_dataset
            self._precomputed_eval_ref_log_probs = True

        return super().get_eval_dataloader(eval_dataset=eval_dataset)

    def tokenize_row(
        self, feature, model: Optional[Union[PreTrainedModel, nn.Module]] = None
    ) -> Dict:
        """Tokenize a single row from a BAPO specific dataset.

        At this stage, we don't convert to PyTorch tensors yet; we just handle the truncation
        in case the prompt + chosen, prompt + rejected, or prompt + base responses is/are too long.
        First we truncate the prompt; if we're still too long, we truncate the chosen/rejected/base.

        We also create the labels for the chosen/rejected/base responses, which are of length equal to
            the sum of the length of the prompt and the chosen/rejected/base response, with
            label_pad_token_id  for the prompt tokens.
        """
        batch = {}
        prompt = feature["prompt"]
        chosen = feature["chosen"]
        rejected = feature["rejected"]
        base = feature["base"]

        if not self.is_encoder_decoder:
            # Check issues below for more details
            #  1. https://github.com/huggingface/trl/issues/907
            #  2. https://github.com/EleutherAI/lm-evaluation-harness/pull/531#issuecomment-1595586257
            #  3. https://github.com/LianjiaTech/BELLE/issues/337

            if not isinstance(prompt, str):
                raise ValueError(f"prompt should be an str but got {type(prompt)}")
            prompt_tokens = self.tokenizer(prompt, add_special_tokens=False)
            prompt_tokens = {f"prompt_{k}": v for k, v in prompt_tokens.items()}

            if not isinstance(chosen, str):
                raise ValueError(f"chosen should be an str but got {type(chosen)}")
            chosen_tokens = self.build_tokenized_answer(prompt, chosen)

            if not isinstance(rejected, str):
                raise ValueError(f"rejected should be an str but got {type(rejected)}")
            rejected_tokens = self.build_tokenized_answer(prompt, rejected)

            if not isinstance(base, str):
                raise ValueError(f"base should be an str but got {type(base)}")
            base_tokens = self.build_tokenized_answer(prompt, base)

            # Last prompt token might get merged by tokenizer and
            # it should not be included for generation if that happens
            prompt_len_input_ids = len(prompt_tokens["prompt_input_ids"])

            chosen_prompt_len_input_ids = len(chosen_tokens["prompt_input_ids"])
            rejected_prompt_len_input_ids = len(rejected_tokens["prompt_input_ids"])
            base_prompt_len_input_ids = len(base_tokens["prompt_input_ids"])
            prompt_len_input_ids = min(
                chosen_prompt_len_input_ids,
                rejected_prompt_len_input_ids,
                base_prompt_len_input_ids,
            )

            for k, v in prompt_tokens.items():
                prompt_tokens[k] = v[:prompt_len_input_ids]

            # Make sure prompts only have one different token at most an
            # and length only differs by 1 at most
            num_diff_tokens = sum(
                [
                    a != b
                    for a, b in zip(
                        chosen_tokens["prompt_input_ids"],
                        rejected_tokens["prompt_input_ids"],
                    )
                ]
            )
            num_diff_len = abs(
                chosen_prompt_len_input_ids - rejected_prompt_len_input_ids
            )
            if num_diff_tokens > 1 or num_diff_len > 1:
                raise ValueError(
                    "Chosen and rejected prompt_input_ids might only differ on the "
                    "last token due to tokenizer merge ops."
                )

            # Make sure prompts only have one different token at most an
            # and length only differs by 1 at most
            num_diff_tokens_chosen_rejected = sum(
                [
                    a != b
                    for a, b in zip(
                        chosen_tokens["prompt_input_ids"],
                        rejected_tokens["prompt_input_ids"],
                    )
                ]
            )
            num_diff_len_chosen_rejected = abs(
                chosen_prompt_len_input_ids - rejected_prompt_len_input_ids
            )

            num_diff_tokens_chosen_base = sum(
                [
                    a != b
                    for a, b in zip(
                        chosen_tokens["prompt_input_ids"],
                        base_tokens["prompt_input_ids"],
                    )
                ]
            )
            num_diff_len_chosen_base = abs(
                chosen_prompt_len_input_ids - base_prompt_len_input_ids
            )

            num_diff_tokens_rejected_base = sum(
                [
                    a != b
                    for a, b in zip(
                        rejected_tokens["prompt_input_ids"],
                        base_tokens["prompt_input_ids"],
                    )
                ]
            )
            num_diff_len_rejected_base = abs(
                rejected_prompt_len_input_ids - base_prompt_len_input_ids
            )

            if (
                num_diff_tokens_chosen_rejected > 1
                or num_diff_len_chosen_rejected > 1
                or num_diff_tokens_chosen_base > 1
                or num_diff_len_chosen_base > 1
                or num_diff_tokens_rejected_base > 1
                or num_diff_len_rejected_base > 1
            ):
                raise ValueError(
                    "Chosen, rejected, and base prompt_input_ids might only differ on the "
                    "last token due to tokenizer merge ops."
                )

            # add BOS token to the head of prompt
            prompt_tokens["prompt_input_ids"] = [
                self.tokenizer.bos_token_id
            ] + prompt_tokens["prompt_input_ids"]
            chosen_tokens["prompt_input_ids"] = [
                self.tokenizer.bos_token_id
            ] + chosen_tokens["prompt_input_ids"]
            rejected_tokens["prompt_input_ids"] = [
                self.tokenizer.bos_token_id
            ] + rejected_tokens["prompt_input_ids"]
            base_tokens["prompt_input_ids"] = [
                self.tokenizer.bos_token_id
            ] + base_tokens["prompt_input_ids"]

            prompt_tokens["prompt_attention_mask"] = [1] + prompt_tokens[
                "prompt_attention_mask"
            ]
            chosen_tokens["prompt_attention_mask"] = [1] + chosen_tokens[
                "prompt_attention_mask"
            ]
            rejected_tokens["prompt_attention_mask"] = [1] + rejected_tokens[
                "prompt_attention_mask"
            ]
            base_tokens["prompt_attention_mask"] = [1] + base_tokens[
                "prompt_attention_mask"
            ]

            # add EOS token to the end of answer
            chosen_tokens["input_ids"].append(self.tokenizer.eos_token_id)
            chosen_tokens["attention_mask"].append(1)

            rejected_tokens["input_ids"].append(self.tokenizer.eos_token_id)
            rejected_tokens["attention_mask"].append(1)

            base_tokens["input_ids"].append(self.tokenizer.eos_token_id)
            base_tokens["attention_mask"].append(1)

            longer_response_length = max(
                len(chosen_tokens["input_ids"]),
                len(rejected_tokens["input_ids"]),
                len(base_tokens["input_ids"]),
            )

            # if combined sequence is too long, truncate the prompt
            for answer_tokens in [
                chosen_tokens,
                rejected_tokens,
                base_tokens,
                prompt_tokens,
            ]:
                if (
                    len(answer_tokens["prompt_input_ids"]) + longer_response_length
                    > self.max_length
                ):
                    if self.truncation_mode == "keep_start":
                        for k in ["prompt_input_ids", "prompt_attention_mask"]:
                            answer_tokens[k] = answer_tokens[k][
                                : self.max_prompt_length
                            ]
                    elif self.truncation_mode == "keep_end":
                        for k in ["prompt_input_ids", "prompt_attention_mask"]:
                            answer_tokens[k] = answer_tokens[k][
                                -self.max_prompt_length :
                            ]
                    else:
                        raise ValueError(
                            f"Unknown truncation mode: {self.truncation_mode}"
                        )

            # if that's still too long, truncate the response
            for answer_tokens in [chosen_tokens, rejected_tokens, base_tokens]:
                if (
                    len(answer_tokens["prompt_input_ids"]) + longer_response_length
                    > self.max_length
                ):
                    for k in ["input_ids", "attention_mask"]:
                        answer_tokens[k] = answer_tokens[k][
                            : self.max_length - self.max_prompt_length
                        ]

            # create the labels
            chosen_sequence_tokens = {
                k: chosen_tokens[f"prompt_{k}"] + chosen_tokens[k]
                for k in ["input_ids", "attention_mask"]
            }
            rejected_sequence_tokens = {
                k: rejected_tokens[f"prompt_{k}"] + rejected_tokens[k]
                for k in ["input_ids", "attention_mask"]
            }
            base_sequence_tokens = {
                k: base_tokens[f"prompt_{k}"] + base_tokens[k]
                for k in ["input_ids", "attention_mask"]
            }
            chosen_sequence_tokens["labels"] = chosen_sequence_tokens["input_ids"][:]
            chosen_sequence_tokens["labels"][
                : len(chosen_tokens["prompt_input_ids"])
            ] = [self.label_pad_token_id] * len(chosen_tokens["prompt_input_ids"])
            rejected_sequence_tokens["labels"] = rejected_sequence_tokens["input_ids"][
                :
            ]
            rejected_sequence_tokens["labels"][
                : len(rejected_tokens["prompt_input_ids"])
            ] = [self.label_pad_token_id] * len(rejected_tokens["prompt_input_ids"])
            base_sequence_tokens["labels"] = base_sequence_tokens["input_ids"][:]
            base_sequence_tokens["labels"][: len(base_tokens["prompt_input_ids"])] = [
                self.label_pad_token_id
            ] * len(base_tokens["prompt_input_ids"])

            for k, toks in {
                "chosen_": chosen_sequence_tokens,
                "rejected_": rejected_sequence_tokens,
                "base_": base_sequence_tokens,
                "": prompt_tokens,
            }.items():
                for type_key, tokens in toks.items():
                    if type_key == "token_type_ids":
                        continue
                    batch[f"{k}{type_key}"] = tokens

        else:
            chosen_tokens = self.tokenizer(
                chosen,
                truncation=True,
                max_length=self.max_target_length,
                add_special_tokens=True,
            )
            rejected_tokens = self.tokenizer(
                rejected,
                truncation=True,
                max_length=self.max_target_length,
                add_special_tokens=True,
            )
            base_tokens = self.tokenizer(
                base,
                truncation=True,
                max_length=self.max_target_length,
                add_special_tokens=True,
            )
            batch["chosen_labels"] = chosen_tokens["input_ids"]
            batch["rejected_labels"] = rejected_tokens["input_ids"]
            batch["base_labels"] = base_tokens["input_ids"]
            batch["prompt_input_ids"] = prompt_tokens["input_ids"]
            batch["prompt_attention_mask"] = prompt_tokens["attention_mask"]

            if model is not None and hasattr(
                model, "prepare_decoder_input_ids_from_labels"
            ):
                batch["base_decoder_input_ids"] = (
                    model.prepare_decoder_input_ids_from_labels(
                        labels=torch.tensor(batch["base_labels"])
                    )
                )
                batch["rejected_decoder_input_ids"] = (
                    model.prepare_decoder_input_ids_from_labels(
                        labels=torch.tensor(batch["rejected_labels"])
                    )
                )
                batch["chosen_decoder_input_ids"] = (
                    model.prepare_decoder_input_ids_from_labels(
                        labels=torch.tensor(batch["chosen_labels"])
                    )
                )

        return batch

    def compute_reference_log_probs(self, padded_batch: Dict) -> Dict:
        """Computes log probabilities of the reference model for a single padded batch of a BAPO specific dataset."""
        compte_ref_context_manager = (
            torch.cuda.amp.autocast
            if self._peft_has_been_casted_to_bf16
            else nullcontext
        )

        # compute reference logps
        with torch.no_grad(), compte_ref_context_manager():
            if self.ref_model is None:
                with self.null_ref_context():
                    (
                        reference_chosen_logps,
                        reference_rejected_logps,
                        reference_base_logps,
                        _,
                        _,
                        _,
                    ) = self.concatenated_forward(self.model, padded_batch)
            else:
                (
                    reference_chosen_logps,
                    reference_rejected_logps,
                    reference_base_logps,
                    _,
                    _,
                    _,
                ) = self.concatenated_forward(self.ref_model, padded_batch)

        return reference_chosen_logps, reference_rejected_logps, reference_base_logps

    @staticmethod
    def concatenated_inputs(
        batch: Dict[str, Union[List, torch.LongTensor]],
        is_encoder_decoder: bool = False,
        label_pad_token_id: int = -100,
        padding_value: int = 0,
        device: Optional[torch.device] = None,
    ) -> Dict[str, torch.LongTensor]:
        """Concatenate the chosen, rejected, and base inputs into a single tensor.

        Args:
            batch: A batch of data. Must contain the keys 'chosen_input_ids', 'rejected_input_ids', and 'base_input_ids', which are tensors of shape (batch_size, sequence_length).
            is_encoder_decoder: Whether the model is an encoder-decoder model.
            label_pad_token_id: The label pad token id.
            padding_value: The padding value to use for the concatenated inputs_ids.
            device: The device for the concatenated inputs.

        Returns:
            A dictionary containing the concatenated inputs under the key 'concatenated_input_ids'.
        """
        concatenated_batch = {}

        if is_encoder_decoder:
            max_length = max(
                batch["chosen_labels"].shape[1],
                batch["rejected_labels"].shape[1],
                batch["base_labels"].shape[1],
            )
        else:
            max_length = max(
                batch["chosen_input_ids"].shape[1],
                batch["rejected_input_ids"].shape[1],
                batch["base_input_ids"].shape[1],
            )

        for k in batch:
            if k.startswith("chosen") and isinstance(batch[k], torch.Tensor):
                if "labels" in k or is_encoder_decoder:
                    pad_token_id = label_pad_token_id
                elif k.endswith("_input_ids"):
                    pad_value = padding_value
                elif k.endswith("_attention_mask"):
                    pad_value = 0
                concatenated_key = k.replace("chosen", "concatenated")
                concatenated_batch[concatenated_key] = pad_to_length(
                    batch[k], max_length, pad_value=pad_value
                )
        for k in batch:
            if k.startswith("rejected") and isinstance(batch[k], torch.Tensor):
                if "labels" in k or is_encoder_decoder:
                    pad_value = label_pad_token_id
                elif k.endswith("_input_ids"):
                    pad_value = padding_value
                elif k.endswith("_attention_mask"):
                    pad_value = 0
                concatenated_key = k.replace("rejected", "concatenated")
                concatenated_batch[concatenated_key] = torch.cat(
                    (
                        concatenated_batch[concatenated_key],
                        pad_to_length(batch[k], max_length, pad_value=pad_value),
                    ),
                    dim=0,
                ).to(device=device)
        for k in batch:
            if k.startswith("base") and isinstance(batch[k], torch.Tensor):
                if "labels" in k or is_encoder_decoder:
                    pad_value = label_pad_token_id
                elif k.endswith("_input_ids"):
                    pad_value = padding_value
                elif k.endswith("_attention_mask"):
                    pad_value = 0
                concatenated_key = k.replace("base", "concatenated")
                concatenated_batch[concatenated_key] = torch.cat(
                    (
                        concatenated_batch[concatenated_key],
                        pad_to_length(batch[k], max_length, pad_value=pad_value),
                    ),
                    dim=0,
                ).to(device=device)

        if is_encoder_decoder:
            concatenated_batch["concatenated_input_ids"] = (
                batch["prompt_input_ids"].repeat(2, 1).to(device=device)
            )
            concatenated_batch["concatenated_attention_mask"] = (
                batch["prompt_attention_mask"].repeat(2, 1).to(device=device)
            )

        return concatenated_batch

    def bapo_loss(
        self,
        policy_chosen_logps: torch.FloatTensor,
        policy_rejected_logps: torch.FloatTensor,
        policy_base_logps: torch.FloatTensor,
        reference_chosen_logps: torch.FloatTensor,
        reference_rejected_logps: torch.FloatTensor,
        reference_base_logps: torch.FloatTensor,
    ) -> Tuple[
        torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor
    ]:
        """Compute the BAPO loss for a batch of policy and reference model log probabilities.

        Args:
            policy_chosen_logps: Log probabilities of the policy model for the chosen responses. Shape: (batch_size,)
            policy_rejected_logps: Log probabilities of the policy model for the rejected responses. Shape: (batch_size,)
            policy_base_logps: Log probabilities of the policy model for the base responses. Shape: (batch_size,)
            reference_chosen_logps: Log probabilities of the reference model for the chosen responses. Shape: (batch_size,)
            reference_rejected_logps: Log probabilities of the reference model for the rejected responses. Shape: (batch_size,)
            reference_base_logps: Log probabilities of the reference model for the base responses. Shape: (batch_size,)
        Returns:
            A tuple of four tensors: (losses, chosen_rewards, rejected_rewards, base_rewards).
            The losses tensor contains the BAPO loss for each example in the batch.
            The chosen_rewards, rejected_rewards, and base_rewards tensors contain the rewards for the chosen and rejected responses, respectively.
        """

        pi_logratios = policy_chosen_logps - policy_rejected_logps
        if self.reference_free:
            ref_logratios = torch.tensor(
                [0], dtype=pi_logratios.dtype, device=pi_logratios.device
            )
        else:
            ref_logratios = reference_chosen_logps - reference_rejected_logps

        pi_logratios = pi_logratios.to(self.accelerator.device)
        ref_logratios = ref_logratios.to(self.accelerator.device)
        logits = pi_logratios - ref_logratios

        # Base-Anchored Regularization
        base_anchor_reg = reference_base_logps - policy_base_logps

        losses = -(
            F.logsigmoid(self.beta * logits)
            - self.bapo_lambda1 * torch.clamp(base_anchor_reg, min=0)  # lambda1 * max(0, ratio)
        )

        chosen_rewards = (
            self.beta * (policy_chosen_logps - reference_chosen_logps).detach()
        )
        rejected_rewards = (
            self.beta * (policy_rejected_logps - reference_rejected_logps).detach()
        )
        base_rewards = self.beta * (policy_base_logps - reference_base_logps).detach()

        return losses, chosen_rewards, rejected_rewards, base_rewards

    def concatenated_forward(
        self, model: nn.Module, batch: Dict[str, Union[List, torch.LongTensor]]
    ) -> Tuple[
        torch.FloatTensor,
        torch.FloatTensor,
        torch.FloatTensor,
        torch.FloatTensor,
        torch.FloatTensor,
        torch.FloatTensor,
    ]:
        """Run the given model on the given batch of inputs, concatenating the chosen, rejected, and base inputs together.
        We do this to avoid doing two forward passes, because it's faster for FSDP.
        """
        concatenated_batch = self.concatenated_inputs(
            batch,
            is_encoder_decoder=self.is_encoder_decoder,
            label_pad_token_id=self.label_pad_token_id,
            padding_value=self.padding_value,
            device=self.accelerator.device,
        )
        len_chosen = batch["chosen_labels"].shape[0]
        len_rejected = batch["rejected_labels"].shape[0]

        model_kwargs = (
            {
                "labels": concatenated_batch["concatenated_labels"],
                "decoder_input_ids": concatenated_batch.pop(
                    "concatenated_decoder_input_ids", None
                ),
            }
            if self.is_encoder_decoder
            else {}
        )
        all_logits = model(
            concatenated_batch["concatenated_input_ids"],
            attention_mask=concatenated_batch["concatenated_attention_mask"],
            use_cache=False,
            **model_kwargs,
        ).logits

        all_logps = self.get_batch_logps(
            all_logits,
            concatenated_batch["concatenated_labels"],
            average_log_prob=self.loss_type == "ipo",
            is_encoder_decoder=self.is_encoder_decoder,
            label_pad_token_id=self.label_pad_token_id,
        )

        chosen_logps = all_logps[:len_chosen]
        rejected_logps = all_logps[len_chosen : len_chosen + len_rejected]
        base_logps = all_logps[len_chosen + len_rejected :]

        chosen_logits = all_logits[:len_chosen]
        rejected_logits = all_logits[len_chosen : len_chosen + len_rejected]
        base_logits = all_logits[len_chosen + len_rejected :]

        return (
            chosen_logps,
            rejected_logps,
            base_logps,
            chosen_logits,
            rejected_logits,
            base_logits,
        )

    def get_batch_loss_metrics(
        self,
        model,
        batch: Dict[str, Union[List, torch.LongTensor]],
        train_eval: Literal["train", "eval"] = "train",
    ):
        """Compute the BAPO loss and other metrics for the given batch of inputs for train or test."""
        metrics = {}

        (
            policy_chosen_logps,
            policy_rejected_logps,
            policy_base_logps,
            policy_chosen_logits,
            policy_rejected_logits,
            policy_base_logits,
        ) = self.concatenated_forward(model, batch)

        # if reference_chosen_logps and reference_rejected_logps in batch use them, otherwise use the reference model
        if "reference_chosen_logps" in batch and "reference_rejected_logps" in batch:
            reference_chosen_logps = batch["reference_chosen_logps"]
            reference_rejected_logps = batch["reference_rejected_logps"]
            reference_base_logps = batch["reference_base_logps"]
        else:
            with torch.no_grad():
                if self.ref_model is None:
                    with self.null_ref_context():
                        (
                            reference_chosen_logps,
                            reference_rejected_logps,
                            reference_base_logps,
                            _,
                            _,
                            _,
                        ) = self.concatenated_forward(self.model, batch)
                else:
                    (
                        reference_chosen_logps,
                        reference_rejected_logps,
                        reference_base_logps,
                        _,
                        _,
                        _,
                    ) = self.concatenated_forward(self.ref_model, batch)

        losses, chosen_rewards, rejected_rewards, base_rewards = self.bapo_loss(
            policy_chosen_logps,
            policy_rejected_logps,
            policy_base_logps,
            reference_chosen_logps,
            reference_rejected_logps,
            reference_base_logps,
        )
        reward_accuracies = (chosen_rewards > rejected_rewards).float()
        reward_accuracies_cb = (chosen_rewards > base_rewards).float()
        reward_accuracies_br = (base_rewards > rejected_rewards).float()

        prefix = "eval_" if train_eval == "eval" else ""
        metrics[f"{prefix}rewards/chosen"] = chosen_rewards.mean().cpu()
        metrics[f"{prefix}rewards/rejected"] = rejected_rewards.mean().cpu()
        metrics[f"{prefix}rewards/base"] = base_rewards.mean().cpu()
        metrics[f"{prefix}rewards/accuracies"] = reward_accuracies.mean().cpu()
        metrics[f"{prefix}rewards/accuracies_cb"] = reward_accuracies_cb.mean().cpu()
        metrics[f"{prefix}rewards/accuracies_br"] = reward_accuracies_br.mean().cpu()
        metrics[f"{prefix}rewards/margins"] = (
            (chosen_rewards - rejected_rewards).mean().cpu()
        )
        metrics[f"{prefix}rewards/margins_cb"] = (
            (chosen_rewards - base_rewards).mean().cpu()
        )
        metrics[f"{prefix}rewards/margins_br"] = (
            (base_rewards - rejected_rewards).mean().cpu()
        )
        metrics[f"{prefix}logps/rejected"] = policy_rejected_logps.detach().mean().cpu()
        metrics[f"{prefix}logps/chosen"] = policy_chosen_logps.detach().mean().cpu()
        metrics[f"{prefix}logps/base"] = policy_base_logps.detach().mean().cpu()

        metrics[f"{prefix}logits/chosen"] = policy_chosen_logits.detach().mean().cpu()
        metrics[f"{prefix}logits/rejected"] = (
            policy_rejected_logits.detach().mean().cpu()
        )
        metrics[f"{prefix}logits/base"] = policy_base_logits.detach().mean().cpu()

        return losses.mean(), metrics

    def prediction_setp(
        self,
        model: Union[PreTrainedModel, nn.Module],
        inputs: Dict[str, Union[torch.Tensor, Any]],
        prediction_loss_only: bool,
        ignore_keys: Optional[List[str]] = None,
    ):
        if not self.use_dpo_data_collator:
            warnings.warn(
                "prediction_step is only implemented for DPODataCollatorWithPadding, and you passed a datacollator that is different than "
                "DPODataCollatorWithPadding - you might see unexpected behavior. Alternatively, you can implement your own prediction_step method if you are using a custom data collator"
            )
        if ignore_keys is None:
            if hasattr(model, "config"):
                ignore_keys = getattr(model.config, "keys_to_ignore_at_inference", [])
            else:
                ignore_keys = []

        prediction_context_manager = (
            torch.cuda.amp.autocast
            if self._peft_has_been_casted_to_bf16
            else nullcontext
        )

        with torch.no_grad(), prediction_context_manager():
            loss, metrics = self.get_batch_loss_metrics(
                model, inputs, train_eval="eval"
            )

        # force log the metrics
        self.store_metrics(metrics, train_eval="eval")

        if prediction_loss_only:
            return (loss.detach(), None, None)

        # logits for the chosen and rejected samples from model
        logits_dict = {
            "eval_logits/chosen": metrics["eval_logits/chosen"],
            "eval_logits/rejected": metrics["eval_logits/rejected"],
            "eval_logits/base": metrics["eval_logits/base"],
        }
        logits = tuple(
            v.unsqueeze(dim=0) for k, v in logits_dict.items() if k not in ignore_keys
        )
        logits = torch.stack(logits).mean(axis=1).to(self.accelerator.device)
        labels = torch.zeros(logits.shape[0], device=self.accelerator.device)

        return (loss.detach(), logits, labels)
