#!/usr/bin/env bash
set -e

# bash mlsum_predata_trans_mono.sh /mnt/bd/lab-wxz/clt/mlsum/en MSPM

DATADIR="$1"
MODE="$2"
DIVISIONS="$3"

MERGED_NAME="en_fr_mono_retran"
SUMM_LANG="en"
TRAN_LANG="fr"
TRANDIR="/mnt/bd/lab-wxz/clt/mlsum/en2fr"

declare -A LANG_TAGS
LANG_TAGS=([en]="en_XX" [fr]="fr_XX" [zh]="zh_CN")

if [ -z "$DATADIR" ]; then
    echo "Lose the raw data dir!"
    exit
fi
if [[ -n "$DIVISIONS" ]] && [[ "$DIVISIONS" == "test" ]]; then
    PARTS=(test)
else
    PARTS=(train dev)
fi
echo "Parts are : ${PARTS[*]}"

DATA="$DATADIR"
TOKEN="$DATADIR/$MODE/${MERGED_NAME}"
if [ ! -d "$TOKEN"  ] ; then
    mkdir -p "$TOKEN"
fi

if [[ "$MODE" == "MSPM" ]]; then
    MBART=/home/tiger/mbart.cc25.v2
    MODEL=$MBART/sentence.bpe.model
    DICT=$MBART/dict_extend.txt

    echo "MSPM encoding for dataset..."
    for SPLIT in ${PARTS[*]}
    do
        echo $TOKEN/$SPLIT.${MERGED_NAME}.spm.doc
        if [ -f "$TOKEN/$SPLIT.${MERGED_NAME}.spm.doc" ]; then
            rm $TOKEN/$SPLIT.${MERGED_NAME}.spm.*
        fi

        LG="${SUMM_LANG}"
        TAG=${LANG_TAGS[$LG]}
        echo "  encoding $DATA/$SPLIT.$LG.doc to $TOKEN/$SPLIT.$LG.spm.doc ..."
        python3 fairseq/scripts/spm_encode.py --model=$MODEL < "$DATA/$SPLIT.$LG.doc" | sed -e "s/< q >/ <\/s>/g" -e "s/^/<s> /" -e "s/$/ <\/s>/" -e "s/^/[${TAG}] /" > "$TOKEN/$SPLIT.$LG.spm.doc"
        echo "  encoding $DATA/$SPLIT.$LG.sum to $TOKEN/$SPLIT.$LG.spm.sum ..."
        python3 fairseq/scripts/spm_encode.py --model=$MODEL < "$DATA/$SPLIT.$LG.sum" | sed -e "s/< q >/ <\/s>/g" -e "s/^/<s> /" -e "s/$/ <\/s>/" -e "s/^/[${TAG}] /" > "$TOKEN/$SPLIT.$LG.spm.sum"
        cat "$TOKEN/$SPLIT.$LG.spm.doc" >> "$TOKEN/$SPLIT.${MERGED_NAME}.spm.doc"
        cat "$TOKEN/$SPLIT.$LG.spm.sum" >> "$TOKEN/$SPLIT.${MERGED_NAME}.spm.sum"

        LG="${TRAN_LANG}"
        TAG=${LANG_TAGS[$LG]}
        TRANTOKEN="$TRANDIR/$MODE/${MERGED_NAME}"
        if [ ! -d "$TRANTOKEN"  ] ; then
            mkdir -p "$TRANTOKEN"
        fi
        echo "  encoding $TRANDIR/train.en2.$LG.doc to $TRANTOKEN/$SPLIT.$LG.spm.doc ..."
        python3 fairseq/scripts/spm_encode.py --model=$MODEL < "$TRANDIR/train.en2.$LG.doc.noempty" | sed -e "s/< q >/ <\/s>/g" -e "s/^/<s> /" -e "s/$/ <\/s>/" -e "s/^/[${TAG}] /" > "$TRANTOKEN/$SPLIT.$LG.spm.doc"
        echo "  encoding $TRANDIR/train.en2.$LG.sum to $TRANTOKEN/$SPLIT.$LG.spm.sum ..."
        python3 fairseq/scripts/spm_encode.py --model=$MODEL < "$TRANDIR/train.en2.$LG.sum.noempty" | sed -e "s/< q >/ <\/s>/g" -e "s/^/<s> /" -e "s/$/ <\/s>/" -e "s/^/[${TAG}] /" > "$TRANTOKEN/$SPLIT.$LG.spm.sum"
        cat "$TRANTOKEN/$SPLIT.$LG.spm.doc" >> "$TOKEN/$SPLIT.${MERGED_NAME}.spm.doc"
        cat "$TRANTOKEN/$SPLIT.$LG.spm.sum" >> "$TOKEN/$SPLIT.${MERGED_NAME}.spm.sum"
    done
fi

echo "Generating data-bin for dataset..."
INPUT="$TOKEN"
BINDIR="$DATADIR/data-bin/$MODE/${MERGED_NAME}"
if [[ "$MODE" == "SPM" ]] || [[ "$MODE" == "MSPM" ]]; then
    TOKENTYPE="spm"
else
    TOKENTYPE="bpe"
fi

echo "Binarized $INPUT ($TOKENTYPE) to $BINDIR with dict $DICT"

if [ ${#PARTS[*]} == 2 ]; then
    python3 fairseq/preprocess.py \
        --source-lang doc \
        --target-lang sum \
        --trainpref "$INPUT/train.${MERGED_NAME}.$TOKENTYPE" \
        --validpref "$INPUT/dev.${MERGED_NAME}.$TOKENTYPE" \
        --destdir "$BINDIR" \
        --srcdict "$DICT" \
        --tgtdict "$DICT" \
        --workers 30
else
    python3 fairseq/preprocess.py \
        --source-lang doc \
        --target-lang sum \
        --testpref "$INPUT/test.${MERGED_NAME}.$TOKENTYPE"  \
        --destdir "$BINDIR" \
        --thresholdtgt 0 \
        --thresholdsrc 0 \
        --srcdict "$DICT" \
        --tgtdict "$DICT" \
        --workers 30
fi