import pandas as pd
import sys
import numpy as np
import value_resonance_scorer as vrs
import data_processing as dp
from allennlp_models import pretrained
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, confusion_matrix

def load_rte_models():
    # Load all required models
    adversarial_bgb_roberta_snli = pretrained.load_predictor("pair-classification-adversarial-binary-gender-bias-mitigated-roberta-snli")
    bgb_roberta_snli = pretrained.load_predictor("pair-classification-binary-gender-bias-mitigated-roberta-snli")
    decomposable_attention_elmo = pretrained.load_predictor("pair-classification-decomposable-attention-elmo")
    roberta_snli = pretrained.load_predictor("pair-classification-roberta-snli")
    roberta_mnli = pretrained.load_predictor("pair-classification-roberta-mnli")

    rte_models = {
        'roberta_mnli': roberta_mnli,
        'roberta_snli': roberta_snli,
        'decomposable_attention_elmo': decomposable_attention_elmo,
        'binary_gender_bias_roberta_snli': bgb_roberta_snli,
        'adversarial_binary_gender_bias_roberta_snli': adversarial_bgb_roberta_snli
    }

    return rte_models

def predict_te(x, model):
    prem=x['premise']
    hyp=x['hypothesis']
    model_results=model.predict(
                    premise=prem,
                    hypothesis=hyp
                )['label']
    if model_results=='entailment':
        return 2 
    elif model_results=='contradiction':
        return 0
    else:
        return 1

def evaluate_models(rte_models, eval_df):
    # Evaluate models
    for model_key in rte_models.keys():
        eval_df[f'{model_key}'] = eval_df.apply(lambda x: predict_te(x, rte_models[model_key]), axis=1)
    
    vrc=vrs.ValueResonanceClassifier(f'../../models/RVR_WVC/')
    eval_df['res_rob']=eval_df.apply(lambda x: vrc.score_entailment(x['premise'], x['hypothesis']), axis=1)
    eval_df['res_rob']=eval_df['res_rob'].apply(lambda x: dp.correct_labels(x))

def score_df(preds, labels):
    # Compute evaluation scores
    averaging = 'weighted'
    y_pred = preds
    y_true = labels
    if type(y_pred[0]) != type(y_true[0]):
        if type(y_pred[0]) == str:
            y_pred_new = []
            for val in y_pred:
                if val == 'contradiction':
                    y_pred_new.append(0)
                elif val == 'neutral':
                    y_pred_new.append(1)
                elif val == 'entailment':
                    y_pred_new.append(2)
            y_pred = y_pred_new
    accuracy_scores = pd.DataFrame({'accuracy': [accuracy_score(y_true, y_pred)],
                                    'precision': [precision_score(y_true, y_pred, average=averaging)],
                                    'recall': [recall_score(y_true, y_pred, average=averaging)],
                                    'F1': [f1_score(y_true, y_pred, average=averaging)]})
    conf_matrix = confusion_matrix(y_true, y_pred)
    accuracy_scores['conf_mat'] = list(conf_matrix.reshape(1, -1))
    return accuracy_scores

def combine_scores(eval_df, llm_scores):
    # Combine RTE and LLM scores
    scored_df = eval_df.set_index(['idx', 'premise', 'hypothesis', 'label']).join(llm_scores.set_index(['idx', 'premise', 'hypothesis', 'label'])).reset_index().copy()
    return scored_df

def compare_models(scored_df):
    # Evaluate and compare models
    comp_rdf = pd.DataFrame(columns=['model', 'accuracy', 'precision', 'recall', 'F1', 'conf_mat'])
    for model_key in scored_df.columns[4:]:
        acc_df = score_df(scored_df[f'{model_key}'].values, scored_df.label)
        comp_rdf = pd.concat([comp_rdf, pd.DataFrame({'model': model_key,
                                                      'accuracy': acc_df['accuracy'],
                                                      'precision': acc_df['precision'],
                                                      'recall': acc_df['recall'],
                                                      'F1': acc_df['F1'],
                                                      'conf_mat': acc_df['conf_mat']})]).reset_index(drop=True)
    return comp_rdf

def compare_models_by_class(scored_df):
    accuracy_scores_3class=pd.DataFrame({'model':[],'R_a':[],'N_a':[],'C_a':[],'R_f1':[],'N_f1':[],'C_f1':[]})
    
    for col in scored_df.columns[4:]:
        y_pred=scored_df.loc[:,col]
        y_true=scored_df.loc[:,'label']
        temp_df=pd.DataFrame({'model':[col],
                        'R_a':[accuracy_score([2 if y == 2 else 0 for y in y_true], [2 if y == 2 else 0 for y in y_pred])],
                        'N_a':[accuracy_score([2 if y == 1 else 0 for y in y_true], [2 if y == 1 else 0 for y in y_pred])],
                        'C_a':[accuracy_score([2 if y == 0 else 0 for y in y_true], [2 if y == 0 else 0 for y in y_pred])],
                        'R_p':[precision_score([1 if y == 2 else 0 for y in y_true], [1 if y == 2 else 0 for y in y_pred])],
                        'N_p':[precision_score([1 if y == 1 else 0 for y in y_true], [1 if y == 1 else 0 for y in y_pred])],
                        'C_p':[precision_score([1 if y == 0 else 0 for y in y_true], [1 if y == 0 else 0 for y in y_pred])],
                        'R_r':[recall_score([1 if y == 2 else 0 for y in y_true], [1 if y == 2 else 0 for y in y_pred])],
                        'N_r':[recall_score([1 if y == 1 else 0 for y in y_true], [1 if y == 1 else 0 for y in y_pred])],
                        'C_r':[recall_score([1 if y == 0 else 0 for y in y_true], [1 if y == 0 else 0 for y in y_pred])],
                        'R_f1':[f1_score([1 if y == 2 else 0 for y in y_true], [1 if y == 2 else 0 for y in y_pred])],
                        'N_f1':[f1_score([1 if y == 1 else 0 for y in y_true], [1 if y == 1 else 0 for y in y_pred])],
                        'C_f1':[f1_score([1 if y == 0 else 0 for y in y_true], [1 if y == 0 else 0 for y in y_pred])]})
        accuracy_scores_3class=pd.concat([accuracy_scores_3class, temp_df])
        
    return accuracy_scores_3class

def select_top_models(comp_rdf, accuracy_scores_3class):
    # Select top-performing models
    llm_scores = ['llm_output_annot', 'llm_output_annot_instr', 'llm_output_annot_instr_reas', 'llm_output_annot_reas',
                  'llm_output_complete', 'llm_output_complete_instr', 'llm_output_complete_instr_reas', 'llm_output_complete_reas', 'simple']
    rte_scores = ['roberta_mnli', 'roberta_snli', 'decomposable_attention_elmo', 'binary_gender_bias_roberta_snli',
                  'adversarial_binary_gender_bias_roberta_snli']

    top_llm = comp_rdf.loc[comp_rdf.model.isin(llm_scores), :].sort_values(['accuracy', 'F1'], ascending=False).reset_index(drop=True).head(1)
    top_rte = comp_rdf.loc[comp_rdf.model.isin(rte_scores), :].sort_values(['accuracy', 'F1'], ascending=False).reset_index(drop=True).head(1)
    res_roberta = comp_rdf.loc[comp_rdf.model == 'res_rob', :].sort_values(['accuracy', 'F1'], ascending=False).reset_index(drop=True).head(1)
    top_perf_df = pd.concat([res_roberta, pd.concat([top_llm, top_rte])])

    top_perf_df_byclass = accuracy_scores_3class.loc[accuracy_scores_3class.model.isin(top_perf_df.model.values),
                                                     [col for col in accuracy_scores_3class.columns if ('a' in col) or ('f1' in col) or (col == 'model')]].reset_index(drop=True)
    top_perf_complete = top_perf_df[['model', 'accuracy', 'F1']].set_index('model').join(
        top_perf_df_byclass.set_index('model')).reset_index()

    top_llm_overall=pd.DataFrame(comp_rdf.loc[comp_rdf.model.isin(llm_scores),['accuracy','F1']].apply(lambda x: x.max())).T.reset_index(drop=True)
    top_rte_overall=pd.DataFrame(comp_rdf.loc[comp_rdf.model.isin(rte_scores),['accuracy','F1']].apply(lambda x: x.max())).T.reset_index(drop=True)
    res_roberta_overall=comp_rdf.loc[comp_rdf.model=='res_rob',['accuracy','F1']].reset_index(drop=True)

    top_llm_byclass=pd.DataFrame(accuracy_scores_3class.loc[accuracy_scores_3class.model.isin(llm_scores),[col for col in accuracy_scores_3class.columns if ('a' in col) or ('f1' in col)]].apply(lambda x: x.max())).T.reset_index(drop=True)
    top_rte_byclass=pd.DataFrame(accuracy_scores_3class.loc[accuracy_scores_3class.model.isin(rte_scores),[col for col in accuracy_scores_3class.columns if ('a' in col) or ('f1' in col)]].apply(lambda x: x.max())).T.reset_index(drop=True)
    res_roberta_byclass=accuracy_scores_3class.loc[accuracy_scores_3class.model=='res_rob',[col for col in accuracy_scores_3class.columns if ('a' in col) or ('f1' in col)]].reset_index(drop=True)

    res_roberta_complete=res_roberta_overall.join(res_roberta_byclass)
    res_roberta_complete['model']='Res-RoBERTa WVC'
    top_llm_complete=top_llm_overall.join(top_llm_byclass)
    top_llm_complete['model']='\textit{LLM:} Top Competitor'
    top_rte_complete=top_rte_overall.join(top_rte_byclass)
    top_rte_complete['model']='\textit{RTE:} Top Competitor'

    summary_df=pd.concat([res_roberta_complete,pd.concat([top_llm_complete,top_rte_complete])]).reset_index(drop=True)
    summary_df[list(summary_df.columns)[:-1]] = summary_df[list(summary_df.columns)[:-1]].apply(lambda x: round(x,2))

    return top_perf_df, summary_df[['model']+list(summary_df.columns)[:-1]]

def calc_perf_scores(LLM_output):
    results_df=LLM_output[['prompt_construction','idx','label','cleaned_prediction']].pivot(columns='prompt_construction',index=['idx','label']).reset_index()
    results_df.label=results_df.label.astype(str)
    cols=[i[1] for i in results_df.columns[2:]]
    results_df.columns=['idx','label']+cols
    for col in cols:
        results_df[col]=results_df[col].astype(str)
    rename_dict={}
    for col in cols:
        name=''
        if 'compl' in col:
            name=name+'Complete RVR Definition'
        elif 'annot' in col:
            name=name+'Annotator Instructions'
        elif 'simple' in col:
            name=name+'Simplified RVR Definition'
        if ('inst' in col) and ('reas' in col):
            name=name+' (Instructions and Reasoning)'
        elif 'inst' in col:
            name=name+' (Instructions)'
        elif 'reas' in col:
            name=name+' (Reasoning)'
        rename_dict[col]=name
        
            
    LLM_performance_scores=pd.DataFrame({'type':[],
                                        'accuracy':[],
                                        'precision':[],
                                        'recall':[],
                                        'F1':[]})
    for col in cols:
        performance_df=pd.DataFrame({'type':[rename_dict[col]],
                                    'accuracy':[accuracy_score(results_df.label, results_df[col])],
                                    'precision':[precision_score(results_df.label, results_df[col],average='weighted')],
                                    'recall':[recall_score(results_df.label, results_df[col],average='weighted')],
                                    'F1':[f1_score(results_df.label, results_df[col],average='weighted')],
                                    'R_acc':[accuracy_score(results_df.loc[results_df.label=='2',:].label, results_df.loc[results_df.label=='2',:][col])],
                                    'R_F1':[f1_score(results_df.loc[results_df.label=='2',:].label, results_df.loc[results_df.label=='2',:][col],average='weighted')],
                                     'N_acc':[accuracy_score(results_df.loc[results_df.label=='1',:].label, results_df.loc[results_df.label=='1',:][col])],
                                    'N_F1':[f1_score(results_df.loc[results_df.label=='1',:].label, results_df.loc[results_df.label=='1',:][col],average='weighted')],
                                     'C_acc':[accuracy_score(results_df.loc[results_df.label=='0',:].label, results_df.loc[results_df.label=='0',:][col])],
                                    'C_F1':[f1_score(results_df.loc[results_df.label=='0',:].label, results_df.loc[results_df.label=='0',:][col],average='weighted')],})
        LLM_performance_scores=pd.concat([LLM_performance_scores,
                                          performance_df]).reset_index(drop=True)

    return LLM_performance_scores.sort_values('F1',ascending=False).reset_index(drop=True)