import pyarrow
from tqdm import tqdm
import argparse

parser = argparse.ArgumentParser()
parser.add_argument("--dataset_path", type=str, required=True, help="Path to input dataset in pyarrow format")
parser.add_argument("--output_path", type=str, required=True, help="Path to output dataset in model's format")
args = parser.parse_args()

dataset = pyarrow.RecordBatchFileReader(args.dataset_path).read_all().to_pydict()

def generator():
    for i in range(len(dataset['tokens'])):
        for token, label_raw in zip(dataset['tokens'][i], dataset['punctuation'][i]):
            yield token, label_raw

with open(args.output_path, 'w') as f:
    for token, label_raw in tqdm(generator(), total=len(dataset['tokens'])):
        #label_raw = label_raw.replace(" ", "")
        f.write(f"{token}\t{label_raw}\n")