#!/usr/bin/env python
"""Randomly generates a 10-fold split for a text-format PB."""

import argparse
import logging
import random
import os

import tagdata_pb2
import textproto

from sklearn.model_selection import KFold, train_test_split
from pathlib import Path
from tqdm import tqdm


def main(args):
    logging.info("Seed: %d", args.seed)
    random.seed(args.seed)
    kfold = KFold(n_splits = 10, shuffle=False)

    base_path = Path(args.output_path) 
    
    with open(args.input_textproto_path, "r") as source:
        sentences = textproto.read_sentences(source)
    
    # We have to copy into a list so as to have __setitem__.
    sentences = list(sentences.sentences)
    length = len(sentences)
    shard_size = length // 10
    random.shuffle(sentences)

    for fold_no, (train_dev_index, test_index) in tqdm(enumerate(kfold.split(sentences)), total=10, desc=str(args.seed)):
        path_per_repeat = base_path / f"{args.seed}_{fold_no}"
        path_per_repeat.mkdir(parents=True, exist_ok=True)

        train_index, dev_index = train_test_split(
            train_dev_index, test_size=shard_size, random_state=args.seed
        )
        train_sentences = [sentences[x] for x in train_index]
        dev_sentences = [sentences[x] for x in dev_index]
        test_sentences = [sentences[x] for x in test_index]

        message = tagdata_pb2.Sentences()
        message.sentences.extend(train_sentences)
        with open(path_per_repeat / "train.textproto", "w") as sink:
            textproto.write_sentences(message, sink)

        del message.sentences[:]
        message.sentences.extend(dev_sentences)
        with open(path_per_repeat / "dev.textproto", "w") as sink:
            textproto.write_sentences(message, sink)
        # Test.
        del message.sentences[:]
        message.sentences.extend(test_sentences)
        with open(path_per_repeat / "test.textproto", "w") as sink:
            textproto.write_sentences(message, sink)


if __name__ == "__main__":
    logging.basicConfig(level="INFO", format="%(levelname)s: %(message)s")
    parser = argparse.ArgumentParser(description=__doc__)
    parser.add_argument("--seed", required=True, type=int, help="random seed")
    parser.add_argument(
        "--input_textproto_path", required=True, help="input text-format PB"
    )
    parser.add_argument(
        "--output_path",
        required=True,
        help="output directory to create directories per split",
    )
    main(parser.parse_args())
