
import argparse
from pathlib import Path

from matplotlib import pyplot as plt

import numpy as np

from datasets.wmt21 import load_human, load_metric
from datasets.base import MetricScores
from binary.estimation import estimate_alpha, plot_alpha_hist


def load_additional():
    root = Path(__file__).parents[1] / "data" / "wmt21_additional"

    def __load(metric: str, p: Path):
        with p.open('r') as fin:
            scores = np.array([float(line.strip()) for line in fin])

        return MetricScores(
            metric=metric,
            dataset="wmt21",
            system="Facebook-AI",
            scores=scores,
        )

    paths = {
        "wmt19_val": {
            "COMET-DA_2020-ref-C": root / "comet_wmt19_validation_scores",
            "bleurt-20-ref-C": root / "bleurt_wmt19_validation_scores",
            "sentBLEU-ref-C": root / "sentBLEU_wmt19_validation_scores",
        },
        "wmt19_train": {
            "COMET-DA_2020-ref-C": root / "comet_wmt19_scores",
            "bleurt-20-ref-C": root / "bleurt_wmt19_scores",
            "sentBLEU-ref-C": root / "sentBLEU_wmt19_scores",
        }
    }

    return {
        add_type: {
            metric: __load(metric, path)
            for metric, path in add_data.items()
        }
        for add_type, add_data in paths.items()
    }


def main(est: str):

    if est == 'naive':
        estimator = "naive_approximation"
    elif est == 'full':
        estimator = "full_approximation"
    elif est == 'sim':
        estimator = "simulation"
    else:
        raise ValueError(
            f"unknown estimator '{est}', use one of ['naive', 'full', 'sim']")

    mt_system = 'Facebook-AI'
    metrics = [
        "COMET-DA_2020-ref-C",
        "bleurt-20-ref-C",
        "sentBLEU-ref-C",
    ]

    human = load_human()[mt_system]
    full_metric = load_metric()
    metrics = {
        m: full_metric[m][mt_system]
        for m in metrics
    }
    additional = load_additional()

    for add_type, add_data in additional.items():
        for metric_name, metric_scores in metrics.items():

            adds = add_data[metric_name]

            paired = human.pair(metric_scores)

            human_hist = estimate_alpha(human_data=human, estimator=estimator)
            machine_hist = estimate_alpha(
                machine_data=metric_scores, rho_eta_data=paired, estimator=estimator)
            combined_hist = estimate_alpha(
                human_data=human,
                machine_data=metric_scores,
                rho_eta_data=paired,
                estimator=estimator,
            )
            with_additional = estimate_alpha(
                human_data=human,
                machine_data=[metric_scores, adds],
                rho_eta_data=paired,
                estimator=estimator,
            )

            fig, ax = plt.subplots()
            fig.set_size_inches(10, 10)
            ax.set_title(f"{metric_name} {mt_system} -- adds from: {add_type}")

            plot_alpha_hist(alphas=human_hist, name="human", ax=ax)
            plot_alpha_hist(alphas=machine_hist, name="machine", ax=ax)
            plot_alpha_hist(alphas=combined_hist, name="both", ax=ax)
            plot_alpha_hist(
                alphas=with_additional, name="with_additional", ax=ax)

            ax.legend()

            plt.savefig(
                f"plots/alphas/wmt21/{add_type}/{est}/{metric_name}_{mt_system}.png")

            plt.close(fig)


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument(
        '-e',
        '--estimator',
        dest='estimator',
        type=str,
        required=True,
        default="naive",
        help="estimation method: ['naive', 'full', 'sim']",
    )
    args = parser.parse_args()
    main(est=args.estimator)
