#!/bin/bash
#
# bash sh/run/en-de/iwslt17/transfo_base.sh --t=finetune --sdir=standard/k0 --pretrained=checkpoints/en-de/wmt17/transfo_base/checkpoint.avg_last10.pt


# Read script arguments and assign them to variables
for argument in "$@" 
do
    key=$(echo $argument | cut -f1 -d=)
    value=$(echo $argument | cut -f2 -d=)   
    if [[ $key == *"--"* ]]; then
        v="${key/--/}"
        declare $v="${value}" 
   fi
done

# Set variables
src=en
tgt=de
lang=$src-$tgt
script=sh/run/$lang/iwslt17/transfo_base.sh
task=translation
architecture=transformer_vaswani_wmt_en_fr
test_suites=data/$lang/data-bin/wmt17/test_suites
contrapro=data/en-de/test_suites/ContraPro
if [ -n "$datadir" ]; then datadir=$datadir ; else datadir=data/$lang/data-bin/iwslt17/standard ; fi
if [ -n "$lenpen" ]; then lenpen=$lenpen ; else lenpen=0.6 ; fi
if [[ $sdir != "checkpoints/"* ]]; then sdir=checkpoints/$lang/iwslt17/$sdir; fi

num_workers=8
n_best_checkpoints=5
checkpoint_path=$sdir/checkpoint_best.pt
# checkpoint_path=$sdir/checkpoint.avg_last$n_best_checkpoints.pt
if [ -n "$cuda" ] ; then export CUDA_VISIBLE_DEVICES=$cuda ; fi
if [ -n "$seed" ]; then seed=$seed ; else seed=0 ; fi
if [ -n "$pretrained" ]; then pretrained=$pretrained ; else pretrained=None ; fi
if [ -n "$restore" ]; then restore=$restore ; else restore=checkpoint_last.pt ; fi
if [ -n "$testlog" ]; then testlog=$testlog ; else testlog=test ; fi
if [ -n "$mover" ]; then mover=$mover ; else mover="{}" ; fi
if [ -n "$mt" ]; then maxtok=$mt ; else maxtok=16000 ; fi
if [ -n "$uf" ]; then updatefreq=$uf ; else updatefreq=1 ; fi

if [ $t = "finetune" ]
then
    mkdir -p $sdir/logs
    fairseq-train $datadir \
    --save-dir $sdir \
    --seed $seed \
    --source-lang $src \
    --target-lang $tgt \
    --num-workers $num_workers \
    --finetune-from-model $pretrained \
    --task $task \
    --arch $architecture \
    --optimizer adam --adam-betas '(0.9, 0.98)' --clip-norm 0.0 \
    --lr-scheduler inverse_sqrt --warmup-updates 4000 --min-lr 1e-09 \
    --lr 0.0005 --warmup-init-lr 1e-07 \
    --criterion label_smoothed_cross_entropy --label-smoothing 0.1 \
    --max-tokens $maxtok \
    --update-freq $updatefreq \
    --patience 5 \
    --keep-best-checkpoints $n_best_checkpoints \
    --no-epoch-checkpoints \
    --log-format json \
    | tee -a $sdir/logs/train.log
###############################################################################
elif [ $t = "boom" ]
then
    # # BLEU on PRO (normal and shuffled)
    # for s in "" ".shuffled"; do
    #     datadir=$test_suites/large_pronoun_testset/k3$s
    #     d=large_pronoun_testset$s
    #     # score reference
    #     bash $script --t=test --sdir=$sdir --testlog=$d --cuda=$cuda --datadir=$datadir --lenpen=$lenpen
    # done
    # # BLEU on test set
    # bash $script --t=test --sdir=$sdir --cuda=$cuda --lenpen=$lenpen
    # BLEU on shuffled test set
    datadir=data/$lang/data-bin/iwslt17/test_shuffled
    bash $script --t=test --sdir=$sdir --testlog=test.shuffled --cuda=$cuda --datadir=$datadir --lenpen=$lenpen
###############################################################################
elif [ $t = "results" ]
then
    for s in "" ".shuffled"; do
        d=test$s
        echo "RESULTS FOR $sdir/logs/$d.score"
        echo ""
        cat $sdir/logs/$d.score
        echo "-----------------------------------"
        echo ""
    done
    for s in "" ".shuffled"; do
        d=large_pronoun_testset$s
        echo "RESULTS FOR $sdir/logs/$d.score"
        echo ""
        cat $sdir/logs/$d.score
        echo "-----------------------------------"
        echo ""
    done
###############################################################################
elif [ $t = "test" ]
then
    fairseq-generate $datadir \
    --task $task \
    --source-lang $src \
    --target-lang $tgt \
    --path $checkpoint_path \
    --batch-size 128 \
    --remove-bpe \
    --beam 4 \
    --lenpen $lenpen \
    --temperature 1 \
    --num-workers $num_workers \
    | tee $sdir/logs/$testlog.log
    # score with sacrebleu
    grep ^S $sdir/logs/$testlog.log | sed 's/^S-//g' | sort -nk 1 | cut -f2- | sacremoses detokenize > $sdir/logs/$testlog.out.src
    grep ^T $sdir/logs/$testlog.log | sed 's/^T-//g' | sort -nk 1 | cut -f2- | sacremoses detokenize > $sdir/logs/$testlog.out.ref
    grep ^H $sdir/logs/$testlog.log | sed 's/^H-//g' | sort -nk 1 | cut -f3- | sacremoses detokenize > $sdir/logs/$testlog.out.sys
    tools/mosesdecoder/scripts/generic/multi-bleu-detok.perl $sdir/logs/$testlog.out.ref < $sdir/logs/$testlog.out.sys | tee $sdir/logs/$testlog.score
###############################################################################
elif [ $t = "score" ]
then
    grep ^S $sdir/logs/$testlog.log | sed 's/^S-//g' | sort -nk 1 | cut -f2- | sacremoses detokenize > $sdir/logs/$testlog.out.src
    grep ^T $sdir/logs/$testlog.log | sed 's/^T-//g' | sort -nk 1 | cut -f2- | sacremoses detokenize > $sdir/logs/$testlog.out.ref
    grep ^H $sdir/logs/$testlog.log | sed 's/^H-//g' | sort -nk 1 | cut -f3- | sacremoses detokenize > $sdir/logs/$testlog.out.sys
    tools/mosesdecoder/scripts/generic/multi-bleu-detok.perl $sdir/logs/$testlog.out.ref < $sdir/logs/$testlog.out.sys | tee $sdir/logs/$testlog.score
###############################################################################
elif [ $t = "score-split" ]
then
    grep ^S $sdir/logs/$testlog.log | sed 's/^S-//g' | sort -nk 1 | cut -f2- | paste -d " "  - - > $sdir/logs/$testlog.out.src
    grep ^T $sdir/logs/$testlog.log | sed 's/^T-//g' | sort -nk 1 | cut -f2- | paste -d " "  - - > $sdir/logs/$testlog.out.ref
    grep ^H $sdir/logs/$testlog.log | sed 's/^H-//g' | sort -nk 1 | cut -f3- | paste -d " "  - - > $sdir/logs/$testlog.out.sys
    fairseq-score \
    --sys $sdir/logs/$testlog.out.sys \
    --ref $sdir/logs/$testlog.out.ref \
    | tee $sdir/logs/$testlog.score
###############################################################################
elif [ $t = "score-ref" ]
then
    fairseq-generate $datadir \
    --task $task \
    --source-lang $src \
    --target-lang $tgt \
    --path $checkpoint_path \
    --model-overrides $mover \
    --score-reference \
    --batch-size 64 \
    --remove-bpe \
    --num-workers $num_workers \
    | tee $sdir/logs/$testlog.log
###############################################################################
elif [ $t = "average" ]
then
    python scripts/average_checkpoints.py \
        --inputs $sdir/checkpoint.best_* \
        --output $sdir/checkpoint.$n_best_checkpoints.best.average.pt
###############################################################################
elif [ $t = "test-suites" ]
then
    # evaluate on ContraPro (original and with shuffled context)
    for s in ".shuffled"; do
        datadir=$test_suites/large_pronoun/k3$s
        d=large_pronoun$s
        # score reference
        bash $script --t=score-ref --src=$src --tgt=$tgt --sdir=$sdir --datadir=$datadir --testlog=$d --cuda=$cuda
        # evaluate
        echo "extract scores..."
        grep ^H $sdir/logs/$d.log | sed 's/^H-//g' | sort -nk 1 | cut -f2 > $sdir/logs/$d.full_score 
        awk 'NR % 4 == 0' $sdir/logs/$d.full_score > $sdir/logs/$d.score
        echo "evaluate model performance on test-suite by comparing scores..."
        python3 $contrapro/evaluate.py --reference $contrapro/contrapro.json --scores $sdir/logs/$d.score --maximize > $sdir/logs/$d.result
    done
    echo "-----------------------------------"
    echo ""
    # print results
    for s in "" ".shuffled"; do
        d=large_pronoun$s
        echo "Results for $d"
        echo "file: $sdir/logs/$d.result"
        grep total $sdir/logs/$d.result
        echo "-----------------------------------"
        echo ""
    done
###############################################################################
else
    echo "Argument t is not valid."
fi