#!/usr/bin/env bash

# bash xgiga_proj.sh generate en zh mspm4_xgiga debug wiki40b_10w _best;
# bash xgiga_proj.sh train en en mspm4_xgiga fl_debug layer12;

MODE="$1"
LG="$2"
EVAL_LG="$3" # 'all' or single language
DATAVER="$4"
NAME="$5"
component_type="$6" # layer10_new

declare -i args_start=7
if [ $MODE == "generate" ]; then
    args_start=$((${args_start}+1))
fi

argslist=""
for (( i = ${args_start}; i <= $# ; i++ ))
    do
        j=${!i}
        argslist="${argslist} $j "
    done
echo $argslist >&2

# define
json_name="baseConfig.json"
if [ "${component_type}" == "layer10" ]; then
    json_name="layer10.json"
elif [ "${component_type}" == "wiki40b_10w" ]; then
    json_name="wiki40b_10w.json"
elif [ "${component_type}" == "wiki40b_10w_layer10_wtag" ]; then
    json_name="wiki40b_10w_layer10_wtag.json"
elif [ "${component_type}" == "layer10_wtag" ]; then
    json_name="layer10_wtag.json"
elif [ "${component_type}" == "layer10_new" ]; then
    json_name="layer10_new.json"
elif [ "${component_type}" == "layer12_v1Enc" ]; then
    json_name="layer12_v1Enc.json"
fi

predicted_lgs=(${EVAL_LG})
if [ ${EVAL_LG} == 'all' ]; then
    predicted_lgs=(en zh fr)
fi

predicted_lg_str=""
for lg in ${predicted_lgs[*]}
do
    predicted_lg_str="$predicted_lg_str $lg"
done

renameSplit(){
    path=$1
    oldsplit=$2
    newsplit=$3

    mv ${path}/${oldsplit}.doc-sum.doc.bin ${path}/${newsplit}.doc-sum.doc.bin
    mv ${path}/${oldsplit}.doc-sum.doc.idx ${path}/${newsplit}.doc-sum.doc.idx
    mv ${path}/${oldsplit}.doc-sum.sum.bin ${path}/${newsplit}.doc-sum.sum.bin
    mv ${path}/${oldsplit}.doc-sum.sum.idx ${path}/${newsplit}.doc-sum.sum.idx

    echo "rename ${path}/${oldsplit}.* to ${path}/${newsplit}.*"
}

cd "$(dirname $0)" || return
echo "Install fairseq" >&2

export https_proxy=http://bj-rd-proxy.byted.org:3128 http_proxy=http://bj-rd-proxy.byted.org:3128 no_proxy=code.byted.org

sudo mkdir -p /usr/lib/python3.7/site-packages/
sudo pip3 install -e fairseq
bash install_pyrouge.sh

export PYROUGE_HOME_DIR=$(pwd)/RELEASE-1.5.5
export PYROUGE_TEMP_PATH=/opt/tiger

pyrouge_set_rouge_path $PYROUGE_HOME_DIR
chmod +x $PYROUGE_HOME_DIR/ROUGE-1.5.5.pl

prefix=hdfs://haruna/home/byte_arnold_hl_mlnlc/user/wuxianze.0
dataset_path=${prefix}/Datasets/multilingual/data-bin/${DATAVER}/${LG}
tensorboard_logdir=${prefix}/Workspace/Multilingual/xgiga/${LG}/logs/${NAME}
checkpoint_path=${prefix}/Workspace/Multilingual/xgiga/${LG}/checkpoints/${NAME}
pretrained_path=${prefix}/Workspace/Multilingual/pretrained

# change checkpoint path
hdfs dfs -mkdir -p $tensorboard_logdir
hdfs dfs -mkdir -p $checkpoint_path

local_root=~/xgiga_${NAME}
resource_root=${local_root}/resource
output_path=${local_root}/output
model_path=${local_root}/model
component_path=~/component
mkdir -p ${resource_root}
mkdir -p ${output_path}
mkdir -p ${model_path}
mkdir -p ${component_path}

local_dataset_path=${resource_root}/dataset
mkdir -p ${local_dataset_path}
hadoop fs -copyToLocal ${dataset_path}/* ${local_dataset_path}
echo "Download resource from ${dataset_path} to ${local_dataset_path}" >&2

renameSplit ${local_dataset_path} "valid" "valid_$LG"
renameSplit ${local_dataset_path} "test" "test_$LG"

for lg in ${predicted_lgs[*]}
do
    if [ ! -e ${local_dataset_path}}/valid_${lg}.doc-sum.doc.bin ]; then
        echo "Download resource from ${prefix}/Datasets/multilingual/data-bin/${DATAVER}/$lg/valid.* to ${local_dataset_path}/" >&2
        hdfs dfs -get ${prefix}/Datasets/multilingual/data-bin/${DATAVER}/$lg/valid.* ${local_dataset_path}/
        renameSplit ${local_dataset_path} "valid" "valid_$lg"
    fi
done

local_tensorboard_path=${output_path}/tensorboard_logdir
mkdir -p ${local_tensorboard_path}

local_checkpoint_path=${output_path}/checkpoint_path
mkdir -p ${local_checkpoint_path}

local_pretrained_path=${model_path}/mbart.cc25.v2
if [ ! -d ${local_pretrained_path} ]; then
    echo "Load pretrained model from ${pretrained_path}/mbart.cc25.v2.tar.gz to ${local_pretrained_path}" >&2
    hadoop fs -copyToLocal ${pretrained_path}/mbart.cc25.v2.tar.gz ${model_path}
    tar -xvzf ${model_path}/mbart.cc25.v2.tar.gz -C ${model_path}
else
    echo "Pretrained model in ${local_pretrained_path}" >&2
fi

hdfs dfs -get hdfs://haruna/home/byte_arnold_hl_mlnlc/user/wuxianze.0/Workspace/Multilingual/ComponentConfig/${json_name} ${model_path}/component_config.json
hdfs dfs -get hdfs://haruna/home/byte_arnold_hl_mlnlc/user/wuxianze.0/Datasets/cc100/components/* ${component_path}/

echo "Finish download files" >&2

# check
echo "component: ${json_name}" >&2
cat ${model_path}/component_config.json >&2
echo "predicted_lg_str: ${predicted_lg_str}" >&2


langs=ar_AR,cs_CZ,de_DE,en_XX,es_XX,et_EE,fi_FI,fr_XX,gu_IN,hi_IN,it_IT,ja_XX,kk_KZ,ko_KR,lt_LT,lv_LV,my_MM,ne_NP,nl_XX,ro_RO,ru_RU,si_LK,tr_TR,vi_VN,zh_CN

if [[ "$MODE" == "train" || "$MODE" == "train_2stage" ]]; then
    echo "Training..."

    (inotifywait -m ${local_checkpoint_path} -e close_write |
        while read path action file; do
            if [[ "$file" =~ .*pt$ ]]; then
                echo "Checkpoint detected: $file" >&2
                hadoop fs -put -f ${local_checkpoint_path}/$file ${checkpoint_path}/ && echo "checkpoint uploaded: $file to ${checkpoint_path}/$file" >&2
                rm ${local_checkpoint_path}/$file
            fi
        done) &

    if [[ "$MODE" == "train" ]]; then
        restore_file=${local_pretrained_path}/model.pt
    else
        restore_file=${adapter_path}/checkpoint_best.pt
    fi

    valid_subsets=
    for lg in ${predicted_lgs[*]}
    do
        if [ "${valid_subsets}" == "" ]; then
            valid_subsets="valid_${lg}"
        else
            valid_subsets="${valid_subsets},valid_${lg}"
        fi
    done

    python3 fairseq/train.py ${local_dataset_path} --ddp-backend=no_c10d \
        --save-dir ${local_checkpoint_path} \
        --tensorboard-logdir ${local_tensorboard_path} \
        --restore-file ${restore_file} \
        --task summarization_from_pretrained_mbart_mspm4 \
        --arch mbart_summ_abs_large \
        --source-lang doc --target-lang sum \
        --langs $langs \
        --dataset-impl mmap \
        --truncate-source \
        --encoder-normalize-before --decoder-normalize-before \
        --layernorm-embedding \
        --criterion label_smoothed_cross_entropy --label-smoothing 0.2 \
        --required-batch-size-multiple 1 \
        --dropout 0.1 --attention-dropout 0.1 \
        --weight-decay 0.01 --optimizer adam \
        --lr 3e-5 --min-lr -1 \
        --lr-scheduler polynomial_decay \
        --clip-norm 0.1 \
        --skip-invalid-size-inputs-valid-test \
        --find-unused-parameters \
        --num-workers 10  --adam-betas '(0.9, 0.999)' --adam-eps 1e-08 \
        --total-num-update 781250 --warmup-updates 15625 \
        --max-epoch 50 --fp16 \
        --log-format simple \
        --keep-best-checkpoints 3 \
        --no-epoch-checkpoints \
        --user-dir examples/summarization \
        --valid-subset ${valid_subsets} \
        --best-checkpoint-metric ${EVAL_LG}_RL_F \
        --eval-rouge --maximize-best-checkpoint-metric \
        --eval-rouge-args '{"beam": 5, "lenpen": 0.6, "max_len_a": 0.78, "max_len_b": 2, "min_len": 2, "no_repeat_ngram_size": 3}' \
        --eval-rouge-print-samples \
        --adapter-num-layer 2 --batch-size 8 \
        --log-interval 800 \
        --freeze-decoder --freeze-encoder \
        --loose-load --save-interval-updates 4000 \
        --max-sentences-valid 8 \
        --patience 12 \
        --reset-optimizer --reset-dataloader --reset-meters --reset-lr-scheduler \
        --component-config ${model_path}/component_config.json \
        $argslist
    echo "Put ${local_tensorboard_path} to ${tensorboard_logdir}" >&2
    hadoop fs -put -f ${local_tensorboard_path}/* ${tensorboard_logdir}/
    sleep 300

elif [ "$MODE" == "keep_training" ]; then
    echo "Training..."

    (inotifywait -m ${local_checkpoint_path} -e close_write |
        while read path action file; do
            if [[ "$file" =~ .*pt$ ]]; then
                echo "Checkpoint detected: $file" >&2
                hadoop fs -put -f ${local_checkpoint_path}/$file ${checkpoint_path}/ && echo "checkpoint uploaded: $file to ${checkpoint_path}/$file" >&2
                rm ${local_checkpoint_path}/$file
            fi
        done) &
    hdfs dfs -get ${checkpoint_path}/checkpoint_last.pt ${local_pretrained_path}
    restore_file=${local_pretrained_path}/checkpoint_last.pt

    valid_subsets=
    for lg in ${predicted_lgs[*]}
    do
        if [ "${valid_subsets}" == "" ]; then
            valid_subsets="valid_${lg}"
        else
            valid_subsets="${valid_subsets},valid_${lg}"
        fi
    done

    python3 fairseq/train.py ${local_dataset_path} --ddp-backend=no_c10d \
        --save-dir ${local_checkpoint_path} \
        --tensorboard-logdir ${local_tensorboard_path} \
        --restore-file ${restore_file} \
        --task summarization_from_pretrained_mbart_mspm4 \
        --arch mbart_summ_abs_large \
        --source-lang doc --target-lang sum \
        --langs $langs \
        --dataset-impl mmap \
        --truncate-source \
        --encoder-normalize-before --decoder-normalize-before \
        --layernorm-embedding \
        --criterion label_smoothed_cross_entropy --label-smoothing 0.2 \
        --required-batch-size-multiple 1 \
        --dropout 0.1 --attention-dropout 0.1 \
        --weight-decay 0.01 --optimizer adam \
        --lr 3e-5 --min-lr -1 \
        --lr-scheduler polynomial_decay \
        --clip-norm 0.1 \
        --skip-invalid-size-inputs-valid-test \
        --find-unused-parameters \
        --num-workers 10  --adam-betas '(0.9, 0.999)' --adam-eps 1e-08 \
        --total-num-update 781250 --warmup-updates 15625 \
        --max-epoch 50 --fp16 \
        --log-format simple \
        --keep-best-checkpoints 3 \
        --no-epoch-checkpoints \
        --user-dir examples/summarization \
        --valid-subset ${valid_subsets} \
        --best-checkpoint-metric ${EVAL_LG}_RL_F \
        --eval-rouge --maximize-best-checkpoint-metric \
        --eval-rouge-args '{"beam": 5, "lenpen": 0.6, "max_len_a": 0.78, "max_len_b": 2, "min_len": 2, "no_repeat_ngram_size": 3}' \
        --eval-rouge-print-samples \
        --adapter-num-layer 2 --batch-size 8 \
        --log-interval 800 \
        --freeze-decoder --freeze-encoder \
        --loose-load --save-interval-updates 4000 \
        --max-sentences-valid 8 \
        --patience 12 \
        --component-config ${model_path}/component_config.json \
        $argslist
    
    echo "Put ${local_tensorboard_path} to ${tensorboard_logdir}" >&2
    hadoop fs -put -f ${local_tensorboard_path}/* ${tensorboard_logdir}/
    sleep 300

elif [ "$MODE" == "generate" ]; then
    ckpt_suffix="$7"
    if [ ! -e ${local_dataset_path}}/test_${EVAL_LG}.doc-sum.doc.bin ]; then
        echo "Download resource from ${prefix}/Datasets/multilingual/data-bin/${DATAVER}/${EVAL_LG}/test.* to ${local_dataset_path}/" >&2
        hdfs dfs -get ${prefix}/Datasets/multilingual/data-bin/${DATAVER}/${EVAL_LG}/test.* ${local_dataset_path}/
        renameSplit ${local_dataset_path} "test" "test_${EVAL_LG}"
        echo "Download resource from ${prefix}/Datasets/multilingual/data-bin/${DATAVER}/${EVAL_LG}/dict.* to ${local_dataset_path}/" >&2
        hdfs dfs -get ${prefix}/Datasets/multilingual/data-bin/${DATAVER}/${EVAL_LG}/dict.* ${local_dataset_path}/
    fi

    echo "Generating..."

    rm -r ${local_checkpoint_path}/checkpoint${ckpt_suffix}.pt
    hadoop fs -copyToLocal ${checkpoint_path}/checkpoint${ckpt_suffix}.pt ${local_checkpoint_path}
    echo "Load checkpoints from ${checkpoint_path}/checkpoint${ckpt_suffix}.pt to ${local_checkpoint_path}" >&2

    suffix=$(echo "$argslist" | sed -e "s/-//g"  -e "s/  */_/g")

    # --model-overrides "{'component_config': '${model_path}/component_config.json', 'encoder_version': 'v1', 'arch': 'mbart_summ_abs_large'}" \
    python3 fairseq/generate.py ${local_dataset_path} \
        --path ${local_checkpoint_path}/checkpoint${ckpt_suffix}.pt \
        --task summarization_from_pretrained_mbart_mspm4 \
        --gen-subset test_${EVAL_LG} \
        --source-lang doc --target-lang sum \
        --langs $langs \
        --remove-bpe 'sentencepiece' \
        --max-len-a 0.78 --max-len-b 2 --min-len 2 \
        --lenpen 0.6 \
        --no-repeat-ngram-size 3 \
        --truncate-source \
        --user-dir examples/summarization \
        --batch-size 8 \
        $argslist > ${local_tensorboard_path}/"output$suffix"

        cat ${local_tensorboard_path}/"output$suffix" | grep -P "^H" | \
        sort -V |cut -f 3- | sed -e "s/\[[a-z]\{2\}_[A-Z]\{2\}\]//g" | sed -e "s/\[[a-z]\{,10\}\] //g" > ${local_tensorboard_path}/"test$suffix.hypo"

        echo "Load ground truth file from ${prefix}/Datasets/multilingual/xgiga/raw/test.y.${EVAL_LG}"
        hadoop fs -get ${prefix}/Datasets/multilingual/xgiga/raw/test.y.${EVAL_LG} ${local_dataset_path}/

        hypo=${local_tensorboard_path}/"test$suffix.hypo"
        ref=${local_dataset_path}/test.y.${EVAL_LG}

        if [ ${EVAL_LG} == "zh" ]; then
        # split the reference and hypothesis into chars
            cat ${hypo} | python3 -u ./xnlg/zh_split_words.py > ${local_tensorboard_path}/"test$suffix.hypo.char"
            cat ${ref} | python3 -u ./xnlg/zh_split_words.py > ${local_dataset_path}/test.y.${EVAL_LG}.char
            hypo=${local_tensorboard_path}/"test$suffix.hypo.char"
            ref=${local_dataset_path}/test.y.${EVAL_LG}.char
            python3 ./xnlg/calc_rouge.py --ref ${ref} --hyp ${hypo} --zh True
        else
            python3 ./xnlg/calc_rouge.py --ref ${ref} --hyp ${hypo}
        fi
    echo "Put ${local_tensorboard_path} to ${tensorboard_logdir}" >&2
    hadoop fs -put -f ${local_tensorboard_path}/* ${tensorboard_logdir}/
    sleep 120
elif [ "$MODE" == "dump_embedding" ]; then
    tokenize(){
        base_dir=$(pwd)

        INPUT=$1
        OUTPUT=$2
        DICT=$3

        echo "tokenize ${INPUT} to ${OUTPUT} using mbart's spm..."
        # setup MBART
        MBART=/home/tiger/mbart.cc25.v2
        if [ ! -d $MBART ]; then
            hadoop fs -copyToLocal hdfs://haruna/home/byte_arnold_hl_mlnlc/user/wuxianze.0/Workspace/Multilingual/pretrained/mbart.CC25.v2.tar.gz /home/tiger
            tar -xvzf /home/tiger/mbart.CC25.v2.tar.gz -C /home/tiger
        fi
        MODEL=$MBART/sentence.bpe.model
        python3 ${base_dir}/fairseq/scripts/spm_encode.py --model=$MODEL < $INPUT > $OUTPUT

        echo "first line of ${OUTPUT}"
        head -n 1 $OUTPUT
    }

    # # tokenize raw text
    tokenize $INPATH $INPATH.spm ${local_dataset_path}/dict.doc.txt
    cat $INPATH.spm | sed -e "s/$/ ./" > $INPATH.spm.special 

    echo "Start dumping embedding..."
    suffix=$(echo "$argslist" | sed -e "s/-//g"  -e "s/  */_/g")

    python3 fairseq/interactive_dump_embedding.py ${local_dataset_path}  \
        --path ${local_checkpoint_path}/checkpoint_best.pt \
        --task multi_task_from_pretrained_mbart_mspm4 \
        --gen-subset test \
        --source-lang doc --target-lang sum \
        --langs $langs \
        --results-path ${local_embedding_path} \
        --remove-bpe 'sentencepiece'  \
        --truncate-source \
        --prefix-tokens ${TEST_LG_TAG} --doc-lang ${TEST_LG_TAG} \
        --user-dir examples/summarization \
        --batch-size 1 $argslist \
        --input $INPATH.spm.special
fi
