
from datasets.wmt21 import load_human, load_metric

from binary.estrs import estimate_alpha, estimate_rho_eta, run_simulated
from spot_the_bot.compare_pairwise import render_latex
from datasets.base import BinaryHuman

import jinja2
import numpy as np
import scipy.stats as stats
from sklearn.model_selection import train_test_split

RHO_APPROX = 1000
ETA_APPROX = 1000
ALPHA_APPROX = 2000

BLEURT_RHO_ETA = 0.6
BLEU_RHO_ETA = 0.52
N_HUMAN = 527

AVG_ALPHA = 0.65


HUMAN = load_human()
METRIC = load_metric()

SYSTEMS = ["Facebook-AI", "VolcTrans-GLAT", "HuaweiTSC"]
SHORT_MT = {
    "Facebook-AI": "FBAI",
    "VolcTrans-GLAT": "VT",
    "HuaweiTSC": "HT",
}
AUTOS = ["bleurt-20-ref-C", "sentBLEU-ref-C"]


def experiment_info():
    print("SUCCESS RATES:")
    for mt in SYSTEMS:
        print(mt, f"{HUMAN[mt].binary_scores.mean():.3f}")
    print()

    for auto in AUTOS:
        print(auto)
        for mt in SYSTEMS:
            d = estimate_rho_eta(
                paired=HUMAN[mt].pair(METRIC[auto][mt]),
                selection_strategy="min_diff",
            )
            print(mt, f"rho: {d.rho:.3f}", f"eta: {d.eta:.3f}")
        print()


def theoretical_table(
    alpha: float,
    rho: float,
    eta: float,
):
    nms = [0, 1000, 10000]
    nos = [0, 50, 100, 527]

    data = {}
    for no in nos:
        for nm in nms:
            if nm == 0 and no == 0:
                data[(no, nm)] = "1.000"
                continue
            approx = run_simulated(
                alpha=alpha,
                rho=rho,
                eta=eta,
                n_oracle=no,
                n_metric=nm,
                n_rho_eta=527,
                rho_approx=RHO_APPROX,
                eta_approx=ETA_APPROX,
                alpha_approx=ALPHA_APPROX,
            )
            stdev = np.sqrt(2*approx.var())
            _, eps = stats.norm.interval(1. - 0.05, loc=0., scale=stdev)
            data[(no, nm)] = f"{eps:.3f}"

    env = jinja2.Environment(
        loader=jinja2.FileSystemLoader("./templates")
    )
    template = env.get_template("epsilon_table_v2.tex")

    col_format = "c | c |" + "".join([" c"] * len(nms))

    return template.render(
        col_format=col_format,
        n_os=nos,
        n_ms=nms,
        data=data,
        include_rho_eta=False,
    )


def theoretical_tables():
    bleurt_tex = theoretical_table(
        alpha=AVG_ALPHA,
        rho=BLEURT_RHO_ETA,
        eta=BLEURT_RHO_ETA,
    )
    with open("./res/wmt21_showcase/bleurt_theoretical.tex", 'w') as fout:
        fout.write(f"{bleurt_tex}\n")

    bleu_tex = theoretical_table(
        alpha=AVG_ALPHA,
        rho=BLEU_RHO_ETA,
        eta=BLEU_RHO_ETA,
    )
    with open("./res/wmt21_showcase/sentbleu_theoretical.tex", 'w') as fout:
        fout.write(f"{bleu_tex}\n")


def simple_subsample(h: BinaryHuman, n: int) -> BinaryHuman:
    return BinaryHuman(
        system=h.system,
        dataset=h.dataset,
        binary_scores=h.binary_scores[:n],
    )


def random_subsample(h: BinaryHuman, n: int) -> BinaryHuman:
    new_scores, _ = train_test_split(
        h.binary_scores,
        train_size=n,
        random_state=0xdeadbeef,
        shuffle=True,
    )
    return BinaryHuman(
        system=h.system,
        dataset=h.dataset,
        binary_scores=new_scores,
    )


def matching_subsample(h: BinaryHuman, n: int) -> BinaryHuman:
    new_scores, _ = train_test_split(
        h.binary_scores,
        train_size=n,
        random_state=0xdeadbeef,
        shuffle=True,
        stratify=h.binary_scores,
    )
    return BinaryHuman(
        system=h.system,
        dataset=h.dataset,
        binary_scores=new_scores,
    )


def compare_alphas(metric: str):
    experiments = [
        (527, 0),
        (100, 0),
        (100, 1000),
    ]
    metric_scores = METRIC[metric]
    for no, nm in experiments:
        bot_data = {}
        for mt in SYSTEMS:
            m_data = metric_scores[mt]
            h_data = HUMAN[mt]
            rho_eta_data = h_data.pair(m_data)

            approx = estimate_alpha(
                human_data=matching_subsample(h_data, no) if no < 527 else h_data,
                machine_data=None if nm == 0 else m_data,
                rho_eta_data=None if nm == 0 else rho_eta_data,
                approx_alpha=ALPHA_APPROX,
                approx_rho=RHO_APPROX,
                approx_eta=ETA_APPROX,
            )
            bot_data[mt] = approx

        latex = render_latex(
            bot_data=bot_data,
            short_names=SHORT_MT,
        )

        with open(f'./res/wmt21_showcase/{metric}_no{no}_nm{nm}.tex', 'w') as fout:
            fout.write(f"{latex}\n")


def main():
    experiment_info()
    print("computing theoretical tables")
    theoretical_tables()
    for m in AUTOS:
        print(f"comparisons based on {m}")
        compare_alphas(m)


if __name__ == '__main__':
    main()
