import json
import pandas as pd
from functools import partial
from tqdm import tqdm
from pathlib import Path
import argparse
import pdb
import sqlalchemy
from sqlalchemy import create_engine
from sqlalchemy.types import String, Integer
from sklearn.model_selection import train_test_split

ssss = set()
COUNT_FILE = 'count.json'

dtypes = {'p_title':String(), 'p_created_utc':Integer(), 'p_subreddit':String(), 'p_is_self':String(), 'p_score':Integer(), 
         'p_author':String(), 'p_author_flair_text':String(), 'p_selftext':String(), 'link_id':String(), 
     'c_created_utc':Integer(), 'c_score':Integer(), 'c_author':String(), 'c_body':String(), 'c_polarity':Integer(),
    'year':Integer(), 'month':Integer(), 'id':Integer()}

def read_jsonlist(path, filters=[]):
    with open(path, encoding="utf8") as f:
        for ln in tqdm(f):
            if not ln:
                continue
            if ln.strip() == '':
                continue
            try:
                di = json.loads(ln.strip())
                if sum([not filt(di) for filt in filters]) > 0:
                    del di
                    continue
                yield di
            except:
                print('Encountered an error. Skipping a line.')
                continue

def score_filter(di, low=-1, high=5):
    return (di['score'] >= high) or (di['score'] < low)

def sticky_filter(di):
    return ~di['stickied']

def time_filter(post, comment, thresh=60*60*10):
    return (comment['created_utc'] - post['created_utc']) <= thresh

def filter_subreddit(di, subreddits):
    return di.get('subreddit', '').lower() in subreddits

def filter_posts(di, posts):
    return di['link_id'] in posts

def filter_top_level_comments(di):
    return di['parent_id'].startswith('t3_')

def read_submissions_and_comments(rs_loc, rc_loc, subs, rs_cols, rc_cols):
    print('Reading submissions')
    rs = pd.DataFrame(list(read_jsonlist(rs_loc, filters=[score_filter, sticky_filter, partial(filter_subreddit, subreddits=subs)])))
    post_ids = set(rs['id'].apply(lambda e: 't3_' + e).unique())
    rs = rs[list(set(rs.columns).intersection(set(rs_cols)))]
    print('Reading comments')
    rc = pd.DataFrame(list(read_jsonlist(rc_loc, filters=[partial(filter_posts, posts=post_ids), filter_top_level_comments, score_filter])))
    rc = rc[list(set(rc.columns).intersection(set(rc_cols)))]
    print('Merging submissions and comments')
    
    rs.columns = [('link_id' if e == 'id' else f'p_{e}') for e in rs.columns]
    rc.columns = [('link_id' if e == 'link_id' else f'c_{e}') for e in rc.columns]
    rs['link_id'] = 't3_' + rs['link_id']
    merged = rs.merge(rc, on='link_id')
    return rs, rc, merged

class ScoreConverter:
    @classmethod
    def to_polarity(self, score, low=-2, high=25):
        if score <= low:
            return 0
        elif score >= high:
            return 1
        else:
            return 2

def get_num_rows(count_file):
    count_file = Path(count_file)
    if not count_file.exists():
        return 0, 0, 0
    f = open(count_file, 'r')
    di = json.load(f)
    train, val, test = di['train'], di['val'], di['test']
    f.close()
    return train, val, test

def remove_null_comments(df):
    return df[~df['c_body'].apply(lambda e: ('[deleted]' == e) or ('[removed]' == e))].reset_index(drop=True)

def add_extra_cols(df, year, month, count):
    df['year'] = year
    df['month'] = month
    df['id'] = df.index + count
    return df

class ExtractAndMerge():
    def __init__(self, subs, rs_cols, rc_cols, engine, dtypes):
        self.subs = set(subs)
        self.rs_cols, self.rc_cols = rs_cols, rc_cols
        self.engine = engine
        self.dtypes = dtypes

    def extract_and_merge(self, rs_file, rc_file, out_file, year, month):
        rs, rc, merged = read_submissions_and_comments(rs_file, rc_file, self.subs, self.rs_cols, self.rc_cols)
        merged['c_polarity'] = merged.c_score.apply(ScoreConverter.to_polarity)
        merged = remove_null_comments(merged)
        train, rest = train_test_split(merged, test_size=0.2, random_state=701)
        val, test = train_test_split(rest, test_size=0.5, random_state=701)
        train, val, test = train.reset_index(drop=True), val.reset_index(drop=True), test.reset_index(drop=True)
        train_count, val_count, test_count = get_num_rows(COUNT_FILE)
        train = add_extra_cols(train, year, month, train_count)
        val = add_extra_cols(val, year, month, val_count)
        test = add_extra_cols(test, year, month, test_count)
        train.to_json(str(out_file) + '_train.json')
        val.to_json(str(out_file) + '_val.json')
        test.to_json(str(out_file) + '_test.json')
        print(f'JSON has been saved at location : {out_file}')


parser = argparse.ArgumentParser()
parser.add_argument("--year", type=int, required=True)
parser.add_argument("--month", type=int, required=True)
parser.add_argument("--data_dir", type=str, required=True)
parser.add_argument("--out_file_suffix", type=str, default='')
parser.add_argument("--subreddits_json", type=str, default='subreddit_list.json')
args = parser.parse_args()
print(args)
with open(args.subreddits_json) as f:    
    subs = [e.lower() for e in json.load(f)]
rs_cols = ['author', 'author_flair_text', 'subreddit', 'title', 'selftext', 'score', 'is_self', 'id', 'created_utc']
rc_cols = ['author', 'body', 'score', 'link_id', 'created_utc']
print(f'Filtering {len(subs)} number of subreddits.')
data_dir = Path(args.data_dir)
year = args.year
i = args.month
em = ExtractAndMerge(subs, rs_cols, rc_cols, None, dtypes)
i = str(i).zfill(2)
print(i)
(data_dir / 'out').mkdir(exist_ok=True)
em.extract_and_merge(data_dir / 'submissions' / f'RS_{year}-{i}',
    data_dir / 'comments' / f'RC_{year}-{i}',
    data_dir / 'out' / f'{year}_{i}{args.out_file_suffix}', 
    year, args.month)
