#!/bin/bash
#set -e

################################################################################ Parameters Setting ################################################################################

# Setting base model for sft and dpo
base_model=huggyllama/llama-7b
model_name=LLaMA_7B

# Specifing which gpus to running
gpus=0,1,2,3

# Time stamp for output path
# 0520 chosen/reject fix
sft_time_stamp=20240306_1900
dpo_time_stamp=20240525_1100
now_time=20240525_1108

# Setting run stage
if_sft_train=true
if_sft_eval=false
if_dpo_train=true
if_dpo_eval=true
if_compare_res=false

# Setting hyper-params of specific algorithm
dpo_type=skto
dpo_name=SDPO #SKTO
dpo_loss=sigmoid # kto_pair
skto_gamma=0.01
skto_class_nums=3
dpo_ftx=0.0


# Setting training parameters
sft_batch_size=4
sft_train_epoch_num=2.0
dpo_batch_size=2
dpo_train_epoch_num=1.0
eval_batch_size=4
gradient_steps=4

sft_deepspeed=./scripts/deepspeed_configs/ds_config.json
dpo_deepspeed=./scripts/deepspeed_configs/ds_config.json

# Settting datasets
# hh
sft_dataset=hh_rlhf_en_train_sft
dpo_dataset=hh_rlhf_en_train
eval_dataset=hh_rlhf_en_test_2K

# tldr
# sft_dataset=tldr_train_sft
# dpo_dataset=tldr_train_preference
# eval_dataset=tldr_test_500


# Setting path of compared results
eval_model=gpt-4
compair_name=$model_name/Baseline-VS-$dpo_name/$eval_dataset/gamma_$skto_gamma-classe_$skto_class_nums-$now_time-$eval_model

# Setting path for saving trained models
sft_model_checkpoint=./outputs/$model_name-SFT-Trainset_$sft_dataset-$sft_time_stamp
sft_export_model=./models/$model_name-SFT-Trainset_$sft_dataset-$sft_time_stamp
dpo_model_checkpoint=./outputs/$model_name-SFT-$dpo_name-gamma_$skto_gamma-classe_$skto_class_nums-Trainset_$dpo_dataset-$dpo_time_stamp
dpo_export_model=./models/$model_name-SFT-$dpo_name-gamma_$skto_gamma-classe_$skto_class_nums-Trainset_$dpo_dataset-$dpo_time_stamp

# Setting path for results on evalsets
eval_sft_res=./results/$model_name-SFT-Evalset_$eval_dataset-$sft_time_stamp
eval_dpo_res=./results/$model_name-SFT-$dpo_name-gamma_$skto_gamma-classe_$skto_class_nums-Evalset_$eval_dataset-$now_time

# Setting other parameters
step_id=0

####################################################################################### End #######################################################################################


# [STEP 1] train SFT model based on raw model
if [ $if_sft_train = true ]; then
    echo -e "\n\nStarting step [$[step_id + 1]]: train SFT model.\n\n"
    step_id=$[step_id+1]

    deepspeed --include localhost:$gpus --master_port $(expr $RANDOM % 55535 + 10000) src/train_bash.py \
        --report_to none \
        --deepspeed $sft_deepspeed \
        --stage sft \
        --model_name_or_path $base_model \
        --do_train \
        --dataset $sft_dataset \
        --template default \
        --finetuning_type lora \
        --lora_target q_proj,v_proj \
        --output_dir $sft_model_checkpoint \
        --overwrite_output_dir \
        --save_steps 1000 \
        --num_train_epochs $sft_train_epoch_num \
        --per_device_train_batch_size $sft_batch_size \
        --gradient_accumulation_steps $gradient_steps \
        --optim adamw_torch \
        --learning_rate 5e-5 \
        --lr_scheduler_type cosine \
        --bf16


    # [STEP 2] export SFT model
    echo -e "\n\nStarting step [$[step_id + 1]]: export SFT model.\n\n"
    step_id=$[step_id+1]

    python src/export_model.py \
        --model_name_or_path $base_model \
        --adapter_name_or_path $sft_model_checkpoint \
        --template default \
        --finetuning_type lora \
        --export_dir $sft_export_model \
        --export_size 2 \
        --export_legacy_format False
fi


# [STEP 3] eval SFT model
if [ $if_sft_eval = true ]; then
    echo -e "\n\nStarting step [$[step_id + 1]]: evaluate SFT model.\n\n"
    step_id=$[step_id+1]

    deepspeed --include localhost:$gpus --master_port $(expr $RANDOM % 55535 + 10000) src/train_bash.py \
        --report_to none \
        --stage sft \
        --model_name_or_path $sft_export_model \
        --do_predict \
        --dataset $eval_dataset \
        --template default \
        --finetuning_type lora \
        --output_dir $eval_sft_res \
        --per_device_eval_batch_size 1 \
        --predict_with_generate


    # [STEP 4] change predictions.jsonl to predictions.json
    echo -e "\n\nStarting step [$[step_id + 1]]: transfer eval results to predictions.\n\n"
    step_id=$[step_id+1]

    python src/process_gen_predictions_to_predictions.py \
            --generate_predictions $eval_sft_res \
            --dataset $eval_dataset

fi


# [STEP 5] Preference alignment on SFT model under deepspeed
if [ $if_dpo_train = true ]; then
    echo -e "\n\nStarting step [$[step_id + 1]]: model preference alignment using SDPO.\n\n"
    step_id=$[step_id+1]

    deepspeed --include localhost:$gpus --master_port $(expr $RANDOM % 55535 + 10000) src/train_bash.py \
        --deepspeed $dpo_deepspeed \
        --report_to tensorboard \
        --logging_dir $dpo_model_checkpoint \
        --stage $dpo_type \
        --model_name_or_path $sft_export_model \
        --do_train \
        --dataset $dpo_dataset \
        --template default \
        --finetuning_type lora \
        --create_new_adapter \
        --lora_target q_proj,v_proj \
        --output_dir $dpo_model_checkpoint \
        --overwrite_output_dir \
        --save_steps 1000 \
        --num_train_epochs $dpo_train_epoch_num \
        --per_device_train_batch_size $dpo_batch_size \
        --gradient_accumulation_steps $gradient_steps \
        --optim adamw_torch \
        --learning_rate 1e-5 \
        --lr_scheduler_type cosine \
        --logging_steps 10 \
        --flash_attn \
        --dpo_ftx $dpo_ftx \
        --dpo_loss $dpo_loss \
        --skto_gamma $skto_gamma \
        --skto_class_nums $skto_class_nums \
        --fp16



    # [STEP 6] export DPO model
    echo -e "\n\nStarting step [$[step_id + 1]]: export aligned model.\n\n"
    step_id=$[step_id+1]

    python src/export_model.py \
        --model_name_or_path $sft_export_model \
        --adapter_name_or_path $dpo_model_checkpoint \
        --template default \
        --finetuning_type lora \
        --export_dir $dpo_export_model \
        --export_size 2 \
        --export_legacy_format False

fi


# [STEP 7] eval DPO model
if [ $if_dpo_eval = true ]; then
    echo -e "\n\nStarting step [$[step_id + 1]]: evaluate aligned model.\n\n"
    step_id=$[step_id+1]


    deepspeed --include localhost:$gpus --master_port $(expr $RANDOM % 55535 + 10000) src/train_bash.py \
        --report_to none \
        --stage sft \
        --model_name_or_path $dpo_export_model \
        --do_predict \
        --dataset $eval_dataset \
        --template default \
        --finetuning_type lora \
        --output_dir $eval_dpo_res \
        --per_device_eval_batch_size $eval_batch_size \
        --predict_with_generate \
        --flash_attn


    # [STEP 6] transfer generate_predictions to predictions
    echo -e "\n\nStarting step [$[step_id + 1]]: transfer eval results to predictions.\n\n"
    step_id=$[step_id+1]

    python src/process_gen_predictions_to_predictions.py \
            --generate_predictions $eval_dpo_res \
            --dataset $eval_dataset

fi

# [STEP 7] compare results using GPT-4
if [ $if_compare_res = true ]; then
    echo -e "\n\nStarting step [$[step_id + 1]]: comparing results using GPT-4.\n\n"
    step_id=$[step_id+1]

    python src/res_compare.py \
            --baseline_path $eval_dataset \
            --candidate_path $model_name-SFT-$dpo_name-gamma_$skto_gamma-classe_$skto_class_nums-Evalset_$eval_dataset-$now_time \
            --save_path $compair_name \
            --eval_model $eval_model
fi
