#!/bin/bash
#SBATCH --job-name=qnli # create a short name for your job
#SBATCH -N 1 # node count
#SBATCH --ntasks-per-node 1 # number of tasks to run per node
#SBATCH --cpus-per-task 16 # cpu-cores per task (>1 if multi-threaded tasks),--cpus-per-task
#SBATCH --gpus-per-node 1 # total gpus for job

DO_TEST="0"
TASK=qnli

OUTPUT_PREFIX=finetune_hyperlora
RUN=glue
CHECKPOINT=flanv2_1121_t5_large_auto_demo
BSZ=8
LR=3e-4

# TASKS=("rte" "sst2" "mrpc" "stsb" "qqp" "mnli" "qnli" "cola")

if [ "$DO_TEST" != "1" ]; then
echo "DO TRAIN"
python3 finetune.py \
    --finetune \
    --seed 42 \
    --do_train \
    --do_eval \
    --do_test \
    --bf16 \
    --task $TASK \
    --hypelora_name_or_path hf_models/hyperlora-t5-large \
    --hyper_model_name_or_path hf_models/t5-large \
    --model_name_or_path hf_models/t5-large \
    --pretrain_checkpoint YOUR_PRETRAIN_CHECKPOINT_FILE \
    --per_device_train_batch_size $BSZ \
    --per_device_eval_batch_size 32 \
    --gradient_accumulation_steps 1 \
    --learning_rate $LR \
    --preprocessing_num_workers 12 \
    --generation_max_length 512 \
    --logging_strategy steps \
    --logging_steps 5 \
    --lr_scheduler_type 'cosine' \
    --num_train_epochs 30 \
    --warmup_ratio 0.06 \
    --max_seq_length 512 \
    --max_answer_length 256 \
    --val_max_answer_length 256 \
    --load_best_model_at_end \
    --greater_is_better True \
    --predict_with_generate \
    --evaluation_strategy 'epoch' \
    --save_strategy 'epoch' \
    --save_total_limit 1 \
    --input_column src_texts \
    --output_column tgt_texts \
    --temperature 8 \
    --loss_beta 0.2 \
    --lora_rank 16 \
    --lora_alpha 8 \
    --lora_target_modules "['q', 'v']" \
    --lora_dropout 0.05 \
    --report_to wandb \
    --run_name ${OUTPUT_PREFIX}_${RUN}_$TASK \
    --output_dir output/$OUTPUT_PREFIX/$CHECKPOINT/$RUN/$TASK

echo $TASK
echo $LR
cat output/$OUTPUT_PREFIX/$CHECKPOINT/$RUN/$TASK/test_results.json
rm -r output/$OUTPUT_PREFIX/$CHECKPOINT/$RUN/$TASK/checkpoint-*
else
echo "DO TEST"
python3 finetune.py \
    --finetune \
    --seed 42 \
    --do_test \
    --bf16 \
    --task $TASK \
    --hypelora_name_or_path hf_models/hyperlora-t5-large \
    --model_name_or_path hf_models/flan-t5-large \
    --pretrain_checkpoint YOUR_PRETRAIN_CHECKPOINT_FILE \
    --finetune_checkpoint YOUR_FINETUNE_CHECKPOINT_FILE \
    --per_device_eval_batch_size 32 \
    --preprocessing_num_workers 12 \
    --generation_max_length 512 \
    --max_seq_length 256 \
    --max_answer_length 256 \
    --val_max_answer_length 256 \
    --greater_is_better True \
    --predict_with_generate \
    --input_column src_texts \
    --output_column tgt_texts \
    --temperature 8 \
    --loss_beta 0.2 \
    --lora_rank 16 \
    --lora_alpha 8 \
    --lora_target_modules "['q', 'v']" \
    --lora_dropout 0.05 \
    --output_dir output/$OUTPUT_PREFIX/$RUN/$TASK
fi