
from typing import Optional

import numpy as np

from binary.estrs import simulate_fixed, run_simulated


ALPHA_APPROX = 2000
RHO_ETA_APPROX = 1000


def compute_curve(
    n_human: int,
    n_rho_eta: Optional[int],
    alpha: float,
    rho: float,
    eta: float,
    gamma: float,
):
    n_metric = np.linspace(200, 10000, 50, dtype=int)
    epsilons = np.zeros_like(n_metric, dtype=float)

    for ix, nm in enumerate(n_metric):
        if n_rho_eta is None:
            approx = simulate_fixed(
                alpha=alpha,
                rho=rho,
                eta=eta,
                n_oracle=n_human,
                n_metric=int(nm),
                alpha_approx=ALPHA_APPROX,
            )
        else:
            approx = run_simulated(
                alpha=alpha,
                rho=rho,
                eta=eta,
                n_oracle=n_human,
                n_metric=int(nm),
                n_rho_eta=n_rho_eta,
                rho_approx=RHO_ETA_APPROX,
                eta_approx=RHO_ETA_APPROX,
                alpha_approx=ALPHA_APPROX,
            )
        epsilons[ix] = approx.epsilon(gamma=gamma)

    return {
        "alpha": alpha,
        "rho": rho,
        "eta": eta,
        "gamma": gamma,
        "n_human": n_human,
        "n_rho_eta": n_rho_eta,
        "n_metric": list(map(int, n_metric)),
        "epsilon": list(epsilons),
    }


if __name__ == "__main__":
    import argparse
    import json
    parser = argparse.ArgumentParser()
    parser.add_argument("--human", type=int, required=True, dest="n_human")
    parser.add_argument(
        "--rho_eta", type=int, required=False, dest="n_rho_eta", default=None)
    parser.add_argument("--alpha", type=float, required=True, dest="alpha")
    parser.add_argument("--rho", type=float, required=True, dest="rho")
    parser.add_argument("--eta", type=float, required=True, dest="eta")
    parser.add_argument(
        "--gamma", type=float, required=False, dest="gamma", default=.05)
    args = parser.parse_args()

    curve_data = compute_curve(
        n_human=args.n_human,
        n_rho_eta=args.n_rho_eta,
        alpha=args.alpha,
        rho=args.rho,
        eta=args.eta,
        gamma=args.gamma,
    )

    print(json.dumps(curve_data, indent=2))
