from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay, precision_recall_curve
from sklearn.metrics import classification_report
from torch.utils.tensorboard import SummaryWriter
from sklearn.metrics import roc_curve, auc
from torch.nn.functional import one_hot
from collections import defaultdict
from matplotlib.axes import Axes
import matplotlib.pyplot as plt
from typing import Union, Tuple
from copy import deepcopy
from torch import Tensor
import seaborn as sns
import pandas as pd
import torchmetrics
import numpy as np
import torch
import math
import time
import io


class ScalarMetrics:
    def __init__(self, class_names: list):
        self.class_names = class_names

    def metric_dict_to_tensorboard(self, metric_dict: dict, time_step: int, writer: SummaryWriter, pre_key: str = '') -> None:
        """
        Writes a dictionary, which contains scalar metrics and their names, to tensorboard.

        Parameters
        ----------
        metric_dict: Dictionary containing the scalar metrics and their names.
        time_step: Current time step (needed for tensorboard).
        writer: A tensorboard writer object, which writes the metrics to tensorboard.
        pre_key
        """

        output_dict = dict()

        for key, value in metric_dict.items():
            if key == 'accuracy':
                output_dict['accuracy'] = value
            else:
                for sub_key, sub_value in value.items():
                    total_key = key + '/' + sub_key
                    output_dict[total_key] = sub_value

        for key, value in output_dict.items():
            writer.add_scalar(pre_key + key, value, time_step)
        writer.flush()

    def __call__(self, preds: Tensor, tars: Tensor, writer: Union[None, SummaryWriter] = None,
                 time_step: Union[None, int] = None, pre_key: str = '') -> dict:
        """
        Computes a set of scalar metrics and (optionally) writes them to tensorboard.

        Parameters
        ----------
        preds: Pytorch tensor of predictions.
        tars: Pytorch tensor of targets.

        Returns
        -------
        metric_dict: Dictionary, which contains metrics and their corresponding names.
        """

        tars = tars.detach().cpu().numpy()

        if len(preds.size()) > 1:
            predicted_classes = torch.argmax(preds, dim=1).detach().cpu().numpy()
        else:
            predicted_classes = (preds > 0.5).long().detach().cpu().numpy()

        metric_dict = classification_report(tars, predicted_classes, target_names=self.class_names, output_dict=True)

        # write the results to tensorboard
        if writer is not None and time_step is not None:
            self.metric_dict_to_tensorboard(metric_dict, time_step, writer, pre_key)
        return metric_dict


class ROCCurve:

    def single_class_roc_curve(self, predictions: torch.Tensor, targets: torch.Tensor) -> tuple:
        """
        Computes metrics needed for plotting the roc curve and auc values for a binary classification
        problem.

        Parameters
        ----------
        predictions: Model predicitons.
        targets: Class labels.

        Returns
        -------
        fpr: False positive rate.
        tpr: True positive rate.
        class_auc: AUC value.
        """

        # convert the data types
        predictions = predictions.detach().cpu().numpy()
        targets = targets.detach().cpu().numpy()

        # compute the false positive rate and true positve rate for the current class
        fpr, tpr, _ = roc_curve(targets, predictions)
        class_auc = auc(fpr, tpr)
        return fpr, tpr, class_auc

    def micro_roc_curve(self, predictions: torch.Tensor, targets: torch.Tensor) -> tuple:
        """
        Computes the micro average roc curves and micro average auc.

        Parameters
        ----------
        predictions: Model predicitons.
        targets: Class labels.

        Returns
        -------
        fpr: False positive rate.
        tpr: True positive rate.
        class_auc: AUC value.
        """

        # compute the micro average roc curve
        predictions = torch.ravel(predictions)
        targets = torch.ravel(targets)

        # compute fpr, tpr and auc for the micro average
        fpr, tpr, class_auc = self.single_class_roc_curve(predictions, targets)
        return fpr, tpr, class_auc

    def macro_roc_curve(self, output_dict: dict) -> tuple:
        """
        Computes the macro average roc curves and macro average auc.

        Parameters
        ----------
        output_dict: Dictionary, containing the class specific roc curve values.

        Returns
        -------
        fpr: False positive rate.
        tpr: True positive rate.
        macro_auc: AUC value.
        """

        # aggregate all unique false positive rates
        fpr_list = [values[0] for _, values in output_dict.items()]
        fpr = np.unique(np.concatenate(fpr_list))

        # interpolate all ROC curves at this points
        tpr = np.zeros_like(fpr)
        for _, values in output_dict.items():
            tpr += np.interp(fpr, values[0], values[1])
        tpr /= len(output_dict)

        # compute the auc
        macro_auc = auc(fpr, tpr)
        return fpr, tpr, macro_auc

    def multi_class_roc_curve(self, preds: torch.Tensor, tars: torch.Tensor, num_classes: int, class_names: Union[None, list]) -> dict:
        """
        Computes metrics needed for plotting the roc curve and auc values for a multiclass classification
        problem.

        Parameters
        ----------
        preds: Model predicitons.
        tars: Class labels.
        num_classes: Number of classes in the multiclass classification problem.
        class_names: Names of the classes.

        Returns
        -------
        output_dict: Dictionary containing the auc values for the single classes.
        """

        # convert the one targets into their one hot representation
        tars = one_hot(tars)

        output_dict = dict()
        for cls_idx in range(num_classes):

            # get the name of the current class
            class_name = class_names[cls_idx] if class_names is not None else str(cls_idx)

            # extract predictions and targets for the current class
            predictions_cls, targets_cls = preds[:, cls_idx], tars[:, cls_idx]

            # compute fpr, tpr and auc for the current class
            fpr, tpr, class_auc = self.single_class_roc_curve(predictions_cls, targets_cls)

            # store the computed values in a dictionary
            output_dict[class_name] = [fpr, tpr, class_auc]

        # compute the macro average roc curve
        fpr, tpr, macro_auc = self.macro_roc_curve(output_dict)
        output_dict['macro'] = [fpr, tpr, macro_auc]

        # compute the micro average roc curve
        fpr, tpr, class_auc = self.micro_roc_curve(preds, tars)
        output_dict['micro'] = [fpr, tpr, class_auc]
        return output_dict

    def compute_plot_label(self, key: str, auc: float) -> str:
        """
        Computes the label for the legend when plotting a roc curve.

        Parameters
        ----------
        key: Key which determines which kind of roc curve is computed.
        auc: AUC value.

        Returns
        -------
        label: Label for the plot legend.
        """

        # Micro or macro-average roc curve
        if key in ['micro', 'macro']:
            label = "{0} ROC curve (area = {1:0.2f})".format(key, auc)

        # roc curve for binary classification problems
        elif key == 'binary':
            label = "ROC curve (area = {0:0.2f})".format(auc)

        # roc curve for a certain class in a multiclass classification problem
        else:
            label = "ROC curve of class '{0}' (area = {1:0.2f})".format(key, auc)
        return label

    def adapt_plot_properties(self) -> None:
        """
        Adapts the properties of a roc curve plot.
        """

        # compute labels and the plot legend
        plt.xlabel("False Positive Rate")
        plt.ylabel("True Positive Rate")
        plt.legend(loc="lower right")

        # adapt the axis and plot the 'random diagonal'
        plt.plot([0, 1], [0, 1], "k--")
        plt.ylim([0.0, 1.05])
        plt.xlim([0.0, 1.0])

    def plot_roc_curve(self, roc_dict: dict, colors: Union[None, list] = None,  writer: Union[None, SummaryWriter] = None,
                       time_step: Union[None, int] = None, pre_key: str = ''):
        """
        Plots provided roc curves and (optionally) stores them in a tensorboard writer object.

        Parameters
        ----------
        roc_dict: Dictionary, which contains different roc curves (fpr, tpr and auc values).
        colors: Colors, in which the roc curves should be plotted.
        writer: Tensorboard writer object.
        time_step: Cuurrent time step (needed for tensorboard plot).
        pre_key
        """

        fig = plt.figure()
        for i, (key, values) in enumerate(roc_dict.items()):
            label = self.compute_plot_label(key, values[2])
            if colors is not None:
                color = colors[i]
                plt.plot(values[0], values[1], color=color, label=label)
            else:
                plt.plot(values[0], values[1], label=label)

        # adapt the plot properties to look good
        self.adapt_plot_properties()

        # write the results to tensorboard
        if writer is not None and time_step is not None:
            ClassificationMetrics.write_plot_to_tensorboard(fig, writer, time_step, pre_key + 'ROC Curve')
        return fig

    def __call__(self, predictions: torch.Tensor, targets: torch.Tensor, class_names: Union[None, list] = None,
                 colors: Union[None, list] = None, writer: Union[None, SummaryWriter] = None,
                 time_step: Union[None, int] = None, pre_key: str = ''):
        """
        Plots the roc curve(s) given a set of predictions as well as corresponding targets and (optionally) stores the curves in
        tensorboard.

        Parameters
        ----------
        predictions: Model predictions.
        targets: Class labels.
        class_names: Names of the classes (optionally).
        colors: Colors used for plotting the roc curves (optionally).
        writer: Tensorboard writer object (optionally).
        time_step: Current time step (optionally, needed for tensorboard).
        pre_key
        """

        # get the number of classes
        if len(predictions.size()) == 2:
            assert predictions.size()[1] != 2, "Binary predictions need to be provided in a one dimensional Tensor!"
            num_classes = predictions.size()[1]
        else:
            num_classes = 2

        # compute the roc curves for the binary case
        if num_classes == 2:
            fpr, tpr, class_auc = self.single_class_roc_curve(predictions, targets)
            roc_dict = {'binary': [fpr, tpr, class_auc]}

        # comptue the roc curves for the multidimensional case
        else:
            roc_dict = self.multi_class_roc_curve(predictions, targets, num_classes, class_names)

        # plot the roc curve(s)
        fig = self.plot_roc_curve(roc_dict, colors, writer, time_step, pre_key)
        return fig


class PRCurve:

    def __init__(self, desired_pr=(0.99, 0.95, 0.90), desired_re=(0.99, 0.95, 0.90)):
        self.desired_precisions = desired_pr
        self.desired_recalls = desired_re
        self.precision_thresholds = dict()
        self.recall_thresholds = dict()

    def single_class_pr_curve(self, predictions: torch.Tensor, targets: torch.Tensor) -> tuple:
        """
        """

        # convert the data types
        predictions = predictions.detach().cpu().numpy()
        targets = targets.detach().cpu().numpy()

        # compute the false positive rate and true positve rate for the current class
        pr, re, thresholds = precision_recall_curve(targets, predictions)
        thresholds = np.insert(thresholds, 0, 0.0)
        return pr, re, thresholds

    def interpolate_at_value(self, org_base: np.ndarray, org_values: np.ndarray, inter_points: np.ndarray):
        int_values = np.interp(inter_points, org_base, org_values)
        return int_values

    def macro_pr_curve(self, output_dict: dict) -> tuple:
        """
        """

        # aggregate all unique false positive rates
        threshold_list = [values[2] for _, values in output_dict.items()]
        tr = np.unique(np.concatenate(threshold_list))

        re = np.zeros_like(tr)
        pr = np.zeros_like(tr)
        for _, values in output_dict.items():
            thresholds = values[2]
            precisions = values[0]
            recalls = values[1]

            pr += self.interpolate_at_value(thresholds, precisions, tr)
            re += self.interpolate_at_value(thresholds, recalls, tr)
        pr /= len(output_dict)
        re /= len(output_dict)
        return pr, re, tr

    def micro_pr_curve(self, predictions: torch.Tensor, targets: torch.Tensor) -> tuple:
        """
        """

        # compute the micro average roc curve
        predictions = torch.ravel(predictions)
        targets = torch.ravel(targets)

        # compute precision and recall
        pr, re, thresholds = self.single_class_pr_curve(predictions, targets)
        return pr, re, thresholds

    def multi_class_pr_curve(self, preds: torch.Tensor, tars: torch.Tensor, num_classes: int, class_names: Union[None, list]) -> dict:
        """
        """

        # convert the one targets into their one hot representation
        tars = one_hot(tars)

        output_dict = dict()
        for cls_idx in range(num_classes):

            # get the name of the current class
            class_name = class_names[cls_idx] if class_names is not None else str(cls_idx)

            # extract predictions and targets for the current class
            predictions_cls, targets_cls = preds[:, cls_idx], tars[:, cls_idx]

            # compute fpr, tpr and auc for the current class
            pr, re, thresholds = self.single_class_pr_curve(predictions_cls, targets_cls)

            # store the computed values in a dictionary
            output_dict[class_name] = [pr, re, thresholds]

        # compute the macro average roc curve
        pr, re, thresholds = self.macro_pr_curve(output_dict)
        output_dict['macro'] = [pr, re, thresholds]

        # compute the micro average roc curve
        pr, re, thresholds = self.micro_pr_curve(preds, tars)
        output_dict['micro'] = [pr, re, thresholds]
        return output_dict

    def compute_plot_label(self, key: str) -> str:
        """
        Computes the label for the legend when plotting a pr curve.

        Parameters
        ----------
        key: Key which determines which kind of roc curve is computed.

        Returns
        -------
        label: Label for the plot legend.
        """

        # Micro or macro-average roc curve
        if key in ['micro', 'macro']:
            label = "{0} Precision - Recall curve)".format(key)

        # roc curve for binary classification problems
        elif key == 'binary':
            label = "Precision - Recall curve"

        # roc curve for a certain class in a multiclass classification problem
        else:
            label = "Precision - Recall curve of class '{0}')".format(key)
        return label

    def adapt_plot_properties(self, ax) -> None:
        """
        Adapts the properties of a roc curve plot.
        """

        # compute labels and the plot legend
        ax.set_xlabel("Thresholds")
        ax.set_ylabel("Value")
        ax.legend(loc="center right")
        ax.grid()

    def plot_pr_curve(self, pr_dict: dict, writer: Union[None, SummaryWriter] = None,
                       time_step: Union[None, int] = None, pre_key: str = ''):
        """
        """

        if len(pr_dict) > 1:
            fig, axs = plt.subplots(int(math.ceil(float(len(pr_dict)) / 2)), 2)
            axs = axs.ravel()
        else:
            fig, axs = plt.subplots(1, 1)
            axs = [axs]

        fig.set_figheight(10)
        fig.set_figwidth(10)

        for i, (key, values) in enumerate(pr_dict.items()):
            ax = axs[i]

            thresholds = values[2]
            precisions = values[0]
            recalls = values[1]

            if key == 'binary':
                title = "Precision - Recall curve"
            elif key in ['micro', 'macro']:
                title = "{0} Precision - Recall curve".format(key)
            else:
                title = "Precision - Recall curve of class '{0}'".format(key)
            ax.set_title(title)

            ax.plot(thresholds, precisions, color='r', label='Precision')
            ax.plot(thresholds, recalls, color='b', label='Recall')

            self.precision_thresholds[key] = self.interpolate_at_value(precisions, thresholds, self.desired_precisions)
            self.recall_thresholds[key] = self.interpolate_at_value(recalls, thresholds, self.desired_recalls)

            # adapt the plot properties to look good
            self.adapt_plot_properties(ax)

        # write the results to tensorboard
        if writer is not None and time_step is not None:
            ClassificationMetrics.write_plot_to_tensorboard(fig, writer, time_step, pre_key + 'PR Curve')
        return fig

    def __call__(self, predictions: torch.Tensor, targets: torch.Tensor, class_names: Union[None, list] = None,
                 writer: Union[None, SummaryWriter] = None, time_step: Union[None, int] = None, pre_key: str = ''):
        """

        """

        # get the number of classes
        if len(predictions.size()) == 2:
            assert predictions.size()[1] != 2, "Binary predictions need to be provided in a one dimensional Tensor!"
            num_classes = predictions.size()[1]
        else:
            num_classes = 2

        # compute the roc curves for the binary case
        if num_classes == 2:
            pr, re, thresholds = self.single_class_pr_curve(predictions, targets)
            pr_dict = {'binary': [pr, re, thresholds]}

        # comptue the roc curves for the multidimensional case
        else:
            pr_dict = self.multi_class_pr_curve(predictions, targets, num_classes, class_names)

        # plot the roc curve
        fig = self.plot_pr_curve(pr_dict, writer, time_step, pre_key)
        return fig


class ConfMatrix:
    def __init__(self, class_names: Union[None, list]):
        self.class_names = class_names
        self.normalization_mapping = {
            'None': 'No Normalization',
            'pred': 'Normalized Predictions',
            'true': 'Normalized Targets',
            'all': 'Full Normalization'
        }

    def compute_confusion_matrices(self, predictions: torch.Tensor, targets: torch.Tensor) -> np.ndarray:
        """
        Computes different confusion matrices, which are differently normalized.

        Parameters
        ----------
        predictions: Model predictions.
        targets: Labels.

        Returns
        -------
        matrix_dict: Dicitonary containing the confusion matrices.
        """

        # adapt the predictions to match the required dimensions
        if len(predictions.size()) > 1:
            predictions = torch.argmax(predictions, dim=1)
        else:
            predictions = (predictions > 0.5).long()
        targets = targets.long()

        # convert the predictions and targets to the correct data types
        predictions = predictions.detach().cpu().numpy()
        targets = targets.detach().cpu().numpy()

        # compute the confusion matrices with different types of normalization and store them into a dictionary
        matrix_dict = dict()
        for normalization in [None, 'true', 'pred', 'all']:
            matrix_dict[str(normalization)] = confusion_matrix(y_true=targets, y_pred=predictions, normalize=normalization,
                                                               labels=list(range(len(self.class_names))))
        return matrix_dict

    def plot_confusion_matrix(self, normalization_method: str, cm: np.ndarray, ax: Axes) -> None:
        """
        Plots a confusion matrix.

        Parameters
        ----------
        normalization_method: kind of normalization method.
        cm: Confusion matrix.
        ax: Matplotlib axes.
        """

        cm = np.round(cm, decimals=2)

        sns.set(font_scale=1)
        sns.heatmap(cm, annot=True, ax=ax, cmap="Blues", fmt="g", cbar=False)

        ax.tick_params(axis='both', which='major', labelsize=10)
        ax.xaxis.set_ticklabels(self.class_names)
        ax.yaxis.set_ticklabels(self.class_names)
        ax.set_title(self.normalization_mapping[normalization_method])

    def __call__(self, predictions: torch.Tensor, targets: torch.Tensor, writer: Union[None, SummaryWriter] = None,
                 time_step: Union[None, int] = None, pre_key: str = ''):
        """
        Computes and plot different confusion matrices and (optionally) writes them to tensorboard.

        Parameters
        ----------
        predictions: Model predictions.
        targets: Groundtruth labels.
        writer
        time_step
        pre_key
        """

        matrix_dict = self.compute_confusion_matrices(predictions, targets)

        fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(12, 10))
        fig.text(0.5, 0.02, 'Targets', ha='center', fontsize=16)
        fig.text(0.02, 0.5, 'Predictions', va='center', rotation='vertical', fontsize=16)

        for ax, key in zip([ax1, ax2, ax3, ax4], list(matrix_dict.keys())):
            plt.sca(ax)
            self.plot_confusion_matrix(key, matrix_dict[key], ax)

        # write the results to tensorboard
        if writer is not None and time_step is not None:
            ClassificationMetrics.write_plot_to_tensorboard(fig, writer, time_step, pre_key + 'Confusion Matrices')
        return fig


class PredictionStats:
    def __init__(self, class_names: list):
        self.class_names = class_names

    # noinspection PyUnboundLocalVariable
    def results_per_class(self, predictions: list, predicted_class: list, targets: list, uncertainties: Union[None, list]):
        """
        """

        pred_dict, cls_dict = defaultdict(list), defaultdict(list)
        if uncertainties is not None:
            unc_dict = defaultdict(list)
        else:
            unc_dict = None

        indices = list(range(len(predictions)))
        for idx in indices:

            pred, cls, tar = predictions[idx], predicted_class[idx], targets[idx]
            pred_dict[tar].append(pred)
            cls_dict[tar].append(cls)

            if uncertainties is not None:
                u = uncertainties[idx]
                unc_dict[tar].append(u)
        return pred_dict, cls_dict, unc_dict

    def accurracy(self, predicted_classes: list, targets: list) -> float:
        correct = [1 for pred, tar in zip(predicted_classes, targets) if pred == tar]
        incorrect = len(predicted_classes) - len(correct)
        correct = len(correct)
        if len(predicted_classes) > 0:
            accurracy = float(correct) / len(predicted_classes)
        else:
            accurracy = None
        return accurracy, correct, incorrect

    def compute_bin_indices(self, results: list):
        # compute the bin values
        max_val = np.max(results)
        min_val = np.min(results)
        bins = np.linspace(min_val, max_val, num=20).tolist()

        # compute the bin indices
        bin_indices = np.digitize(results, bins=bins).tolist()
        return bins, bin_indices

    def compute_bin_count(self, bin_dict: dict):
        bin_count_dict = dict()
        for key, values in bin_dict.items():
            bin_count_dict[key] = np.bincount(values)
        return bin_count_dict

    def prepare_results(self, predictions: torch.Tensor, targets: torch.Tensor, uncertainties: Union[None, torch.Tensor]):
        targets = targets.detach().cpu().tolist()
        if uncertainties is not None:
            uncertainties = uncertainties.detach().cpu().tolist()

        if len(predictions.size()) > 1:
            predicted_classes = torch.argmax(predictions, dim=1).detach().cpu().tolist()
            predictions = predictions.detach().cpu().tolist()
            predictions = [x[tar] for x, tar in zip(predictions, targets)]
        else:
            predicted_classes = (predictions > 0.5).long().detach().cpu().tolist()
            predictions = predictions.detach().cpu().tolist()
        return predictions, predicted_classes, uncertainties, targets

    def sort_bin_results(self, bin_indices: list, bin_thresholds, predictions: list, predicted_classes: list,
                         uncertainties: Union[list, None], targets: list):

        prediction_bin_dict, class_bin_dict, target_bin_dict = defaultdict(list), defaultdict(list), defaultdict(list)
        uncertainty_bin_dict = defaultdict(list) if uncertainties is not None else None

        for idx, bin_idx in enumerate(bin_indices):
            threshold = bin_thresholds[bin_idx - 1]
            prediction_bin_dict[threshold].append(predictions[idx])
            class_bin_dict[threshold].append(predicted_classes[idx])
            target_bin_dict[threshold].append(targets[idx])
            if uncertainty_bin_dict is not None:
                uncertainty_bin_dict[threshold].append(uncertainties[idx])
        return prediction_bin_dict, class_bin_dict, target_bin_dict, uncertainty_bin_dict

    def select_class(self, cls: int, pred_dict: dict, tar_dict: dict):

        pred_dict_cls, tar_dict_cls = dict(), dict()
        for threshold in list(pred_dict.keys()):

            predictions = pred_dict[threshold]
            targets = tar_dict[threshold]
            indices = [idx for idx, value in enumerate(targets) if value == cls]
            pred_dict_cls[threshold] = [predictions[idx] for idx in indices]
            tar_dict_cls[threshold] = [targets[idx] for idx in indices]
        return pred_dict_cls, tar_dict_cls

    def compute_bar_plot(self, threshold_center, thresholds, correct_list, incorrect_list, accurracy_list, width):

        r = 0
        val = thresholds[0] if thresholds[0] > 0 else thresholds[1]
        while True:
            if val >= 1:
                break
            else:
                r += 1
                val *= 10

        thresholds_pd = [round(x, r) for x in thresholds]

        fig, ax = plt.subplots()
        fig.set_figheight(4)
        fig.set_figwidth(8)
        ax.bar(threshold_center, height=correct_list, linewidth=1.0, width=width,
                color=(0.2, 0.4, 0.6, 0.4), edgecolor='grey', alpha=0.4, label='Correct Predictions')
        ax.bar(threshold_center, height=incorrect_list, bottom=correct_list, linewidth=1.0, width=width,
                color='lightcoral', edgecolor='grey', alpha=0.4, label='Incorrect Predictions')
        plt.xticks(thresholds_pd)

        ax2 = ax.twinx()
        ax2.set_ylabel('Accurracy', color='darkblue')  # we already handled the x-label with ax1
        ax2.scatter(threshold_center, accurracy_list, color='darkblue', marker='o', s=20)
        ax2.plot(threshold_center, accurracy_list, color='darkblue', label='Accurracy')
        ax2.tick_params(axis='y', labelcolor='darkblue')

        ax.set_ylabel('Number \nof Predictions', color='grey')
        ax.tick_params(axis='y', labelcolor='grey')
        ax.tick_params(axis='x', rotation=45)
        ax.legend(loc='upper left')
        ax2.legend(loc='upper right')
        fig.tight_layout()
        return fig

    def compute_accurracy_barplot(self, pred_cls_dict: dict, tar_dict: dict):

        thresholds = list(pred_cls_dict.keys())
        thresholds.sort()

        correct_list, incorrect_list, accurracy_list = list(), list(), list()
        for threshold in thresholds:
            predicted_classes = pred_cls_dict[threshold]
            targets = tar_dict[threshold]
            accurracy, correct, incorrect = self.accurracy(predicted_classes, targets)
            correct_list.append(correct)
            incorrect_list.append(incorrect)
            accurracy_list.append(accurracy)

        width = thresholds[1] - thresholds[0]
        threshold_center = [x + width / 2 for x in thresholds]
        fig = self.compute_bar_plot(threshold_center, thresholds, correct_list, incorrect_list, accurracy_list, width)
        return fig

    def dict_to_tensorboard(self, d: dict, key: str, pre_key: str, writer: SummaryWriter, time_step: int) -> None:
        for name, fig in d.items():
            if fig is not None:
                if name != 'Complete':
                    key_act = key + '/ Class {0}'.format(name)
                else:
                    key_act = key
                ClassificationMetrics.write_plot_to_tensorboard(fig, writer, time_step, pre_key + key_act)

    def confidence_uncertainty_plot(self, confidence_dict: dict, uncertainty_dict):

        uncertainties, confidences = list(), list()
        thresholds = list(confidence_dict.keys())
        thresholds.sort()
        for threshold in thresholds:
            if len(uncertainty_dict[threshold]) > 0 and len(confidence_dict[threshold]) > 0:
                uncertainties.append(sum(uncertainty_dict[threshold]) / len(uncertainty_dict[threshold]))
                confidences.append(sum(confidence_dict[threshold]) / len(confidence_dict[threshold]))

        if len(uncertainties) > 0 and len(confidences) > 0:
            fig, ax = plt.subplots()
            fig.set_figheight(8)
            fig.set_figwidth(8)
            ax.set_ylabel('Uncertainty', color='darkblue')
            ax.set_xlabel('Confidence', color='darkblue')
            ax.scatter(confidences, uncertainties, color='darkblue', marker='o', s=20)
            ax.plot(confidences, uncertainties, color='darkblue')
            plt.grid()
        else:
            fig = None
        return fig

    def uncertainty_ece_plot(self, prediction_dict: dict, targets: dict):

        uncertainties, expected_calibration_errors = list(), list()
        thresholds = list(prediction_dict.keys())
        thresholds.sort()

        try:
            for threshold in thresholds:
                if len(prediction_dict[threshold]) > 0 and len(targets[threshold]) > 0:
                    preds = prediction_dict[threshold]
                    tars = targets[threshold]
                    ece = torchmetrics.functional.calibration_error(torch.Tensor(preds), torch.Tensor(tars),
                                                                    n_bins=15, norm='l1', task='binary')
                    expected_calibration_errors.append(ece)
                    uncertainties.append(threshold)

            if len(uncertainties) > 0 and len(expected_calibration_errors) > 0:
                fig, ax = plt.subplots()
                fig.set_figheight(8)
                fig.set_figwidth(8)
                ax.set_xlabel('Uncertainty', color='darkblue')
                ax.set_ylabel('Expected \nCalibration Error', color='darkblue')
                ax.scatter(uncertainties, expected_calibration_errors, color='darkblue', marker='o', s=20)
                ax.plot(uncertainties, expected_calibration_errors, color='darkblue')
                plt.grid()
            else:
                fig = None
        except:
            fig = None
        return fig

    def __call__(self, predictions: torch.Tensor, targets: torch.Tensor, uncertainties: Union[None, torch.Tensor],
                 time_step: Union[None, int] = None, writer: Union[None, SummaryWriter] = None, pre_key: str = ''):
        """
        """

        # prepare the data, such that it can be further processed
        predictions, predicted_classes, uncertainties, targets = self.prepare_results(predictions, targets, uncertainties)

        # divide the predictions into bins and extract the counts of each bin
        bin_thresholds, prediction_bins = self.compute_bin_indices(predictions)
        
        # sort the results based on prediction thresholds
        pred_dict, cls_dict, tar_dict, un_dict = self.sort_bin_results(prediction_bins, bin_thresholds, predictions,
                                                                       predicted_classes, uncertainties, targets)
        
        # plot confidence vs uncertainty
        if uncertainties is not None:

            pred_unc_plot_dict = dict()
            pred_unc_plot_dict['Complete'] = self.confidence_uncertainty_plot(pred_dict, un_dict)
            for cls, name in enumerate(self.class_names):
                pred_dict_cls, _ = self.select_class(cls, pred_dict, tar_dict)
                un_dict_cls, _ = self.select_class(cls, un_dict, tar_dict)
                pred_unc_plot_dict[name] = self.confidence_uncertainty_plot(pred_dict_cls, un_dict_cls)

            if writer is not None and time_step is not None:
                key = 'Confidence vs Uncertainty'
                self.dict_to_tensorboard(pred_unc_plot_dict, key, pre_key, writer, time_step)

        # plot confidence vs accurracy
        confidence_bar_plots = dict()
        confidence_bar_plots['Complete'] = self.compute_accurracy_barplot(cls_dict, tar_dict)
        
        for cls, name in enumerate(self.class_names):
            cls_dict_cls, tar_dict_cls = self.select_class(cls, cls_dict, tar_dict)
            confidence_bar_plots[name] = self.compute_accurracy_barplot(cls_dict_cls, tar_dict_cls)

        if writer is not None and time_step is not None:
            key = 'Confidence vs Accurracy'
            self.dict_to_tensorboard(confidence_bar_plots, key, pre_key, writer, time_step)

        if uncertainties is not None:

            # divide the uncertainties into bins and extract the counts of each bin
            bin_thresholds, uncertainty_bins = self.compute_bin_indices(uncertainties)

            # sort the results based on prediction thresholds
            pred_dict, cls_dict, tar_dict, un_dict = self.sort_bin_results(uncertainty_bins, bin_thresholds, predictions,
                                                                           predicted_classes, uncertainties, targets)

            ece_bar_plots = dict()
            ece_bar_plots['Complete'] = self.uncertainty_ece_plot(pred_dict, tar_dict)
            for cls, name in enumerate(self.class_names):
                cls_dict_cls, tar_dict_cls = self.select_class(cls, cls_dict, tar_dict)
                ece_bar_plots[name] = self.uncertainty_ece_plot(cls_dict_cls, tar_dict_cls)

            if writer is not None and time_step is not None:
                key = 'Uncertainty vs Expected Calibration Error'
                self.dict_to_tensorboard(ece_bar_plots, key, pre_key, writer, time_step)

            # plot uncertainties vs accurracy
            uncertainty_bar_plots = dict()
            uncertainty_bar_plots['Complete'] = self.compute_accurracy_barplot(cls_dict, tar_dict)
            for cls, name in enumerate(self.class_names):
                cls_dict_cls, tar_dict_cls = self.select_class(cls, cls_dict, tar_dict)
                uncertainty_bar_plots[name] = self.compute_accurracy_barplot(cls_dict_cls, tar_dict_cls)

            if writer is not None and time_step is not None:
                key = 'Uncertainty vs Accurracy'
                self.dict_to_tensorboard(uncertainty_bar_plots, key, pre_key, writer, time_step)
        writer.flush()


class ClassificationMetrics:
    def __init__(self, class_names: Union[None, list], pr_thresh: tuple = (0.9, ), re_thresh: tuple = (0.9,)):

        self.scalar_metrics = ScalarMetrics(class_names)
        self.conf_matrix = ConfMatrix(class_names)
        self.roc_curve = ROCCurve()
        self.pr_curve = PRCurve(desired_pr=pr_thresh, desired_re=re_thresh)

    @classmethod
    def write_plot_to_tensorboard(cls, fig, writer, time_step, key: str) -> None:
        """
        Writes a matplotlib figure to tensorbaord.

        Parameters
        ----------
        fig: Matplotlib figure.
        writer: Tensorboard writer object.
        time_step: Current time step.
        key: Key within tensorboard.
        """

        fig.canvas.draw()
        img = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8)
        img = img.reshape(fig.canvas.get_width_height()[::-1] + (3,)) / 255.0

        writer.add_image(key, img, time_step, dataformats='HWC')
        writer.flush()

    def __call__(self, predictions: torch.Tensor, targets: torch.Tensor, class_names: Union[None, list] = None,
                 writer: Union[None, SummaryWriter] = None, time_step: Union[None, int] = None,
                 colors: Union[None, list] = None, pre_key: str = ''):
        """

        Parameters
        ----------
        predictions: Model predictions.
        targets: Groundtruth labels.
        class_names: Names of the classes (optionally).
        colors: Colors used for plotting the roc curves (optionally).
        writer: Tensorboard writer object (optionally).
        time_step: Current time step (optionally, needed for tensorboard).
        pre_key

        Returns
        -------
        output_dict: Dictionary containing sclar metrics and plots of roc cruves and confusion matrices.
        """

        output_dict = dict()
        output_dict['scalar'] = self.scalar_metrics(predictions, targets, writer, time_step, pre_key)
        output_dict['roc_curve'] = self.roc_curve(predictions, targets, class_names, colors, writer, time_step, pre_key)
        output_dict['pr_curve'] = self.pr_curve(predictions, targets, class_names, writer, time_step, pre_key)

        if len(predictions.size()) == 1:
            predictions = (predictions > 0.5).int()
        else:
            predictions = torch.argmax(predictions, dim=1)
        output_dict['confusion_matrix'] = self.conf_matrix(predictions, targets, writer, time_step, pre_key)
        return output_dict
