#!/bin/bash
export CUDA_DEVICE_MAX_CONNECTIONS=1
export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
DIR=`pwd`

# Guide:
# This script supports distributed training on multi-gpu workers (as well as single-worker training).
# Please set the options below according to the comments.
# For multi-gpu workers training, these options should be manually set for each worker.
# After setting the options, please run the script on each worker.

# Number of GPUs per GPU worker
GPUS_PER_NODE=$(python -c 'import torch; print(torch.cuda.device_count())')

# Number of GPU workers, for single-worker training, please set to 1
NNODES=${NNODES:-1}

# The rank of this worker, should be in {0, ..., WORKER_CNT-1}, for single-worker training, please set to 0
NODE_RANK=${NODE_RANK:-0}

# The ip address of the rank-0 worker, for single-worker training, please set to localhost
MASTER_ADDR=${MASTER_ADDR:-localhost}

# The port for communication
MASTER_PORT=${MASTER_PORT:-6001}

# MODEL="Qwen/Qwen1.5-1.8B" # Set the path if you do not want to load from huggingface directly
MODEL="microsoft/phi-1_5"
# ATTENTION: specify the path to your training data, which should be a json file consisting of a list of conversations.
# See the section for finetuning in README for more information.
DATA=("data/glue_data/CoLA/train.json" 
      "data/glue_data/MNLI/train.json"
      "data/glue_data/MRPC/train.json"
      "data/glue_data/QNLI/train.json"
      "data/glue_data/QQP/train.json"
      "data/glue_data/RTE/train.json"
      "data/glue_data/SST-2/train.json"
      "data/glue_data/STS-B/train.json"
      "data/glue_data/WNLI/train.json"
      )

DATA_EVAL=("data/glue_data/CoLA/dev.json" 
           "data/glue_data/MNLI/dev_matched.json" 
           "data/glue_data/MNLI/dev_mismatched.json"
           "data/glue_data/MRPC/dev.json"
           "data/glue_data/QNLI/dev.json"
           "data/glue_data/QQP/dev.json"
           "data/glue_data/RTE/dev.json"
           "data/glue_data/SST-2/dev.json"
           "data/glue_data/STS-B/dev.json"
           "data/glue_data/WNLI/dev.json"
           )
OUTPUT_PATH="output/phi/round1/multitask/bikt"

TASKS_PRE_ROUND=()
TASKS_CURR_ROUND=('cola' 'mnli' 'mrpc' 'qnli' 'qqp' 'rte' 'sst2' 'stsb' 'wnli')

function usage() {
    echo '
Usage: bash finetune/finetune_ds.sh [-m MODEL_PATH] [-d DATA_PATH]
'
}

while [[ "$1" != "" ]]; do
    case $1 in
        -m | --model )
            shift
            MODEL=$1
            ;;
        -d | --data )
            shift
            DATA=$1
            ;;
        -h | --help )
            usage
            exit 0
            ;;
        * )
            echo "Unknown argument ${1}"
            exit 1
            ;;
    esac
    shift
done

DISTRIBUTED_ARGS="
    --nproc_per_node $GPUS_PER_NODE \
    --nnodes $NNODES \
    --node_rank $NODE_RANK \
    --master_addr $MASTER_ADDR \
    --master_port $MASTER_PORT
"

torchrun $DISTRIBUTED_ARGS finetune_phi.py \
    --model_name_or_path $MODEL \
    --data_path ${DATA[@]} \
    --eval_data_path ${DATA_EVAL[@]} \
    --fp16 True \
    --output_dir $OUTPUT_PATH \
    --eval_output_dir $OUTPUT_PATH \
    --num_train_epochs 10 \
    --per_device_train_batch_size 8 \
    --per_device_eval_batch_size 16 \
    --gradient_accumulation_steps 8 \
    --evaluation_strategy "steps" \
    --eval_steps 1 \
    --save_strategy "epoch" \
    --save_total_limit 10 \
    --learning_rate 5e-6 \
    --weight_decay 0.1 \
    --adam_beta2 0.95 \
    --warmup_ratio 0.01 \
    --lr_scheduler_type "cosine" \
    --logging_steps 1 \
    --report_to "none" \
    --model_max_length 512 \
    --gradient_checkpointing True \
    --lazy_preprocess True \
    --deepspeed finetune/ds_config_zero2.json \
    --classfication True \
    --use_prompt True \
    --only_prompt False \
    --fix_prompt False \
    --prompt_config "finetune/configs/config_prompt_p20.json" \
    --fix_word_embeddings True \
    --multitask_train_prompt True \
    --is_multitask True \
    --fix_prompt_pre_round True \
    --tasks_curr_round ${TASKS_CURR_ROUND[@]}

