# the values for these matrices are taken from another script (link to file shared by AD and AL).
import numpy as np
import matplotlib.pyplot as plt


confusion_matrices = dict()
confusion_matrices['arc'] = dict()
confusion_matrices['race'] = dict()
confusion_matrices['cupa'] = dict()

# CUPA GPT 4
# Best parameters set:
# {'logistic__C': 1, 'logistic__penalty': 'l2', 'tfidf__max_df': 0.2, 'tfidf__min_df': 0.02, 'tfidf__ngram_range': (3, 4)}
confusion_matrices['cupa']['gpt_4'] = [
    [46,  9,  7,  0,  4],
    [ 8, 31, 13,  2, 12],
    [ 2, 10, 32, 14,  8],
    [ 0,  4, 21, 23, 18],
    [ 0,  3,  7, 15, 41],
]
#     accuracy                           0.52       330

# CUPA GPT 3.5
# Best parameters set:
# {'logistic__C': 0.5, 'logistic__penalty': 'l1', 'tfidf__max_df': 0.3, 'tfidf__min_df': 0.005, 'tfidf__ngram_range': (3, 3)}
confusion_matrices['cupa']['gpt_3_5'] = [
    [52,  6,  1,  0,  7],
    [ 3, 36, 14,  0, 13],
    [ 3, 10, 32,  0, 21],
    [ 0,  4,  4,  0, 58],
    [ 0,  1,  2,  0, 63],
]
#     accuracy                           0.55       330

# RACE GPT 4
# Best parameters set:
# {'logistic__C': 1, 'logistic__penalty': 'l2', 'tfidf__max_df': 0.3, 'tfidf__min_df': 0.02, 'tfidf__ngram_range': (3, 3)}
confusion_matrices['race']['gpt_4'] = [
    [39,  6,  2,  2,  1],
    [ 6, 24, 11,  6,  2],
    [ 2,  6, 20, 15,  7],
    [ 0,  1, 11, 16, 22],
    [ 2,  3,  2, 12, 30],
]
#     accuracy                           0.52       248

# RACE GPT 3.5
# Best parameters set:
# {'logistic__C': 1, 'logistic__penalty': 'l1', 'tfidf__max_df': 0.4, 'tfidf__min_df': 0.01, 'tfidf__ngram_range': (3, 4)}
confusion_matrices['race']['gpt_3_5'] = [
    [15,  7, 13,  1, 14],
    [ 3,  9, 31,  1,  5],
    [ 1, 19, 21,  6,  3],
    [ 0,  4,  5, 15, 26],
    [ 1,  0,  0,  7, 41],
]
#     accuracy                           0.41       248

# ARC GPT 4
# Best parameters set:
# {'tfidf__max_df': 0.3, 'tfidf__min_df': 0.025, 'tfidf__ngram_range': (3, 4)}
confusion_matrices['arc']['gpt_4'] = [
    [97, 11,  2,  1,  5],
    [13, 63, 27,  7,  5],
    [ 1, 27, 61, 14, 13],
    [ 1,  5, 34, 46, 30],
    [ 3,  6,  6, 31, 69],
]
#     accuracy                           0.58       578

# ARC GPT 3.5
# Best parameters set:
# {'tfidf__max_df': 0.6, 'tfidf__min_df': 0.01, 'tfidf__ngram_range': (3, 4)}
confusion_matrices['arc']['gpt_3_5'] = np.array([
    [64, 30,  9,  2, 11],
    [38, 50, 14,  7,  6],
    [21, 17, 31, 22, 25],
    [ 4,  5, 23, 39, 45],
    [ 2,  2, 10, 39, 62],
])
#     accuracy                           0.43       578


def plot_single_confusion_matrix(model, dataset):
    confusion_matrix = np.array(confusion_matrices[dataset][model])

    fig, ax = plt.subplots(figsize=(5, 5))
    cmaps = {
        'arc': 'Blues',
        'cupa': 'Reds',
        'race': 'Oranges',
    }
    im = ax.imshow(confusion_matrix, cmap=cmaps[dataset])
    # Create colorbar
    cbar = ax.figure.colorbar(im, ax=ax, shrink=0.5)
    cbar.ax.set_ylabel('Count', rotation=-90, va="bottom")
    # Show all ticks and label them with the respective list entries
    ax.set_xticks(np.arange(confusion_matrix.shape[0]))
    ax.set_yticks(np.arange(confusion_matrix.shape[1]))
    ax.set_xticklabels(['one', 'two', 'three', 'four', 'five'])
    ax.set_yticklabels(['one', 'two', 'three', 'four', 'five'])
    ax.set_xlabel('Predicted level')
    ax.set_ylabel('True level')
    # Loop over data dimensions and create text annotations
    for i in range(confusion_matrix.shape[0]):
        for j in range(confusion_matrix.shape[1]):
            text = ax.text(j, i, confusion_matrix[i, j], ha="center", va="center")
    dataset_name_for_plot = {
        'arc': 'ARC',
        'cupa': 'CUP&A',
        'race': 'RACE',
    }
    model_name_for_plot = {
        'gpt_3_5': 'GPT-3.5',
        'gpt_4': 'GPT-4',
    }
    ax.set_title(f"Confusion Matrix - {model_name_for_plot[model]} - {dataset_name_for_plot[dataset]}")
    fig.tight_layout()
    # plt.show()
    plt.savefig(f'output_figures/for_paper/confusion_matrix_{dataset}_{model}.pdf')


if __name__ == '__main__':
    for model_param in ['gpt_3_5', 'gpt_4']:
        for dataset_param in ['arc', 'race', 'cupa']:
            plot_single_confusion_matrix(model_param, dataset_param)
