declare -A task2class
task2class['cola']=2
task2class['mnli']=3
task2class['mrpc']=2
task2class['qnli']=2
task2class['qqp']=2
task2class['rte']=2
task2class['sst-2']=2
task2class['sts-b']=5
task2class['wnli']=2
task2class['boolq']=2
task2class['cb']=3
task2class['copa']=2
task2class['mrc']=2
task2class['srte']=2
task2class['wic']=2
task2class['snli']=3
task2class['paws']=2
task2class['imdb']=2
task2class['anli_r1']=3
task2class['anli_r2']=3
task2class['anli_r3']=3
task2class['scitail']=2
task2class['winogrande']=2

prompt_len=20
prompt_config='configs/config_prompt_p20.json'
batch_size_pt=64
batch_size_ft=32
text_encoder='bert-base-uncased'
dataset_root=./data/discrimination/
round=2
# 'boolq' 'cb' 'copa' 'srte' 'wic' 'mrc' 'snli' 'paws' 'imdb'
TaskList=('winogrande')

for seed in {42..44}
do
  ### bert-base
  ### baseline1: finetuning
  for t in ${TaskList[@]}
  do
    echo '================================================================================================================='
    echo finetuning task: $t config: ${t}.yaml class_num: ${task2class[$t]}
    echo '================================================================================================================='
    python -m torch.distributed.launch --nproc_per_node=8 --use_env train_bert.py \
          --config ./configs/Round$round/Target/${t}.yaml \
          --output_dir output/bert_base/round$round/p$prompt_len/seed$seed/Target/$t/finetuning \
          --class_num ${task2class[$t]} \
          --task $t \
          --batch_size_train $batch_size_ft \
          --batch_size_test $batch_size_ft \
          --text_encoder $text_encoder \
          --seed $seed \
          --dataset_root $dataset_root
  done

  ### baseline2: finetuning prompts & backbone
  for t in ${TaskList[@]}
  doLanguage/configs/Round2/Target/scitail.yaml
    echo task: $t config: ./configs/Target/${t}.yaml class_num: ${task2class[$t]}
    python -m torch.distributed.launch --nproc_per_node=8 --use_env train_bert.py \
           --config ./configs/Target/${t}.yaml \
           --output_dir output/bert_base/p$prompt_len/seed$seed/Target/$t/finetuning_with_prompt \
           --use_prompt \
           --prompt_config $prompt_config \
           --class_num ${task2class[$t]} \
           --task $t \
           --batch_size_train $batch_size_ft \
           --batch_size_test $batch_size_ft \
           --text_encoder bert-base-uncased \
           --seed $seed
  done

  ### baseline3: vanilla prompt tuning
  for t in ${TaskList[@]}
  do
    echo '================================================================================================================='
    echo vanilla prompt tuning task: $t config: ./configs/Target/${t}_prompt.yaml class_num: ${task2class[$t]}
    echo '================================================================================================================='
    python -m torch.distributed.launch --nproc_per_node=8 --use_env train_bert.py \
            --config ./configs/Round$round/Target/${t}_prompt.yaml \
            --output_dir output/bert_base/round$round/p$prompt_len/seed$seed/Target/$t/prompt_tuning \
            --use_prompt \
            --only_prompt \
            --prompt_config $prompt_config \
            --class_num ${task2class[$t]} \
            --task $t \
            --batch_size_train $batch_size_pt \
            --batch_size_test $batch_size_pt \
            --text_encoder $text_encoder \
            --dataset_root $dataset_root \
            --seed $seed \
            --dataset_root $dataset_root
  done

  ### baseline4: multitask prompt tuning as initialization
  for t in ${TaskList[@]}
  do
    echo '================================================================================================================='
    echo multitask prompt tuning as initialization task: $t config: ./configs/Target/${t}_prompt.yaml class_num: ${task2class[$t]}
    echo '================================================================================================================='
    python -m torch.distributed.launch --nproc_per_node=8 --use_env train_bert.py \
            --config ./configs/Round$Language/configs/Round2/Target/scitail.yamlround/Target/${t}_prompt.yaml \
            --output_dir output/bert_base/round$round/p$prompt_len/seed$seed/Target/$t/prompt_tuning_mpt \
            --checkpoint output/bert_base/round$round/p$prompt_len/seed$seed/multitask/prompt_tuning/checkpoint_best.pth \
            --use_prompt \
            --only_prompt \
            --prompt_config $prompt_config \
            --class_num ${task2class[$t]} \
            --task $t \
            --batch_size_train $batch_size_pt \
            --batch_size_test $batch_size_pt \
            --text_encoder $text_encoder \
            --seed $seed \
            --dataset_root $dataset_root
  done

  ### BiKT_PT: prompt tuning on feedback model
  for t in ${TaskList[@]}
  do 
    echo '================================================================================================================='
    echo BiKT_PT: prompt tuning on feedback model task: $t config: ./configs/Target/${t}_prompt.yaml class_num: ${task2class[$t]}
    echo '================================================================================================================='
    python -m torch.distributed.launch --nproc_per_node=8 --use_env train_bert.py \
           --config ./configs/Round$round/Target/${t}_prompt.yaml \
           --output_dir output/bert_base/round$round/p$prompt_len/seed$seed/Target/$t/prompt_tuning_bikt \
           --checkpoint output/bert_base/round$round/p$prompt_len/seed$seed/multitask/bikt/checkpoint_best.pth \
           --use_prompt \
           --only_prompt \
           --prompt_config $prompt_config \
           --class_num ${task2class[$t]} \
           --task $t \
           --batch_size_train $batch_size_pt \
           --batch_size_test $batch_size_pt \
           --text_encoder $text_encoder \
           --dataset_root $dataset_root \
           --seed $seed \
           --dataset_root $dataset_root
  done

  ### BiKT_PT_no_pre_round: prompt tuning on feedback model - Ablation
  for t in ${TaskList[@]}
  do 
    echo '================================================================================================================='
    echo BiKT_PT_no_pre_round task: $t config: ./configs/Target/${t}_prompt.yaml class_num: ${task2class[$t]}
    echo '================================================================================================================='
    python -m torch.distributed.launch --nproc_per_node=8 --use_env train_bert.py \
           --config ./configs/Round$round/Target/${t}_prompt.yaml \
           --output_dir output/bert_base/round$round/p$prompt_len/seed$seed/Target/$t/prompt_tuning_bikt_no_pre_round \
           --checkpoint output/bert_base/round$round/p$prompt_len/seed$seed/multitask/bikt_no_pre_round/checkpoint_best.pth \
           --use_prompt \
           --only_prompt \
           --prompt_config $prompt_config \
           --class_num ${task2class[$t]} \
           --task $t \
           --batch_size_train $batch_size_pt \
           --batch_size_test $batch_size_pt \
           --text_encoder $text_encoder \
           --dataset_root $dataset_root \
           --seed $seed \
           --dataset_root $dataset_root
  done

  ### BiKT_PT_raw: prompt tuning on feedback model - Ablation
  for t in ${TaskList[@]}
  do 
    echo '================================================================================================================='
    echo BiKT_PT_raw task: $t config: ./configs/Target/${t}_prompt.yaml class_num: ${task2class[$t]}
    echo '================================================================================================================='
    python -m torch.distributed.launch --nproc_per_node=8 --use_env train_bert.py \
           --config ./configs/Round$round/Target/${t}_prompt.yaml \
           --output_dir output/bert_base/round$round/p$prompt_len/seed$seed/Target/$t/prompt_tuning_bikt_raw \
           --checkpoint output/bert_base/round$round/p$prompt_len/seed$seed/multitask/bikt_raw/checkpoint_best.pth \
           --use_prompt \
           --only_prompt \
           --prompt_config $prompt_config \
           --class_num ${task2class[$t]} \
           --task $t \
           --batch_size_train $batch_size_pt \
           --batch_size_test $batch_size_pt \
           --text_encoder $text_encoder \
           --dataset_root $dataset_root \
           --seed $seed \
           --dataset_root $dataset_root
  done

  ### BiKT_PT_wo_frozen: prompt tuning on feedback model - Ablation
  for t in ${TaskList[@]}
  do 
    echo '================================================================================================================='
    echo BiKT_PT_wo_frozen task: $t config: ./configs/Target/${t}_prompt.yaml class_num: ${task2class[$t]}
    echo '================================================================================================================='
    python -m torch.distributed.launch --nproc_per_node=8 --use_env train_bert.py \
           --config ./configs/Round$round/Target/${t}_prompt.yaml \
           --output_dir output/bert_base/round$round/p$prompt_len/seed$seed/Target/$t/prompt_tuning_bikt_wo_frozen \
           --checkpoint output/bert_base/round$round/p$prompt_len/seed$seed/multitask/bikt_wo_frozen/checkpoint_best.pth \
           --use_prompt \
           --only_prompt \
           --prompt_config $prompt_config \
           --class_num ${task2class[$t]} \
           --task $t \
           --batch_size_train $batch_size_pt \
           --batch_size_test $batch_size_pt \
           --text_encoder $text_encoder \
           --dataset_root $dataset_root \
           --seed $seed \
           --dataset_root $dataset_root
  done

  ### BiKT_PT: prompt tuning (averaged) on feedback model
  for t in ${TaskList[@]}
  do 
    echo task: $t config: ./configs/Target/${t}_prompt.yaml class_num: ${task2class[$t]}
    python -m torch.distributed.launch --nproc_per_node=8 --use_env train_bert.py \
           --config ./configs/Round$round/Target/${t}_prompt.yaml \
           --output_dir output/bert_base/round$round/p$prompt_len/seed$seed/Target/$t/prompt_tuning_bikt_average \
           --checkpoint output/bert_base/round$round/p$prompt_len/seed$seed/multitask/bikt/checkpoint_averaged_prompt.pth \
           --use_prompt \
           --only_prompt \
           --prompt_config $prompt_config \
           --class_num ${task2class[$t]} \
           --task $t \
           --batch_size_train $batch_size_pt \
           --batch_size_test $batch_size_pt \
           --text_encoder bert-base-uncased \
           --seed $seed
  done

  ### BiKT_PT: prompt tuning (weighted averaged) on feedback model
  for t in ${TaskList[@]}
  do 
    echo task: $t config: ./configs/Target/${t}_prompt.yaml class_num: ${task2class[$t]}
    python -m torch.distributed.launch --nproc_per_node=8 --use_env train_bert.py \
           --config ./configs/Target/${t}_prompt.yaml \
           --output_dir output/bert_base/p$prompt_len/seed$seed/Target/$t/prompt_tuning_bikt_weigthed_average \
           --checkpoint output/bert_base/p$prompt_len/seed$seed/multitask/bikt/checkpoint_weighted_averaged_prompt.pth \
           --use_prompt \
           --only_prompt \
           --prompt_config $prompt_config \
           --class_num ${task2class[$t]} \
           --task $t \
           --batch_size_train $batch_size_pt \
           --batch_size_test $batch_size_pt \
           --text_encoder bert-base-uncased \
           --seed $seed
  done

  ### BiKT_FT: finetuning with prompt (random_init) on feedback model
  for t in ${TaskList[@]}
  do 
    echo '================================================================================================================='
    echo BiKT_FT task: $t config: ./configs/Target/${t}.yaml class_num: ${task2class[$t]}
    echo '================================================================================================================='
    python -m torch.distributed.launch --nproc_per_node=8 --use_env train_bert.py \
           --config ./configs/Round$round/Target/${t}.yaml \
           --output_dir output/bert_base/round$round/p$prompt_len/seed$seed/Target/$t/finetuning_bikt \
           --checkpoint output/bert_base/round$round/p$prompt_len/seed$seed/multitask/bikt/checkpoint_best.pth \
           --use_prompt \
           --prompt_config $prompt_config \
           --class_num ${task2class[$t]} \
           --task $t \
           --batch_size_train $batch_size_ft \
           --batch_size_test $batch_size_ft \
           --text_encoder $text_encoder \
           --seed $seed \
           --dataset_root $dataset_root
  done

  # ### BiKT_FT: finetuning with prompt (average) on feedback model
  for t in ${TaskList[@]}
  do 
    echo task: $t config: ./configs/Target/${t}.yaml class_num: ${task2class[$t]}
    python -m torch.distributed.launch --nproc_per_node=8 --use_env train_bert.py \
           --config ./configs/Target/${t}.yaml \
           --output_dir output/bert_base/p$prompt_len/seed$seed/Target/$t/finetuning_bikt_average \
           --checkpoint output/bert_base/p$prompt_len/seed$seed/multitask/bikt/checkpoint_averaged_prompt.pth \
           --use_prompt \
           --prompt_config $prompt_config \
           --class_num ${task2class[$t]} \
           --task $t \
           --batch_size_train $batch_size_ft \
           --batch_size_test $batch_size_ft \
           --text_encoder bert-base-uncased \
           --seed $seed
  done

  # ### BiKT_FT: finetuning with prompt (weighted average) on feedback model
  for t in ${TaskList[@]}
  do 
    echo task: $t config: ./configs/Target/${t}.yaml class_num: ${task2class[$t]}
    python -m torch.distributed.launch --nproc_per_node=8 --use_env train_bert.py \
           --config ./configs/Target/${t}.yaml \
           --output_dir output/bert_base/p$prompt_len/seed$seed/Target/$t/finetuning_bikt_weighted_average \
           --checkpoint output/bert_base/p$prompt_len/seed$seed/multitask/bikt/checkpoint_weighted_averaged_prompt.pth \
           --use_prompt \
           --prompt_config $prompt_config \
           --class_num ${task2class[$t]} \
           --task $t \
           --batch_size_train $batch_size_ft \
           --batch_size_test $batch_size_ft \
           --text_encoder bert-base-uncased \
           --seed $seed
  done

  # ### BiKT without prompt: prompt tuning on feedback model (without prompt)
  for t in ${TaskList[@]}
  do 
    echo '================================================================================================================='
    echo BiKT without prompt task: $t config: ./configs/Target/${t}_prompt.yaml class_num: ${task2class[$t]}
    echo '================================================================================================================='
    python -m torch.distributed.launch --nproc_per_node=8 --use_env train_bert.py \
          --config ./configs/Round$round/Target/${t}_prompt.yaml \
          --output_dir output/bert_base/round$round/p$prompt_len/seed$seed/Target/$t/prompt_tuning_bikt_ablation \
          --checkpoint output/bert_base/round$round/p$prompt_len/seed$seed/multitask/finetuning/checkpoint_best.pth \
          --use_prompt \
          --only_prompt \
          --prompt_config $prompt_config \
          --class_num ${task2class[$t]} \
          --task $t \
          --batch_size_train $batch_size_pt \
          --batch_size_test $batch_size_pt \
          --text_encoder $text_encoder \
          --seed $seed \
          --dataset_root $dataset_root
  done
done