
import json

from datasets.wmt21 import load_human, load_metric

from spot_the_bot.compare_pairwise import render_latex

from binary.estrs import estimate_alpha


def main():
    human = load_human()
    metric = load_metric()

    metric_name = "bleurt-20-ref-C"
    mt_systems = [
        "Facebook-AI",
        "HuaweiTSC",
        "Nemo",
        "Online-W",
        "UEdin",
        "VolcTrans-AT",
        "VolcTrans-GLAT",
    ]
    names = {n: n for n in mt_systems}  # TODO come up with short names

    result = {}
    for mt in mt_systems:
        h = human[mt]
        m = metric[metric_name][mt]
        paired = h.pair(m)

        sub_res = {
            "oracle_only": estimate_alpha(
                human_data=h,
                machine_data=None,
                rho_eta_data=None,
                approx_eta=1000,
                approx_rho=1000,
                approx_alpha=2000,
            ),
            "both": estimate_alpha(
                human_data=h,
                machine_data=m,
                rho_eta_data=paired,
                approx_rho=1000,
                approx_eta=1000,
                approx_alpha=2000,
            )
        }

        result[mt] = sub_res

    oracle = {
        mt: d['oracle_only']
        for mt, d in result.items()
    }
    with open('res/wmt21_pairwise_oracle.tex', 'w') as fout:
        fout.write(f"{render_latex(oracle, names)}\n")

    both = {
        mt: d['both']
        for mt, d in result.items()
    }
    with open('res/wmt21_pairwise_both.tex', 'w') as fout:
        fout.write(f"{render_latex(both, names)}\n")

    result = {
        mt: {
            sub: app.json()
            for sub, app in mt_data.items()
        }
        for mt, mt_data in result.items()
    }
    with open('res/wmt21_estimates.json', 'w') as fout:
        json.dump(fp=fout, obj=result, indent=2)


if __name__ == "__main__":
    main()
