
from typing import Optional

from datasets.wmt21 import load_human, load_metric
from datasets.spot_the_bot import load_annotated, load_additional

from binary.estrs import estimate_alpha

from spot_the_bot.compare_pairwise import render_latex
from emnlp2022.stb_showcase import matching_subsample


STB_SYSTEMS = [
    "bert_rank",
    "huggingface",
    "kvmemnn",
    "lost_in_conversation",
    "suckybot",
    "model",
]
STB_SHORT = {
    "bert_rank": "BR",
    "huggingface": "HF",
    "kvmemnn": "KV",
    "lost_in_conversation": "LiC",
    "suckybot": "S2S",
    "model": "BL",
}

WMT_SYSTEMS = [
    "Facebook-AI",
    "HuaweiTSC",
    "Nemo",
    "Online-W",
    "UEdin",
    "VolcTrans-AT",
    "VolcTrans-GLAT",
]
WMT_SHORT = {
    s: s for s in WMT_SYSTEMS
}


def get_sys(domain: str):
    if domain == "wmt":
        return WMT_SYSTEMS, WMT_SHORT
    elif domain == "stb":
        return STB_SYSTEMS, STB_SHORT
    else:
        raise ValueError(
            f"'domain' needs to be one of ['wmt', 'stb'], got unknown '{domain}'")


def load(
    domain: str,
):
    if domain == "wmt":
        human = load_human()
        metric = load_metric()
        additional = None
        return {
            "human": human,
            "metric": metric,
            "additional": additional,
        }
    elif domain == "stb":
        annotated = load_annotated()
        additional = load_additional()
        return {
            "human": annotated['convai2']['human'],
            "metric": annotated['convai2']['metric'],
            "additional": additional['convai2'],
        }
    else:
        raise ValueError(
            f"'domain' needs to be one of ['wmt', 'stb'], got unknown '{domain}'")


def compute_table(
    domain: str,
    metric: Optional[str],
    subsample: Optional[int],
) -> str:

    if domain not in {'wmt', 'stb'}:
        raise ValueError(
            f"'domain' needs to be one of ['wmt', 'stb'], got unknown '{domain}'")

    data = load(domain)
    systems, shorts = get_sys(domain)

    system_approxes = {}
    for system in systems:
        human_data = data['human'][system]
        if metric is None:
            metric_data = None
            rho_eta_data = None
        else:
            metric_data = data['metric'][metric][system]
            rho_eta_data = human_data.pair(metric_data)
            if data['additional'] is not None:
                metric_data = [metric_data, data['additional'][metric][system]]

        system_approxes[system] = estimate_alpha(
            human_data=human_data if subsample is None else matching_subsample(human_data, subsample),
            machine_data=metric_data,
            rho_eta_data=rho_eta_data,
        )

    return render_latex(
        bot_data=system_approxes,
        short_names=shorts,
    )


if __name__ == "__main__":
    import argparse
    parser = argparse.ArgumentParser()
    parser.add_argument("--domain", type=str, required=True, dest="domain")
    parser.add_argument(
        "--metric", type=str, required=False, default=None, dest="metric")
    parser.add_argument(
        "--subsample", type=int, required=False, default=None, dest="subsample")

    args = parser.parse_args()

    tab = compute_table(
        domain=args.domain,
        metric=args.metric,
        subsample=args.subsample,
    )

    print(tab)
