declare -A task2class
task2class['cola']=2
task2class['mnli']=3
task2class['mrpc']=2
task2class['qnli']=2
task2class['qqp']=2
task2class['rte']=2
task2class['sst-2']=2
task2class['sts-b']=5
task2class['wnli']=2
task2class['boolq']=2
task2class['cb']=3
task2class['copa']=2
task2class['mrc']=2
task2class['srte']=2
task2class['wic']=2
task2class['snli']=3
task2class['paws']=2
task2class['imdb']=2
task2class['anli_r1']=3
task2class['anli_r2']=3
task2class['anli_r3']=3
task2class['scitail']=2
task2class['swag']=4
task2class['swag2']=2
task2class['winogrande']=2
task2class['arceasy']=2
task2class['ag']=4
task2class['yelppolarity']=2


prompt_len=20
prompt_config='configs/config_prompt_p20.json'
batch_size_pt=32
batch_size_ft=16
text_encoder='roberta-base'
export CUDA_VISIBLE_DEVICES=0
dataset_root=./data/discrimination/
round=1
TaskList=('boolq' 'mrc')

for seed in {42..42}
do
  ### roberta-base
  for t in ${TaskList[@]}
  do
    echo '================================================================================================================='
    echo lora
    echo '================================================================================================================='
    python -m torch.distributed.launch --nproc_per_node=8 --master_port=25648 --use_env train_peft_roberta.py \
          --config ./configs/Round$round/Target/${t}_lora.yaml \
          --output_dir output/roberta_base/round$round/lora/seed$seed/Target/$t/lora \
          --class_num ${task2class[$t]} \
          --task $t \
          --batch_size_train $batch_size_ft \
          --batch_size_test $batch_size_ft \
          --text_encoder $text_encoder \
          --seed $seed \
          --dataset_root $dataset_root \
          --use_lora \
          --max_length 512
  done

  for t in ${TaskList[@]}
  do 
    echo '================================================================================================================='
    echo BiKT_lora
    echo '================================================================================================================='
    python -m torch.distributed.launch --nproc_per_node=8 --master_port=25648 --use_env train_peft_roberta.py \
           --config ./configs/Round$round/Target/${t}_lora.yaml \
           --output_dir output/roberta_base/round$round/lora/seed$seed/Target/$t/lora_bikt \
           --checkpoint output/roberta_base/round$round/p$prompt_len/seed$seed/multitask/bikt/checkpoint_best.pth \
           --use_prompt \
           --prompt_config $prompt_config \
           --class_num ${task2class[$t]} \
           --task $t \
           --batch_size_train $batch_size_ft \
           --batch_size_test $batch_size_ft \
           --text_encoder $text_encoder \
           --seed $seed \
           --dataset_root $dataset_root \
           --use_lora \
           --max_length 512
  done
done