
from typing import Optional
from dataclasses import dataclass

import numpy as np
import scipy.stats as stats

from matplotlib import pyplot as plt
from matplotlib.ticker import MultipleLocator

from binary.estrs import simulate_fixed

ALPHA = 0.65
GAMMA = 0.05
ALPHA_APPROX = 2000


@dataclass
class CurveData:
    n_metric: int
    rho_eta: np.array
    epsilons: np.array


def curve_data(
    n_metric: int,
    gamma: float = GAMMA,
    alpha: float = ALPHA,
    alpha_approx: int = ALPHA_APPROX,
) -> CurveData:
    rho_eta = np.linspace(.5, .975, 20)  # 0.5, 0.525, 0.50, ..., 0.95
    rho_eta[0] = .51  # 0.5 is useless
    rho_eta[-1] = .99  # looks nicer if everything goes to almost 1

    epsilons = np.zeros_like(rho_eta)

    for ix, r in enumerate(rho_eta):
        approx = simulate_fixed(
            alpha=alpha,
            rho=r,
            eta=r,
            n_oracle=0,
            n_metric=n_metric,
            alpha_approx=alpha_approx,
        )
        stdev = np.sqrt(2*approx.var())
        _, eps = stats.norm.interval(1. - gamma, loc=0., scale=stdev)
        epsilons[ix] = eps

    return CurveData(
        n_metric=n_metric,
        rho_eta=rho_eta,
        epsilons=epsilons,
    )


def plot_curve_data(curve: CurveData, ax: Optional[plt.Axes] = None):
    if ax is None:
        plt_target = plt
    else:
        plt_target = ax

    plt_target.plot(
        curve.rho_eta,
        curve.epsilons,
        label="$|\\mathcal{T}_M| = " + f"{curve.n_metric}" + "$")


def main():
    plt.rcParams.update({"text.usetex": True})
    plt.rcParams.update({"font.size": 21.0})

    fig, ax = plt.subplots()
    fig.set_size_inches(10, 10)
    # ax.set_title("Title")

    for n_metric in [1000, 10000]:
        curve = curve_data(n_metric=n_metric)
        plot_curve_data(curve, ax)

    ax.legend()
    ax.xaxis.set_major_locator(MultipleLocator(.1))
    ax.xaxis.set_minor_locator(MultipleLocator(.01))
    ax.yaxis.set_major_locator(MultipleLocator(.1))
    ax.yaxis.set_minor_locator(MultipleLocator(.01))
    ax.set_xlim(.49, 1.01)
    ax.set_ylim(0.0, 0.41)
    ax.set_xlabel("Accuracy of automated metric")
    ax.set_ylabel("Distinguishable Difference")

    fig.tight_layout()

    plt.savefig("res/main_fig.pdf")


if __name__ == "__main__":
    import matplotlib
    matplotlib.use("QtAgg")
    main()

