import pandas as pd
from pathlib import Path
from sqlalchemy import create_engine, MetaData, Table, Index
from sqlalchemy.orm import sessionmaker
from tqdm import tqdm
import json
import argparse
from sqlalchemy.types import String, Integer

parser = argparse.ArgumentParser()
parser.add_argument("--json_dir", type=str, required=True)
parser.add_argument("--db_name", type=str, required=True)
parser.add_argument("--subreddit_list", type=str, default=None)
parser.add_argument("--json_suffix", type=str, default='')
parser.add_argument("--split_db_by_subreddit", type=int, default=0) # 0 if False, 1 if True
parser.add_argument("--equal_sampling", type=int, default=0) # 0 if False, 1 if True
parser.add_argument("--equal_sample_col", type=str, default='c_polarity') 
args = parser.parse_args()

json_dir = Path(args.json_dir)
db_path = args.db_name
json_list = args.subreddit_list # 'subreddit_list.json'
json_suffix = args.json_suffix

def read_df(json_dir, year, month, split, subreddits, json_suffix):
    if subreddits:
        subreddits = set([e.lower() for e in subreddits])
    json_dir = Path(json_dir)
    file = json_dir / f'{year}_{str(month).zfill(2)}{json_suffix}_{split}.json'

    df = pd.read_json(file)
    df = df[df['c_polarity'] != 2]
    if subreddits:
        df = df[df['p_subreddit'].apply(lambda e: e.lower() in subreddits)]
    df = df.reset_index(drop=True)
    return df

def sample_by_group(df, sample_size, group_by):
    group_counts = df[group_by].value_counts().to_dict()
    if type(sample_size) == str and sample_size == 'max':
        sample_size = max(group_counts.values())
    res = None
    for g, subdf in df.groupby(group_by):
        samp = subdf.sample(n=sample_size, replace=True if group_counts[g] < sample_size else False)
        res = samp if res is None else pd.concat((res, samp), axis=0)
    return res

def df_to_sql(engine, json_dir, year, month, split, subreddits, json_suffix):
    df = read_df(json_dir, year, month, split, subreddits, json_suffix)
    if 'id' in df.columns:
        df = df.drop(columns=['id']).reset_index(drop=True)
        df['id'] = df.index + counts[split]
    df.to_sql(split, con=engine, if_exists='append', index=False, 
            dtype=dtypes)
    counts[split] += df.shape[0]

def read_multiple_jsons(split, subreddits, json_suffix, years=[2018, 2019, 2020, 2021, 2022], months=range(1, 13)):
    res = None
    for year in years:
        for month in tqdm(range(1, 13)):
            if year == 2023 and month >= 7:
                break
            df = read_df(json_dir, year, month, split, subreddits, json_suffix)
            res = df if res is None else pd.concat((res, df), axis=0)
    res = res.reset_index(drop=True)
    if 'id' in res.columns:
        res = res.drop(columns=['id'])
        res['id'] = res.index
    return res

def create_id_index(table_name, engine):
    metadata = MetaData(bind=engine)
    table = Table(table_name, metadata, autoload=True)
    id_idx = Index(f'{table_name}_idx', table.c['id'])
    id_idx.create(bind=engine)

dtypes = {'p_title':String(), 'p_created_utc':Integer(), 'p_subreddit':String(), 'p_is_self':String(), 'p_score':Integer(), 
         'p_author':String(), 'p_author_flair_text':String(), 'p_selftext':String(), 'link_id':String(), 
     'c_created_utc':Integer(), 'c_score':Integer(), 'c_author':String(), 'c_body':String(), 'c_polarity':Integer(),
    'year':Integer(), 'month':Integer(), 'id':Integer()}

counts = {'train':0, 'val':0, 'test':0}
engine = create_engine(f'sqlite:///{db_path}', echo=False)

if json_list:
    with open(json_list) as f:
        subreddits = [e.lower() for e in json.load(f)]
else:
    subreddits = None

if args.split_db_by_subreddit == 0:
    for year in [2018, 2019, 2020, 2021, 2022, 2023]:
        for month in tqdm(range(1, 13)):
            for split in ['train', 'val', 'test']:
                df_to_sql(engine, json_dir, year, month, split, subreddits, json_suffix)
                print(counts)
    create_id_index('train', engine)
    create_id_index('val', engine)
    create_id_index('test', engine)
    print(counts)
else:
    df_train = read_multiple_jsons('train', subreddits, json_suffix)
    df_val = read_multiple_jsons('val', subreddits, json_suffix)
    df_test = read_multiple_jsons('test', subreddits, json_suffix)
    for sub in tqdm(df_train['p_subreddit'].unique()):
        subdf_train = df_train[df_train['p_subreddit'] == sub].reset_index(drop=True)
        subdf_val = df_val[df_val['p_subreddit'] == sub].reset_index(drop=True)
        subdf_test = df_test[df_test['p_subreddit'] == sub].reset_index(drop=True)
        
        if args.equal_sampling:
            subdf_train = sample_by_group(subdf_train, 'max', args.equal_sample_col).reset_index(drop=True)
        
        subdf_train['id'] = subdf_train.index
        subdf_val['id'] = subdf_val.index
        subdf_test['id'] = subdf_test.index
        
        engine = create_engine(f'sqlite:///{db_path[:-3]}_{sub}.db', echo=False)
        subdf_train.to_sql('train', con=engine, if_exists='append', index=False, 
            dtype=dtypes)
        create_id_index('train', engine)
        subdf_val.to_sql('val', con=engine, if_exists='append', index=False, 
            dtype=dtypes)
        create_id_index('val', engine)
        subdf_test.to_sql('test', con=engine, if_exists='append', index=False, 
            dtype=dtypes)
        create_id_index('test', engine)
        
        
print('Done writing all data.')
