import torch
import torch.nn.functional as F
from trl import DPOTrainer

__all__ = ["DPOPTrainer"]


class DPOPTrainer(DPOTrainer):
    def __init__(self, dpop_lambda, *args, **kwargs):
        super(DPOPTrainer, self).__init__(*args, **kwargs)
        self.dpop_lambda = dpop_lambda

    def dpo_loss(
        self,
        policy_chosen_logps: torch.FloatTensor,
        policy_rejected_logps: torch.FloatTensor,
        reference_chosen_logps: torch.FloatTensor,
        reference_rejected_logps: torch.FloatTensor,
        reference_free: bool = False,
    ):
        pi_logratios = policy_chosen_logps - policy_rejected_logps
        ref_logratios = reference_chosen_logps - reference_rejected_logps

        if reference_free:
            ref_logratios = 0

        logits = pi_logratios - ref_logratios
        # add regularization here
        positive_reg = reference_chosen_logps - policy_chosen_logps

        losses = -(
            F.logsigmoid(self.beta * logits)
            - self.dpop_lambda
            * torch.clamp(positive_reg, min=0)  # lambda * max(0, ratio)
        )
        chosen_rewards = (
            self.beta * (policy_chosen_logps - reference_chosen_logps).detach()
        )
        rejected_rewards = (
            self.beta * (policy_rejected_logps - reference_rejected_logps).detach()
        )

        return losses, chosen_rewards, rejected_rewards
