import pandas as pd
from tqdm import tqdm
import argparse
import json
from functools import partial
from collections import defaultdict

def get_sub2domain_map(path, filters=[]):
    sub2domain = defaultdict(lambda: defaultdict(int))
    with open(path, encoding="utf8") as f:
        for ln in tqdm(f):
            if not ln:
                continue
            if ln.strip() == '':
                continue
            try:
                di = json.loads(ln.strip())
                if 'domain' not in di:
                    print('No domain in this submission. Skipping a line.')
                    continue
                if sum([not filt(di) for filt in filters]) > 0:
                    del di
                    continue
                sub2domain[di['subreddit']][di['domain']] += 1
            except:
                print('Encountered an error. Skipping a line.')
                continue
    return sub2domain


def score_filter(di, high=5):
    return (di['score'] >= high)

def sticky_filter(di):
    return ~di['stickied']

def filter_subreddit(di, subreddits):
    return di.get('subreddit', '').lower() in subreddits

def main():
    # Parse command-line arguments
    parser = argparse.ArgumentParser(description='Load a pushshift submissions file and form a subreddit to domain map')
    parser.add_argument('--input-file', required=True, help='Path to the input file where every line is a JSON object')
    parser.add_argument('--subreddit-file', required=True, help='Path to the JSON file containing the list of subreddits to retain')
    parser.add_argument('--output-file', required=True, help='Path to the output JSON file')
    args = parser.parse_args()

    subreddit_file = args.subreddit_file
    with open(subreddit_file, 'r') as f:
        subreddit_list = set([e.lower() for e in json.load(f)])

    input_file = args.input_file
    sub2domain = get_sub2domain_map(input_file, filters=[sticky_filter, partial(filter_subreddit, subreddits=subreddit_list), score_filter])

    output_file = args.output_file
    with open(output_file, 'w') as f:
        json.dump(sub2domain, f)

if __name__ == '__main__':
    main()

# Example usage: python get_post_urls.py --input-file /mnt/d/reddit/submissions/RS_2018-01 --subreddit-file bots_subs.json --output-file /mnt/e/reddit/sub2domain_maps/2018-01.json