import pandas as pd
import numpy as np
from pathlib import Path
from tqdm import tqdm
import argparse
from sklearn.metrics import classification_report
from collections import defaultdict

def _calc_prob_score(preds, sub):
    y, pred, prob = preds['c_polarity'], preds[f'prediction_{sub}'], preds[f'prob_{sub}']    
    diff = ((prob[y == pred]).sum() - (prob[y != pred]).sum())
    diff = 0 if diff < 0 else diff
    return diff / prob.sum()

def calc_prob_score(preds, sub):
    return (_calc_prob_score(preds[preds['c_polarity'] == 0], sub) *  _calc_prob_score(preds[preds['c_polarity'] == 1], sub))

def calc_prob_score_raw(preds, sub):
    return _calc_prob_score(preds[preds['c_polarity'] == 0], sub),  _calc_prob_score(preds[preds['c_polarity'] == 1], sub)

def calc_nll_loss(preds, sub):
    prob = preds[f'prob_{sub}'].copy()
    mask = preds[f'prediction_{sub}'] == preds['c_polarity']
    prob[~mask] = (1 - prob[~mask])
    return -np.log(prob + 1e-5).sum() / prob.shape[0]

sub_translate = {'Donut': 'Bad_Cop_No_Donut', 'Donald': 'The_Donald'}

def substitute_sub(sub):
    if sub in sub_translate:
        return sub_translate[sub]
    return sub

def calc_dist_metrics(preds, sub, suffix, t=0.0, similarity_thresh=0.05):
    if t > 0.0:
        condition = preds[f'prob_{sub}'] >= t        
        opposite_preds = preds['c_polarity'].apply(lambda e: 0 if e == 1 else 1) 
        preds[f'prediction_{sub}'] = preds[f'prediction_{sub}'].where(condition, opposite_preds[~condition] )
    if similarity_thresh is None:
        sel_subs = set(preds['p_subreddit'].unique())
    else:
        sub_sim = preds[f'{substitute_sub(sub)}_similar'].fillna(1.0)
        n_per_sub = preds['p_subreddit'].value_counts()
        preds = preds[sub_sim.apply(bool)]
        sub_ratio = (preds['p_subreddit'].value_counts() / n_per_sub).fillna(0.0)
        sel_subs = set(sub_ratio[sub_ratio >= similarity_thresh].index)
    res = {}
    for s, subdf in preds.groupby('p_subreddit'):
        if s in sel_subs:
            clf_res = classification_report(subdf['c_polarity'], subdf[f'prediction_{sub}'], output_dict=True, zero_division=0)
            clf_res = defaultdict(lambda : defaultdict(float), clf_res)
            res[s] = {f'Accuracy_{suffix}': clf_res['accuracy'], f'f1_{suffix}': clf_res['macro avg']['f1-score'], f'data_points_{suffix}':clf_res['macro avg']['support'],
                    f'nll_loss_{suffix}': calc_nll_loss(subdf, sub), f'neg_f1_{suffix}': clf_res['0']['f1-score'], f'sim_ratio_{suffix}': None if similarity_thresh is None else sub_ratio[s]}
        else:
            res[s] = {f'Accuracy_{suffix}': 0.0, f'f1_{suffix}': 0.0, f'data_points_{suffix}': 0.0,
                    f'nll_loss_{suffix}': 100.0, f'neg_f1_{suffix}': 0.0, f'sim_ratio_{suffix}':sub_ratio[s] }
    res = pd.DataFrame(res).transpose()
    return res

def calc_dist_metrics_multi_thresh(preds, sub, thresh=[0.85,], similarity_thresh=0.05):
    res = calc_dist_metrics(preds, sub, suffix='all', similarity_thresh=similarity_thresh)
    for t in thresh:
        res = pd.concat((res, calc_dist_metrics(preds, sub, suffix=f'{t}', t=t, similarity_thresh=similarity_thresh)), axis=1)
    return res

def calc_dist_metrics_all_subs(preds, subs, select_suffix='0.85', metric='f1', similarity_thresh=0.05):
    preds = preds.copy()
    dists = {}
    for sub in subs:
        # if we train on "sub" and predict on other subreddits, the following will be the f1 scores of such predictions
        dists[substitute_sub(sub)] = calc_dist_metrics_multi_thresh(preds, sub, similarity_thresh=similarity_thresh)[f'{metric}_{select_suffix}']
    dists = pd.DataFrame(dists).transpose()
    return dists


parser = argparse.ArgumentParser()
parser.add_argument("--in_dir", type=str, required=True)
parser.add_argument("--preds_json", type=str, required=True)
parser.add_argument("--bots_out_json", type=str, required=True)
args = parser.parse_args()

# from types import SimpleNamespace
# args = SimpleNamespace(in_dir='llama_preds/old_and_opp_for_mix/', preds_json='data/mix_test_25_3000_with_opp_sim.json',
#                       bots_out_json='bots_out.json')

path = Path(args.in_dir)
pred_dirs = list(path.glob('*.csv'))
subs = [p.stem.split(".")[0] for p in pred_dirs]
preds = pd.read_json(args.preds_json)
for p in tqdm(pred_dirs):
    sub = p.stem.split(".")[0]
    pred = pd.read_csv(p)
    pred['is_correct'] = pred['prediction'] == pred['y']
    pred = pred.drop(columns=['id', 'y'])
    pred.columns = [e + '_' + sub for e in pred.columns]
    preds = pd.concat((preds, pred), axis=1)
pred_cols = [f'is_correct_{s}' for s in subs]
preds['total_correct'] = preds[pred_cols].sum(axis=1)
bots = calc_dist_metrics_all_subs(preds, subs, select_suffix='all', similarity_thresh=None).sort_index(axis=1)
bots.to_csv(args.bots_out_json)