
from typing import List

import numpy as np
import scipy.stats as stats

from jinja2 import Environment, FileSystemLoader

from binary.estimation import FullApproximationAlphaEstimator


def interval(
    histogram: np.array,
    belief_level: float = 0.05,
):
    p_lower = belief_level / 2.
    p_upper = 1. - p_lower
    n_approx = histogram.shape[0]

    cdf = np.cumsum(histogram)

    lower_ix = (cdf > p_lower).nonzero()[0][0]
    upper_ix = (cdf > p_upper).nonzero()[0][0]

    return lower_ix / n_approx, upper_ix / n_approx


def theoretical_hist_for_params(
    alpha: float,
    rho: float,
    eta: float,
    n_human: int,
    n_metric: int,
    n_rho_eta: int,
    alpha_approx: int = 2000,
    rho_eta_approx: int = 1000,
):
    expected_human_pos = round(alpha * n_human) if n_human > 0 else None

    alpha_obs = (alpha * (rho + eta - 1.)) + (1. - eta)
    expected_metric_pos = round(alpha_obs * n_metric) if n_metric > 0 else None

    est = FullApproximationAlphaEstimator(
        alpha_approx=alpha_approx,
        rho_eta_approx=rho_eta_approx,
    )

    rho_total = int(n_rho_eta * alpha)
    eta_total = n_rho_eta - rho_total

    rho_pos = int(rho * rho_total)
    eta_pos = int(eta * eta_total)

    if n_human == 0 and n_metric == 0:
        raise ValueError(f"'n_human' and 'n_metric' cannot both be 0")
    elif n_human > 0 and n_metric == 0:
        hist = est.human_only(
            human_pos=expected_human_pos, human_total=n_human)
    elif n_human == 0 and n_metric > 0:
        hist = est.machine_only(
            n_pos=expected_metric_pos,
            n_total=n_metric,
            rho_pos=rho_pos,
            rho_total=rho_total,
            eta_pos=eta_pos,
            eta_total=eta_total,
        )
    else:
        hist = est.human_machine_both(
            n_pos=expected_metric_pos,
            n_total=n_metric,
            rho_pos=rho_pos,
            rho_total=rho_total,
            eta_pos=eta_pos,
            eta_total=eta_total,
            human_pos=expected_human_pos,
            human_total=n_human,
        )

    return hist


def theoretical_interval_for_params(
        alpha: float,
        rho: float,
        eta: float,
        n_human: int,
        n_metric: int,
        n_rho_eta: int,
        alpha_approx: int = 2000,
        rho_eta_approx: int = 1000,
        belief_level: float = 0.05,
):
    hist = theoretical_hist_for_params(
        alpha=alpha,
        rho=rho,
        eta=eta,
        n_human=n_human,
        n_metric=n_metric,
        n_rho_eta=n_rho_eta,
        alpha_approx=alpha_approx,
        rho_eta_approx=rho_eta_approx,
    )
    return interval(histogram=hist, belief_level=belief_level)


def theoretical_epsilon_for_params(
        alpha: float,
        rho: float,
        eta: float,
        n_human: int,
        n_metric: int,
        n_rho_eta: int,
        alpha_approx: int = 2000,
        rho_eta_approx: int = 1000,
        belief_level: float = 0.05,
):
    hist = theoretical_hist_for_params(
        alpha=alpha,
        rho=rho,
        eta=eta,
        n_human=n_human,
        n_metric=n_metric,
        n_rho_eta=n_rho_eta,
        alpha_approx=alpha_approx,
        rho_eta_approx=rho_eta_approx,
    )

    alphas = (np.arange(alpha_approx) + .5) / alpha_approx
    mean = alphas @ hist
    var = ((alphas * alphas) @ hist) - (mean * mean)  # variance of only 1 alpha estimate

    stdev = np.sqrt(2 * var)  # std. deviation of 2 alpha estimates with the same variance

    _, upper = stats.norm.interval(1. - belief_level, loc=0., scale=stdev)
    return upper


def latex_table_for_params(
    alpha: float,
    rho: float,
    eta: float,
    table_label: str,
    alpha_approx: int = 2000,
    rho_eta_approx: int = 1000,
    belief_level: float = 0.05,
) -> str:

    data = {}
    for log_n_human in range(7, 21):
        for log_n_machine in range(7, 21):
            no = 2**log_n_human
            nm = 2**log_n_machine
            print(f"computing no = 2^{log_n_human} and nm = 2^{log_n_machine}")
            eps = theoretical_epsilon_for_params(
                alpha=alpha,
                rho=rho,
                eta=eta,
                n_human=no,
                n_metric=nm,
                n_rho_eta=no,
                alpha_approx=alpha_approx,
                rho_eta_approx=rho_eta_approx,
                belief_level=belief_level,
            )
            data[(no, nm)] = f"{eps:.3f}"

    env = Environment(loader=FileSystemLoader("./templates"))
    template = env.get_template("epsilon_table.tex")
    return template.render(
        alpha=f"{alpha:.2f}",
        rho=f"{rho:.2f}",
        eta=f"{eta:.2f}",
        gamma=f"{belief_level:.2f}",
        label=table_label,
        data=data,
    )


def __ugly_pad(n: int, tgt_len: int):
    nstr = str(n)
    if len(nstr) >= tgt_len:
        return nstr
    npad = tgt_len - len(nstr)

    # one space on the right
    nstr += " "
    npad -= 1

    # rest on the left
    while npad > 0:
        nstr = " " + nstr
        npad -= 1

    return nstr


def print_mat(
    alpha: float,
    rho: float,
    eta: float,
    ns: List[int],
    interval_sizes: np.array,
):
    lines = [
        f"alpha:\t{alpha:.2f}",
        f"rho:\t{rho:.2f}",
        f"eta:\t{eta:.2f}",
        "",
        "downwards: human samples",
        "left to right: metric samples",
        "",
        "used the same number as human samples to estimate rho / eta",
        "",
        "\t|" + "".join([f"{__ugly_pad(n, 8)}|" for n in ns]),
    ]

    for ix, n_human in enumerate(ns):
        cells = [f" {interval_sizes[ix, jx]:.4f} |" for jx in range(len(ns))]
        lines.append(f"{n_human}\t|" + "".join(cells))

    return "\n".join(lines)


def main():
    ns = [0] + [2**i for i in range(4, 21)]

    for _alpha in range(5, 99, 5):
        alpha = _alpha / 100
        for _rho in range(50, 101, 5):
            rho = _rho / 100
            for _eta in range(50, 101, 5):
                eta = _eta / 100
                res = np.zeros((len(ns), len(ns)))
                print(f"alpha: {alpha:.2f} rho: {rho:.2f} eta: {eta:.2f}")
                for ix, n_human in enumerate(ns):
                    for jx, n_machine in enumerate(ns):
                        if n_human == 0 and n_machine == 0:
                            res[ix, jx] = 1.
                            continue
                        l, u = theoretical_interval_for_params(
                            alpha=alpha,
                            rho=rho,
                            eta=eta,
                            n_human=n_human,
                            n_metric=n_machine,
                            n_rho_eta=n_human,
                        )
                        res[ix, jx] = u - l

                m = print_mat(alpha, rho, eta, ns, res)

                fname = f"alpha_{alpha:.2f}_rho_{rho:.2f}_eta_{eta:.2f}"
                fname = fname.replace(".", "_")
                with open(f"plots/intervals/{fname}.txt", 'w') as fout:
                    fout.write(m)
                    fout.write('\n')


if __name__ == "__main__":
    main()
