import argparse
from collections import defaultdict
import json
import random


parser = argparse.ArgumentParser()
parser.add_argument('input', type=str)
parser.add_argument('output', type=str)
parser.add_argument('-k', type=int, default=4, help='Number of examples / class.')
args = parser.parse_args()


by_class = defaultdict(list)
with open(args.input, 'r') as f:
    for line in f:
        instance = json.loads(line)
        by_class[instance['label']].append(instance)

dataset = []
for v in by_class.values():
    random.shuffle(v)  # In-place
    dataset.extend(v[:args.k])
random.shuffle(dataset)

with open(args.output, 'w') as g:
    for instance in dataset:
        print(json.dumps(instance), file=g)
