
from typing import Dict, Optional, List, Tuple

from sklearn.metrics import RocCurveDisplay, roc_auc_score

from matplotlib import pyplot as plt
import numpy as np

from datasets.base import BinaryHuman, MetricScores

from binary.estimation import estimate_rho_eta, RhoEtaResult


def metric_roc_per_system(
    human_annotations: Dict[str, BinaryHuman],
    metric_scores: Dict[str, MetricScores],
    selection_strategy: str = "min_diff",
    save_path: Optional[str] = None,
):

    common = sorted(
        set(human_annotations.keys()).intersection(metric_scores.keys()))

    first = list(metric_scores.values())[0]

    fig, ax = plt.subplots()
    ax.set_xlim(0.0, 1.01)
    ax.set_ylim(0.0, 1.01)
    fig.set_size_inches(10, 10)
    # ax.set_title(f"{first.dataset} - {first.metric}")

    rho_etas = {}
    for system in common:
        human: BinaryHuman = human_annotations[system]
        metric: MetricScores = metric_scores[system]

        paired = human.pair(metric)

        rho_etas[system] = estimate_rho_eta(
            paired=paired, selection_strategy=selection_strategy)

    # random clf
    ax.plot([0.0, 1.0], [0.0, 1.0], linestyle="--", label="random", c='b')

    # rho = eta
    ax.plot([0.0, 1.0], [1.0, 0.0], linestyle=":", c='r')

    # actual roc-curves
    for system, res in rho_etas.items():
        _ = RocCurveDisplay.from_predictions(
            y_true=res.binary_human,
            y_pred=res.scalar_metric,
            name=system,
            ax=ax,
        )

    # print cutoff points
    xs = []
    ys = []
    for system, res in rho_etas.items():
        xs.append(1. - res.eta)
        ys.append(res.rho)
    ax.scatter(
        x=xs,
        y=ys,
        marker="x",
        c='k',
        label="Selected",
        s=52,
    )

    ax.set_xlabel("False Positive Rate (1- η)")
    ax.set_ylabel("True Positive Rate (ρ)")
    ax.legend()

    if save_path is not None:
        plt.savefig(save_path)


def __aggregate_predictions(
    human_annotations: Dict[str, BinaryHuman],
    metric_scores: Dict[str, MetricScores],
):
    common = sorted(
        set(human_annotations.keys()).intersection(metric_scores.keys()))

    pairs = [
        human_annotations[system].pair(metric_scores[system])
        for system in common
    ]

    y_true = np.array([
        score
        for paired in pairs
        for score in paired.human_binary
    ])

    y_pred = np.array([
        score
        for paired in pairs
        for score in paired.scores
    ])

    return y_true, y_pred


def metric_roc_aggregated(
    human_annotations: Dict[str, BinaryHuman],
    metrics: Dict[str, Dict[str, MetricScores]],
    save_path: Optional[str] = None,
):

    first = list(human_annotations.values())[0]

    fig, ax = plt.subplots()
    fig.set_size_inches(10, 10)
    ax.set_title(f"{first.dataset} metrics")

    for metric_name, metric_data in metrics.items():
        y_true, y_pred = __aggregate_predictions(
            human_annotations=human_annotations,
            metric_scores=metric_data,
        )

        _ = RocCurveDisplay.from_predictions(
            y_true=y_true,
            y_pred=y_pred,
            name=metric_name,
            ax=ax,
        )

    if save_path is not None:
        plt.savefig(save_path)


def metric_aurocs(
    human_annotations: Dict[str, BinaryHuman],
    metrics: Dict[str, Dict[str, MetricScores]],
) -> List[Tuple[str, float]]:

    result = []
    for metric_name, metric_data in metrics.items():
        y_true, y_pred = __aggregate_predictions(
            human_annotations=human_annotations,
            metric_scores=metric_data,
        )

        score = roc_auc_score(
            y_true=y_true,
            y_score=y_pred,
            average=None,
        )

        result.append((metric_name, score))

    return sorted(result, key=lambda t: t[1], reverse=True)
