
from dataclasses import dataclass
from typing import List, Union, Optional
from itertools import product

import stan
from sklearn.metrics import roc_curve
import scipy.stats as stats
import numpy as np
from matplotlib import pyplot as plt


from datasets.base import BinaryPaired, BinaryHuman, MetricScores


@dataclass
class RhoEtaResult:
    metric: str
    dataset: str
    rho: float
    eta: float
    threshold: float
    rho_pos: int
    rho_total: int
    eta_pos: int
    eta_total: int
    binary_human: np.array
    scalar_metric: np.array
    binary_metric: np.array


def estimate_rho_eta(
    paired: Union[BinaryPaired, List[BinaryPaired]],
    selection_strategy: str = "max_sum",
) -> RhoEtaResult:

    if type(paired) is list:
        if len({p.metric for p in paired}) > 1:
            raise ValueError(
                f"tried estimating rho/eta with data from different metrics")
        if len({p.dataset for p in paired}) > 1:
            raise ValueError(
                f"tried estimating rho/eta with data from different datasets")

        y_true = np.array([
            label
            for pair in paired
            for label in pair.human_binary
        ], dtype=np.bool)

        y_pred = np.array([
            score
            for pair in paired
            for score in pair.scores
        ])
        metric = paired[0].metric
        dataset = paired[0].dataset
    else:
        y_true = paired.human_binary
        y_pred = paired.scores
        metric = paired.metric
        dataset = paired.dataset

    fpr, tpr, thresholds = roc_curve(
        y_true=y_true,
        y_score=y_pred,
    )
    rhos = tpr
    etas = 1. - fpr

    if selection_strategy == 'max_sum':
        selected = (rhos + etas).argmax()
    elif selection_strategy == 'min_diff':
        selected = np.abs(rhos - etas).argmin()
    else:
        raise ValueError(
            f"unknown selection strategy: '{selection_strategy}'"
            f"use one of ['max_sum', 'min_diff']")
    threshold = thresholds[selected]

    pred_labels = y_pred >= threshold

    rho_pos = np.sum(pred_labels[y_true])
    rho_total = np.sum(y_true)
    eta_pos = np.sum(~pred_labels[~y_true])
    eta_total = np.sum(~y_true)

    # sanity check
    assert np.allclose(rho_pos / rho_total, rhos[selected])
    assert np.allclose(eta_pos / eta_total, etas[selected])

    return RhoEtaResult(
        metric=metric,
        dataset=dataset,
        rho=rho_pos / rho_total,
        eta=eta_pos / eta_total,
        threshold=threshold,
        rho_pos=rho_pos,
        rho_total=rho_total,
        eta_pos=eta_pos,
        eta_total=eta_total,
        binary_human=y_true,
        scalar_metric=y_pred,
        binary_metric=pred_labels,
    )


def estimate_alpha(
    human_data: Optional[Union[BinaryHuman, List[BinaryHuman]]] = None,
    machine_data: Optional[Union[MetricScores, List[MetricScores]]] = None,
    rho_eta_data: Optional[Union[RhoEtaResult, BinaryPaired, List[BinaryPaired]]] = None,
    estimator: str = "naive_approximation",
    n_approx: int = 2000,
):

    if (human_data is None) and (machine_data is None):
        raise ValueError(
            f"need to provide either 'human_data' or 'machine_data'"
            f" for alpha estimation")

    if estimator == "simulation":
        est = SimulationAlphaEstimator(
            n_bins=n_approx,
            return_fit=False,
        )
    elif estimator == "naive_approximation":
        est = NaiveApproximationAlphaEstimator(
            n_approx=n_approx,
        )
    elif estimator == "full_approximation":
        est = FullApproximationAlphaEstimator(
            alpha_approx=n_approx,
            rho_eta_approx=1000,
        )
    else:
        raise ValueError(f"unknown estimation strategy '{estimator}',"
                         f" use one of "
                         f"['naive_approximation', 'simulation', 'full_approximation']")

    if human_data is not None:
        if type(human_data) is list:
            human_pos = sum([h.binary_scores.sum() for h in human_data])
            human_total = sum([h.binary_scores.shape[0] for h in human_data])
        else:
            human_pos = human_data.binary_scores.sum()
            human_total = human_data.binary_scores.shape[0]
    else:
        human_pos = None
        human_total = None

    estimator_args = {
        'human_pos': human_pos,
        'human_total': human_total,
    }

    if machine_data is None:
        return est.human_only(**estimator_args)

    if rho_eta_data is None:
        raise ValueError(
            f"you need to provide 'rho_eta_data' to estimate"
            f" alpha from machine data")

    if type(rho_eta_data) is RhoEtaResult:
        rho_eta_res = rho_eta_data
    else:
        rho_eta_res = estimate_rho_eta(paired=rho_eta_data)

    estimator_args["rho_pos"] = rho_eta_res.rho_pos
    estimator_args["rho_total"] = rho_eta_res.rho_total
    estimator_args["eta_pos"] = rho_eta_res.eta_pos
    estimator_args["eta_total"] = rho_eta_res.eta_total

    if type(machine_data) is list:
        machine_scores = np.array([
            score
            for ms in machine_data
            for score in ms.scores
        ])
    else:
        machine_scores = machine_data.scores

    estimator_args["n_pos"] = (machine_scores >= rho_eta_res.threshold).sum()
    estimator_args["n_total"] = machine_scores.shape[0]

    if human_data is None:
        return est.machine_only(**estimator_args)
    else:
        return est.human_machine_both(**estimator_args)


def plot_alpha_hist(
    alphas: np.array,
    name: str,
    ax: Optional[plt.Axes] = None,
):
    xs = np.arange(0, alphas.shape[0]) / alphas.shape[0]
    width = 1. / alphas.shape[0]

    if ax is not None:
        ax.bar(x=xs, height=alphas, width=width, label=name, alpha=0.6)
    else:
        plt.bar(x=xs, height=alphas, width=width, label=name, alpha=0.6)


class SimulationAlphaEstimator:
    HUMAN_ONLY = """
    data {
        int<lower=1> n_total;
        int<lower=0> n_pos;
    }
    parameters {
        real<lower=0, upper=1> alpha;
    }
    model {
        alpha ~ uniform(0, 1);
        n_pos ~ binomial(n_total, alpha);
    }
    """

    MACHINE_ONLY = """
    data {
        int<lower=1> n_total;
        int<lower=0> n_pos;
        real<lower=0> rho_a;
        real<lower=0> rho_b;
        real<lower=0> eta_a;
        real<lower=0> eta_b;
    }
    parameters {
        real<lower=0, upper=1> alpha;
        real<lower=0, upper=1> rho;
        real<lower=0, upper=1> eta;
    }
    transformed parameters {
        real<lower=0, upper=1> obs_alpha = (alpha * (rho + eta - 1)) + (1 - eta);
    }
    model {
        alpha ~ uniform(0, 1);
        rho ~ beta(rho_a + 1, rho_b + 1);
        eta ~ beta(eta_a + 1, eta_b + 1);
        n_pos ~ binomial(n_total, obs_alpha);
    }
    """

    HUMAN_MACHINE = """
    data {
        int<lower=1> human_total;
        int<lower=0> human_pos;
        int<lower=1> n_total;
        int<lower=0> n_pos;
        real<lower=0> rho_a;
        real<lower=0> rho_b;
        real<lower=0> eta_a;
        real<lower=0> eta_b;
    }
    parameters {
        real<lower=0, upper=1> alpha;
        real<lower=0, upper=1> rho;
        real<lower=0, upper=1> eta;
    }
    transformed parameters {
        real<lower=0, upper=1> obs_alpha = (alpha * (rho + eta - 1)) + (1 - eta);
    }
    model {
        alpha ~ uniform(0, 1);
        rho ~ beta(rho_a + 1, rho_b + 1);
        eta ~ beta(eta_a + 1, eta_b + 1);
        n_pos ~ binomial(n_total, obs_alpha);
        human_pos ~ binomial(human_total, alpha);
    }
    """

    def __init__(
        self,
        num_chains: int = 10,
        num_samples: int = 10000,
        seed: int = 0xdead,
        n_bins: int = 2000,
        return_fit: bool = False
    ):
        self.num_chains = num_chains
        self.num_samples = num_samples
        self.seed = seed
        self.n_bins = n_bins
        self.return_fit = return_fit

    def __run_sim(self, program: str, data: dict):
        posterior = stan.build(
            program_code=program,
            data=data,
            random_seed=self.seed,
        )
        fit = posterior.sample(
            num_chains=self.num_chains, num_samples=self.num_samples)

        if self.return_fit:
            return fit
        else:
            hist, _ = np.histogram(
                a=fit['alpha'][0],
                bins=self.n_bins,
                range=(0., 1.),
                density=False,
            )
            return hist / hist.sum()

    def human_only(self, human_pos: int, human_total: int, **kwargs):
        data = {
            "n_pos": human_pos,
            "n_total": human_total,
        }
        return self.__run_sim(
            program=SimulationAlphaEstimator.HUMAN_ONLY, data=data)

    def machine_only(
        self,
        n_pos: int,
        n_total: int,
        rho_pos: int,
        rho_total: int,
        eta_pos: int,
        eta_total: int,
        **kwargs,
    ):
        data = {
            "n_pos": n_pos,
            "n_total": n_total,
            "rho_a": rho_pos,
            "rho_b": rho_total - rho_pos,
            "eta_a": eta_total,
            "eta_b": eta_total - eta_pos,
        }
        return self.__run_sim(
            program=SimulationAlphaEstimator.MACHINE_ONLY, data=data)

    def human_machine_both(
        self,
        n_pos: int,
        n_total: int,
        rho_pos: int,
        rho_total: int,
        eta_pos: int,
        eta_total: int,
        human_pos: int,
        human_total: int,
        **kwargs,
    ):
        data = {
            "n_pos": n_pos,
            "n_total": n_total,
            "rho_a": rho_pos,
            "rho_b": rho_total - rho_pos,
            "eta_a": eta_total,
            "eta_b": eta_total - eta_pos,
            "human_pos": human_pos,
            "human_total": human_total,
        }
        return self.__run_sim(
            program=SimulationAlphaEstimator.HUMAN_MACHINE, data=data)


class NaiveApproximationAlphaEstimator:

    def __init__(self, n_approx: int = 2000):
        self.n_approx = n_approx

        self.alphas = (np.arange(self.n_approx) + .5) / self.n_approx
        self.alpha_prior = np.ones(self.n_approx) / self.n_approx
        self.alpha_log_prior = (-np.log(self.n_approx)) * np.ones(self.n_approx)

    @staticmethod
    def softmax(ls):
        nums = np.exp(ls)
        denom = np.sum(nums)
        return nums / denom

    def human_only(self, human_pos: int, human_total: int, **kwargs):
        logits = stats.binom.logpmf(k=human_pos, n=human_total, p=self.alphas)
        # cancels out since uniform prior
        # logits += self.alpha_log_prior
        return self.softmax(logits)

    def machine_only(
            self,
            n_pos: int,
            n_total: int,
            rho_pos: int,
            rho_total: int,
            eta_pos: int,
            eta_total: int,
            **kwargs,
    ):
        rho = rho_pos / rho_total
        eta = eta_pos / eta_total

        alpha_obs = (self.alphas * (rho + eta - 1.)) + (1. - eta)

        logits = stats.binom.logpmf(k=n_pos, n=n_total, p=alpha_obs)
        # cancels out since uniform prior
        # logits += self.alpha_log_prior

        return self.softmax(logits)

    def human_machine_both(
            self,
            n_pos: int,
            n_total: int,
            rho_pos: int,
            rho_total: int,
            eta_pos: int,
            eta_total: int,
            human_pos: int,
            human_total: int,
            **kwargs,
    ):
        rho = rho_pos / rho_total
        eta = eta_pos / eta_total

        alpha_obs = (self.alphas * (rho + eta - 1.)) + (1. - eta)

        logits = stats.binom.logpmf(k=human_pos, n=human_total, p=self.alphas)
        logits += stats.binom.logpmf(k=n_pos, n=n_total, p=alpha_obs)
        # cancels out since uniform prior
        # logits += self.alpha_log_prior

        return self.softmax(logits)


class FullApproximationAlphaEstimator:

    def __init__(self, alpha_approx=2000, rho_eta_approx=1000):
        self.alpha_approx = alpha_approx
        self.rho_eta_approx = rho_eta_approx

        self.alphas = (np.arange(self.alpha_approx) + .5) / self.alpha_approx

        self.rhos = (np.arange(self.rho_eta_approx) + .5) / self.rho_eta_approx
        self.etas = (np.arange(self.rho_eta_approx) + .5) / self.rho_eta_approx

        self.rho_eta_lower = np.arange(self.rho_eta_approx) / self.rho_eta_approx
        self.rho_eta_upper = (np.arange(self.rho_eta_approx) + 1.) / self.rho_eta_approx

    @staticmethod
    def __p_n_machine_marginalized(
        n_pos: int,
        n_total: int,
        rhos: np.array,
        p_rhos: np.array,
        etas: np.array,
        p_etas: np.array,
        alphas: np.array,
        batch_size: int = 2 ** 15,
    ):
        acc = np.zeros_like(alphas)

        # rho_eta_comb = np.transpose(
        #     [np.repeat(rhos, etas.shape[0]), np.repeat(etas[np.newaxis, :], rhos.shape[0], axis=0).reshape(-1)])
        rho_eta_comb = np.array(list(product(rhos, etas)))
        p_comb = np.kron(p_rhos, p_etas)

        mult = (rho_eta_comb[:, 0] + rho_eta_comb[:, 1] - 1.)
        adds = (1. - rho_eta_comb[:, 1])

        for batch_start in range(0, rho_eta_comb.shape[0], batch_size):
            mb = mult[np.newaxis, batch_start:batch_start + batch_size]
            ma = adds[np.newaxis, batch_start:batch_start + batch_size]

            alpha_obs = (alphas[:, np.newaxis] * mb) + ma
            ps = stats.binom.pmf(k=n_pos, n=n_total, p=alpha_obs)

            acc += ps @ p_comb[batch_start:batch_start + batch_size]

        return acc / acc.sum()

    def human_only(self, human_pos: int, human_total: int, **kwargs):
        un_normed = stats.binom.pmf(k=human_pos, n=human_total, p=self.alphas)
        return un_normed / un_normed.sum()

    def machine_only(
            self,
            n_pos: int,
            n_total: int,
            rho_pos: int,
            rho_total: int,
            eta_pos: int,
            eta_total: int,
            **kwargs,
    ):
        p_rhos = stats.beta.cdf(
            x=self.rho_eta_upper, a=rho_pos + 1, b=rho_total - rho_pos + 1
        ) - stats.beta.cdf(x=self.rho_eta_lower, a=rho_pos + 1, b=rho_total - rho_pos + 1)
        p_etas = stats.beta.cdf(
            x=self.rho_eta_upper, a=eta_pos + 1, b=eta_total - eta_pos + 1
        ) - stats.beta.cdf(x=self.rho_eta_lower, a=eta_pos + 1, b=eta_total - eta_pos + 1)
        assert np.allclose(p_rhos.sum(), 1.)
        assert np.allclose(p_etas.sum(), 1.)
        return FullApproximationAlphaEstimator.__p_n_machine_marginalized(
            n_pos=n_pos,
            n_total=n_total,
            rhos=self.rhos,
            p_rhos=p_rhos,
            etas=self.etas,
            p_etas=p_etas,
            alphas=self.alphas,
        )

    def human_machine_both(
            self,
            n_pos: int,
            n_total: int,
            rho_pos: int,
            rho_total: int,
            eta_pos: int,
            eta_total: int,
            human_pos: int,
            human_total: int,
            **kwargs,
    ):
        human = self.human_only(human_pos=human_pos, human_total=human_total)
        machine = self.machine_only(
            n_pos=n_pos,
            n_total=n_total,
            rho_pos=rho_pos,
            rho_total=rho_total,
            eta_pos=eta_pos,
            eta_total=eta_total,
        )
        res = human * machine
        return res / res.sum()
