from typing import List, Tuple
import pyarrow
from tqdm import tqdm
from numba import njit
import argparse
import re

_RE_COMBINE_WHITESPACE = re.compile(r"\s+")

parser = argparse.ArgumentParser()
parser.add_argument("--dataset_path", required=True, type=str, help="Path to input dataset")
parser.add_argument("--output_path", required=True, type=str, help="Path to output file")
args = parser.parse_args()

def parse_sample(text) -> Tuple[List[str], List[str], List[bool]]:
    tokens = []
    labels = []
    upper_case = []
    
    is_token = True
    
    token_upper = text[0].isupper()
    actual_token = ""
    actual_interpunction = ""

    text = _RE_COMBINE_WHITESPACE.sub(" ", text)
    
    for char in text:
        if char.isalnum():
            if not is_token:
                labels.append(actual_interpunction)
                actual_interpunction = ""
                
                if char.isupper():
                    token_upper = True
                else:
                    token_upper = False
            
            is_token = True
            actual_token += char.lower()
        else:
            if is_token:
                tokens.append(actual_token)
                upper_case.append(token_upper)
                actual_token = ""
            
            is_token = False
            actual_interpunction += char
            
    if is_token:
        tokens.append(actual_token)
        upper_case.append(token_upper)
        labels.append("<eos>")
    else:
        labels.append(actual_interpunction)
            
    return tokens, labels, upper_case

with open(args.dataset_path, 'r') as f:
    lines = f.readlines()

lines_filtered = list(filter(lambda x: len(x.split(" ")) > 4, lines))

dataset = {
    'tokens': [],
    'punctuation': [],
    'uppercase': []
}

for line in tqdm(lines_filtered):
    tokens, punctuation, uppercase = parse_sample(line)

    dataset['tokens'].append(tokens)
    dataset['punctuation'].append(punctuation)
    dataset['uppercase'].append(uppercase)

table = pyarrow.Table.from_pydict(dataset)
with pyarrow.RecordBatchFileWriter(args.output_path, table.schema) as writer:
    writer.write_table(table)

