
import json
from typing import Dict
import jinja2

from datasets.spot_the_bot import load_annotated, load_additional
from binary.estrs import estimate_alpha, Approximation, compare


def render_latex(
        bot_data: Dict[str, Approximation],
        short_names: Dict[str, str],
) -> str:
    bots = sorted(bot_data.keys(), key=lambda b: bot_data[b].mean(), reverse=True)
    table_data = {}

    perf = {
        b: f"{approx.mean():.2f}"
        for b, approx in bot_data.items()
    }

    for bot1 in bots:
        for bot2 in bots:
            if bot1 == bot2:
                table_data[(bot1, bot2)] = "-"
            else:
                a1 = bot_data[bot1]
                a2 = bot_data[bot2]
                eps = a1.mean() - a2.mean()
                p = compare(a1, a2)
                form = f"{eps:.2f} ({p:.3f})"
                table_data[(bot1, bot2)] = form

    env = jinja2.Environment(
        loader=jinja2.FileSystemLoader('./templates'),
    )
    template = env.get_template('pairwise_bots.tex')

    return template.render(
        bots=bots,
        layout="c |" + "".join([" c"] * len(bots)),
        data=table_data,
        short_names=short_names,
        perf=perf,
    )


def main():
    annotated = load_annotated()['convai2']
    additional = load_additional()['convai2']

    metric_name = "usr_ret"
    bots = [
        "bert_rank",
        "huggingface",
        "kvmemnn",
        "lost_in_conversation",
        "suckybot",
        "model",
    ]
    short_names = {
        "bert_rank": "BR",
        "huggingface": "HF",
        "kvmemnn": "KV",
        "lost_in_conversation": "LiC",
        "suckybot": "S2S",
        "model": "BL",
    }

    result = {}
    for bot in bots:
        human = annotated['human'][bot]
        metric = annotated['metric'][metric_name][bot]
        paired = human.pair(metric)
        adds = additional[metric_name][bot]

        bot_res = {}
        bot_res['oracle_only'] = estimate_alpha(
            human_data=human,
            machine_data=None,
            rho_eta_data=None,
            approx_alpha=2000,
            approx_rho=1000,
            approx_eta=1000,
        )

        bot_res['matched'] = estimate_alpha(
            human_data=human,
            machine_data=metric,
            rho_eta_data=paired,
            approx_alpha=2000,
            approx_rho=1000,
            approx_eta=1000,
        )

        bot_res["additional"] = estimate_alpha(
            human_data=human,
            machine_data=[metric, adds],
            rho_eta_data=paired,
            approx_alpha=2000,
            approx_rho=1000,
            approx_eta=1000,
        )

        result[bot] = bot_res

    oracle = {
        bot: data['oracle_only']
        for bot, data in result.items()
    }
    with open("res/stb_pairwise_oracle.tex", 'w') as fout:
        fout.write(f"{render_latex(oracle, short_names)}\n")

    matched = {
        bot: data["matched"]
        for bot, data in result.items()
    }
    with open("res/stb_pairwise_matched.tex", 'w') as fout:
        fout.write(f"{render_latex(matched, short_names)}\n")

    adds = {
        bot: data["additional"]
        for bot, data in result.items()
    }
    with open("res/stb_pairwise_additional.tex", 'w') as fout:
        fout.write(f"{render_latex(adds, short_names)}\n")

    result = {
        bot: {
            exp: app.json()
            for exp, app in bot_data.items()
        }
        for bot, bot_data in result.items()
    }
    with open('res/spot_the_bot_estimates.json', 'w') as fout:
        json.dump(fp=fout, obj=result, indent=2)


if __name__ == "__main__":
    main()
