
from datasets.spot_the_bot import load_annotated, load_additional

from binary.estrs import estimate_alpha, estimate_rho_eta, run_simulated
from spot_the_bot.compare_pairwise import render_latex
from emnlp2022.wmt21_showcase import matching_subsample

import jinja2
import numpy as np
import scipy.stats as stats

RHO_APPROX = 1000
ETA_APPROX = 1000
ALPHA_APPROX = 2000

BETTER_RHO_ETA = 0.90
USR_RHO_ETA = 0.52

AVG_ALPHA = 0.30

ANN = load_annotated()
ADD = load_additional()['convai2']

HUMAN = ANN['convai2']['human']
METRIC = ANN['convai2']['metric']

SYSTEMS = ["model", "lost_in_conversation", "kvmemnn"]
SHORT_MT = {
    "model": "BL",
    "lost_in_conversation": "LiC",
    "kvmemnn": "KV",
}
AUTOS = ["usr_ret"]


def experiment_info():
    print("SUCCESS RATES:")
    for mt in SYSTEMS:
        print(mt, f"{HUMAN[mt].binary_scores.mean():.3f}")
    print()

    for auto in AUTOS:
        print(auto)
        for mt in SYSTEMS:
            d = estimate_rho_eta(
                paired=HUMAN[mt].pair(METRIC[auto][mt]),
                selection_strategy="min_diff",
            )
            print(mt, f"rho: {d.rho:.3f}", f"eta: {d.eta:.3f}")
        print()


def theoretical_table(
    alpha: float,
    rho: float,
    eta: float,
):
    nms = [0, 1000, 10000]
    nos = [0, 50, 100, 600]

    data = {}
    for no in nos:
        for nm in nms:
            if nm == 0 and no == 0:
                data[(no, nm)] = "1.000"
                continue
            approx = run_simulated(
                alpha=alpha,
                rho=rho,
                eta=eta,
                n_oracle=no,
                n_metric=nm,
                n_rho_eta=670,
                rho_approx=RHO_APPROX,
                eta_approx=ETA_APPROX,
                alpha_approx=ALPHA_APPROX,
            )
            stdev = np.sqrt(2*approx.var())
            _, eps = stats.norm.interval(1. - 0.05, loc=0., scale=stdev)
            data[(no, nm)] = f"{eps:.3f}"

    env = jinja2.Environment(
        loader=jinja2.FileSystemLoader("./templates")
    )
    template = env.get_template("epsilon_table_v2.tex")

    col_format = "c | c |" + "".join([" c"] * len(nms))

    return template.render(
        col_format=col_format,
        n_os=nos,
        n_ms=nms,
        data=data,
        include_rho_eta=False,
    )


def theoretical_tables():
    good_tex = theoretical_table(
        alpha=AVG_ALPHA,
        rho=BETTER_RHO_ETA,
        eta=BETTER_RHO_ETA,
    )
    with open("./res/stb_showcase/better_theoretical.tex", 'w') as fout:
        fout.write(f"{good_tex}\n")

    usr_tex = theoretical_table(
        alpha=AVG_ALPHA,
        rho=USR_RHO_ETA,
        eta=USR_RHO_ETA,
    )
    with open("./res/stb_showcase/usr_theoretical.tex", 'w') as fout:
        fout.write(f"{usr_tex}\n")


def compare_alphas(metric: str):
    experiments = [
        (600, 0),
        (100, 0),
        (100, 10000),
    ]
    metric_scores = METRIC[metric]
    add_scores = ADD[metric]
    for no, nm in experiments:
        bot_data = {}
        for mt in SYSTEMS:
            m_data = metric_scores[mt]
            adds = add_scores[mt]
            h_data = HUMAN[mt]
            rho_eta_data = h_data.pair(m_data)

            approx = estimate_alpha(
                human_data=matching_subsample(h_data, no) if no < 600 else h_data,
                machine_data=None if nm == 0 else [m_data, adds],
                rho_eta_data=None if nm == 0 else rho_eta_data,
                approx_alpha=ALPHA_APPROX,
                approx_rho=RHO_APPROX,
                approx_eta=ETA_APPROX,
            )
            bot_data[mt] = approx

        latex = render_latex(
            bot_data=bot_data,
            short_names=SHORT_MT,
        )

        with open(f'./res/stb_showcase/{metric}_no{no}_nm{nm}.tex', 'w') as fout:
            fout.write(f"{latex}\n")


def main():
    experiment_info()
    print("computing theoretical tables")
    theoretical_tables()
    for m in AUTOS:
        print(f"comparisons based on {m}")
        compare_alphas(m)


if __name__ == '__main__':
    main()
