import argparse
import json
import pyarrow
import numpy as np
from collections import Counter

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--dataset_path", type=str, required=True, help="Path do dataset")
    parser.add_argument("--top_fraction", type=float, required=True, help="What fraction of punctuation marks to keep")
    parser.add_argument("--output_path", type=str, required=True, help="Path do output json file")
    args = parser.parse_args()

    marks = pyarrow.RecordBatchFileReader("data/europarl_en.bin").read_pandas()['punctuation']
    marks = np.concatenate(marks.values)

    total_marks = len(marks)
    
    most_common_marks = Counter(marks).most_common()

    dominating_class_support = most_common_marks[0][1]

    classes = [most_common_marks[0][0]]
    total_counted = 0
    for mark, count in most_common_marks[1:]:
        total_counted += count
        classes.append(mark)

        if total_counted / (total_marks - dominating_class_support) >= args.top_fraction:
            break

    print(f"Got total of {len(classes)} classes")

    with open(args.output_path, 'w') as f:
        json.dump(classes, f)
