
import argparse
from typing import Dict, List

from matplotlib import pyplot as plt

from datasets.spot_the_bot import load_annotated, load_additional
from datasets.base import BinaryHuman, BinaryPaired, MetricScores
from binary.estimation import estimate_alpha, plot_alpha_hist


def leave_out_pairing(
        leave_out: str,
        humans: Dict[str, BinaryHuman],
        metrics: Dict[str, MetricScores],
) -> List[BinaryPaired]:
    result = []

    for bot, h in humans.items():
        if bot == leave_out:
            continue

        m = metrics.get(bot)

        if m is not None:
            result.append(h.pair(m))

    return result


def main(est: str, external=False):

    if est == 'naive':
        estimator = "naive_approximation"
    elif est == 'full':
        estimator = "full_approximation"
    elif est == 'sim':
        estimator = "simulation"
    else:
        raise ValueError(
            f"unknown estimator '{est}', use one of ['naive', 'full', 'sim']")

    annotated = load_annotated()['convai2']
    additional = load_additional()['convai2']

    for metric_name, metric_data in annotated['metric'].items():
        for bot_name, scores in metric_data.items():
            human = annotated['human'].get(bot_name)
            if human is None:
                continue

            adds = additional.get(metric_name, {}).get(bot_name)

            if external:
                paired = leave_out_pairing(
                    leave_out=bot_name,
                    humans=annotated['human'],
                    metrics=metric_data,
                )
            else:
                paired = human.pair(scores)

            human_hist = estimate_alpha(human_data=human, estimator=estimator)
            machine_hist = estimate_alpha(
                machine_data=scores, rho_eta_data=paired, estimator=estimator)
            combined_hist = estimate_alpha(
                human_data=human,
                machine_data=scores,
                rho_eta_data=paired,
                estimator=estimator,
            )

            if adds is not None:
                with_additional = estimate_alpha(
                    human_data=human,
                    machine_data=[scores, adds],
                    rho_eta_data=paired,
                    estimator=estimator,
                )
            else:
                with_additional = None

            if external:
                title = f"convai2 {metric_name} {bot_name} -- external rho/eta"
            else:
                title = f"convai2 {metric_name} {bot_name}"

            fig, ax = plt.subplots()
            fig.set_size_inches(10, 10)
            ax.set_title(title)

            plot_alpha_hist(alphas=human_hist, name="human", ax=ax)
            plot_alpha_hist(alphas=machine_hist, name="machine", ax=ax)
            plot_alpha_hist(alphas=combined_hist, name="both", ax=ax)
            if with_additional is not None:
                plot_alpha_hist(
                    alphas=with_additional, name="with_additional", ax=ax)

            ax.legend()

            if external:
                save_path = f"plots/alphas/stb/convai2/{est}_ext/{metric_name}_{bot_name}.png"
            else:
                save_path = f"plots/alphas/stb/convai2/{est}/{metric_name}_{bot_name}.png"

            plt.savefig(save_path)

            plt.close(fig)


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument(
        '-e',
        '--estimator',
        dest='estimator',
        type=str,
        required=True,
        default="naive",
        help="estimation method: ['naive', 'full', 'sim']",
    )
    parser.add_argument(
        "--external", action=argparse.BooleanOptionalAction, dest="external")
    args = parser.parse_args()
    main(est=args.estimator, external=args.external)
