import pandas as pd
from pathlib import Path
from sqlalchemy import create_engine, MetaData, Table, Index
from sqlalchemy.types import String, Integer
from tqdm import tqdm
import json
import argparse

def form_sqlite_str(path):
    return f'sqlite:///{path}'

def sample_test_from_db(db_path, n):
    engine = create_engine(form_sqlite_str(db_path))
    test = pd.read_sql_table('test', engine)
    test_pos = test[test.c_polarity == 1]
    test_neg = test[test.c_polarity == 0]
    if test_pos.shape[0] <= n or test_neg.shape[0] <= n:
        n = min(test_pos.shape[0], test_neg.shape[0])
    test_pos = test_pos.sample(n=n).reset_index(drop=True) 
    test_neg = test_neg.sample(n=n).reset_index(drop=True)
    print(db_path, test.shape, test_pos.shape, test_neg.shape)
    return test_pos, test_neg

def create_id_index(table_name, engine):
    metadata = MetaData(bind=engine)
    table = Table(table_name, metadata, autoload=True)
    id_idx = Index(f'{table_name}_idx', table.c['id'])
    id_idx.create(bind=engine)
    
def create_combined_test_df(db_paths, per_class):
    pos, neg = None, None
    for db_path in tqdm(db_paths):
        test_pos, test_neg = sample_test_from_db(db_path, n=per_class)
        pos = test_pos if pos is None else pd.concat((pos, test_pos), axis=0)
        neg = test_neg if neg is None else pd.concat((neg, test_neg), axis=0)
    test = pd.concat((pos, neg), axis=0).reset_index(drop=True)
    test['id'] = test.index
    return test
    
dtypes = {'p_title':String(), 'p_created_utc':Integer(), 'p_subreddit':String(), 'p_is_self':String(), 'p_score':Integer(), 
         'p_author':String(), 'p_author_flair_text':String(), 'p_selftext':String(), 'link_id':String(), 
     'c_created_utc':Integer(), 'c_score':Integer(), 'c_author':String(), 'c_body':String(), 'c_polarity':Integer(),
    'year':Integer(), 'month':Integer(), 'id':Integer()}

parser = argparse.ArgumentParser()
parser.add_argument("--subs", type=str, required=True)
parser.add_argument("--data_dir", type=str, required=True)
parser.add_argument("--prefix", type=str, default="")
parser.add_argument("--suffix", type=str, default="")
parser.add_argument("--sample_size", type=int, default=3000)
parser.add_argument("--out_dir", type=str, required=True)
parser.add_argument("--db_name", type=str, default='test_data.db')
parser.add_argument("--csv_name", type=str, default='test_data.csv')
parser.add_argument("--table_name", type=str, default='test')
args = parser.parse_args()

with open(args.subs) as f:
    subs = json.load(f)
    
data_dir, out_dir, sample_size = Path(args.data_dir), Path(args.out_dir), args.sample_size
db_name, csv_name, table_name = args.db_name, args.csv_name, args.table_name

db_paths = [(data_dir/f'{args.prefix}{f}{args.suffix}.db') for f in subs]
per_class = sample_size // 2

test = create_combined_test_df(db_paths, per_class)
engine = create_engine(form_sqlite_str(str(out_dir / db_name)))
test.to_sql(table_name, engine, dtype=dtypes)
create_id_index(table_name, engine)
test.to_csv(out_dir / csv_name)