from pathlib import Path
import pandas as pd
import argparse
import json

parser = argparse.ArgumentParser()
parser.add_argument("--input_json", type=str, required=True)
parser.add_argument("--sim_out_path", type=str, required=True)
parser.add_argument("--group1_json", type=str, required=True)
parser.add_argument("--group2_json", type=str, required=True)
args = parser.parse_args()


input_json = Path(args.input_json)
sim_out_path = Path(args.sim_out_path)

df = pd.read_json(input_json)
labels = df.groupby('p_subreddit')['sent_label'].value_counts(normalize=True).reset_index()
labels = labels[labels['sent_label'] == 'POS']
pct_pos_per_sub = labels.set_index('p_subreddit')['proportion'].to_dict()
with open(args.group1_json) as f:
    group1 = json.load(f)
with open(args.group2_json) as f:
    group2 = json.load(f)
sub_dists = {s1: sorted([ {"subreddit":s2, "distance": abs(pct_pos_per_sub[s1] - pct_pos_per_sub[s2])} for s2 in pct_pos_per_sub], key=lambda e: e['distance']) for s1 in pct_pos_per_sub}
similarity = {s:len(set([e['subreddit'] for e in sub_dists[s][1:5]]).intersection(group1 if s in group1 else group2)) / 4 for s in sub_dists}
print(similarity)
with open(sim_out_path, 'w') as f:
    json.dump(similarity, f)