from collections import defaultdict
from contextlib import nullcontext
from types import MethodType
from typing import TYPE_CHECKING, Dict, Literal, Optional, Tuple, Union
import torch
import torch.nn.functional as F
from transformers import BatchEncoding, Trainer, AutoTokenizer
from trl import DPOTrainer
from trl.trainer.utils import disable_dropout_in_model
from trl.models import PreTrainedModelWrapper
from torch.nn import functional as F
from llamafactory.extras.constants import IGNORE_INDEX
from llamafactory.train.d2o.scheduler import FixedSamplerScheduler, LinearSamplerScheduler, ExponentialSamplerScheduler, IncreasingDensityScheduler
from ..trainer_utils import create_custom_optimzer, create_custom_scheduler, get_batch_logps, get_ref_context


if TYPE_CHECKING:
    from transformers import PreTrainedModel, ProcessorMixin
    from ...hparams import FinetuningArguments
    

class D2OTrainer(DPOTrainer):
    # RRHF loss
    def __init__(
        self,
        model: Union["PreTrainedModel", torch.nn.Module],
        ref_model: Optional[Union["PreTrainedModel", torch.nn.Module]],
        finetuning_args: "FinetuningArguments",
        processor: Optional["ProcessorMixin"]=None,
        disable_dropout: Optional[bool] = True,
        scheduler: Optional[str]=None,
        warm_up_steps: Optional[int]=1600,
        sample_scale: Optional[float]=2,
        sample_interval: Optional[int]=400,
        tokenizer: Optional[AutoTokenizer] = None,
        **kwargs
    ):
        if disable_dropout:
            disable_dropout_in_model(model)
            if ref_model is not None:
                disable_dropout_in_model(ref_model)
        self.multiple_K = finetuning_args.multiple_K
        self.finetuning_args = finetuning_args
        self.processor = processor
        self.reference_free = False
        self.use_dpo_data_collator = True # hack to avoid warning
        self.generate_during_eval = False # disable at evaluation
        self.label_pad_token_id = IGNORE_INDEX
        self.padding_value = 0
        self.is_encoder_decoder = model.config.is_encoder_decoder
        self.precompute_ref_log_probs = False
        self._precomputed_train_ref_log_probs = False
        self._precomputed_eval_ref_log_probs = False
        self._peft_has_been_casted_to_bf16 = False
        
        self.ref_model = ref_model
        self.alpha = finetuning_args.d2o_alpha
        self.beta = finetuning_args.d2o_beta
        self._stored_metrics = defaultdict(lambda: defaultdict(list))

        # dpo hyperparams
        self.beta = finetuning_args.pref_beta
        self.loss_type = finetuning_args.pref_loss
        self.ftx_gamma = finetuning_args.pref_ftx
        self.label_smoothing = finetuning_args.dpo_label_smoothing
        self.simpo_gamma = finetuning_args.simpo_gamma

        Trainer.__init__(self, model=model, **kwargs)
        if not hasattr(self, "accelerator"):
            raise AttributeError("Please update `transformers`.")

        if ref_model is not None:
            if self.is_deepspeed_enabled:
                if not (
                    getattr(ref_model, "is_loaded_in_8bit", False)
                    or getattr(ref_model, "is_loaded_in_4bit", False)
                ): # quantized models are already set on the correct device
                    self.ref_model = self._prepare_deepspeed(self.ref_model)
            else:
                self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True)
                self.ref_model.eval()
        # add model to data_collator
        self.data_collator.model = self.model
        self.data_collator.accelerator = self.accelerator
        self.data_collator.steps = 0
        self.data_collator.ref_model = self.ref_model
        # prepare sample tokenizer
        self.data_collator.sample_tokenizer = tokenizer
        # get one scheduler
        self.data_collator.scheduler = None
        if scheduler == 'increase':
            self.data_collator.scheduler = IncreasingDensityScheduler(warmup_steps=warm_up_steps, initial_interval=sample_interval, increase_factor=sample_scale)
        elif scheduler == 'fix-step':
            self.data_collator.scheduler = FixedSamplerScheduler(warmup_steps=warm_up_steps, sample_interval=sample_interval) 
        elif scheduler == 'decrease':
            self.data_collator.scheduler = ExponentialSamplerScheduler(warmup_steps=warm_up_steps, scale=sample_scale)
        
        self.loss_num = 2
         # 目前先按照alpha == beta d2o_alpha
    def create_optimizer(self) -> "torch.optim.Optimizer":
        if self.optimizer is None:
            self.optimizer = create_custom_optimzer(self.model, self.args, self.finetuning_args)
        return super().create_optimizer()

    def create_scheduler(
        self, num_training_steps: int, optimizer: Optional["torch.optim.Optimizer"] = None
    ) -> "torch.optim.lr_scheduler.LRScheduler":
        create_custom_scheduler(self.args, num_training_steps, optimizer)
        return super().create_scheduler(num_training_steps, optimizer)

    def _save(self, output_dir: Optional[str] = None, state_dict: Optional[Dict[str, "torch.Tensor"]] = None) -> None:
        super()._save(output_dir, state_dict)
        if self.processor is not None:
            output_dir = output_dir if output_dir is not None else self.args.output_dir
            getattr(self.processor, "image_processor").save_pretrained(output_dir)
    
    def concatenated_forward(
        self,
        model: Optional[torch.nn.Module] = None,
        batch: Optional[Dict[str, torch.Tensor]] = None
    ) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
        
        batch_size = self.args.per_device_train_batch_size
        # split batch
        # prompt_ids, prompt_attention_mask, input_ids, attention_mask = batch["prompt_ids"], batch["input_ids"], batch["attention_mask"]
        # chosen_prompt_ids, chosen_prompt_attention_mask, chosen_input_ids, chosen_attention_mask = prompt_ids[:batch_size], prompt_attention_mask[:batch_size], input_ids[:batch_size], attention_mask[:batch_size]
        # rejected_input_ids, rejected_attention_mask = input_ids[batch_size:], attention_mask[batch_size:]
        
        batch_copied = BatchEncoding({k: v.detach().clone() for k, v in batch.items()}) # avoid error
        
        all_logits = model(
            input_ids=batch_copied["input_ids"],
            attention_mask=batch_copied["attention_mask"],
            return_dict=True
        ).logits.to(torch.float32)

        all_logps, _ = get_batch_logps(
            all_logits,
            batch["labels"],
        ) # average_log_prob=False
        # [batch_size* self.multiple_K] batch_size
        multiple_K = self.multiple_K
        chosen_logps, rejected_logps = all_logps.split(batch_size* multiple_K, dim=0)
        chosen_logits, rejected_logits = all_logits.split(batch_size * multiple_K, dim=0)
        return chosen_logps, rejected_logps, chosen_logits, rejected_logits

    def compute_preference_loss(
        self,
        policy_chosen_logps: torch.FloatTensor,
        policy_rejected_logps: torch.FloatTensor,
        reference_chosen_logps: torch.FloatTensor,
        reference_rejected_logps: torch.FloatTensor,
    ) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
        """Compute the DPO loss for a batch of policy and reference model log probabilities.

        Args:
            policy_chosen_logps: Log probabilities of the policy model for the chosen responses. Shape: (batch_size,)
            policy_rejected_logps: Log probabilities of the policy model for the rejected responses. Shape: (batch_size,)
            reference_chosen_logps: Log probabilities of the reference model for the chosen responses. Shape: (batch_size,)
            reference_rejected_logps: Log probabilities of the reference model for the rejected responses. Shape: (batch_size,)
            reference_free: If True, we ignore the _provided_ reference model and implicitly use a reference model that assigns equal probability to all responses.

        Returns:
            A tuple of three tensors: (losses, chosen_rewards, rejected_rewards).
            The losses tensor contains the DPO loss for each example in the batch.
            The chosen_rewards and rejected_rewards tensors contain the rewards for the chosen and rejected responses, respectively.
        """
        if self.loss_num == 1:
            return self.formula_1(policy_chosen_logps, policy_rejected_logps, reference_chosen_logps, reference_rejected_logps)
        elif self.loss_num==2:
            return self.formula_2(policy_chosen_logps, policy_rejected_logps, reference_chosen_logps, reference_rejected_logps)
        elif self.loss_num==3:
            return self.formula_3(policy_chosen_logps, policy_rejected_logps, reference_chosen_logps, reference_rejected_logps)
    
    def formula_1(
            self,
            multiple_policy_chosen_logps: torch.FloatTensor,
            policy_rejected_logps: torch.FloatTensor,
            multiple_reference_chosen_logps: torch.FloatTensor,
            reference_rejected_logps: torch.FloatTensor,
            reference_free: bool = False,
    )-> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
        # if self.data_collator.steps < self.data_collator.scheduler.warmup_steps:
        #     logits = - policy_rejected_logps + reference_rejected_logps
        # else:
        losses = 0
        chosen_rewards = 0 
        for i in range(self.multiple_K):
            policy_chosen_logps = multiple_policy_chosen_logps[i * self._train_batch_size:i * self._train_batch_size + self._train_batch_size]
            pi_logratios = policy_chosen_logps - policy_rejected_logps
            if reference_free:
                ref_logratios = 0
            else:
                reference_chosen_logps = multiple_reference_chosen_logps[i * self._train_batch_size:i * self._train_batch_size + self._train_batch_size]
                ref_logratios = reference_chosen_logps - reference_rejected_logps

            logits = pi_logratios - ref_logratios
            losses += -F.logsigmoid(self.beta * logits)
            chosen_rewards += self.beta * (policy_chosen_logps - reference_chosen_logps).detach()
        # mean
        losses = losses / self.multiple_K
        chosen_rewards = chosen_rewards / self.multiple_K
        rejected_rewards = self.beta * (policy_rejected_logps - reference_rejected_logps).detach()

        return losses, chosen_rewards, rejected_rewards
    
    def formula_2(
            self,
            multiple_policy_chosen_logps: torch.FloatTensor,
            policy_rejected_logps: torch.FloatTensor,
            multiple_reference_chosen_logps: torch.FloatTensor,
            reference_rejected_logps: torch.FloatTensor,
            reference_free: bool = False,
    )-> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
        # if self.data_collator.steps < self.data_collator.scheduler.warmup_steps:
        #     logits = - policy_rejected_logps + reference_rejected_logps
        # else:
        mean_chosen_diff = 0
        chosen_rewards = 0 
        for i in range(self.multiple_K):
            policy_chosen_logps = multiple_policy_chosen_logps[i * self._train_batch_size:i * self._train_batch_size + self._train_batch_size]
            mean_chosen_diff += policy_chosen_logps
            if not reference_free:
                reference_chosen_logps = multiple_reference_chosen_logps[i * self._train_batch_size:i * self._train_batch_size + self._train_batch_size]
                mean_chosen_diff -= reference_chosen_logps

        mean_chosen_diff = mean_chosen_diff / self.multiple_K
        rejected_diff = policy_rejected_logps 
        
        if not reference_free:
            rejected_diff = rejected_diff - reference_rejected_logps

        # mean
        logits = self.beta * mean_chosen_diff - self.alpha * rejected_diff
        losses = -F.logsigmoid(logits)
        chosen_rewards = self.beta * (mean_chosen_diff).detach()
        rejected_rewards = self.alpha * (policy_rejected_logps - reference_rejected_logps).detach()

        return losses, chosen_rewards, rejected_rewards
    
    def formula_3(
            self,
            multiple_policy_chosen_logps: torch.FloatTensor,
            policy_rejected_logps: torch.FloatTensor,
            multiple_reference_chosen_logps: torch.FloatTensor,
            reference_rejected_logps: torch.FloatTensor,
            reference_free: bool = False,
    )-> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
        # if self.data_collator.steps < self.data_collator.scheduler.warmup_steps:
        #     logits = - policy_rejected_logps + reference_rejected_logps
        # else:
        mean_policy_chosen_logps_diff = 0
        mean_policy_chosen_logps, mean_reference_chosen_logps = 0, 0
        chosen_rewards = 0 
        for i in range(self.multiple_K):
            policy_chosen_logps = multiple_policy_chosen_logps[i * self._train_batch_size:i * self._train_batch_size + self._train_batch_size]
            mean_policy_chosen_logps += policy_chosen_logps
            
            if not reference_free:
                reference_chosen_logps = multiple_reference_chosen_logps[i * self._train_batch_size:i * self._train_batch_size + self._train_batch_size]
                mean_reference_chosen_logps += reference_chosen_logps
                mean_policy_chosen_logps_diff += torch.exp(self.beta *(policy_chosen_logps - reference_chosen_logps))
        
        # sum()
        mean_policy_chosen_logps_diff = mean_policy_chosen_logps_diff / self.multiple_K

        rejected_diff = torch.exp((policy_rejected_logps - reference_rejected_logps) * self.beta)

        # mean
        logits = torch.log(mean_policy_chosen_logps_diff) - torch.log(mean_policy_chosen_logps_diff + torch.exp(rejected_diff))
        losses = -F.logsigmoid(logits)
        chosen_rewards = self.beta * (mean_policy_chosen_logps - mean_reference_chosen_logps).detach()
        rejected_rewards = self.beta * (policy_rejected_logps - reference_rejected_logps).detach()

        return losses, chosen_rewards, rejected_rewards
    
    def compute_reference_log_probs(
        self, model: "PreTrainedModel", batch: Dict[str, "torch.Tensor"]
    ) -> Tuple[Optional["torch.Tensor"], Optional["torch.Tensor"]]:
        r"""
        Computes log probabilities of the reference model.
        """
        if not self.finetuning_args.use_ref_model:
            return None, None

        if self.ref_model is None:
            ref_model = model
            ref_context = get_ref_context(self.accelerator, model)
        else:
            ref_model = self.ref_model
            ref_context = nullcontext()

        with torch.no_grad(), ref_context:
            reference_chosen_logps, reference_rejected_logps, *_ = self.concatenated_forward(ref_model, batch)

        return reference_chosen_logps, reference_rejected_logps

    def get_batch_loss_metrics(
        self,
        model: "PreTrainedModel",
        batch: Dict[str, "torch.Tensor"],
        train_eval: Literal["train", "eval"] = "train",
    ) -> Tuple["torch.Tensor", Dict[str, "torch.Tensor"]]:
        r"""
        Computes the DPO loss and other metrics for the given batch of inputs for train or test.
        """
        metrics = {}
        (
            policy_chosen_logps,
            policy_rejected_logps,
            policy_chosen_logits,
            policy_rejected_logits,
        ) = self.concatenated_forward(model, batch)

        reference_chosen_logps, reference_rejected_logps = self.compute_reference_log_probs(model, batch)
        losses, chosen_rewards, rejected_rewards = self.compute_preference_loss(
            policy_chosen_logps,
            policy_rejected_logps,
            reference_chosen_logps,
            reference_rejected_logps,
        )

        reward_accuracies = (chosen_rewards > rejected_rewards).float()

        prefix = "eval_" if train_eval == "eval" else ""
        metrics["{}rewards/chosen".format(prefix)] = chosen_rewards.mean().cpu()
        metrics["{}rewards/rejected".format(prefix)] = rejected_rewards.mean().cpu()
        metrics["{}rewards/accuracies".format(prefix)] = reward_accuracies.mean().cpu()
        metrics["{}rewards/margins".format(prefix)] = (chosen_rewards - rejected_rewards).mean().cpu()
        metrics["{}logps/rejected".format(prefix)] = policy_rejected_logps.detach().mean().cpu()
        metrics["{}logps/chosen".format(prefix)] = policy_chosen_logps.detach().mean().cpu()
        metrics["{}logits/rejected".format(prefix)] = policy_rejected_logits.detach().mean().cpu()
        metrics["{}logits/chosen".format(prefix)] = policy_chosen_logits.detach().mean().cpu()
        return losses.mean(), metrics