import pandas as pd
from pathlib import Path
import numpy as np
from tqdm import tqdm
import json
from sqlalchemy import create_engine
import argparse

def get_self_sim_stats(sub_dir):
    self_sim = pd.read_json(sub_dir / 'self_sim.json')
    np.fill_diagonal(self_sim.values, 0)
    sims = self_sim.sum(axis=0) / (self_sim.shape[0] - 1)
    return {'mean': sims.mean(), 'std': sims.std()}

def calculate_self_similarities(in_dir, subs):
    res = {}
    for sub in tqdm(subs):
        res[sub] = get_self_sim_stats(in_dir / sub)
    return res

def read_sim_df(in_dir, sub1, sub2, sim_thresh):
    sims = pd.read_json(in_dir / sub1 / f'{sub1}_{sub2}.json')
    sim_mean = sims.mean(axis=0)
    return pd.DataFrame({ 'p_title': pd.read_json(in_dir / sub1 / f'{sub2}_texts.json', typ='series'), 
                        f'{sub1}_similarity': sim_mean,
                        f'{sub1}_similar': sim_mean >= sim_thresh})

def form_sqlite_str(path):
    return f'sqlite:///{path}'

def read_similarity_df(sub, all_subs, db_dir, in_dir, self_sims, prefix=''):
    test = pd.read_sql_table('test', create_engine( form_sqlite_str(str(db_dir / f'{prefix}{sub}.db')) ) )
    for sub1 in tqdm(subs):
        if sub1 == sub:
            continue
        test = add_similarity_columns(test, in_dir, sub1, sub, self_sims)
    return test

def add_similarity_columns(test, in_dir, sub1, sub2, self_sims):
    sim = read_sim_df(in_dir, sub1, sub2, sim_thresh=self_sims[sub1]['mean']-self_sims[sub1]['std'])
    return test.merge(sim, on='p_title')

def calc_emb_psr(df, sub):
    return pd.Series({f'{sub}_similar':1.0, **df[[f'{e}_similar' for e in subs if e != sub]].mean().to_dict()})

def calc_emb_psr_from_json(json_path, sub):
    sub_json = pd.read_json(json_path)
    return calc_emb_psr(sub_json, sub)

# from types import SimpleNamespace
# args = SimpleNamespace(in_dir='E:/sub_similarities/', db_dir='out/mix_and_opp/',
#             out_dir='out/sub_sim_jsons/', preds_json='data/mix_test_25_3000_with_opp.json',
#             subreddit_json='subreddits.json', out_preds_json='data/test_emb_psr_sim_out.json',
#                       emb_psr_out_json='emb_psr_out.json')

parser = argparse.ArgumentParser()
parser.add_argument("--in_dir", type=str, required=True)
parser.add_argument("--db_dir", type=str, required=True)
parser.add_argument("--out_dir", type=str, required=True)
parser.add_argument("--preds_json", type=str, required=True)
parser.add_argument("--out_preds_json", type=str, required=True)
parser.add_argument("--emb_psr_out_json", type=str, required=True)
parser.add_argument("--subreddits_json", type=str, default="subreddits.json")
args = parser.parse_args()

in_dir = Path(args.in_dir)
db_dir = Path(args.db_dir)
out_dir = Path(args.out_dir)
out_dir.mkdir(exist_ok=True)

with open(args.subreddit_json) as f:
    subs = json.load(f)

self_sims = calculate_self_similarities(in_dir, subs)
self_sims
with open(in_dir / 'self_sims.json', 'w') as f:
    json.dump(self_sims, f)

emb_psr = {}
for sub in subs:
    print(sub)
    test = read_similarity_df(sub, subs, db_dir, in_dir, self_sims)
    emb_psr[sub] = calc_emb_psr(test, sub)
    test.to_json(out_dir / f'{sub}.json')
emb_psr = pd.DataFrame(emb_psr).transpose()
emb_psr.columns = [e.replace('_similar', '') for e in emb_psr.columns]
emb_psr = emb_psr.sort_index()[sorted(emb_psr.columns)]
emb_psr.to_json(args.emb_psr_out_json)

cols = ['p_title', 'p_subreddit'] + [f'{s}_similarity' for s in subs] + [f'{s}_similar' for s in subs]
all_test = None
for sub in tqdm(subs):
    test = pd.read_json(out_dir / f'{sub}.json')
    all_test = test if all_test is None else pd.concat((all_test, test), axis=0)
all_test = all_test[cols].drop_duplicates(subset=['p_title', 'p_subreddit'], ignore_index=True)
preds = pd.read_json(args.preds_json)
preds_enhanced = preds.merge(all_test, how='left', on=['p_title', 'p_subreddit'])
preds_enhanced.to_json(args.out_preds_json)