import argparse
import json
from pathlib import Path
import pandas as pd
from tqdm import tqdm
from sqlalchemy import create_engine, MetaData, Table, Index
from sqlalchemy.types import String, Integer


cols = ['author', 'author_flair_text', 'subreddit', 'title', 'selftext', 'score', 'is_self', 'id', 'created_utc']

dtypes = {'author':String(), 'author_flair_text':String(), 'subreddit':String(), 'title':String(), 'selftext':String(),
          'score':Integer(), 'is_self':String(), 'link_id':String(), 'created_utc':Integer(), 'id':Integer()}

def form_sqlite_str(path):
    return f'sqlite:///{path}'

def create_id_index(table_name, engine):
    metadata = MetaData(bind=engine)
    table = Table(table_name, metadata, autoload=True)
    id_idx = Index(f'{table_name}_idx', table.c['id'])
    id_idx.create(bind=engine)

def main():
    parser = argparse.ArgumentParser(description='Combines the subreddit json data from multiple directories (multiple months) into a single json file for each subreddit.')
    parser.add_argument('--input-dirs-json', required=True, help='Path to the json list containing the list of input directories.')
    parser.add_argument('--subreddit-file', required=True, help='Path to the JSON file containing the list of subreddits to consider.')
    parser.add_argument('--output-dir', required=True, help='Path to the output directory.')
    parser.add_argument('--table-name', type=str, default='test', help='Name of the sqlite table.')
    args = parser.parse_args()

    table_name = args.table_name
    subreddit_file = args.subreddit_file
    output_dir = Path(args.output_dir)
    output_dir.mkdir(parents=True, exist_ok=True)
    json_out_dir = output_dir / 'json'
    json_out_dir.mkdir(parents=True, exist_ok=True)
    db_out_dir = output_dir / 'db'
    db_out_dir.mkdir(parents=True, exist_ok=True)
    
    # Read the list of subreddits
    with open(subreddit_file, 'r') as f:
        subreddit_list = {e.lower() for e in json.load(f)}
    
    # Read the list of input directories
    with open(args.input_dirs_json, 'r') as f:
        input_dirs = [Path(e) for e in json.load(f)]

    # Combine the subreddit json data from multiple directories into a single json file for each subreddit
    for subreddit in tqdm(subreddit_list):
        combined_df = None
        for input_dir in input_dirs:
            input_file = input_dir / f'{subreddit}.json'
            if not input_file.exists():
                print(f'File {input_file} does not exist. Skipping...')
                continue
            df = pd.read_json(input_file)
            if combined_df is None:
                combined_df = df
            else:
                combined_df = pd.concat([combined_df, df])
        if combined_df is None:
            combined_df = pd.DataFrame(columns=cols)
        combined_df = combined_df[cols].reset_index(drop=True)
        combined_df['link_id'] = combined_df['id']
        combined_df['id'] = combined_df.index
        combined_df.to_json(json_out_dir / f'{subreddit}.json')
        # save combined_df as a sqlite table called "test" in db_out_dir / f'{subreddit}.db'
        engine = create_engine(form_sqlite_str(db_out_dir / f'{subreddit}.db'))
        combined_df.to_sql(table_name, engine, dtype=dtypes)
        create_id_index(table_name, engine)


if __name__ == '__main__':
    main()

# Example usage: python data/combine_submissions_months.py --input-dirs-json data/submission_month_dirs.json --subreddit-file data/emb_psr_subs.json --output-dir /mnt/e/reddit/emb_psr/combined