import pandas as pd
from tqdm import tqdm
import argparse
import json
from pathlib import Path

def main():
    # Parse command-line arguments
    parser = argparse.ArgumentParser(description='Load a json file with submission details and create a separate json file for each subreddit')
    parser.add_argument('--input-file', required=True, help='Path to the input submissions json')
    parser.add_argument('--subreddit-file', required=True, help='Path to the JSON file containing the list of subreddits to consider')
    parser.add_argument('--output-dir', required=True, help='Path to the output directory')
    args = parser.parse_args()

    input_file = args.input_file
    subreddit_file = args.subreddit_file
    output_dir = Path(args.output_dir)
    output_dir.mkdir(parents=True, exist_ok=True)

    # Read the list of subreddits
    with open(subreddit_file, 'r') as f:
        subreddit_list = {e.lower() for e in json.load(f)}
    
    # Read the input file
    df = pd.read_json(input_file)
    df['subreddit'] = df['subreddit'].str.lower()

    # Filter the dataframe by subreddit and save the filtered dataframe as a JSON to the output directory
    for subreddit in tqdm(subreddit_list):
        sub_df = df[df['subreddit'] == subreddit]
        sub_df.to_json(output_dir / f'{subreddit}.json')

if __name__ == '__main__':
    main()

# Example usage: python data/gen_submissions_per_subreddit.py --input-file /mnt/e/reddit/emb_psr/2018-01-out.json --subreddit-file data/emb_psr_subs.json --output-dir /mnt/e/reddit/emb_psr/2018-01