#!/bin/bash

# bash sh/run/en-ru/truecase_voita_opensubs/context_aware/han.sh --t=test-suites --cuda=1 --lenpen=0.6 --sdir=standard/k1

# Read script arguments and assign them to variables
for argument in "$@" 
do
    key=$(echo $argument | cut -f1 -d=)
    value=$(echo $argument | cut -f2 -d=)   
    if [[ $key == *"--"* ]]; then
        v="${key/--/}"
        declare $v="${value}" 
   fi
done

# Set variables
src=en
tgt=ru
lang=$src-$tgt
corpus=truecase_voita_opensubs/context_aware

script=sh/run/$lang/$corpus/han.sh
task=translation_han
architecture=han_transformer_wmt_en_fr
# test_suites=data/$lang/data-bin/wmt17/test_suites
# contrapro=data/en-de/test_suites/ContraPro
if [ -n "$datadir" ]; then datadir=$datadir ; else datadir=data/$lang/data-bin/$corpus/standard ; fi
if [[ $sdir != "checkpoints/"* ]]; then sdir=checkpoints/$lang/$corpus/$sdir; fi

num_workers=8
n_best_checkpoints=5
checkpoint_path=$sdir/checkpoint_best.pt
# checkpoint_path=$sdir/checkpoint.avg_last$n_best_checkpoints.pt
if [ -n "$pretrained" ]; then pretrained=$pretrained ; else pretrained=None ; fi
if [ -n "$cuda" ] ; then export CUDA_VISIBLE_DEVICES=$cuda ; fi
if [ -n "$seed" ]; then seed=$seed ; else seed=0 ; fi
if [ -n "$testlog" ]; then testlog=$testlog ; else testlog=test ; fi
if [ -n "$lenpen" ]; then lenpen=$lenpen ; else lenpen=0.6 ; fi
if [ -n "$mover" ]; then mover=$mover ; else mover="{}" ; fi
if [ -n "$mt" ]; then maxtok=$mt ; else maxtok=16000 ; fi
if [ -n "$uf" ]; then updatefreq=$uf ; else updatefreq=2 ; fi
if [ -n "$siu" ]; then siu=$siu ; else siu=2000 ; fi

if [ $t = "train" ]
then
    mkdir -p $sdir/logs
    fairseq-train $datadir \
    --save-dir $sdir \
    --seed $seed \
    --source-lang $src \
    --target-lang $tgt \
    --num-workers $num_workers \
    --task $task \
    --arch $architecture \
    --pretrained-transformer-checkpoint $pretrained \
    --n-context-sents $k \
    --freeze-transfo-params \
    --optimizer adam --adam-betas '(0.9, 0.98)' --clip-norm 0.0 \
    --lr-scheduler inverse_sqrt --warmup-updates 4000 --min-lr 1e-09 \
    --lr 5e-04 --warmup-init-lr 1e-07 \
    --criterion label_smoothed_cross_entropy --label-smoothing 0.1 \
    --max-tokens $maxtok \
    --update-freq $updatefreq \
    --max-source-positions 200 \
    --max-target-positions 200 \
    --patience 5 \
    --save-interval-updates $siu \
    --keep-interval-updates 10 \
    --no-epoch-checkpoints \
    --log-format json \
    | tee -a $sdir/logs/train.log
###############################################################################
elif [ $t = "finetune" ]
then
    mkdir -p $sdir/logs
    fairseq-train $datadir \
    --save-dir $sdir \
    --seed $seed \
    --source-lang $src \
    --target-lang $tgt \
    --num-workers $num_workers \
    --finetune-from-model $pretrained \
    --task $task \
    --arch $architecture \
    --n-context-sents $k \
    --freeze-transfo-params \
    --optimizer adam --adam-betas '(0.9, 0.98)' --clip-norm 0.0 \
    --lr-scheduler fixed --lr 2e-04 --fa 1 --lr-shrink 0.99 \
    --criterion label_smoothed_cross_entropy --label-smoothing 0.1 \
    --max-tokens $maxtok \
    --update-freq $updatefreq \
    --max-source-positions 200 \
    --max-target-positions 200 \
    --patience 5 \
    --save-interval-updates $siu \
    --keep-interval-updates 10 \
    --no-epoch-checkpoints \
    --log-format json \
    | tee -a $sdir/logs/train.log
###############################################################################
elif [ $t = "test" ]
then
    fairseq-generate $datadir \
    --task $task \
    --source-lang $src \
    --target-lang $tgt \
    --path $checkpoint_path \
    --batch-size 64 \
    --remove-bpe \
    --beam 4 \
    --lenpen $lenpen \
    --temperature 1.2 \
    --num-workers $num_workers \
    | tee $sdir/logs/$testlog.log
    # score with multi-bleu
    grep ^S $sdir/logs/$testlog.log | sed 's/^S-//g' | sort -nk 1 | awk 'NR % 4 == 0' | cut -f2- | awk '{print tolower($0)}' > $sdir/logs/$testlog.out.src
    grep ^T $sdir/logs/$testlog.log | sed 's/^T-//g' | sort -nk 1 | awk 'NR % 4 == 0' | cut -f2- | awk '{print tolower($0)}' > $sdir/logs/$testlog.out.ref
    grep ^H $sdir/logs/$testlog.log | sed 's/^H-//g' | sort -nk 1 | awk 'NR % 4 == 0' | cut -f3- | awk '{print tolower($0)}' > $sdir/logs/$testlog.out.sys
    tools/mosesdecoder/scripts/generic/multi-bleu.perl $sdir/logs/$testlog.out.ref < $sdir/logs/$testlog.out.sys | tee $sdir/logs/$testlog.score
###############################################################################
elif [ $t = "score-ref" ]
then
    fairseq-generate $datadir \
    --task $task \
    --source-lang $src \
    --target-lang $tgt \
    --path $checkpoint_path \
    --model-overrides $mover \
    --score-reference \
    --batch-size 64 \
    --remove-bpe \
    --num-workers $num_workers \
    | tee $sdir/logs/$testlog.log
###############################################################################
elif [ $t = "test-suites" ]
then
    # # evaluate on test-set
    # bash $script --t=test --sdir=$sdir --cuda=$cuda --lenpen=$lenpen --mover=$mover
    # evaluate on consistency testset
    datadir=data/en-ru/data-bin/truecase_voita_opensubs/testset_consistency/ellipsis_vp
    d=ellipsis_vp
    # score reference
    # bash $script --t=score-ref --src=$src --tgt=$tgt --sdir=$sdir --datadir=$datadir --testlog=$d --cuda=$cuda --mover=$mover
    # evaluate
    echo "extract scores..."
    grep ^H $sdir/logs/$d.log | sed 's/^H-//g' | sort -nk 1 | cut -f2 > $sdir/logs/$d.full_score 
    awk 'NR % 4 == 0' $sdir/logs/$d.full_score | cut -c2- > $sdir/logs/$d.score
    echo "evaluate model performance on test-suite by comparing scores..."
    repo=data/en-ru/test_suites/good-translation-wrong-in-context/
    python3 $repo/scripts/evaluate_consistency.py --repo-dir $repo --test ellipsis_vp --scores $sdir/logs/$d.score --results-file $sdir/logs/$d.results > $sdir/logs/$d.result
    echo "-----------------------------------"
    # bash $script --t=results --sdir=$sdir
###############################################################################
elif [ $t = "results" ]
then
    d=test
    echo "RESULTS FOR $sdir/logs/$d.score"
    echo ""
    cat $sdir/logs/$d.score
    echo "-----------------------------------"
    d=ellipsis_vp
    echo "RESULTS FOR $sdir/logs/$d.result"
    echo ""
    cat $sdir/logs/$d.result
    echo "-----------------------------------"
    echo ""
###############################################################################
elif [ $t = "average" ]
then
    python scripts/average_checkpoints.py \
        --inputs $sdir/checkpoint_*0.pt \
        --output $sdir/checkpoint.avg_last5.pt
###############################################################################
else
    echo "Argument t is not valid."
fi