#!/bin/bash

MODEL_TYPE=$1
DATASET=$2
N_SAMPLES=$3
export CUDA_VISIBLE_DEVICES=$4

if [[ $MODEL_TYPE == "condenser" ]]
then
  pretrained_model_cfg="Luyu/condenser"
elif [[ $MODEL_TYPE == "cocondenser" ]]
then
  pretrained_model_cfg="Luyu/co-condenser-wiki"
elif [[ $MODEL_TYPE == "orqa-unsup" ]]
then
  pretrained_model_cfg="../dpr_models/orqa_torch"
else
  pretrained_model_cfg="bert-base-uncased"
fi

TRAIN_FILE="${DATASET}-train"
DEV_FILE="${DATASET}-dev"

if [[ $N_SAMPLES != "full" ]]
then
	TRAIN_FILE+="-$N_SAMPLES"
fi

if [[ $N_SAMPLES = 128 ]]
then
	BATCH_SIZE=32
	N_EPOCHS=80
elif [[ $DATASET == "webquestions" || $DATASET == "curatedtrec" ]] && [[ $N_SAMPLES == "full" ]]
then
  BATCH_SIZE=128
	N_EPOCHS=100
else
	BATCH_SIZE=128
	N_EPOCHS=40
fi

if [[ $N_SAMPLES != "full" ]]
then
	WARMUP_STEPS=32
else
  if [[ $DATASET == "webquestions" ]]
  then
    WARMUP_STEPS=200
  elif [[ $DATASET == "curatedtrec" ]]
  then
    WARMUP_STEPS=100
  else
    WARMUP_STEPS=1000
  fi
fi

MODEL_NAME="${MODEL_TYPE}-${DATASET}-${N_SAMPLES}"

ARGS=""
if [[ $N_SAMPLES != "full" ]]
then
  ARGS+="--no_save --no_eval "
fi

if [[ $MODEL_TYPE != "bert" && $MODEL_TYPE != "orqa-unsup" && $MODEL_TYPE != "condenser" && $MODEL_TYPE != "cocondenser" ]]
then
  MODEL_PATH="../od_splinter/runs/${MODEL_TYPE}/checkpoints/dpr_biencoder.100000"
  if [[ ! -f $MODEL_PATH ]]
  then
    echo "Model does not exist!"
    exit
  fi
  ARGS+="--model_file ${MODEL_PATH}"
fi

python train_dense_encoder.py \
--max_grad_norm 2.0 \
--encoder_model_type hf_bert \
--pretrained_model_cfg $pretrained_model_cfg \
--load_only_model \
--do_lower_case \
--seed 12345 \
--sequence_length 240 \
--warmup_steps $WARMUP_STEPS \
--batch_size $BATCH_SIZE \
--train_file "../DPR/data/retriever/${TRAIN_FILE}.json" \
--dev_file "../DPR/data/retriever/${DEV_FILE}.json" \
--output_dir "../od_splinter/runs/${MODEL_TYPE}/${DATASET}-${N_SAMPLES}" \
--learning_rate 1e-05 \
--num_train_epochs $N_EPOCHS \
--dev_batch_size $BATCH_SIZE \
--val_av_rank_start_epoch $(( N_EPOCHS - 1 )) \
--log_batch_step 1000000 \
--train_rolling_loss_step 10 \
--wandb_project od_splinter \
--wandb_name "train-${MODEL_NAME}" \
--fp16 \
$ARGS
