import json
import pandas as pd
import argparse
from collections import defaultdict

parser = argparse.ArgumentParser()
parser.add_argument("--cat2sub_json", type=str, required=True)
parser.add_argument("--similarity_json", type=str, required=True)
parser.add_argument("--top_n", type=int, default=None)
args = parser.parse_args()

default_top_n = args.top_n

with open(args.cat2sub_json) as f:
    cat2sub = json.load(f)
sub2cat = {sub:cat for cat, subs in cat2sub.items() for sub in subs}

with open(args.similarity_json) as f:
    sims = {sub: sorted(s.items(), key=lambda e: e[1], reverse=True) for sub, s in json.load(f).items()}

def calc_hits_at_n(sims, sub2cat, default_top_n):
    res = []
    res_per_cat = defaultdict(list)
    for sub, closest in sims.items():
        original_cat = sub2cat[sub]
        top_n = len(cat2sub[original_cat]) - 1 if default_top_n is None else default_top_n
        # remove "sub" from the closest list
        closest = [e for e in closest if e[0] != sub]
        closest_cats = [sub2cat[e[0]] == original_cat for e in closest[:top_n]]
        res.append(sum(closest_cats) / len(closest_cats))
        res_per_cat[original_cat].append(sum(closest_cats) / len(closest_cats))

    mean_hits = pd.Series(res).mean()

    print(f"Mean Hits@n: {mean_hits}")

    # print results per category
    for cat, hits in res_per_cat.items():
        print(f"Category: {cat}, Mean Hits@n: {pd.Series(hits).mean()}")
    return mean_hits

calc_hits_at_n(sims, sub2cat, default_top_n)

# res = []
# res_per_cat = defaultdict(list)
# for sub, closest in sims.items():
#     original_cat = sub2cat[sub]
#     # remove "sub" from the closest list
#     closest = [e for e in closest if e[0] != sub]
#     closest_cats = [sub2cat[e[0]] == original_cat for e in closest[:top_n]]
#     res.append(sum(closest_cats) / len(closest_cats))
#     res_per_cat[original_cat].append(sum(closest_cats) / len(closest_cats))

# mean_hits = pd.Series(res).mean()

# print(f"Mean hits at {top_n}: {mean_hits}")

# # print results per category
# for cat, hits in res_per_cat.items():
#     print(f"Category: {cat}, Mean hits at {top_n}: {pd.Series(hits).mean()}")

# Example usage: python similarity/calc_hits_at_n.py --cat2sub_json '/mnt/f/Github/C3-contextual-community-comparison/cat2sub.json' --similarity_json '/mnt/f/Github/influence_risk_analysis/sims.json' --top_n 6