import pyarrow
from sklearn.model_selection import train_test_split
import argparse

RANDOM_STATE = 42

parser = argparse.ArgumentParser()
parser.add_argument("--dataset_path", type=str, required=True, help="Path input dataset file")
parser.add_argument("--output_train_path", type=str, required=True, help="Path to output train file")
parser.add_argument("--output_dev_path", type=str, required=True, help="Path to output dev file")
parser.add_argument("--output_test_path", type=str, required=True, help="Path to output test file")
args = parser.parse_args()

dataset = pyarrow.RecordBatchFileReader(args.dataset_path).read_all().to_pandas()
test_frac = 0.1
valid_frac = 0.1

train, test_dev = train_test_split(dataset, test_size=(test_frac + valid_frac), random_state=RANDOM_STATE)
test, dev = train_test_split(test_dev, test_size=(test_frac/(test_frac+valid_frac)), random_state=RANDOM_STATE)

table_train = pyarrow.Table.from_pandas(train)
table_dev = pyarrow.Table.from_pandas(dev)
table_test = pyarrow.Table.from_pandas(test)

with pyarrow.RecordBatchFileWriter(args.output_train_path, table_train.schema) as writer:
    writer.write_table(table_train)

with pyarrow.RecordBatchFileWriter(args.output_dev_path, table_dev.schema) as writer:
    writer.write_table(table_dev)

with pyarrow.RecordBatchFileWriter(args.output_test_path, table_test.schema) as writer:
    writer.write_table(table_test)

