
import json
from typing import Optional, Union, List
from dataclasses import dataclass
import subprocess

import numpy as np
import scipy.stats as stats

from binary.estimation import estimate_rho_eta, RhoEtaResult
from datasets.base import BinaryHuman, BinaryPaired, MetricScores


@dataclass
class Approximation:
    n_approx: int
    values: np.array
    probas: np.array

    def json(self) -> dict:
        return {
            "n_approx": self.n_approx,
            "values": list(self.values),
            "probas": list(self.probas),
        }

    @staticmethod
    def from_json(d: dict) -> 'Approximation':
        return Approximation(
            n_approx=d['n_approx'],
            values=np.array(d['values']),
            probas=np.array(d['probas']),
        )

    def mean(self) -> float:
        return self.probas @ self.values

    def var(self) -> float:
        m = self.mean()
        sq = (self.values * self.values) @ self.probas
        return sq - (m*m)

    def stdev(self) -> float:
        return np.sqrt(self.var())

    def epsilon(self, gamma: float):
        stdev = np.sqrt(2*self.var())
        _, eps = stats.norm.interval(1. - gamma, loc=0., scale=stdev)
        return eps


@dataclass
class BinomialExperiment:
    pos: int
    tot: int

    def json(self) -> dict:
        return {
            'pos': self.pos,
            'tot': self.tot,
        }

    @staticmethod
    def from_json(d: dict) -> 'BinomialExperiment':
        return BinomialExperiment(**d)


@dataclass
class PriorInformation:
    binom: BinomialExperiment
    n_approx: int

    def json(self) -> dict:
        return {
            "binom": self.binom.json(),
            "n_approx": self.n_approx,
        }

    @staticmethod
    def from_json(d: dict) -> 'PriorInformation':
        return PriorInformation(
            binom=BinomialExperiment.from_json(d['binom']),
            n_approx=d['n_approx'],
        )


@dataclass
class Experiment:
    rho: PriorInformation
    eta: PriorInformation
    oracle: PriorInformation
    metric: BinomialExperiment

    def json(self) -> dict:
        return {
            'rho': self.rho.json(),
            'eta': self.eta.json(),
            'oracle': self.oracle.json(),
            'metric': self.metric.json(),
        }

    @staticmethod
    def from_json(d: dict) -> 'Experiment':
        return Experiment(
            rho=PriorInformation.from_json(d['rho']),
            eta=PriorInformation.from_json(d['eta']),
            oracle=PriorInformation.from_json(d['oracle']),
            metric=BinomialExperiment.from_json(d['metric']),
        )


def approximate_uniform(n_approx: int) -> Approximation:
    return Approximation(
        n_approx=n_approx,
        values=(np.arange(n_approx) + .5) / n_approx,
        probas=np.ones(n_approx) / n_approx
    )


def approximate_beta(beta_a: float, beta_b: float, n_approx: int) -> Approximation:
    lower = np.arange(n_approx) / n_approx
    upper = (np.arange(n_approx) + 1.) / n_approx
    mid = (np.arange(n_approx) + .5) / n_approx

    beta = stats.beta(a=beta_a, b=beta_b)

    return Approximation(
        n_approx=n_approx,
        values=mid,
        probas=beta.cdf(upper) - beta.cdf(lower)
    )


def compare(bigger: Approximation, smaller: Approximation) -> float:
    if bigger.n_approx != smaller.n_approx:
        raise ValueError(
            f"need both inputs to have the same approximation granularity,"
            f" got {bigger.n_approx} and {smaller.n_approx}")

    acc = 0.
    for i in range(smaller.n_approx):
        for j in range(i+1, bigger.n_approx):
            acc += smaller.probas[i]*bigger.probas[j]

    return acc


def build_experiment(
    n_pos: int,
    n_total: int,
    rho_pos: int,
    rho_total: int,
    eta_pos: int,
    eta_total: int,
    human_pos: int,
    human_total: int,
    alpha_approx: int,
    rho_approx: int,
    eta_approx: int,
):
    return Experiment(
        rho=PriorInformation(
            binom=BinomialExperiment(
                pos=rho_pos,
                tot=rho_total,
            ),
            n_approx=rho_approx,
        ),
        eta=PriorInformation(
            binom=BinomialExperiment(
                pos=eta_pos,
                tot=eta_total,
            ),
            n_approx=eta_approx,
        ),
        oracle=PriorInformation(
            binom=BinomialExperiment(
                pos=human_pos,
                tot=human_total,
            ),
            n_approx=alpha_approx,
        ),
        metric=BinomialExperiment(
            pos=n_pos,
            tot=n_total,
        )
    )


def run_experiment(experiment: Experiment) -> Approximation:
    send_to_stdin = json.dumps(experiment.json())

    run_res = subprocess.run(
        "./bin/toe",
        input=send_to_stdin,
        capture_output=True,
        text=True,
        encoding='utf-8',
    )

    res_json = json.loads(run_res.stdout)

    return Approximation.from_json(res_json)


def likelihood_metric(
    pos: int,
    tot: int,
    alpha: float,
    rho: float,
    eta: float,
) -> float:
    p = (alpha * (rho + eta - 1.)) + (1. - eta)
    return stats.binom.pmf(k=pos, n=tot, p=p)


def run_fixed(
    metric: BinomialExperiment,
    alpha: PriorInformation,
    rho: float,
    eta: float,
) -> Approximation:
    prior = approximate_beta(
        beta_a=alpha.binom.pos + 1,
        beta_b=alpha.binom.tot - alpha.binom.pos + 1,
        n_approx=alpha.n_approx,
    )

    post = np.array([
        pa * likelihood_metric(
            pos=metric.pos,
            tot=metric.tot,
            alpha=a,
            rho=rho,
            eta=eta,
        )
        for a, pa in zip(prior.values, prior.probas)
    ])

    post /= post.sum()

    return Approximation(
        n_approx=prior.n_approx,
        values=prior.values,
        probas=post,
    )


def estimate_alpha(
        human_data: Optional[Union[BinaryHuman, List[BinaryHuman]]] = None,
        machine_data: Optional[Union[MetricScores, List[MetricScores]]] = None,
        rho_eta_data: Optional[Union[RhoEtaResult, BinaryPaired, List[BinaryPaired]]] = None,
        approx_alpha: int = 2000,
        approx_rho: int = 2000,
        approx_eta: int = 2000,
):

    if (machine_data is not None) and (rho_eta_data is None):
        raise ValueError(
            f"if you provide 'machine_data' you also have to provide 'rho_eta_data'")

    if human_data is not None:
        if type(human_data) is list:
            human_pos = sum([h.binary_scores.sum() for h in human_data])
            human_total = sum([h.binary_scores.shape[0] for h in human_data])
        else:
            human_pos = human_data.binary_scores.sum()
            human_total = human_data.binary_scores.shape[0]
    else:
        human_pos = 0
        human_total = 0

    if rho_eta_data is not None:
        if type(rho_eta_data) is not RhoEtaResult:
            rho_eta_data = estimate_rho_eta(rho_eta_data)

        rho_pos = rho_eta_data.rho_pos
        rho_total = rho_eta_data.rho_total
        eta_pos = rho_eta_data.eta_pos
        eta_total = rho_eta_data.eta_total
    else:
        rho_pos = 0
        rho_total = 0
        eta_pos = 0
        eta_total = 0

    if machine_data is not None:
        if type(machine_data) is list:
            machine_scores = np.array([
                score
                for ms in machine_data
                for score in ms.scores
            ])
        else:
            machine_scores = machine_data.scores

        n_pos = (machine_scores >= rho_eta_data.threshold).sum()
        n_total = machine_scores.shape[0]
    else:
        n_pos = 0
        n_total = 0

    experiment = build_experiment(
        n_pos=int(n_pos),
        n_total=int(n_total),
        rho_pos=int(rho_pos),
        rho_total=int(rho_total),
        eta_pos=int(eta_pos),
        eta_total=int(eta_total),
        human_pos=int(human_pos),
        human_total=int(human_total),
        alpha_approx=int(approx_alpha),
        rho_approx=int(approx_rho),
        eta_approx=int(approx_eta),
    )

    return run_experiment(experiment=experiment)


def simulated_experiment(
    alpha: float,
    rho: float,
    eta: float,
    n_oracle: int,
    n_metric: int,
    n_rho_eta: int,
    rho_approx: int,
    eta_approx: int,
    alpha_approx: int,
) -> Experiment:
    pos_oracle = int(round(alpha*n_oracle))
    tot_oracle = n_oracle

    p_obs = (alpha*(rho + eta - 1.)) + (1. - eta)
    pos_m = int(round(p_obs*n_metric))
    tot_m = n_metric

    tot_rho = int(round(alpha * n_rho_eta))
    tot_eta = n_rho_eta - tot_rho

    pos_rho = int(round(rho*tot_rho))
    pos_eta = int(round(eta*tot_eta))

    return build_experiment(
        n_pos=pos_m,
        n_total=tot_m,
        rho_pos=pos_rho,
        rho_total=tot_rho,
        eta_pos=pos_eta,
        eta_total=tot_eta,
        human_pos=pos_oracle,
        human_total=tot_oracle,
        rho_approx=rho_approx,
        eta_approx=eta_approx,
        alpha_approx=alpha_approx,
    )


def run_simulated(
    alpha: float,
    rho: float,
    eta: float,
    n_oracle: int,
    n_metric: int,
    n_rho_eta: int,
    rho_approx: int,
    eta_approx: int,
    alpha_approx: int,
) -> Approximation:
    exp = simulated_experiment(
        alpha=alpha,
        rho=rho,
        eta=eta,
        n_oracle=n_oracle,
        n_metric=n_metric,
        n_rho_eta=n_rho_eta,
        rho_approx=rho_approx,
        eta_approx=eta_approx,
        alpha_approx=alpha_approx,
    )
    return run_experiment(exp)


def simulate_fixed(
    alpha: float,
    rho: float,
    eta: float,
    n_oracle: int,
    n_metric: int,
    alpha_approx: int,
):
    alpha_prior = PriorInformation(
        binom=BinomialExperiment(
            pos=int(round(alpha*n_oracle)),
            tot=n_oracle,
        ),
        n_approx=alpha_approx,
    )
    p_metric = (alpha * (rho + eta - 1.)) + (1. - eta)
    metric = BinomialExperiment(
        pos=int(round(p_metric*n_metric)),
        tot=n_metric,
    )
    return run_fixed(
        metric=metric,
        alpha=alpha_prior,
        rho=rho,
        eta=eta,
    )
