import os
import argparse
import random

def readTxt(inputfile):
    with open(inputfile, 'r') as fin:
        datas = [line.strip() for line in fin.readlines()]
        print("read {} lines from file {}".format(len(datas), inputfile))
    return datas

def saveTxt(datas: list, outputfile: str):
    with open(outputfile, 'w') as fout:
        for line in datas:
            fout.write(line + '\n')
    print("save {} lines to file {}".format(len(datas), outputfile))

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--same-split", help="apply the same split for all datasets", action="store_true")
    parser.add_argument("--input-dir", type=str)
    parser.add_argument("--dev-size", type=int)

    args = parser.parse_args()

    datas_dict = dict()
    lengths = []
    for filename in os.listdir(args.input_dir):
        datas_dict[filename] = readTxt(
            os.path.join(args.input_dir, filename)
        )
        lengths.append(len(datas_dict[filename]))
    
    if args.same_split:
        all_ids = list(range(lengths[0]))
        random.shuffle(all_ids)
        train_ids = all_ids[:-args.dev_size]
        dev_ids = all_ids[-args.dev_size:]

        train_dir = os.path.join(args.input_dir, "train")
        dev_dir = os.path.join(args.input_dir, "dev")
        os.makedirs(train_dir, exist_ok=True)
        os.makedirs(dev_dir, exist_ok=True)

        for key in datas_dict:
            train_set, dev_set = [], []
            for (i, item) in enumerate(datas_dict[key]):
                if i in dev_ids:
                    dev_set.append(datas_dict[key][i])
                else:
                    train_set.append(datas_dict[key][i])
            saveTxt(train_set, os.path.join(train_dir, key))
            saveTxt(dev_set, os.path.join(dev_dir, key))
