
from typing import Optional
from dataclasses import dataclass

import numpy as np

from binary.estrs import simulate_fixed, run_simulated


ALPHA_APPROX = 2000
RHO_ETA_APPROX = 1000


@dataclass
class CurveData:
    n_human: int
    n_metric: int
    n_rho_eta: Optional[int]
    alpha: float
    gamma: float
    rho_eta: np.array
    epsilon: np.array

    def json(self) -> dict:
        return {
            'n_human': self.n_human,
            'n_metric': self.n_metric,
            'n_rho_eta': self.n_rho_eta,
            'gamma': self.gamma,
            'rho_eta': list(self.rho_eta),
            'epsilon': list(self.epsilon),
        }

    @staticmethod
    def from_json(d: dict) -> 'CurveData':
        d['rho_eta'] = np.array(d['rho_eta'])
        d['epsilon'] = np.array(d['epsilon'])
        return CurveData(**d)


def compute_curve(
    n_human: int,
    n_metric: int,
    n_rho_eta: Optional[int],
    alpha: float,
    gamma: float,
    alpha_approx: int = ALPHA_APPROX,
    rho_eta_approx: int = RHO_ETA_APPROX,
):
    rho_eta = np.linspace(.5, .975, 20)
    rho_eta[0] = .51
    rho_eta[-1] = .99

    epsilons = np.zeros_like(rho_eta)

    for ix, rhet in enumerate(rho_eta):
        if n_rho_eta is None:
            approx = simulate_fixed(
                alpha=alpha,
                rho=rhet,
                eta=rhet,
                n_oracle=n_human,
                n_metric=n_metric,
                alpha_approx=alpha_approx,
            )
        else:
            approx = run_simulated(
                alpha=alpha,
                rho=rhet,
                eta=rhet,
                n_oracle=n_human,
                n_metric=n_metric,
                n_rho_eta=n_rho_eta,
                rho_approx=rho_eta_approx,
                eta_approx=rho_eta_approx,
                alpha_approx=alpha_approx,
            )
        epsilons[ix] = approx.epsilon(gamma)

    return CurveData(
        n_human=n_human,
        n_metric=n_metric,
        n_rho_eta=n_rho_eta,
        alpha=alpha,
        gamma=gamma,
        rho_eta=rho_eta,
        epsilon=epsilons,
    )


if __name__ == "__main__":
    import argparse
    import json
    parser = argparse.ArgumentParser()
    parser.add_argument("--human", type=int, required=True, dest="n_human")
    parser.add_argument("--metric", type=int, required=True, dest="n_metric")
    parser.add_argument(
        "--rho_eta", type=int, required=False, dest="n_rho_eta", default=None)
    parser.add_argument("--alpha", type=float, required=True, dest="alpha")
    parser.add_argument(
        "--gamma", type=float, required=False, dest="gamma", default=.05)
    args = parser.parse_args()

    curve = compute_curve(
        n_human=args.n_human,
        n_metric=args.n_metric,
        n_rho_eta=args.n_rho_eta,
        alpha=args.alpha,
        gamma=args.gamma,
    )

    print(json.dumps(curve.json(), indent=2))

