#!/usr/bin/env bash

# bash xgiga_zmbart_dump_embedding_w_text.sh mspm4_xgiga zh_CN None /opt/tiger/sumtest/tmpdata/test.x.zh test_x_zh 

DATAVER="$1"
TEST_LG_TAG="$2"
EXTEND_DICT="$3"
INPATH="$4"
EMBDIR="$5"

TEST_LG="en"

NAME="zmbart"
argslist=""
for (( i = 6; i <= $# ; i++ ))
  do
    j=${!i}
    argslist="${argslist} $j "
  done
echo $argslist >&2

tokenize(){   
    base_dir=$(pwd)

    INPUT=$1
    OUTPUT=$2
    DICT=$3

    echo "tokenize ${INPUT} to ${OUTPUT} using mbart's spm..."
    # setup MBART
    MBART=/home/tiger/mbart.cc25
    if [ ! -d $MBART ]; then
        hadoop fs -copyToLocal hdfs://haruna/home/byte_arnold_hl_mlnlc/user/wuxianze.0/Workspace/Multilingual/pretrained/mbart.CC25.tar.gz /home/tiger
        tar -xvzf /home/tiger/mbart.CC25.tar.gz -C /home/tiger
    fi
    MODEL=$MBART/sentence.bpe.model
    python3 ${base_dir}/fairseq/scripts/spm_encode.py --model=$MODEL < $INPUT > $OUTPUT

    echo "first line of ${OUTPUT}"
    head -n 1 $OUTPUT
}

cd "$(dirname $0)" || return
echo "Install fairseq" >&2

sudo mkdir -p /usr/lib/python3.7/site-packages/
sudo pip3 install -e fairseq
# pip3 install -r requirements.txt -i http://pypi.byted.org/simple/ --trusted-host=pypi.byted.org
pip3 install torch==1.8.0 -i http://pypi.byted.org/simple/ --trusted-host=pypi.byted.org

export PYROUGE_HOME_DIR=$(pwd)/RELEASE-1.5.5
export PYROUGE_TEMP_PATH=/opt/tiger

pyrouge_set_rouge_path $PYROUGE_HOME_DIR
chmod +x $PYROUGE_HOME_DIR/ROUGE-1.5.5.pl

prefix=hdfs://haruna/home/byte_arnold_hl_mlnlc/user/wuxianze.0

dataset_path=${prefix}/Datasets/multilingual/data-bin/${DATAVER}/${TEST_LG}
pretrained_path=${prefix}/Workspace/Multilingual/pretrained

local_root=~/xgiga_dumpEmb_ZmBART_${NAME}
resource_root=${local_root}/resource
output_path=${local_root}/output
model_path=${local_root}/model
local_embedding_path=${local_root}/embedding/${EMBDIR}
mkdir -p ${resource_root}
mkdir -p ${output_path}
mkdir -p ${model_path}
mkdir -p ${local_embedding_path}

local_dataset_path=${resource_root}/dataset
mkdir -p ${local_dataset_path}
hadoop fs -copyToLocal ${dataset_path}/* ${local_dataset_path}
echo "Download resource from ${dataset_path} to ${local_dataset_path}" >&2

echo "remove ${local_dataset_path}/dict.*.txt"
rm ${local_dataset_path}/dict.*.txt
hdfs dfs -get ${pretrained_path}/dict_extend.txt ${local_dataset_path}
echo "treat ${local_dataset_path}/dict_extend.txt as ${local_dataset_path}/dict.*.txt"
cp ${local_dataset_path}/dict_extend.txt ${local_dataset_path}/dict.doc.txt
cp ${local_dataset_path}/dict_extend.txt ${local_dataset_path}/dict.sum.txt

local_pretrained_path=${model_path}/ZmBART
if [ ! -d ${local_pretrained_path} ]; then
    echo "Load pretrained model from ${pretrained_path}/checkpoint_ZmBART.pt to ${local_pretrained_path}/checkpoint_ZmBART.pt" >&2
    mkdir -p ${local_pretrained_path}
    hdfs dfs -get ${pretrained_path}/checkpoint_ZmBART.pt ${local_pretrained_path}/
else
  echo "Pretrained model in ${local_pretrained_path}" >&2
fi

if [ ! "${EXTEND_DICT}" == "None" ]; then
    echo "Load extend dictionary from ${pretrained_path}/${EXTEND_DICT}.txt to ${local_dataset_path}" >&2
    hdfs dfs -get ${pretrained_path}/${EXTEND_DICT}.txt ${local_dataset_path}
    echo "write the extended dictionary into ${local_dataset_path}/dict.*.txt"
    cat ${local_dataset_path}/${EXTEND_DICT}.txt >> ${local_dataset_path}/dict.doc.txt
    cat ${local_dataset_path}/${EXTEND_DICT}.txt >> ${local_dataset_path}/dict.sum.txt
    if [ ! -e ${local_dataset_path}/${EXTEND_DICT}.txt ]; then
        echo "[ERROR] Load extend dictionary ${EXTEND_DICT}.txt failed!" >&2
    fi
fi

# # tokenize raw text
tokenize $INPATH $INPATH.spm ${local_dataset_path}/dict.doc.txt
cat $INPATH.spm | sed -e "s/$/ ./" > $INPATH.spm.special 

echo "Finish download files" >&2
langs=ar_AR,cs_CZ,de_DE,en_XX,es_XX,et_EE,fi_FI,fr_XX,gu_IN,hi_IN,it_IT,ja_XX,kk_KZ,ko_KR,lt_LT,lv_LV,my_MM,ne_NP,nl_XX,ro_RO,ru_RU,si_LK,tr_TR,vi_VN,zh_CN

echo "Start dumping embedding..."
suffix=$(echo "$argslist" | sed -e "s/-//g"  -e "s/  */_/g")

head -n 1000 $INPATH.spm.special > $INPATH.spm.special.first1000

python3 fairseq/interactive_dump_embedding.py ${local_dataset_path}  \
    --path ${local_pretrained_path}/checkpoint_ZmBART.pt \
    --task multi_task_from_pretrained_mbart_mspm4 \
    --gen-subset test \
    --source-lang doc --target-lang sum \
    --langs $langs \
    --results-path ${local_embedding_path} \
    --remove-bpe 'sentencepiece'  \
    --truncate-source \
    --prefix-tokens ${TEST_LG_TAG} --doc-lang ${TEST_LG_TAG} \
    --user-dir examples/summarization \
    --batch-size 1 $argslist \
    --input $INPATH.spm.special.first1000 \
    --model-overrides "{'layernorm_embedding': True}"
# sleep 600
