import json
from tqdm import tqdm
from collections import defaultdict
import pandas as pd
from pathlib import Path
import argparse

def form_sqlite_str(path):
    return f'sqlite:///{path}'

def read_json(fpath):
    with open(fpath) as f:
        return json.load(f)

def add_user_counts(uc1, uc2):
    res = {}
    users = set(uc1.keys()).union( set(uc2.keys()) )
    for u in tqdm(users):
        u1_counts, u2_counts = uc1.get(u, {}), uc2.get(u, {})
        subs = set(u1_counts.keys()).union( set(u2_counts.keys()) )
        res[u] = {s: u1_counts.get(s, 0) + u2_counts.get(s, 0) for s in subs}
    return res

def get_user_counts_for_list(user_count_list):
    res = read_json(user_count_list[0])
    for f in user_count_list[1:]:
        res = add_user_counts(res, read_json(f))
    return res

def cvt_uc_counts_to_sub2user_list(uc_counts):
    sub2user = defaultdict(set)
    for u, sub_counts in tqdm(uc_counts.items()):
        for sub in sub_counts:
            sub2user[sub].add(u)
    return sub2user

def calculate_coocc_mat(sub2user):
    coocc_mat = defaultdict(lambda : defaultdict(int))
    for s1 in sub2user.keys():
        for s2 in sub2user.keys():
            coocc_mat[s1][s2] = len(sub2user[s1].intersection(sub2user[s2]))
    coocc_mat = pd.DataFrame(coocc_mat)
    return coocc_mat


parser = argparse.ArgumentParser()
parser.add_argument("--in_dir", type=str, required=True, help="Directory containing the json files mapping users to subreddit counts.")
parser.add_argument("--out_dir", type=str, required=True, help="Directory to save the cooccurrence matrix in.")
parser.add_argument("--min_comments", type=int, default=10, 
    help="Users must have at least these many comments in a subreddit to be included in the user subreddit cooccurrence matrix.")
parser.add_argument("--min_subreddits", type=int, default=2,
    help="Users must have 'min_comments' number of comments in 'min_subreddits' number of subreddits to be included in the user subreddit cooccurrence matrix.")
parser.add_argument("--ignore_n_top_subs", type=int, default=200,
    help="Ignore these many subreddits with the most number of commenters (after applying the above filtering)")
parser.add_argument("--n_top_subs", type=int, default=2000,
    help="Number of subreddits with most commenters (ignoring n_top_subs) to include in the cooccurrence matrix.")
parser.add_argument("--required_subs", type=str, default='./required_subs.json',
    help="Subreddits in this json (list) will always be included in the cooccurrence matrix.")
args = parser.parse_args()


in_dir = Path(args.in_dir)
out_dir = Path(args.out_dir)
out_dir.mkdir(exist_ok=True)
min_comments = args.min_comments
min_subreddits = args.min_subreddits
ignore_n_top_subs = args.ignore_n_top_subs
n_top_subs = args.n_top_subs
with open(args.required_subs) as f:
    subs = {e.lower() for e in json.load(f)}


user_count_list = list(in_dir.glob('*.json'))
print('The following json files are being used as the input files:')
print(user_count_list)
uc = get_user_counts_for_list(user_count_list)
print('Length of the final user count dict: ' + str(len(uc)))
with open(out_dir / 'raw_user_sub_counts.json', 'w') as f:
    json.dump(uc, f)
ucf = {u: tmp for u, e in uc.items() for tmp in [{s: v for s, v in e.items() if v >= min_comments}] if len(tmp) >= min_subreddits}
print('After filtering users using "min_comments" and "min_subreddits", we are left with these many users: ' + str(len(ucf)))
sub2user = cvt_uc_counts_to_sub2user_list(ucf)
print('Total number of subreddits :' + str(len(sub2user)))
sub_user_counts = sorted([(k, len(v)) for k,v in sub2user.items()], key=lambda e:e[1], reverse=True)
chosen_subs = [e[0] for e in sub_user_counts[ignore_n_top_subs:ignore_n_top_subs+n_top_subs]]
len(chosen_subs)
sub2user = {s:sub2user[s] for s in subs}


calculate_coocc_mat({s:sub2user[s] for s in subs}).to_json(out_dir / 'comat_required.json')
calculate_coocc_mat(sub2user).to_json(out_dir / 'comat_full.json')