import pyarrow
from tqdm import trange
import argparse

parser = argparse.ArgumentParser()
parser.add_argument("--dataset_path", type=str, required=True, help="Path to input dataset in pyarrow format")
parser.add_argument("--output_path", type=str, required=True, help="Path to output dataset in model's format")
args = parser.parse_args()

dataset = pyarrow.RecordBatchFileReader(args.dataset_path).read_all().to_pydict()

with open(args.output_path, 'w') as f:
    for i in trange(len(dataset['tokens'])):
        text = []
        for token, label_raw in zip(dataset['tokens'][i], dataset['punctuation'][i]):
            if len(token) == 0:
                continue

            text.append(token)
            text.append("_PUNC_" + label_raw.replace(
                " ", "SPACE"
            ))

        text = " ".join(text)
        f.write(f"{text}\n")