
import numpy as np

from binary.estrs import run_simulated


ALPHA_APPROX = 2000
RHO_ETA_APPROX = 1000


def compute_curve(
    n_human: int,
    n_metric: int,
    alpha: float,
    rho: float,
    eta: float,
    gamma: float,
):
    n_rho_eta = np.linspace(100, 5000, 50, dtype=int)
    epsilons = np.zeros_like(n_rho_eta, dtype=float)

    for ix, nre in enumerate(n_rho_eta):
        approx = run_simulated(
            alpha=alpha,
            rho=rho,
            eta=eta,
            n_oracle=n_human,
            n_metric=n_metric,
            n_rho_eta=int(nre),
            rho_approx=RHO_ETA_APPROX,
            eta_approx=RHO_ETA_APPROX,
            alpha_approx=ALPHA_APPROX,
        )
        epsilons[ix] = approx.epsilon(gamma=gamma)

    return {
        "alpha": alpha,
        "rho": rho,
        "eta": eta,
        "gamma": gamma,
        "n_human": n_human,
        "n_metric": n_metric,
        "n_rho_eta": list(map(int, n_rho_eta)),
        "epsilon": list(epsilons),
    }


if __name__ == "__main__":
    import argparse
    import json
    parser = argparse.ArgumentParser()
    parser.add_argument("--human", type=int, required=True, dest="n_human")
    parser.add_argument("--metric", type=int, required=True, dest="n_metric")
    parser.add_argument("--alpha", type=float, required=True, dest="alpha")
    parser.add_argument("--rho", type=float, required=True, dest="rho")
    parser.add_argument("--eta", type=float, required=True, dest="eta")
    parser.add_argument(
        "--gamma", type=float, required=False, dest="gamma", default=.05)
    args = parser.parse_args()

    curve_data = compute_curve(
        n_human=args.n_human,
        n_metric=args.n_metric,
        alpha=args.alpha,
        rho=args.rho,
        eta=args.eta,
        gamma=args.gamma,
    )

    print(json.dumps(curve_data, indent=2))
