#!/bin/bash -l
# SLURM SUBMIT SCRIPT
#SBATCH --job-name=Mulima
#SBATCH --account=TODO
#SBATCH --partition=TODO
#SBATCH --nodes=1
#SBATCH --gres=gpu:4
#SBATCH --ntasks-per-node=1
#SBATCH --cpus-per-task=48
#SBATCH --hint=nomultithread
#SBATCH --time=10:00:00
#SBATCH --output=sbatch_logs/%x-%A-%a.out           # output file name
#SBATCH --error=sbatch_logs/%x-%A-%a.err            # error file name
#SBATCH --array=0

#### Environment variables ####
export CXX=g++
export CC=gcc
export CUDA_DEVICE_MAX_CONNECTIONS=1
# force crashing on nccl issues like hanging broadcast
export NCCL_ASYNC_ERROR_HANDLING=1
export NCCL_IB_TIMEOUT=50
export UCX_RC_TIMEOUT=4s
export NCCL_SOCKET_IFNAME=ib0
export GLOO_SOCKET_IFNAME=ib0
export NCCL_DEBUG=INFO
export NCCL_IB_RETRY_CNT=10

export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python

# Enable logging
set -x -e
echo "START TIME: $(date)"

source "path/to/env"
cd $INSTALLATION_DIR/FastChat

##### Network parameters #####
MASTER_ADDR=$(scontrol show hostnames $SLURM_JOB_NODELIST | head -n 1)
# Allow communication over InfiniBand cells.
MASTER_ADDR="${MASTER_ADDR}"
# Get IP for hostname.
MASTER_ADDR="$(nslookup "$MASTER_ADDR" | grep -oP '(?<=Address: ).*')"
MASTER_PORT=6000

export LAUNCHER="python -m torch.distributed.run \
    --nproc_per_node 4 \
    --nnodes $SLURM_JOB_NUM_NODES \
    --rdzv_endpoint $MASTER_ADDR:$MASTER_PORT \
    --rdzv_backend c10d \
    --max_restarts 0 \
    --tee 3 \
    "
SRUN_ARGS=" \
    --wait=60 \
    --kill-on-bad-exit=1 \
    "

LANGS=(
    # "DE"
    # "EN"
    # "EN_DE_FR_IT_ES_sampled"
    "FR"
    # "IT"
    # "ES"
    # "EN_DE_FR_IT_ES"
)
# LIMA
# TRAIN_DATA_PATH_PREFIX="/path/to/train_with_stats_EN_DE_ES_FR_IT_fastchat_format"
# VAL_DATA_PATH_PREFIX="/path/to/val_dataset_EN_manually_curated_EN_DE_ES_FR_IT_fastchat_format"
# NUM_EPOCHS=7
# Bactrian-X
TRAIN_DATA_PATH_PREFIX="../path/to/train_bactrianx"
VAL_DATA_PATH_PREFIX="../path/to/val_bactrianx"
NUM_EPOCHS=3
LR="1e-5"
LANG=${LANGS[$SLURM_ARRAY_TASK_ID]}
MODEL_PATH="7B_EU24/checkpoints/iter_0236250_trfs"
export PYTHONPATH="$MODEL_PATH:$PYTHONPATH"
DIR_NAME="7B_single_token_$LANG"
NAME_PREFIX="train_run_24EU_bactrianx"
DATE=$(date +%Y_%m_%d_%H_%M_%S)
UNIQUE_NAME="${NAME_PREFIX}_${LANG}_lr_${LR}_${DATE}"
SAVE_DIR="24EU_runs/$DIR_NAME/$UNIQUE_NAME"
mkdir -p "$SAVE_DIR/checkpoints/"
mkdir -p "$SAVE_DIR/tensorboard_logs/"


CMD="fastchat/train/train.py \
    --model_name_or_path $MODEL_PATH  \
    --data_path ${TRAIN_DATA_PATH_PREFIX}_$LANG.jsonl \
    --eval_data_path ${VAL_DATA_PATH_PREFIX}_$LANG.jsonl \
    --output_dir $SAVE_DIR/checkpoints \
    --logging_dir $SAVE_DIR/tensorboard_logs \
    --report_to tensorboard \
    --run_name $UNIQUE_NAME \
    --logging_steps 1 \
    --evaluation_strategy epoch \
    --save_strategy epoch \
    --num_train_epochs $NUM_EPOCHS \
    --fsdp \"full_shard auto_wrap\" \
    --fsdp_transformer_layer_cls_to_wrap LlamaDecoderLayer \
    --bf16 True \
    --torch_compile False \
    --per_device_train_batch_size 4 \
    --per_device_eval_batch_size 4 \
    --gradient_accumulation_steps 16 \
    --learning_rate $LR \
    --weight_decay 0.1 \
    --adam_beta1 0.9 \
    --adam_beta2 0.95 \
    --lr_scheduler_type linear \
    --tf32 False \
    --gradient_checkpointing True \
    --model_max_length 2048 \
    --lazy_preprocess False \
    --trust_remote_code True 2>&1 | tee $SAVE_DIR/out_$UNIQUE_NAME.log"

echo "[$(date)] Start fine-tuning $UNIQUE_NAME..."
srun $SRUN_ARGS --jobid $SLURM_JOB_ID bash -c "$LAUNCHER --node_rank $SLURM_PROCID $CMD"
echo "[$(date)] Done fine-tuning $UNIQUE_NAME"
set +x