# Inspired by: https://github.com/huggingface/trl/blob/main/examples/research_projects/stack_llama_2/scripts/dpo_llama2.py

from typing import TYPE_CHECKING, Optional, List
from transformers import Seq2SeqTrainingArguments

from llmtuner.data import split_dataset
from llmtuner.extras.constants import IGNORE_INDEX
from llmtuner.extras.ploting import plot_loss
from llmtuner.hparams import ModelArguments
from llmtuner.model import load_model_and_tokenizer
from llmtuner.train.d2o_ema.collator import DPODataCollatorWithPadding
from llmtuner.train.d2o_ema.trainer import D2OTrainer
from llmtuner.train.d2o_ema.data import get_dataset, preprocess_dataset
from llmtuner.train.utils import create_modelcard_and_push, create_ref_model

if TYPE_CHECKING:
    from transformers import TrainerCallback
    from llmtuner.hparams import DataArguments, FinetuningArguments


def run_d2o_ema(
    model_args: "ModelArguments",
    data_args: "DataArguments",
    training_args: "Seq2SeqTrainingArguments",
    finetuning_args: "FinetuningArguments",
    callbacks: Optional[List["TrainerCallback"]] = None
):
    # dataset json_dict, max_K(数据文件中最大正例数目)
    dataset, max_K = get_dataset(model_args, data_args)
    assert finetuning_args.multiple_K <= max_K, "multiple_K should be less than max_K"
    multiple_K = finetuning_args.multiple_K
    model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args, training_args.do_train, stage="sft")
    dataset = preprocess_dataset(dataset, tokenizer, data_args, training_args, multiple_K, stage="rm")

    # eval dataset TODO: No evaluation right now
    # from copy import deepcopy
    # eval_data_args = deepcopy(data_args)
    # eval_data_args.dataset = "pair_safe_330k_test_no_repeat"
    # eval_data_args.init_for_training(training_args.seed)
    # eval_dataset = get_dataset(model_args, eval_data_args)
    # eval_dataset = preprocess_dataset(eval_dataset, tokenizer, eval_data_args, training_args, stage="rm")

    data_collator = DPODataCollatorWithPadding(
        tokenizer=tokenizer,
        pad_to_multiple_of=4,
        label_pad_token_id=IGNORE_INDEX if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id
    )

    # Create reference model
    if finetuning_args.ref_model is None and (not training_args.do_train): # use the model itself
        ref_model = model
    else:
        ref_model = create_ref_model(model_args, finetuning_args, stage="dpo")

    # Update arguments
    training_args_dict = training_args.to_dict()
    # training_args_dict.update({
    #     "evaluation_strategy":"steps",
    #     "eval_steps":250, 
    #     "logging_first_step":True,
    #     "save_total_limit":1,
    # })
    training_args_dict.update(dict(remove_unused_columns=False)) # important for pairwise dataset
    training_args = Seq2SeqTrainingArguments(**training_args_dict)
    train_dataset = split_dataset(dataset, data_args, training_args)['train_dataset']
    #eval_dataset = split_dataset(eval_dataset, eval_data_args, training_args)['train_dataset']
    #print("Eval length: ",len(eval_dataset))

    # Initialize our Trainer
    trainer = D2OTrainer(
        beta=finetuning_args.dpo_beta,
        multiple_K=finetuning_args.multiple_K,
        model=model,
        ref_model=ref_model,
        args=training_args,
        tokenizer=tokenizer,
        data_collator=data_collator,
        callbacks=callbacks,
        label_smoothing=finetuning_args.label_smoothing,
        loss_type=finetuning_args.loss_type,
        train_dataset=train_dataset,
        #eval_dataset=eval_dataset
    )

    # Training
    if training_args.do_train:
        train_result = trainer.train(resume_from_checkpoint=training_args.resume_from_checkpoint)
        trainer.save_model()
        trainer.log_metrics("train", train_result.metrics)
        trainer.save_metrics("train", train_result.metrics)
        trainer.save_state()
        
        # "eval_loss",'eval_rewards/accuracies'
        if trainer.is_world_process_zero() and finetuning_args.plot_loss:
            plot_loss(training_args.output_dir, keys=["loss",  "rewards/accuracies"])
        trainer.log_metrics("eval", train_result.metrics)
        trainer.save_metrics("eval", train_result.metrics)

    # Evaluation
    if training_args.do_eval:
        metrics = trainer.evaluate(metric_key_prefix="eval")
        if id(model) == id(ref_model): # unable to compute rewards without a reference model
            remove_keys = [key for key in metrics.keys() if "rewards" in key]
            for key in remove_keys:
                metrics.pop(key)
        trainer.log_metrics("eval", metrics)
        trainer.save_metrics("eval", metrics)

    # Create model card
    create_modelcard_and_push(trainer, model_args, data_args, training_args, finetuning_args)
