from sentence_transformers import SentenceTransformer, util
from pathlib import Path
import pandas as pd
from sqlalchemy import create_engine
from collections import defaultdict
from tqdm import tqdm
import random
import argparse
import json

def form_sqlite_str(path):
    return f'sqlite:///{path}'

def subreddit_title_sample(df, sample_size):
    # random shuffle
    title = list(df.p_title.unique())
    random.shuffle(title)
    res = title[:sample_size]
    return res

def calc_all_subreddit_similarities(compare_dbs, compare_against_dbs, out_dir, sample_size, bs=64):
    out_dir = Path(out_dir)
    out_dir.mkdir(exist_ok=True)
    
    dfs = {}
    for e in tqdm(compare_dbs):
        sub = e['sub']
        sub_out_dir = out_dir / sub
        sub_out_dir.mkdir(exist_ok=True)
        all = pd.read_sql_table('test', create_engine( form_sqlite_str(str(e['db'])) ))
        sample = subreddit_title_sample(all, sample_size)
        self_similarity = calc_cos_sim(sample, sample, bs=bs)
        pd.DataFrame(self_similarity).to_json(sub_out_dir / 'self_sim.json')
        dfs[sub] = {'all': all, 'sample': sample, 'sub_out_dir': sub_out_dir}
        
    for e in tqdm(compare_against_dbs):
        sub = e['sub']
        if sub in dfs:
            continue
        all = pd.read_sql_table('test', create_engine( form_sqlite_str(str(e['db'])) ))
        sample = subreddit_title_sample(all, sample_size)
        dfs[sub] = {'all': all, 'sample': sample,}
        

    sims = defaultdict(dict)
    for e in tqdm(compare_dbs):
        s1 = e['sub']
        out_dir = dfs[s1]['sub_out_dir']
        s1_titles = dfs[s1]['sample']
        pd.Series(s1_titles).to_json(out_dir / f'{s1}_texts.json')
        for e2 in compare_against_dbs:
            s2 = e2['sub']
            # if s1 == s2:
            #     continue
            print(s2)
            s2_titles = sorted(list(dfs[s2]['all'].p_title.unique()))
            s1_sim_s2 = calc_cos_sim(s1_titles, s2_titles, bs=bs)
            pd.DataFrame(s1_sim_s2).to_json(out_dir / f'{s1}_{s2}.json')
            pd.Series(s2_titles).to_json(out_dir / f'{s2}_texts.json')

def calc_cos_sim(texts1, texts2, bs=64):
    enc1, enc2 = model.encode(texts1, device='cuda', batch_size=bs), model.encode(texts2, device='cuda', batch_size=bs)
    sim = util.cos_sim(enc1, enc2)
    return sim

parser = argparse.ArgumentParser()
parser.add_argument("--model_name", type=str, default="all-mpnet-base-v2")
parser.add_argument("--data_dir", type=str, required=True)
parser.add_argument("--prefix", type=str, default="")
parser.add_argument("--suffix", type=str, default="")
parser.add_argument("--subs", type=str, required=True)
parser.add_argument("--out_dir", type=str, required=True)
parser.add_argument("--sample_size", type=int, default=2500)
parser.add_argument("--batch_size", type=int, default=950)
args = parser.parse_args()

with open(args.subs) as f:
    subs = json.load(f)

model = SentenceTransformer(args.model_name)
compare_dbs = [{'sub':sub, 'db': Path(args.data_dir) / f'{args.prefix}{sub}{args.suffix}.db'} for sub in subs]


calc_all_subreddit_similarities(compare_dbs, compare_dbs, args.out_dir, args.sample_size, bs=args.batch_size)