#!/usr/bin/env bash


# bash adapter_proj.sh train_2stage mspm4_cc100 mbartV2_cc100_adapterV1_encProjFix_projLN_trainLA_noTA_ls0 enzhfr_final_ln.json hdfs://haruna/home/byte_arnold_hl_mlnlc/user/wuxianze.0/Workspace/Multilingual/xgiga/ZhEnFr_infillNoise_r0.35_len256_right/checkpoints/mbartV2_cc100_adapterV1_encProjFix_projLN_trainLA_noTA_ls0 layer10 --freeze-decoder --freeze-encoder --num-workers 6 --criterion label_smoothed_cross_entropy --arch mbart_summ_abs_w_adapter_large --encoder-version v2 --adapter-num-layer 2 --ln-after-proj --freeze-adapter;

cd ..

MODE="$1"
DATAVER="$2"
NAME="$3"
ADAPTER_CONFIG_NAME="$4"
hdfs_adapter="$5"
COMPONENT="$6"

argslist=""
for (( i = 7; i <= $# ; i++ ))
  do
    j=${!i}
    argslist="${argslist} $j "
  done
echo $argslist >&2

EVAL_LGS=(ZhEnFr_infillNoise_r0.35_len256_right en zh fr)
skip_loading=False

renameSplit(){
  path=$1
  oldsplit=$2
  newsplit=$3

  mv ${path}/${oldsplit}.doc-sum.doc.bin ${path}/${newsplit}.doc-sum.doc.bin
  mv ${path}/${oldsplit}.doc-sum.doc.idx ${path}/${newsplit}.doc-sum.doc.idx
  mv ${path}/${oldsplit}.doc-sum.sum.bin ${path}/${newsplit}.doc-sum.sum.bin
  mv ${path}/${oldsplit}.doc-sum.sum.idx ${path}/${newsplit}.doc-sum.sum.idx

  echo "rename ${path}/${oldsplit}.* to ${path}/${newsplit}.*"
}

hadoop_get(){
  hadoop=$1
  local=$2

  hdfs dfs -get $hadoop $local
  echo "Download $hadoop to $local ... "
}

cd "$(dirname $0)" || return

local_root=~/xgiga_${NAME}
resource_root=${local_root}/resource
output_path=${local_root}/output
model_path=${local_root}/model
component_path=~/component
adapter_path=${local_root}/adapter

local_dataset_path=${resource_root}/dataset
local_tensorboard_path=${output_path}/tensorboard_logdir
local_checkpoint_path=${output_path}/checkpoint_path

if [ "$skip_loading" != "True" ]; then
  echo "Install fairseq" >&2

  sudo mkdir -p /usr/lib/python3.7/site-packages/
sudo pip3 install -e fairseq
  sudo pip3 install -r requirements.txt -i http://pypi.byted.org/simple/ --trusted-host=pypi.byted.org

  export PYROUGE_HOME_DIR=$(pwd)/RELEASE-1.5.5
  export PYROUGE_TEMP_PATH=/opt/tiger

  pyrouge_set_rouge_path $PYROUGE_HOME_DIR
  chmod +x $PYROUGE_HOME_DIR/ROUGE-1.5.5.pl

  prefix=hdfs://haruna/home/byte_arnold_hl_mlnlc/user/wuxianze.0
  pretrained_path=${prefix}/Workspace/Multilingual/pretrained

  mkdir -p ${resource_root}
  mkdir -p ${output_path}
  mkdir -p ${model_path}
  mkdir -p ${component_path}
  mkdir -p ${adapter_path}

  local_dataset_path=${resource_root}/dataset
  mkdir -p ${local_dataset_path}

  for LG in ${EVAL_LGS[*]}
  do
    dataset_path=${prefix}/Datasets/multilingual/data-bin/${DATAVER}/${LG}
    hadoop fs -copyToLocal ${dataset_path}/* ${local_dataset_path}
    echo "Download resource from ${dataset_path} to ${local_dataset_path}" >&2
    renameSplit ${local_dataset_path} "valid" "valid_$LG"
  done

  local_tensorboard_path=${output_path}/tensorboard_logdir
  mkdir -p ${local_tensorboard_path}

  local_checkpoint_path=${output_path}/checkpoint_path
  mkdir -p ${local_checkpoint_path}

  # local_pretrained_path=${model_path}/mbart.cc25.v2
  # if [ ! -d ${local_pretrained_path} ]; then
  #   echo "Load pretrained model from ${pretrained_path}/mbart.cc25.v2.tar.gz to ${local_pretrained_path}" >&2
  #   hadoop fs -copyToLocal ${pretrained_path}/mbart.cc25.v2.tar.gz ${model_path}
  #   tar -xvzf ${model_path}/mbart.cc25.v2.tar.gz -C ${model_path}
  # else
  #   echo "Pretrained model in ${local_pretrained_path}" >&2
  # fi

  hdfs dfs -get hdfs://haruna/home/byte_arnold_hl_mlnlc/user/wuxianze.0/Workspace/Multilingual/AdapterConfig/${ADAPTER_CONFIG_NAME} ${model_path}/
  hdfs dfs -get hdfs://haruna/home/byte_arnold_hl_mlnlc/user/wuxianze.0/Datasets/cc100/components/* ${component_path}/
  rm -r ${model_path}/component_config.json
  if [ "$COMPONENT" = "layer10" ]; then
    hdfs dfs -get hdfs://haruna/home/byte_arnold_hl_mlnlc/user/wuxianze.0/Workspace/Multilingual/ComponentConfig/layer10.json ${model_path}/component_config.json
    echo "downloading hdfs://haruna/home/byte_arnold_hl_mlnlc/user/wuxianze.0/Workspace/Multilingual/ComponentConfig/layer10.json"
  else
    hdfs dfs -get hdfs://haruna/home/byte_arnold_hl_mlnlc/user/wuxianze.0/Workspace/Multilingual/ComponentConfig/baseConfig.json ${model_path}/component_config.json
    echo "downloading hdfs://haruna/home/byte_arnold_hl_mlnlc/user/wuxianze.0/Workspace/Multilingual/ComponentConfig/baseConfig.json"
  fi

  # download adapter
  # hdfs dfs -get ${hdfs_adapter}/*_best.pt ${adapter_path}/
  hadoop_get ${hdfs_adapter}/*_best.pt ${adapter_path}/

  echo "Finish download files" >&2
fi

langs=ar_AR,cs_CZ,de_DE,en_XX,es_XX,et_EE,fi_FI,fr_XX,gu_IN,hi_IN,it_IT,ja_XX,kk_KZ,ko_KR,lt_LT,lv_LV,my_MM,ne_NP,nl_XX,ro_RO,ru_RU,si_LK,tr_TR,vi_VN,zh_CN

if [[ "$MODE" == "train" || "$MODE" == "train_2stage" ]]; then
  if [[ "$MODE" == "train" ]]; then
    restore_file=${local_pretrained_path}/model.pt
  else
    restore_file=${adapter_path}/checkpoint_best.pt
  fi

  valid_subsets=
  for lg in ${EVAL_LGS[*]}
  do
    if [ "${valid_subsets}" == "" ]; then
      valid_subsets="valid_${lg}"
    else
      valid_subsets="${valid_subsets},valid_${lg}"
    fi
  done

  python3 fairseq/train.py ${local_dataset_path} --ddp-backend=no_c10d \
    --save-dir ${local_checkpoint_path} \
    --tensorboard-logdir ${local_tensorboard_path} \
    --restore-file ${restore_file} \
    --task summarization_from_pretrained_mbart_mspm4 \
    --arch mbart_summ_abs_w_adapter_large \
    --source-lang doc --target-lang sum \
    --langs $langs \
    --dataset-impl mmap \
    --truncate-source \
    --encoder-normalize-before --decoder-normalize-before \
    --layernorm-embedding \
    --criterion label_smoothed_cross_entropy --label-smoothing 0 \
    --required-batch-size-multiple 1 \
    --dropout 0.1 --attention-dropout 0.1 \
    --weight-decay 0 --optimizer adam \
    --lr 1e-10 --min-lr -1 \
    --lr-scheduler polynomial_decay \
    --clip-norm 0.1 \
    --skip-invalid-size-inputs-valid-test \
    --find-unused-parameters \
    --num-workers 10  --adam-betas '(0.9, 0.999)' --adam-eps 1e-08 \
    --total-num-update 2 --warmup-updates 0 \
    --max-tokens 2048 \
    --max-epoch 3 --fp16 \
    --log-interval 1 \
    --save-interval-updates 1 \
    --log-format simple \
    --keep-best-checkpoints 3 \
    --no-epoch-checkpoints \
    --user-dir examples/summarization \
    --train-subset valid_en \
    --valid-subset ${valid_subsets} \
    --loose-load --iadapter-config ${model_path}/${ADAPTER_CONFIG_NAME} \
    --train-iadapter \
    --eval-batchs 20 \
    --eval-rouge-args '{"beam": 1, "max_len_a": 1.5, "max_len_b": 2, "min_len": 2}' \
    --eval-rouge \
    --eval-rouge-print-samples \
    --component-config ${model_path}/component_config.json \
    --trained-iadapter-dir ${adapter_path} \
    --reset-optimizer --reset-dataloader --reset-meters --reset-lr-scheduler \
    $argslist
fi