
import jinja2

import numpy as np
import scipy.stats as stats

from binary.estrs import run_simulated, simulate_fixed


def latex_table(
    alpha: float,
    rho: float,
    eta: float,
    gamma: float = 0.05,
    rho_approx: int = 1000,
    eta_approx: int = 1000,
    alpha_approx: int = 2000,
    fixed_rho_eta: bool = False
) -> str:
    n_os = [0, 10, 50, 100, 250, 500, 1000, 2500, 5000, 10000]
    n_ms = [0, 1000, 2500, 5000, 10000, 50000, 100000]

    data = {}
    for no in n_os:
        for nm in n_ms:
            if no == 0 and nm == 0:
                data[(no, nm)] = "1.000"
                continue
            if fixed_rho_eta:
                approx = simulate_fixed(
                    alpha=alpha,
                    rho=rho,
                    eta=eta,
                    n_oracle=no,
                    n_metric=nm,
                    alpha_approx=alpha_approx,
                )
            else:
                approx = run_simulated(
                    alpha=alpha,
                    rho=rho,
                    eta=eta,
                    n_oracle=no,
                    n_metric=nm,
                    n_rho_eta=no,
                    rho_approx=rho_approx,
                    eta_approx=eta_approx,
                    alpha_approx=alpha_approx,
                )
            stdev = np.sqrt(2*approx.var())
            _, eps = stats.norm.interval(1. - gamma, loc=0., scale=stdev)
            data[(no, nm)] = f"{eps:.3f}"

    env = jinja2.Environment(
        loader=jinja2.FileSystemLoader("./templates")
    )
    template = env.get_template("epsilon_table_v2.tex")

    col_format = "c | c |" + "".join([" c"] * len(n_ms))

    return template.render(
        col_format=col_format,
        n_os=n_os,
        n_ms=n_ms,
        data=data,
        include_rho_eta=not fixed_rho_eta,
    )


if __name__ == "__main__":
    import argparse
    parser = argparse.ArgumentParser()
    parser.add_argument("--alpha", dest="alpha", type=float, required=True)
    parser.add_argument("--rho", dest="rho", type=float, required=True)
    parser.add_argument("--eta", dest="eta", type=float, required=True)
    parser.add_argument(
        "-o", "--out", dest="out", type=str, required=True, default="tab.tex")
    parser.add_argument("-f", "--fixed", dest="fixed", action="store_true")
    args = parser.parse_args()

    tab = latex_table(
        alpha=args.alpha,
        rho=args.rho,
        eta=args.eta,
        fixed_rho_eta=args.fixed,
    )

    with open(args.out, 'w') as fout:
        fout.write(f"{tab}\n")
