#!/bin/bash
#
# "Train" or "test" a standard Trasformer model
# to reproduce Wang et al., 2019' results on IWSLT16 fr-en.

# Read script arguments and assign them to variables
# sh/run_transformer_iwslt16_fr.sh --t=train --src=en --tgt=fr --sdir=en2fr_iwslt16_wmt14_han/k0 --datadir=wmt14
for argument in "$@" 
do

    key=$(echo $argument | cut -f1 -d=)
    value=$(echo $argument | cut -f2 -d=)   
    if [[ $key == *"--"* ]]; then
        v="${key/--/}"
        declare $v="${value}" 
   fi
done

# Set variables
if [ -n "$cuda" ] ; then export CUDA_VISIBLE_DEVICES=$cuda ; fi
if [ -n "$seed" ]; then seed=$seed ; else seed=0 ; fi
if [ -n "$datadir" ]; then data_dir=data/data-bin/iwslt16.dnmt.fr-en/$datadir ; else data_dir=data/data-bin/iwslt16.dnmt.fr-en/standard ; fi
if [[ $sdir != "checkpoints/"* ]]; then sdir=checkpoints/$sdir; fi
src=$src
tgt=$tgt

if [ $t = "train" ]
then
    # train
    mkdir -p $sdir/logs
    fairseq-train $data_dir \
    --save-dir $sdir \
    --seed $seed \
    --source-lang $src \
    --target-lang $tgt \
    --arch transformer_iwslt_fr_en \
    --optimizer adam \
    --lr-scheduler inverse_sqrt \
    --lr 1e-4 \
    --warmup-init-lr 1e-7 \
    --warmup-updates 4000\
    --patience 5 \
    --max-tokens 2048 \
    --log-format json \
    | tee -a $sdir/logs/train.log
elif [ $t = "test" ]
then
    fairseq-generate $data_dir \
    --task translation \
    --source-lang $src \
    --target-lang $tgt \
    --path $sdir/checkpoint.avg5.pt \
    --batch-size 64 \
    --remove-bpe \
    --beam 4 \
    --lenpen 1 \
    --temperature 1.3 \
    --num-workers 8 \
    | tee $sdir/logs/test.log
elif [ $t = "score" ]
then
    grep ^S $sdir/logs/test.log | sed 's/^S-//g' | sort -nk 1 | cut -f2- > $sdir/logs/gen.out.src
    grep ^T $sdir/logs/test.log | sed 's/^T-//g' | sort -nk 1 | cut -f2- > $sdir/logs/gen.out.ref
    grep ^H $sdir/logs/test.log | sed 's/^H-//g' | sort -nk 1 | cut -f3- > $sdir/logs/gen.out.sys
    fairseq-score \
    --sentence-bleu \
    --sys $sdir/logs/gen.out.sys \
    --ref $sdir/logs/gen.out.ref \
    | tee $sdir/logs/score.log
elif [ $t = "average" ]
then
    avg_n_epochs=5
    avg_checkpoint=checkpoint.avg$avg_n_epochs.pt
    python scripts/average_checkpoints.py \
        --inputs $sdir \
        --num-epoch-checkpoints $avg_n_epochs \
        --checkpoint-upper-bound=115 \
        --output $sdir/$avg_checkpoint
else
    echo "Argument is not valid. Type 'train' or 'test'."
fi

