from sentence_transformers import SentenceTransformer, util
from pathlib import Path
import pandas as pd
from sqlalchemy import create_engine
from collections import defaultdict
from tqdm import tqdm
import random
import argparse
import json
import numpy as np

def form_sqlite_str(path):
    return f'sqlite:///{path}'

def subreddit_title_sample(df, sample_size, title_col='p_title'):
    title = list(df[title_col].unique())
    random.shuffle(title)
    res = title[:sample_size]
    return res

def calc_all_subreddit_similarities(compare_dbs, compare_against_dbs, sample_size, sample_size_compare_sub, out_dir, title_col='p_title', bs=64):
    
    dfs = {}
    print('Calculating embeddings')
    for e in tqdm(compare_dbs):
        sub = e['sub']
        print(e)
        all = pd.read_sql_table('test', create_engine( form_sqlite_str(str(e['db'])) ))
        sample = subreddit_title_sample(all, sample_size, title_col=title_col)
        sample_large = subreddit_title_sample(all, sample_size_compare_sub, title_col=title_col)
        sample_encoded = model.encode(sample, device='cuda', batch_size=bs)
        sample_large_encoded = model.encode(sample_large, device='cuda', batch_size=bs)
        # self_similarity = calc_cos_sim(sample_encoded, sample_encoded)
        # self_similarity = pd.DataFrame(self_similarity)
        # np.fill_diagonal(self_similarity.values, 0)
        # sims = self_similarity.sum(axis=0) / (self_similarity.shape[0] - 1)
        # self_sims[sub] = {'mean': sims.mean(), 'std': sims.std()}
        dfs[sub] = {'all': all, 'sample': sample, 'sample_large': sample_large,
                    'sample_encoded': sample_encoded,
                    'sample_large_encoded': sample_large_encoded }
    
    for e in tqdm(compare_against_dbs):
        sub = e['sub']
        if sub in dfs:
            continue
        all = pd.read_sql_table('test', create_engine( form_sqlite_str(str(e['db'])) ))
        sample = subreddit_title_sample(all, sample_size)
        sample_large = subreddit_title_sample(all, sample_size_compare_sub, title_col=title_col)
        dfs[sub] = {'all': all, 'sample': sample, 'sample_large': sample_large,
                    'sample_encoded': model.encode(sample, device='cuda', batch_size=bs),
                    'sample_large_encoded': model.encode(sample_large, device='cuda', batch_size=bs)}
        
    print('Calculating Emb-PSR similarities')
    sims = defaultdict(dict)
    self_sims = {}
    for e in tqdm(compare_dbs):
        s1 = e['sub']
        if Path(out_dir / f'{s1}_titles.json').exists():
            print(f'{s1} already exists')
            continue
        print('Comparing with: ', s1)
        s1_titles = dfs[s1]['sample']
        s1_encoded = dfs[s1]['sample_encoded']

        # calculate self similarity
        s1_ss_titles = dfs[s1]['sample_large']
        s1_ss_encoded = dfs[s1]['sample_large_encoded']
        s1_sim_s1 = calc_cos_sim(s1_encoded, s1_ss_encoded)
        s1_sims = s1_sim_s1.mean(axis=0)
        self_sims[s1] = {'mean': float(s1_sims.mean()), 'std': float(s1_sims.std()) }
        s1_title_sims = pd.DataFrame({title_col: s1_ss_titles, 'similarity': s1_sims,
                            'similar': s1_sims >= self_sims[s1]['mean']-self_sims[s1]['std']})
        merged_res = dfs[s1]['all'].merge(s1_title_sims, on=title_col, how='inner')
        merged_res.to_json(out_dir / f'{s1}_{s1}.json')

        del s1_ss_encoded
        del s1_sim_s1
        del s1_sims
        
        for e2 in tqdm(compare_against_dbs):
            s2 = e2['sub']
            print(s2)
            if s1 == s2:
                continue
            s2_titles = dfs[s2]['sample_large']
            s2_encoded = dfs[s2]['sample_large_encoded']
            s1_sim_s2 = calc_cos_sim(s1_encoded, s2_encoded)
            s2_sims = s1_sim_s2.mean(axis=0)
            s2_title_sims = pd.DataFrame({title_col: s2_titles, 'similarity': s2_sims,
                             'similar': s2_sims >= self_sims[s1]['mean']-self_sims[s1]['std']})
            merged_res = dfs[s2]['all'].merge(s2_title_sims, on=title_col, how='inner')
            merged_res.to_json(out_dir / f'{s1}_{s2}.json')
            sims[s2][s1] = merged_res['similar'].mean()
        # if this file exists, it means we have already calculated the similarities for this subreddit
        pd.Series(s1_titles).to_json(out_dir / f'{s1}_titles.json')
    with open(out_dir / 'self_sims.json', 'w') as f:
        json.dump(self_sims, f)
    return sims

def calc_cos_sim(enc1, enc2):
    sim = util.cos_sim(enc1, enc2)
    return sim

parser = argparse.ArgumentParser()
# parser.add_argument("--model_name", type=str, default="all-mpnet-base-v2")
parser.add_argument("--model_name", type=str, default="all-MiniLM-L6-v2")
parser.add_argument("--data_dir", type=str, required=True)
parser.add_argument("--prefix", type=str, default="")
parser.add_argument("--suffix", type=str, default="")
parser.add_argument("--subs", type=str, required=True)
parser.add_argument("--out_dir", type=str, required=True)
parser.add_argument("--sample_size", type=int, default=2500)
parser.add_argument("--batch_size", type=int, default=950)
parser.add_argument("--title_col", type=str, default='p_title')
parser.add_argument("--sample_size_compare_sub", type=int, default=25000)
args = parser.parse_args()

with open(args.subs) as f:
    subs = json.load(f)

model = SentenceTransformer(args.model_name)

compare_dbs = [({'sub':sub, 'db': Path(args.data_dir) / f'{args.prefix}{sub}{args.suffix}.db'}) for sub in subs
                    if (Path(args.data_dir) / f'{args.prefix}{sub}{args.suffix}.db').exists()]

out_dir = Path(args.out_dir)
out_dir.mkdir(exist_ok=True)

sims = calc_all_subreddit_similarities(compare_dbs, compare_dbs, args.sample_size, args.sample_size_compare_sub, out_dir, bs=args.batch_size, title_col=args.title_col)

with open(out_dir / 'sims.json', 'w') as f:
    json.dump(sims, f)

# Example usage: python similarity/emb_psr_single_step_calc.py --data_dir=out/mix_and_opp/ --subs=subreddits.json --out_dir=out/sub_sim_jsons/ --sample_size=2500 --batch_size=950