#!/bin/bash
# STANDARD
# bash sh/run/en-fr/iwslt17/han.sh --t=train --cuda=0 --k=3 --sdir=standard/k3 --pretrained=checkpoints/en-fr/iwslt17/standard/k0/checkpoint_best.pt
# SPLIT
# bash sh/run/en-fr/iwslt17/han.sh --t=train --cuda=0 --k=3 --sdir=split/k3 --pretrained=checkpoints/en-fr/iwslt17/standard/k0/checkpoint_best.pt --datadir=data/en-fr/data-bin/iwslt17/split
# bash sh/run/en-fr/iwslt17/han.sh --t=finetune --cuda=1 --k=3 --sdir=fromsplit/k3 --pretrained=checkpoints/en-fr/iwslt17/split/k3/checkpoint_best.pt
# FROMNEI
# bash sh/run/en-fr/iwslt17/han.sh --t=train --cuda=0 --k=3 --sdir=split/fromnei/k3 --restore=checkpoints/en-fr/nei/split/k3/checkpoint_best.pt --datadir=data/en-fr/data-bin/iwslt17/split
# bash sh/run/en-fr/iwslt17/han.sh --t=finetune --cuda=1 --k=1 --sdir=fromsplit/fromnei/k1 --pretrained=checkpoints/en-fr/iwslt17/split/fromnei/k1/checkpoint_best.pt
# bash sh/run/en-fr/iwslt17/han.sh --t=finetune --cuda=1 --k=1 --sdir=standard/fromnei/k1 --pretrained=checkpoints/en-fr/nei/standard/k1/checkpoint_best.pt
# CUR
# bash sh/run/en-fr/iwslt17/han.sh --t=finetune --cuda=0 --k=3 --sdir=standard/cur/k3 --pretrained=checkpoints/en-fr/iwslt17/standard/k1/checkpoint_best.pt

# Read script arguments and assign them to variables
for argument in "$@" 
do
    key=$(echo $argument | cut -f1 -d=)
    value=$(echo $argument | cut -f2 -d=)   
    if [[ $key == *"--"* ]]; then
        v="${key/--/}"
        declare $v="${value}" 
   fi
done

# Set variables
src=en
tgt=fr
lang=$src-$tgt
script=sh/run/$lang/iwslt17/han.sh
task=translation_han
architecture=han_transformer_wmt_en_fr
test_suites=data/$lang/data-bin/wmt14/test_suites
bawden=data/$lang/bawden
if [ -n "$datadir" ]; then datadir=$datadir ; else datadir=data/$lang/data-bin/iwslt17/standard ; fi
if [ -n "$lenpen" ]; then lenpen=$lenpen ; else lenpen=0.6 ; fi
if [[ $sdir != "checkpoints/"* ]]; then sdir=checkpoints/$lang/iwslt17/$sdir; fi

num_workers=8
detokenizer=tools/mosesdecoder/scripts/tokenizer/tokenizer.perl
n_best_checkpoints=5
# checkpoint_path=$sdir/checkpoint_last.pt
checkpoint_path=$sdir/checkpoint_best.pt
# checkpoint_path=$sdir/checkpoint.avg_last$n_best_checkpoints.pt
if [ -n "$cuda" ] ; then export CUDA_VISIBLE_DEVICES=$cuda ; fi
if [ -n "$seed" ]; then seed=$seed ; else seed=0 ; fi
if [ -n "$pretrained" ]; then pretrained=$pretrained ; else pretrained=None ; fi
if [ -n "$restore" ]; then restore=$restore ; else restore=checkpoint_last.pt ; fi
if [ -n "$testlog" ]; then testlog=$testlog ; else testlog=test ; fi
if [ -n "$mover" ]; then mover=$mover ; else mover="{}" ; fi
if [ -n "$mt" ]; then maxtok=$mt ; else maxtok=8000 ; fi
if [ -n "$uf" ]; then updatefreq=$uf ; else updatefreq=2 ; fi


if [ $t = "train" ]
then
    mkdir -p $sdir/logs
    python3 -u train.py $datadir \
    --save-dir $sdir \
    --seed $seed \
    --source-lang $src \
    --target-lang $tgt \
    --num-workers $num_workers \
    --task $task \
    --arch $architecture \
    --pretrained-transformer-checkpoint $pretrained \
    --restore-file $restore \
    --n-context-sents $k \
    --freeze-transfo-params \
    --optimizer adam --adam-betas '(0.9, 0.98)' --clip-norm 0.0 \
    --lr-scheduler inverse_sqrt --warmup-updates 4000 --min-lr 1e-09 \
    --lr 1e-03 --warmup-init-lr 1e-07 \
    --criterion label_smoothed_cross_entropy --label-smoothing 0.1 \
    --max-tokens $maxtok \
    --update-freq $updatefreq \
    --patience 5 \
    --no-epoch-checkpoints \
    --log-format json \
    | tee -a $sdir/logs/train.log
###############################################################################
elif [ $t = "finetune" ]
then
    mkdir -p $sdir/logs
    python3 -u train.py $datadir \
    --save-dir $sdir \
    --seed $seed \
    --source-lang $src \
    --target-lang $tgt \
    --num-workers $num_workers \
    --finetune-from-model $pretrained \
    --task $task \
    --arch $architecture \
    --n-context-sents $k \
    --freeze-transfo-params \
    --optimizer adam --adam-betas '(0.9, 0.98)' --clip-norm 0.0 \
    --lr-scheduler fixed --lr 2e-04 --fa 1 --lr-shrink 0.99 \
    --criterion label_smoothed_cross_entropy --label-smoothing 0.1 \
    --max-tokens $maxtok \
    --update-freq $updatefreq \
    --patience 5 \
    --no-epoch-checkpoints \
    --log-format json \
    | tee -a $sdir/logs/train.log
###############################################################################
elif [ $t = "results" ]
then
    for d in lexical_choice; do
        echo "RESULTS FOR $d"
        cat $sdir/logs/$d.result
        echo "-----------------------------------"
        echo ""
    done
    for s in "" ".shuffled"; do
        d=test$s
        echo "RESULTS FOR $sdir/logs/$d.score"
        echo ""
        cat $sdir/logs/$d.score
        echo "-----------------------------------"
        echo ""
    done
    for s in "" ".shuffled"; do
        d=large_pronoun$s
        echo "RESULTS FOR $sdir/logs/$d.results"
        echo ""
        grep total $sdir/logs/$d.result
        echo "-----------------------------------"
        echo ""
    done
###############################################################################
elif [ $t = "test" ]
then
    fairseq-generate $datadir \
    --task $task \
    --source-lang $src \
    --target-lang $tgt \
    --path $checkpoint_path \
    --batch-size 128 \
    --remove-bpe \
    --beam 4 \
    --lenpen $lenpen \
    --temperature 1 \
    --num-workers $num_workers \
    | tee $sdir/logs/$testlog.log
    # score with sacrebleu
    grep ^S $sdir/logs/$testlog.log | sed 's/^S-//g' | sort -nk 1 | cut -f2- | sacremoses detokenize > $sdir/logs/$testlog.out.src
    grep ^T $sdir/logs/$testlog.log | sed 's/^T-//g' | sort -nk 1 | cut -f2- | sacremoses detokenize > $sdir/logs/$testlog.out.ref
    grep ^H $sdir/logs/$testlog.log | sed 's/^H-//g' | sort -nk 1 | cut -f3- | sacremoses detokenize > $sdir/logs/$testlog.out.sys
    tools/mosesdecoder/scripts/generic/multi-bleu-detok.perl $sdir/logs/$testlog.out.ref < $sdir/logs/$testlog.out.sys | tee $sdir/logs/$testlog.score
###############################################################################
elif [ $t = "score" ]
then
    grep ^S $sdir/logs/$testlog.log | sed 's/^S-//g' | sort -nk 1 | cut -f2- | sacremoses detokenize > $sdir/logs/$testlog.out.src
    grep ^T $sdir/logs/$testlog.log | sed 's/^T-//g' | sort -nk 1 | cut -f2- | sacremoses detokenize > $sdir/logs/$testlog.out.ref
    grep ^H $sdir/logs/$testlog.log | sed 's/^H-//g' | sort -nk 1 | cut -f3- | sacremoses detokenize > $sdir/logs/$testlog.out.sys
    tools/mosesdecoder/scripts/generic/multi-bleu-detok.perl $sdir/logs/$testlog.out.ref < $sdir/logs/$testlog.out.sys | tee $sdir/logs/$testlog.score
###############################################################################
elif [ $t = "score-split" ]
then
    grep ^S $sdir/logs/$testlog.log | sed 's/^S-//g' | sort -nk 1 | cut -f2- | paste -d " "  - - > $sdir/logs/$testlog.out.src
    grep ^T $sdir/logs/$testlog.log | sed 's/^T-//g' | sort -nk 1 | cut -f2- | paste -d " "  - - > $sdir/logs/$testlog.out.ref
    grep ^H $sdir/logs/$testlog.log | sed 's/^H-//g' | sort -nk 1 | cut -f3- | paste -d " "  - - > $sdir/logs/$testlog.out.sys
    fairseq-score \
    --sys $sdir/logs/$testlog.out.sys \
    --ref $sdir/logs/$testlog.out.ref \
    | tee $sdir/logs/$testlog.score
###############################################################################
elif [ $t = "score-ref" ]
then
    fairseq-generate $datadir \
    --task $task \
    --source-lang $src \
    --target-lang $tgt \
    --path $checkpoint_path \
    --model-overrides $mover \
    --score-reference \
    --batch-size 64 \
    --remove-bpe \
    --num-workers $num_workers \
    | tee $sdir/logs/$testlog.log
###############################################################################
elif [ $t = "average" ]
then
    python scripts/average_checkpoints.py \
        --inputs $sdir/checkpoint.best_* \
        --output $sdir/checkpoint.$n_best_checkpoints.best.average.pt
###############################################################################
elif [ $t = "test-suites" ]
then
    # evaluate on test-set
    bash $script --t=test --sdir=$sdir --cuda=$cuda --lenpen=$lenpen
    # evaluate on shuffled test-set
    datadir=data/$lang/data-bin/iwslt17/test_shuffled
    bash $script --t=test --sdir=$sdir --testlog=test.shuffled --cuda=$cuda --datadir=$datadir --lenpen=$lenpen
    # evaluate on Bawden's test suites
    for d in lexical_choice; do
        datadir=$test_suites/$d
        # score reference
        bash $script --t=score-ref --src=$src --tgt=$tgt --sdir=$sdir --datadir=$datadir --testlog=$d --cuda=$cuda
        # --mover="{'n_context_sents':'1'}"
        # evaluate
        echo "extract scores..."
        grep ^H $sdir/logs/$d.log | sed 's/^H-//g' | sort -nk 1 | cut -f2 > $sdir/logs/$d.full_score 
        awk 'NR % 2 == 0' $sdir/logs/$d.full_score > $sdir/logs/$d.score
        echo "evaluate model performance on test-suite by comparing scores..."
        orig=$bawden/discourse-mt-test-sets/
        python3 $orig/scripts/evaluate.py $orig/test-sets/$d.json $d $sdir/logs/$d.score --maximise > $sdir/logs/$d.result
    done
    # evaluate on large pronouns test suite (original and with shuffled context)
    for s in "" ".shuffled"; do
        datadir=$test_suites/large_pronoun/k3$s
        d=large_pronoun$s
        # score reference
        bash $script --t=score-ref --src=$src --tgt=$tgt --sdir=$sdir --datadir=$datadir --testlog=$d --cuda=$cuda
        # evaluate
        echo "extract scores..."
        grep ^H $sdir/logs/$d.log | sed 's/^H-//g' | sort -nk 1 | cut -f2 > $sdir/logs/$d.full_score 
        awk 'NR % 4 == 0' $sdir/logs/$d.full_score > $sdir/logs/$d.score
        echo "evaluate model performance on test-suite by comparing scores..."
        orig=$bawden/Large-contrastive-pronoun-testset-EN-FR/OpenSubs
        python3 $orig/scripts/evaluate.py --reference $orig/testset-$lang.json --scores $sdir/logs/$d.score --maximize --results-file $sdir/logs/$d.results > $sdir/logs/$d.result
    done
    echo "-----------------------------------"
    echo ""
    # print results
    for d in lexical_choice; do
        echo "Results for $d"
        cat $sdir/logs/$d.result
        echo "-----------------------------------"x
        echo ""
    done
    for s in "" ".shuffled"; do
        d=large_pronoun$s
        echo "Results for $d"
        grep total $sdir/logs/$d.result
        echo "-----------------------------------"
        echo ""
    done
###############################################################################
else
    echo "Argument t is not valid."
fi