#!/bin/bash
# Runs baseline sweep.

set -euo pipefail

readonly DATA=../data
readonly EVALUATION=../evaluation
readonly OUTPUT=output

# Defaults.
readonly BEAM_WIDTH=4
readonly EPOCHS=60
readonly PATIENCE=12
readonly SED_EM_ITERATIONS=10
readonly MAX_ENSEMBLE_SIZE=10

LANGS=""

train() {
    for LEVEL in low medium high; do
        for TRAIN in "${DATA}/${LEVEL}/"*"_train.tsv"; do
            DEV="${TRAIN//_train.tsv/_dev.tsv}"
            TEST="${TRAIN//_train.tsv/_test.tsv}"
            LG="$(basename ${TRAIN} _train.tsv)"
            if [[ ${LANGS} ]] && ! [[ " ${LANGS} " =~ " ${LG} " ]]; then
                # If we've flagged target languages, we wish to ignore all non-targets.
                continue
            fi
            for ENSEMBLE_SIZE in $(seq 1 "${MAX_ENSEMBLE_SIZE}"); do
                OUTPUTPATH="${OUTPUT}/${LEVEL}/${LG}/${ENSEMBLE_SIZE}"
                if [[ -d ${OUTPUTPATH} ]] && [[ "$(ls -A ${OUTPUTPATH})" ]] && ! [[ ${LANGS} ]]; then
                    # We've already created an instance here. So only would do it again if we're looking for a targeted language.
                    continue
                fi
                DEV="${TRAIN//_train.tsv/_dev.tsv}"
                TEST="${TRAIN//_train.tsv/_test.tsv}"
                # We apply NFD unicode normalization.
                python trans/train.py \
                    --dynet-seed "${ENSEMBLE_SIZE}" \
                    --train "${TRAIN}" \
                    --dev "${DEV}" \
                    --test "${TEST}" \
                    --sed-em-iterations "${SED_EM_ITERATIONS}" \
                    --output "${OUTPUTPATH}" \
                    --epochs "${EPOCHS}" \
                    --beam-width "${BEAM_WIDTH}" \
                    --patience "${PATIENCE}" \
                    --nfd &
            done
            wait
        done
    done
}

ensemble() {
    for LEVEL in low medium high; do
        for TRAIN in "${DATA}/${LEVEL}/"*"_train.tsv"; do
            LG="$(basename ${TRAIN} _train.tsv)"
            if [[ ${LANGS} ]] && ! [[ " ${LANGS} " =~ " ${LG} " ]]; then
                # If we've passed target languages, we wish to ignore all non-targets.
                continue
            fi
            OUTPUTPATH="${OUTPUT}/${LEVEL}/${LG}/ensemble"
            if [[ -d ${OUTPUTPATH} ]] && [[ "$(ls -A ${OUTPUTPATH})" ]] && ! [[ ${LANGS} ]]; then
                # We've already trained on this data up to this ensemble. So no need to do again unless we're looking for a specific language.
                continue
            fi
            for SPLIT in dev test; do
                python trans/ensembling.py \
                    --gold "${DATA}/${LEVEL}/${LG}_${SPLIT}.tsv" \
                    --systems "${OUTPUT}/${LEVEL}/${LG}/"*"/${SPLIT}_beam${BEAM_WIDTH}.predictions" \
                    --output "${OUTPUTPATH}"
            done
        done
    done
}

evaluate() {
    # Creates two-column TSV with gold and hypothesis data.
    for LEVEL in low medium high; do
        for TRAIN in "${DATA}/${LEVEL}/"*"_train.tsv"; do
            LG="$(basename ${TRAIN} _train.tsv)"
            for SPLIT in dev test; do
                paste \
                    "${TRAIN//_train.tsv/_${SPLIT}.tsv}" \
                    "${OUTPUT}/${LEVEL}/${LG}/ensemble/${SPLIT}_${MAX_ENSEMBLE_SIZE}ensemble.predictions" \
                    | cut -f2,4 \
                    > "${OUTPUT}/${LEVEL}/${LG}/ensemble/${SPLIT}_${MAX_ENSEMBLE_SIZE}ensemble.tsv"
            done
        done
    done
    # Calls evaluation script.
    for SPLIT in dev test; do
        for LEVEL in low medium high; do
            echo "${SPLIT} ${LEVEL}:"
            python ${EVALUATION}/./evaluate_all.py \
                "${OUTPUT}/${LEVEL}/"*"/ensemble/${SPLIT}_${MAX_ENSEMBLE_SIZE}ensemble.tsv"
            echo
        done
    done
}

main() {
    mkdir -p "${OUTPUT}"
    train
    ensemble
    evaluate
}

while getopts "rt:" OPTION; do
    case "${OPTION}" in
    t)
        # Specifies target language for retraining.
        LANGS=${OPTARG}
        for lang in ${LANGS}; do
            echo "${OUTPUT}/"*"/${lang}"
            find ${OUTPUT}/* -type d -name ${lang} -exec rm -rf {} +
        done
        ;;
    r)
        # Resets entire run. This will delete all prior data.
        echo "This will delete all prior data. If you wish to continue, press y."
        while :; do
            read -n 1 k <&1
            if [[ $k = y ]]; then
                echo
                echo "Resetting run."
                rm -rf ${OUTPUT}
                break
            else
                echo "This will delete all prior training data. If you wish to continue, press y."
            fi
        done
        ;;
    esac
done

main
