#!/bin/bash
# Creates splits for the PTB 3 POS data.

set -euo pipefail 

# Standard split (https://aclweb.org/aclwiki/POS_Tagging_(State_of_the_art)):
#   train: 00-18
#   dev: 19-21
#   test: 22-24
standard_split() {
    echo "Standard split (${1})..."
    local -r OLD_DIR="${1}/original"
    local -r NEW_DIR="${1}/standard_split"
    mkdir -p "${NEW_DIR}"
    ./parse_to_textproto.py \
        "${OLD_DIR}"/0[0-9]/* \
        "${OLD_DIR}"/1[0-8]/* \
        > "${NEW_DIR}/train.textproto" &
    ./parse_to_textproto.py \
        "${OLD_DIR}"/19/* \
        "${OLD_DIR}"/2[0-1]/* \
        > "${NEW_DIR}/dev.textproto" &
    ./parse_to_textproto.py \
        "${OLD_DIR}"/2[2-4]/* \
        > "${NEW_DIR}/test.textproto" &
    wait
}

# Twenty randomly-generated 80%/10%/10% splits.
random_splits() {
    echo "Random splits (${1})..."
    local -r OLD_DIR="${1}/standard_split"
    local -r NEW_DIR="${1}/random_splits"
    # Temporarily creates a single file.
    local -r ALL=$(mktemp --tmpdir all.XXXXX.$$.textproto)
    cat "${OLD_DIR}/train.textproto" \
        "${OLD_DIR}/dev.textproto" \
        "${OLD_DIR}/test.textproto" \
        > "${ALL}"
    # The first twenty taxicab numbers (http://oeis.org/A001235).
    local -r SEEDS=(1729 4104 13832 20683 32832 39312 40033 46683 64232 65728
                    110656 110808 134379 149389 165464 171288 195841 216027
                    216125 262656)
    for SEED in ${SEEDS[@]}; do
        SEED_DIR="${NEW_DIR}/${SEED}"
        mkdir -p "${SEED_DIR}"
        ./random_split.py \
            --seed="${SEED}" \
            --input_textproto_path="${ALL}" \
            --output_train_textproto_path="${SEED_DIR}/train.textproto" \
            --output_dev_textproto_path="${SEED_DIR}/dev.textproto" \
            --output_test_textproto_path="${SEED_DIR}/test.textproto"
    done
    rm "${ALL}"
}

# Twenty randomly-generated 80%/10%/10% splits.
kfold_random_splits() {
    echo "Random kfold splits (${1})..."
    local -r OLD_DIR="${1}/standard_split"
    local -r NEW_DIR="${1}/random_kfold_splits"
    # Temporarily creates a single file.
    local -r ALL=$(mktemp --tmpdir all.XXXXX.$$.textproto)
    cat "${OLD_DIR}/train.textproto" \
        "${OLD_DIR}/dev.textproto" \
        "${OLD_DIR}/test.textproto" \
        > "${ALL}"
    # The first twenty taxicab numbers (http://oeis.org/A001235).
    local -r SEEDS=(1729 4104 13832 20683 32832 39312 40033 46683 64232 65728
                    110656 110808 134379 149389 165464 171288 195841 216027
                    216125 262656)
    
    mkdir -p "${NEW_DIR}"

    for SEED in ${SEEDS[@]}; do
        ./kfold.py \
            --seed="${SEED}" \
            --input_textproto_path="${ALL}" \
            --output_path="${NEW_DIR}"
    done
    rm "${ALL}"
}

# Subsamples.
subsamples() {
    echo "Subsamples ($1)..."
    local -r OLD_DIR="${1}/standard_split"
    local -r NEW_DIR="${1}/subsamples"
    local -r SIZES=(100 1000 10000)
    for SIZE in ${SIZES[@]}; do
        SIZE_DIR="${NEW_DIR}/${SIZE}"
        mkdir -p "${SIZE_DIR}"
        ./subsample.py \
            --size=${SIZE} \
            "${OLD_DIR}/train.textproto" \
            "${SIZE_DIR}/train.textproto"
        cp "${OLD_DIR}/dev.textproto" "${OLD_DIR}/test.textproto" "${SIZE_DIR}"
    done
}

all_three() {
    standard_split "$1"
    kfold_random_splits "$1"
#    subsamples "$1"
#    random_splits "$1"

}

main() {
    all_three "data/ptb-3"
    all_three "data/ontonotes-5"
}

main
