
import scipy.stats as stats

import matplotlib
from matplotlib import pyplot as plt

from binary.estimation import plot_alpha_hist, NaiveApproximationAlphaEstimator

matplotlib.use("QtAgg")


def bad_luck_curve_shift(
    alpha: float,
    rho: float,
    eta: float,
    n_human: int,
    n_machine: int,
):

    alpha_obs = (alpha * (rho + eta - 1.)) + (1. - eta)
    expected_human_pos = int(alpha * n_human)

    bad_luck_lower, bad_luck_upper = stats.binom.interval(
        .95, n_machine, alpha_obs)

    est = NaiveApproximationAlphaEstimator(n_approx=5000)

    h1 = est.machine_only(
        n_pos=int(bad_luck_lower),
        n_total=n_machine,
        rho_pos=rho,
        rho_total=1.,
        eta_pos=eta,
        eta_total=1.
    )
    h2 = est.machine_only(
        n_pos=int(bad_luck_upper),
        n_total=n_machine,
        rho_pos=rho,
        rho_total=1.,
        eta_pos=eta,
        eta_total=1.
    )
    h3 = est.human_machine_both(
        n_pos=int(bad_luck_lower),
        n_total=n_machine,
        rho_pos=rho,
        rho_total=1.,
        eta_pos=eta,
        eta_total=1.,
        human_pos=expected_human_pos,
        human_total=n_human,
    )
    h4 = est.human_machine_both(
        n_pos=int(bad_luck_upper),
        n_total=n_machine,
        rho_pos=rho,
        rho_total=1.,
        eta_pos=eta,
        eta_total=1.,
        human_pos=expected_human_pos,
        human_total=n_human,
    )

    hhist = est.human_only(human_pos=expected_human_pos, human_total=n_human)

    fig, ax = plt.subplots()

    fig.set_size_inches(10, 10)
    ax.set_title(
        f"n_human={n_human} n_machine={n_machine} alpha={alpha:.2f}"
        f" rho={rho:.2f} eta={eta:.2f}")

    plot_alpha_hist(hhist, name="human only", ax=ax)
    plot_alpha_hist(h1, name="bad luck lower (no human)", ax=ax)
    plot_alpha_hist(h2, name="bad luck lower", ax=ax)
    plot_alpha_hist(h3, name="bad luck upper (no human)", ax=ax)
    plot_alpha_hist(h4, name="bad luck upper (no human)", ax=ax)

    ax.legend()

    plt.show()


def rho_error_curve_shift(
    alpha: float,
    rho: float,
    eta: float,
    n_human: int,
    n_machine: int,
):

    expected_human_pos = int(alpha * n_human)

    alpha_obs_true = (alpha * (rho + eta - 1.)) + (1. - eta)
    expected_machine= int(alpha_obs_true * n_machine)

    est = NaiveApproximationAlphaEstimator(n_approx=5000)

    h1 = est.human_machine_both(
        n_pos=expected_machine,
        n_total=n_machine,
        rho_pos=rho,
        rho_total=1.,
        eta_pos=eta,
        eta_total=1.,
        human_pos=expected_human_pos,
        human_total=n_human,
    )
    h2 = est.human_machine_both(
        n_pos=expected_machine,
        n_total=n_machine,
        rho_pos=rho - 0.05,
        rho_total=1.,
        eta_pos=eta,
        eta_total=1.,
        human_pos=expected_human_pos,
        human_total=n_human,
    )
    h3 = est.human_machine_both(
        n_pos=expected_machine,
        n_total=n_machine,
        rho_pos=rho + 0.05,
        rho_total=1.,
        eta_pos=eta,
        eta_total=1.,
        human_pos=expected_human_pos,
        human_total=n_human,
    )
    hhist = est.human_only(human_pos=expected_human_pos, human_total=n_human)

    fig, ax = plt.subplots()

    fig.set_size_inches(10, 10)
    ax.set_title(
        f"n_human={n_human} n_machine={n_machine} alpha={alpha:.2f}"
        f" rho={rho:.2f} eta={eta:.2f}")

    plot_alpha_hist(hhist, name="human only", ax=ax)
    plot_alpha_hist(h1, name=f"rho={rho:.2f}", ax=ax)
    plot_alpha_hist(h2, name=f"rho={rho-0.05:.2f}", ax=ax)
    plot_alpha_hist(h3, name=f"rho={rho+0.05:.2f}", ax=ax)

    ax.legend()

    plt.show()


def eta_error_curve_shift(
        alpha: float,
        rho: float,
        eta: float,
        n_human: int,
        n_machine: int,
):
    expected_human_pos = int(alpha * n_human)

    alpha_obs_true = (alpha * (rho + eta - 1.)) + (1. - eta)
    expected_machine = int(alpha_obs_true * n_machine)

    est = NaiveApproximationAlphaEstimator(n_approx=5000)

    h1 = est.human_machine_both(
        n_pos=expected_machine,
        n_total=n_machine,
        rho_pos=rho,
        rho_total=1.,
        eta_pos=eta,
        eta_total=1.,
        human_pos=expected_human_pos,
        human_total=n_human,
    )
    h2 = est.human_machine_both(
        n_pos=expected_machine,
        n_total=n_machine,
        rho_pos=rho,
        rho_total=1.,
        eta_pos=eta - 0.05,
        eta_total=1.,
        human_pos=expected_human_pos,
        human_total=n_human,
    )
    h3 = est.human_machine_both(
        n_pos=expected_machine,
        n_total=n_machine,
        rho_pos=rho,
        rho_total=1.,
        eta_pos=eta + 0.05,
        eta_total=1.,
        human_pos=expected_human_pos,
        human_total=n_human,
    )
    hhist = est.human_only(human_pos=expected_human_pos, human_total=n_human)

    fig, ax = plt.subplots()

    fig.set_size_inches(10, 10)
    ax.set_title(
        f"n_human={n_human} n_machine={n_machine} alpha={alpha:.2f}"
        f" rho={rho:.2f} eta={eta:.2f}")

    plot_alpha_hist(hhist, name="human only", ax=ax)
    plot_alpha_hist(h1, name=f"eta={eta:.2f}", ax=ax)
    plot_alpha_hist(h2, name=f"eta={eta-.05:.2f}", ax=ax)
    plot_alpha_hist(h3, name=f"eta={eta+.05:.2f}", ax=ax)

    ax.legend()

    plt.show()


def main():

    # bad_luck_curve_shift(
    #     alpha=.6,
    #     rho=.6,
    #     eta=.6,
    #     n_human=1000,
    #     n_machine=10000,
    # )

    # rho_error_curve_shift(
    #     alpha=.6,
    #     rho=.6,
    #     eta=.6,
    #     n_human=1000,
    #     n_machine=10000,
    # )

    eta_error_curve_shift(
        alpha=.6,
        rho=.6,
        eta=.6,
        n_human=1000,
        n_machine=10000,
    )


if __name__ == "__main__":
    main()

