import pandas as pd
import numpy as np
import seaborn as sns
import matplotlib as mpl
import matplotlib.pyplot as plt
import model_evaluation as me


def plot_confmat(results, models, model_names):
    # Plot confusion matrices
    fig, axs = plt.subplots(1, len(models), figsize=(len(models)*2, 3), constrained_layout=True)
    i = 0
    for m, model in enumerate(models):
        matrix = np.flip(results.loc[results.model == model, 'conf_mat'].values[0].reshape(3, -3))
        ax = axs.flatten()[i]
        ax.yaxis.tick_right()
        ax.xaxis.set_label_position("top")
        g = sns.heatmap(matrix, ax=ax, annot=True, cmap='Blues', cbar=False, fmt='g')
        if m == 0:
            g.set_ylabel("True")
        else:
            g.set_ylabel("")
        if m < len(models)-1:
            g.set_yticks([0.5, 1.5, 2.5], ['', '', ''], rotation=0)
        else:
            g.set_yticks([0.5, 1.5, 2.5], ['R', 'N', 'C'], rotation=0)
        g.set_xlabel("Predicted")
        g.set_xticks([0.5, 1.5, 2.5], ['R', 'N', 'C'])
        g.tick_params(bottom=False)
        g.tick_params(right=False)
        g.set_title(f"{model_names[m]} \n(Acc--{round(results.loc[results.model == model, 'accuracy'].values[0], 2)}, F1--{round(results.loc[results.model == model, 'F1'].values[0], 2)})", wrap=True, pad=25)
        i += 1
    return fig

def plot_bar_charts(results_df):
    sns.set(style="whitegrid")

    # Define color palette
    num_reds = 2
    num_greens = 3
    reds = plt.cm.get_cmap('Reds', num_reds + 4)
    greens = plt.cm.get_cmap('Greens', num_greens + 1)
    colors = list(reds([(2 * i + 2) for i in range(num_reds)])) + list(greens([(i + 1) for i in range(num_greens)]))

    fig, axs = plt.subplots(2, 1, figsize=(6, 7), sharex=True)

    hue_order = ['RoBERTa MNLI',
                 'GPT-3 (Annotator \nInstr. w/ Instructions)',
                 'Res-RoBERTa WVC',
                 'Res-RoBERTa HVE',
                 'Res-RoBERTa Complete']

    g = sns.barplot(data=results_df,
                    x="Dataset", y="accuracy", hue="model", hue_order=hue_order,
                    ax=axs[0],
                    palette=colors)

    g.spines.right.set_visible(False)
    g.spines.top.set_visible(False)
    axs[0].set_xlabel('')
    axs[0].tick_params(
        axis='x',
        which='both',
        bottom=False,
        top=False,
        labelbottom=False)
    axs[0].set_ylabel('Accuracy')
    axs[0].get_legend().remove()

    for bar in g.patches:
        bar_height = bar.get_height()
        axs[0].text(bar.get_x() + bar.get_width() / 2, bar_height, f'.{bar_height * 100:.0f}', ha='center', va='bottom',
                    fontsize=9)

    g = sns.barplot(data=results_df,
                    x="Dataset", y="F1", hue="model", hue_order=hue_order,
                    ax=axs[1],
                    palette=colors)

    g.spines.right.set_visible(False)
    g.spines.top.set_visible(False)

    axs[1].set_xlabel('Test Dataset')
    axs[1].set_ylabel('(Weighted) F1 Score')

    for bar in g.patches:
        bar_height = bar.get_height()
        axs[1].text(bar.get_x() + bar.get_width() / 2, bar_height, f'.{bar_height * 100:.0f}', ha='center', va='bottom',
                    fontsize=9)
        
    # Get handles and labels
    handles, labels = axs[1].get_legend_handles_labels()

    # Handle legends
    handle_legends(axs, handles, labels)
    
    plt.tight_layout()
    plt.show()
    return fig

def handle_legends(axs, handles, labels):
    green_patch = mpl.patches.Rectangle((0, 0), 1, 1, fill=False, edgecolor='green', visible=True)
    red_patch = mpl.patches.Rectangle((0, 0), 1, 1, fill=False, edgecolor='red', visible=True)

    plt.legend(handles=[green_patch] + handles[2:] + [red_patch] + handles[:2],
                  labels=['Res-RoBERTa Models'] + labels[2:] + ['Competitors'] + labels[:2],
                  ncol=2,
                  fancybox=True,
                  loc='lower left',
                  shadow=True,
                  frameon=True,
                  fontsize=10.5,
                  framealpha=1,
                  facecolor="whitesmoke")

def plot_box_plots(results_df):
    fig, axs = plt.subplots(1, 2, figsize=(12, 5), sharey=True)

    rr_comp = results_df[['Dataset', 'model', 'accuracy']].melt(id_vars=['Dataset', 'model'], value_vars=['accuracy'])
    rr_comp['model'] = rr_comp['model'].apply(lambda x: 'Res-RoBERTa' if 'Res' in x else np.nan if 'RTE' in x else 'GPT-3')

    r_index = 13
    g_index = 9

    t20c = plt.cm.get_cmap('tab20c', 20)
    t20b = plt.cm.get_cmap('tab20b', 20)

    colors = [list(t20c(np.linspace(0, 1, 20)))[g_index], list(t20b(np.linspace(0, 1, 20)))[r_index]]

    sns.boxplot(data=rr_comp, ax=axs[0], x='Dataset', y='value', hue='model', palette=colors)
    axs[0].set_ylim([0.4, 1])
    axs[0].set_xlabel('Test Dataset')
    axs[0].set_ylabel('Accuracy')
    axs[0].get_legend().remove()

    rr_comp = results_df[['Dataset', 'model', 'F1']].melt(id_vars=['Dataset', 'model'], value_vars=['F1'])
    rr_comp['model'] = rr_comp['model'].apply(lambda x: 'Res-RoBERTa' if 'Res' in x else np.nan if 'RTE' in x else 'GPT-3')

    sns.boxplot(data=rr_comp, ax=axs[1], x='Dataset', y='value', hue='model', palette=colors)
    axs[1].set_ylim([0.4, 1])
    axs[1].set_xlabel('Test Dataset')
    axs[1].set_ylabel('(Weighted) F1 Score')
    axs[1].yaxis.set_label_position("right")

    axs[1].legend(title='', loc='lower left', fontsize=10.5, fancybox=True, shadow=True, frameon=True,
                  framealpha=1, facecolor="whitesmoke")

    plt.tight_layout()
    return fig

def map_model_names(model_name):
    model_name_dict = {
        'Top Competitor (RTE)': 'RoBERTa MNLI',
        'RoBERTa MNLI': 'RoBERTa MNLI',
        'Top Competitor (GPT-3)': 'GPT-3 (Annotator \nInstr. w/ Instructions)',
        'Res-RoBERTa WVC': 'Res-RoBERTa WVC',
        'Res-RoBERTa HVE': 'Res-RoBERTa HVE',
        'Res-RoBERTa Complete': 'Res-RoBERTa Complete'
    }
    return model_name_dict.get(model_name, model_name)

def generate_comparison_charts(results_df):
    llm_results = pd.read_csv('../../data/LLM_Experiment/llm_results_master_csv.csv')

    # Calculate performance scores
    llm_perf_scores = me.calc_perf_scores(llm_results)
    wvc_perf = me.calc_perf_scores(llm_results.loc[llm_results.dataset == 'WVC', :].copy())
    wvc_perf = wvc_perf.loc[wvc_perf.type == 'Annotator Instructions (Instructions)'].copy()
    sem_perf = me.calc_perf_scores(llm_results.loc[llm_results.dataset == 'Touche', :].copy())
    sem_perf = sem_perf.loc[sem_perf.type == 'Annotator Instructions (Instructions)'].copy()
    g_perf = me.calc_perf_scores(llm_results.loc[llm_results.dataset == 'Noise', :].copy())
    g_perf = g_perf.loc[g_perf.type == 'Annotator Instructions (Instructions)'].copy()
    c_perf = llm_perf_scores.loc[llm_perf_scores.type == 'Annotator Instructions (Instructions)'].copy()

    wvc_perf['Dataset'] = 'WVC'
    sem_perf['Dataset'] = 'Touche HV'
    g_perf['Dataset'] = 'Noise'
    c_perf['Dataset'] = 'Complete'

    llm_rdf = pd.concat([wvc_perf, sem_perf, g_perf, c_perf])
    llm_rdf['model'] = 'Top Competitor (GPT-3)'

    # Concatenate dataframes
    r_df_wllms = pd.concat([results_df, llm_rdf[results_df.columns]])
    r_df_wllms.model = r_df_wllms.model.apply(lambda x: map_model_names(x))

    # Plot bar charts
    bar = plot_bar_charts(r_df_wllms)
    box = plot_box_plots(r_df_wllms)
    return bar, box