import pandas as pd
from scipy.stats import ks_2samp
import value_resonance_scorer as vrs

def load_llm_scores():
    # Load LLM scores
    llm_scores = pd.read_csv('../../data/LLM_Experiment/llm_results_master_csv.csv')
    llm_scores = llm_scores.pivot_table(index=['idx', 'premise', 'hypothesis', 'label'], columns=['prompt_construction'], values=['cleaned_prediction']).reset_index()
    llm_scores.columns = [i[0] for i in llm_scores.columns[:4]] + [i[1] for i in llm_scores.columns[4:]]
    llm_scores['idx'] = llm_scores['idx'] - (447 + 5388)
    return llm_scores

def correct_labels(x):
    if 'ent' in x:
        return 2
    elif 'con' in x:
        return 0
    elif 'neu' in x:
        return 1

def collect_results(test, model='Complete', model_type = 'RVR'):
    if model == 'Entailment' or model_type == 'RTE':
        vrc=vrs.ValueResonanceClassifier(model_path='',model_type='rte',tokenizer='roberta-large-mnli')
    elif model in ['raw_stem','author_believes','text_expresses','text_expresses_belief','original_hypothesis']:
        vrc=vrs.ValueResonanceClassifier(f'../../models/RVR_WVC/')
    else:
        vrc=vrs.ValueResonanceClassifier(f'../../models/RVR_{model}/')
    
    results_df=vrc.predict(premises=list(test.premise.values),
                                hypotheses=list(test[model].values) if model in ['raw_stem','author_believes','text_expresses','text_expresses_belief', 'original_hypothesis'] else list(test.hypothesis.values),
                                do_eval=True,
                                true_labels=list(test.label.values))
    return results_df

def summarize_overall_error_reports(res_results):
    overall_performance_summary=pd.concat([report[1] for name, report in res_results.items()])
    overall_performance_summary=overall_performance_summary.apply(lambda x: x.round(2))
    data_names=[name for name in res_results.keys()]
    overall_performance_summary['Dataset']=data_names
    return overall_performance_summary.set_index('Dataset')

def run_full_error_analysis(model_type,dfs,df_names,res_roberta_results={}):
    for df_name, df in dfs.items():
        print(f"Processing DF: {df_name}")
        res_roberta_results[df_name]=collect_results(df,model_type)
    overall_error_report=summarize_overall_error_reports(res_roberta_results)
    return overall_error_report

def execute_error_analysis(run_results, eval_dfs, eval_df_names, model_names):
    for model in model_names:
        print(f"Processing Model '{model}'")
        rr_temp={}
        oe_report=run_full_error_analysis(model, dfs=eval_dfs,df_names=eval_df_names,res_roberta_results=rr_temp)
        run_results[f"{model}"]={
                                'overall_error_report':oe_report,
                                'res_roberta_results':rr_temp,
                                }

def collect_results_df(run_results,rr_keys,model_recode_dict,df_names_new):
    results_df=[]

    for key in rr_keys:
        temp_df=run_results[f"{key}"]['overall_error_report'].reset_index()
        temp_df['model']=key
        temp_df
        if len(results_df)==0:
            results_df=temp_df
        else:
            results_df=pd.concat([results_df,temp_df])

    results_df['model']=results_df.model.apply(lambda x: model_recode_dict[x])
    results_df['Dataset']=results_df.Dataset.apply(lambda x: df_names_new[x])
    return results_df

def conduct_ks_test(data1, data2):
    D, p_value = ks_2samp(data1, data2)
    return D, p_value

def calculate_summary_statistics(train, test, val):
    labels = ['Conflict', 'Neutral', 'Resonates']
    
    train['source'] = 'Training'
    test['source'] = 'Testing'
    if val.empty:
        full_data = pd.concat([train, test])
    else:
        val['source'] = 'Validation'
        full_data = pd.concat([train, val, test])
    
    full_data['coded_label'] = full_data.label.apply(lambda x: labels[int(x)])
    summary = pd.pivot_table(full_data, values='label', index='source', columns='coded_label', aggfunc='count')
    summary.columns = [lab for lab in labels if lab in full_data.coded_label.unique()]
    summary.index = summary.index.values
    
    summary['Total'] = summary.sum(axis=1).values
    summary.loc['Total'] = summary.sum(axis=0).values
    
    summary_probs = summary.copy()
    summary_probs.iloc[:, :-1] = summary_probs.iloc[:, :-1].apply(lambda x: round(x / x.sum(), 2), axis=1)
    
    summary_combined = summary.iloc[:, :-1].astype(str) + ' (' + summary_probs.iloc[:, :-1].astype(str) + ')'
    summary_combined['Total'] = summary.iloc[:, -1:].astype(str)
    
    return summary_combined.loc[[i for i in ['Training','Validation','Testing','Total'] if i in summary_combined.index],
                                    [c for c in ['Resonates', 'Neutral', 'Conflict', 'Total'] if c in summary_combined.columns]]

def check_overlap(df1,df2):
    return pd.merge(df1[list(df1.columns[1:])],df2[list(df2.columns[1:])], how='inner')

def check_dataframe_overlap(dfs):
    overlap_df = pd.DataFrame()
    for df1n, df1 in dfs.items():
        for df2n, df2 in dfs.items():
            overlap_df.loc[df1n,df2n] = len(check_overlap(df1,df2))
    
    return overlap_df.loc[['wvc_train','HVE_train', 'complete_train', 
                 'wvc_val', 'HVE_val', 'complete_val',
                 'wvc_test', 'HVE_test', 'complete_test'],
                ['wvc_train','HVE_train', 'complete_train', 
                 'wvc_val', 'HVE_val', 'complete_val',
                 'wvc_test', 'HVE_test', 'complete_test']]

def summarize_data_distributions(train_df, test_df, val_df = pd.DataFrame()):
    train = train_df.copy()
    test = test_df.copy()
    val = val_df.copy()
    summary = calculate_summary_statistics(train, test, val)
        
    D_train_test, p_value_train_test = conduct_ks_test(train['label'], test['label'])
    
    if val.empty == False:
        D_train_val, p_value_train_val = conduct_ks_test(train['label'], val['label'])
        D_val_test, p_value_val_test = conduct_ks_test(val['label'], test['label'])
        ks_results = pd.DataFrame({
            'Comparison': ['Training vs. Validation', 'Training vs. Testing', 'Validation vs. Testing'],
            'KS Statistic': [round(D_train_val,2), round(D_train_test,2), round(D_val_test,2)],
            'p-value': [round(p_value_train_val,2), round(p_value_train_test,2), round(p_value_val_test,2)]
        })
    else:
        ks_results = pd.DataFrame({
            'Comparison': ['Training vs. Testing'],
            'KS Statistic': [round(D_train_test,2)],
            'p-value': [round(p_value_train_test,2)]
                             })
    
    print('Data Summary:')
    display(summary)
    print('\nKS-Test:')
    display(ks_results)
