

from matplotlib import pyplot as plt
from sklearn.metrics import RocCurveDisplay, roc_auc_score

import numpy as np

from datasets.wmt21 import load_human, load_metric
from binary.estimation import RhoEtaResult, estimate_rho_eta
from binary.roc import metric_roc_per_system

from emnlp2022.showcase_appendix import WMT_SYSTEMS


def plot_result(
    rho_eta_res: RhoEtaResult,
    name: str,
    ax: plt.Axes,
    color: str,
):
    display = RocCurveDisplay.from_predictions(
        y_true=rho_eta_res.binary_human,
        y_pred=rho_eta_res.scalar_metric,
        name=name,
        ax=ax,
        c=color,
    )


    ax.scatter(
        x=np.array([1. - rho_eta_res.eta]),
        y=np.array([rho_eta_res.rho]),
        marker="x",
        c=color,
    )


if __name__ == "__main__":
    import matplotlib
    matplotlib.use("QtAgg")
    plt.rcParams.update({"font.size": 21.0})

    systems = [
        "Facebook-AI",
        "VolcTrans-GLAT",
        "HuaweiTSC",
    ]

    human = load_human()
    human = {
        s: human[s]
        for s in systems
    }
    metric = load_metric()
    metric = {
        m: {s: m_data[s] for s in systems}
        for m, m_data in metric.items()
    }
    bleurt = metric["bleurt-20-ref-C"]

    metric_roc_per_system(human, bleurt)

    plt.tight_layout()
    # plt.show()
    plt.savefig("res/appendixC/bleurt_roc.png")
