
from typing import Dict

from datasets.wmt21 import load_human, load_metric
from binary.estrs import estimate_alpha, Approximation

import matplotlib.pyplot as plt
from matplotlib.ticker import MultipleLocator


ALPHA_APPROX = 2000
RHO_APPROX = 1000
ETA_APPROX = 1000


def plot_approx(a: Approximation, name: str, ax: plt.Axes):
    ax.bar(
        x=a.values,
        height=a.probas,
        width=1. / a.n_approx,
        label=name,
        alpha=.6,
    )
    ax.axvline(x=a.mean(), linestyle="--")
    ax.legend()


def main():
    human = load_human()
    metric = load_metric()

    mt_systems = [
        "Facebook-AI",
        "VolcTrans-GLAT",
        "HuaweiTSC",
    ]
    mt_short = {
        "Facebook-AI": "FBAI",
        "VolcTrans-GLAT": "VT",
        "HuaweiTSC": "HT",
    }

    metrics = [
        "human",
        "bleurt-20-ref-C",
        "sentBLEU-ref-C",
    ]
    metric_short = {
        "human": "human only",
        "bleurt-20-ref-C": "bleurt-20",
        "sentBLEU-ref-C": "sentBLEU",
    }

    result: Dict[str, Dict[str, Approximation]] = {}
    for mt in mt_systems:
        human_data = human[mt]
        result[mt] = {}
        for m in metrics:
            if m == "human":
                result[mt][m] = estimate_alpha(
                    human_data=human_data,
                    machine_data=None,
                    rho_eta_data=None,
                    approx_alpha=ALPHA_APPROX,
                    approx_rho=RHO_APPROX,
                    approx_eta=ETA_APPROX,
                )
            else:
                metric_data = metric[m][mt]
                result[mt][m] = estimate_alpha(
                    human_data=human_data,
                    machine_data=metric_data,
                    rho_eta_data=human_data.pair(metric_data),
                    approx_alpha=ALPHA_APPROX,
                    approx_rho=RHO_APPROX,
                    approx_eta=ETA_APPROX,
                )

    fig, axes = plt.subplots(nrows=3, ncols=1)
    fig.set_size_inches(10, 12)

    for mt in mt_systems:
        for ax_ix, m in enumerate(metrics):
            plot_approx(
                a=result[mt][m],
                name=mt_short[mt],
                ax=axes[ax_ix],
            )

    for m, ax in zip(metrics, axes):
        ax.set_title(metric_short[m])
        ax.xaxis.set_major_locator(MultipleLocator(.1))
        ax.xaxis.set_minor_locator(MultipleLocator(.02))
        ax.set_xlim(-.01, 1.01)

    # plt.legend()


if __name__ == "__main__":
    main()
    plt.savefig("res/wmt21_alpha_plots.png")
    # plt.show()
