#!/usr/bin/env bash

# bash xgiga_finetune_component_rouge_eval.sh train en zh mspm4_xgiga debug_rouge_eval 

MODE="$1"
LG="$2"
EVAL_LG="$3"
DATAVER="$4"
NAME="$5"

argslist=""
for (( i = 6; i <= $# ; i++ ))
  do
    j=${!i}
    argslist="${argslist} $j "
  done
echo $argslist >&2

predicted_lgs=(en zh fr)

renameSplit(){
  path=$1
  oldsplit=$2
  newsplit=$3

  mv ${path}/${oldsplit}.doc-sum.doc.bin ${path}/${newsplit}.doc-sum.doc.bin
  mv ${path}/${oldsplit}.doc-sum.doc.idx ${path}/${newsplit}.doc-sum.doc.idx
  mv ${path}/${oldsplit}.doc-sum.sum.bin ${path}/${newsplit}.doc-sum.sum.bin
  mv ${path}/${oldsplit}.doc-sum.sum.idx ${path}/${newsplit}.doc-sum.sum.idx

  echo "rename ${path}/${oldsplit}.* to ${path}/${newsplit}.*"
}

renameParallelSplit(){
  path=$1
  oldsplit=$2
  newsplit=$3

  mv ${path}/${oldsplit}.bin ${path}/${newsplit}.bin
  mv ${path}/${oldsplit}.idx ${path}/${newsplit}.idx
  echo "rename ${path}/${oldsplit}.* to ${path}/${newsplit}.*"
}

cd "$(dirname $0)" || return
echo "Install fairseq" >&2

sudo mkdir -p /usr/lib/python3.7/site-packages/
sudo pip3 install -e fairseq
sudo pip3 install -r requirements.txt -i http://pypi.byted.org/simple/ --trusted-host=pypi.byted.org

export PYROUGE_HOME_DIR=$(pwd)/RELEASE-1.5.5
export PYROUGE_TEMP_PATH=/opt/tiger

pyrouge_set_rouge_path $PYROUGE_HOME_DIR
chmod +x $PYROUGE_HOME_DIR/ROUGE-1.5.5.pl

prefix=hdfs://haruna/home/byte_arnold_hl_mlnlc/user/wuxianze.0

dataset_path=${prefix}/Datasets/multilingual/data-bin/${DATAVER}/${LG}
tensorboard_logdir=${prefix}/Workspace/Multilingual/xgiga/${LG}/logs/${NAME}
checkpoint_path=${prefix}/Workspace/Multilingual/xgiga/${LG}/checkpoints/${NAME}
pretrained_path=${prefix}/Workspace/Multilingual/pretrained

# change checkpoint path
hdfs dfs -mkdir -p $tensorboard_logdir
hdfs dfs -mkdir -p $checkpoint_path

local_root=~/xgiga_${NAME}
resource_root=${local_root}/resource
output_path=${local_root}/output
model_path=${local_root}/model
component_path=~/component
mkdir -p ${resource_root}
mkdir -p ${output_path}
mkdir -p ${model_path}
mkdir -p ${component_path}

local_dataset_path=${resource_root}/dataset
mkdir -p ${local_dataset_path}
hadoop fs -copyToLocal ${dataset_path}/* ${local_dataset_path}
echo "Download resource from ${dataset_path} to ${local_dataset_path}" >&2

renameSplit ${local_dataset_path} "valid" "valid_$LG"

for lg in ${predicted_lgs[*]}
do
  if [ ! -e ${local_dataset_path}}/valid_${lg}.doc-sum.doc.bin ]; then
    echo "Download resource from ${prefix}/Datasets/multilingual/data-bin/${DATAVER}/$lg/valid.* to ${local_dataset_path}/" >&2
    hdfs dfs -get ${prefix}/Datasets/multilingual/data-bin/${DATAVER}/$lg/valid.* ${local_dataset_path}/
    renameSplit ${local_dataset_path} "valid" "valid_$lg"
  fi
done

local_tensorboard_path=${output_path}/tensorboard_logdir
mkdir -p ${local_tensorboard_path}

local_checkpoint_path=${output_path}/checkpoint_path
mkdir -p ${local_checkpoint_path}
hadoop fs -copyToLocal ${checkpoint_path}/checkpoint_last.pt ${local_checkpoint_path}
echo "Load checkpoints from ${checkpoint_path}/checkpoint_last.pt to ${local_checkpoint_path}" >&2

local_pretrained_path=${model_path}/mbart.cc25.v2
if [ ! -d ${local_pretrained_path} ]; then
  echo "Load pretrained model from ${pretrained_path}/mbart.cc25.v2.tar.gz to ${local_pretrained_path}" >&2
  hadoop fs -copyToLocal ${pretrained_path}/mbart.cc25.v2.tar.gz ${model_path}
  tar -xvzf ${model_path}/mbart.cc25.v2.tar.gz -C ${model_path}
else
  echo "Pretrained model in ${local_pretrained_path}" >&2
fi

hdfs dfs -get hdfs://haruna/home/byte_arnold_hl_mlnlc/user/wuxianze.0/Workspace/Multilingual/ComponentConfig/baseConfig.json ${model_path}/component_config.json
hdfs dfs -get hdfs://haruna/home/byte_arnold_hl_mlnlc/user/wuxianze.0/Datasets/cc100/components/* ${component_path}/

echo "Finish download files" >&2

langs=ar_AR,cs_CZ,de_DE,en_XX,es_XX,et_EE,fi_FI,fr_XX,gu_IN,hi_IN,it_IT,ja_XX,kk_KZ,ko_KR,lt_LT,lv_LV,my_MM,ne_NP,nl_XX,ro_RO,ru_RU,si_LK,tr_TR,vi_VN,zh_CN

if [[ "$MODE" == "train" || "$MODE" == "keep_training" ]]; then
  echo "Training..."

  (inotifywait -m ${local_checkpoint_path} -e close_write |
      while read path action file; do
          if [[ "$file" =~ .*pt$ ]]; then
              echo "Checkpoint detected: $file" >&2
              hadoop fs -put -f ${local_checkpoint_path}/$file ${checkpoint_path}/ && echo "checkpoint uploaded: $file to ${checkpoint_path}/$file" >&2
              rm ${local_checkpoint_path}/$file
          fi
      done) &

  if [ "$MODE" == "train" ]; then
    argslist="${argslist} --reset-optimizer --reset-dataloader --reset-meters --reset-lr-scheduler"
  fi

  echo "${local_dataset_path} --ddp-backend=no_c10d \
    --save-dir ${local_checkpoint_path} \
    --tensorboard-logdir ${local_tensorboard_path} \
    --restore-file ${local_pretrained_path}/model.pt $argslist"

  valid_subsets=
  for lg in ${predicted_lgs[*]}
  do
    if [ "${valid_subsets}" == "" ]; then
      valid_subsets="valid_${lg}"
    else
      valid_subsets="${valid_subsets},valid_${lg}"
    fi
  done

  python3 fairseq/train.py ${local_dataset_path} --ddp-backend=no_c10d \
    --save-dir ${local_checkpoint_path} \
    --tensorboard-logdir ${local_tensorboard_path} \
    --restore-file ${local_pretrained_path}/model.pt \
    --task summarization_from_pretrained_mbart_mspm4 \
    --arch mbart_large \
    --source-lang doc --target-lang sum \
    --langs $langs \
    --dataset-impl mmap \
    --truncate-source \
    --encoder-normalize-before --decoder-normalize-before \
    --layernorm-embedding \
    --criterion label_smoothed_cross_entropy --label-smoothing 0.2 \
    --required-batch-size-multiple 1 \
    --dropout 0.1 --attention-dropout 0.1 \
    --weight-decay 0.01 --optimizer adam \
    --lr 3e-5 --min-lr -1 \
    --lr-scheduler polynomial_decay \
    --clip-norm 0.1 \
    --skip-invalid-size-inputs-valid-test \
    --find-unused-parameters \
    --num-workers 10  --adam-betas '(0.9, 0.999)' --adam-eps 1e-08 \
    --total-num-update 781250 --warmup-updates 15625 \
    --max-epoch 50 --fp16 \
    --log-format simple \
    --keep-best-checkpoints 3 \
    --no-epoch-checkpoints \
    --user-dir examples/summarization \
    --valid-subset ${valid_subsets} \
    --best-checkpoint-metric ${EVAL_LG}_RL_F \
    --eval-rouge --maximize-best-checkpoint-metric \
    --eval-rouge-args '{"beam": 5, "lenpen": 0.6, "max_len_a": 0.78, "max_len_b": 2, "min_len": 2, "no_repeat_ngram_size": 3}' \
    --eval-rouge-print-samples \
    --max-sentences-valid 8 --eval-tokenized-bleu \
    --component-config ${model_path}/component_config.json \
    --batch-size 8 \
    --log-interval 800 \
    --save-interval-updates 4000 \
    --patience 12 \
    $argslist

elif [ "$MODE" == "generate" ]; then
  echo "Generating..."

  hadoop fs -copyToLocal ${checkpoint_path}/checkpoint_best.pt ${local_checkpoint_path}
  echo "Load checkpoints from ${checkpoint_path}/checkpoint_best.pt to ${local_checkpoint_path}" >&2

  suffix=$(echo "$argslist" | sed -e "s/-//g"  -e "s/  */_/g")
  
  python3 fairseq/generate.py ${local_dataset_path} \
  --path ${local_checkpoint_path}/checkpoint_best.pt \
  --task summarization_from_pretrained_mbart_mspm4 \
  --gen-subset test \
  --source-lang doc --target-lang sum \
  --langs $langs \
  --remove-bpe 'sentencepiece'  \
  --min-len 30 \
  --max-len-b 50 \
  --lenpen 0.6 \
  --no-repeat-ngram-size 3 \
  --truncate-source \
  --user-dir examples/summarization \
  $argslist \
  > ${local_tensorboard_path}/"output$suffix"

  cat ${local_tensorboard_path}/"output$suffix" | grep -P "^H" | sort -V |cut -f 3- | sed -e  "s/\[[a-z]\{2\}_[A-Z]\{2\}\]//g" > ${local_tensorboard_path}/"test$suffix.hypo"

  echo "Load ground truth file from ${prefix}/Datasets/multilingual/xgiga/raw/test.y.${TEST_LG}"
  hadoop fs -get ${prefix}/Datasets/multilingual/xgiga/raw/test.y.${TEST_LG} ${local_dataset_path}

  python3 utils/calRouge.py \
  -c ${local_tensorboard_path}/"test$suffix.hypo" \
  -r ${local_dataset_path}/test.${LG}.sum \
  -l ${LG} -d "<q>"

fi

echo "Put ${local_tensorboard_path} to ${tensorboard_logdir}" >&2
hadoop fs -put -f ${local_tensorboard_path}/* ${tensorboard_logdir}/
sleep 600
