#!/bin/bash
#
# Train, test and score a Hierarchical Attention Network Transformer model
# (Miculicich et al., 2018').

# Read script arguments and assign them to variables
# sh/run_han_transformer_iwslt16_wmt14_fr.sh --cuda=1 --t=train --src=en --tgt=fr --k=5 --sdir=en2fr_iwslt16_wmt14_han/k5 --pretrained=en2fr_iwslt16_wmt14_han/k0/checkpoint.avg5.pt --datadir=wmt14
for argument in "$@"
do
    key=$(echo $argument | cut -f1 -d=)
    value=$(echo $argument | cut -f2 -d=)   
    if [[ $key == *"--"* ]]; then
        v="${key/--/}"
        declare $v="${value}" 
   fi
done

# Set cuda device if argument was given
if [ -n "$cuda" ] 
then 
    echo "setting CUDA_VISIBLE_DEVICES=$cuda"
    export CUDA_VISIBLE_DEVICES=$cuda
fi

# Set variables
if [ -n "$cuda" ] ; then export CUDA_VISIBLE_DEVICES=$cuda ; fi
if [ -n "$seed" ] ; then seed=$seed ; else seed=0 ; fi
if [ -n "$datadir" ] ; then data_dir=data/data-bin/iwslt16.dnmt.fr-en/$datadir ; else data_dir=data/data-bin/iwslt16.dnmt.fr-en/standard ; fi
if [[ $sdir != "checkpoints/"* ]]; then sdir=checkpoints/$sdir; fi
pretrained=checkpoints/$pretrained
src=$src
tgt=$tgt

if [ $t = "train" ]
then
    # train
    mkdir -p $sdir/logs
    fairseq-train $data_dir \
    --task translation_han \
    --save-dir $sdir \
    --seed $seed \
    --source-lang $src \
    --target-lang $tgt \
    --n-context-sents $k \
    --pretrained-transformer-checkpoint $pretrained \
    --freeze-transfo-params \
    --arch han_transformer_iwslt_wmt_en_fr \
    --optimizer adam \
    --lr-scheduler inverse_sqrt \
    --lr 1e-4 \
    --warmup-init-lr 1e-7 \
    --warmup-updates 4000 \
    --patience 5 \
    --max-tokens 8192 \
    --keep-best-checkpoints 5 \
    --no-epoch-checkpoints \
    --log-format json \
    --seed 1 \
    | tee $sdir/logs/train.log
elif [ $t = "test" ]
then
    fairseq-generate $data_dir \
    --task translation_han \
    --source-lang $src \
    --target-lang $tgt \
    --path $sdir/checkpoint_best.pt \
    --batch-size 64 \
    --remove-bpe \
    --beam 4 \
    --lenpen 1 \
    --temperature 1.3 \
    --num-workers 8 \
    | tee $sdir/logs/test.log
elif [ $t = "score" ]
then
    grep ^S $sdir/logs/test.log | sed 's/^S-//g' | sort -nk 1 | cut -f2- > $sdir/logs/gen.out.src
    grep ^T $sdir/logs/test.log | sed 's/^T-//g' | sort -nk 1 | cut -f2- > $sdir/logs/gen.out.ref
    grep ^H $sdir/logs/test.log | sed 's/^H-//g' | sort -nk 1 | cut -f3- > $sdir/logs/gen.out.sys
    fairseq-score \
    --sentence-bleu \
    --sys $sdir/logs/gen.out.sys \
    --ref $sdir/logs/gen.out.ref \
    | tee $sdir/logs/score.log
elif [ $t = "average" ]
then
    python scripts/average_checkpoints.py \
        --inputs $sdir/checkpoint.best_loss_* \
        --output $sdir/checkpoint.avg_best_5.pt
else
    echo "Argument is not valid. Type 'train' or 'test'."
fi