use peroxide::fuga::*;
use peroxide::{seq, zeros};
use rayon::prelude::*;
use serde::{Deserialize, Serialize};
use statrs::distribution::{Beta, Binomial, ContinuousCDF, Discrete};
use std::io;
use std::io::Read;

#[derive(Serialize, Deserialize)]
struct Approximation {
    n_approx: u64,
    values: Vec<f64>,
    probas: Vec<f64>,
}

#[derive(Serialize, Deserialize)]
struct BinomialExperiment {
    pos: u64,
    tot: u64,
}

#[derive(Serialize, Deserialize)]
struct PriorInformation {
    binom: BinomialExperiment,
    n_approx: u64,
}

#[derive(Serialize, Deserialize)]
struct Experiment {
    rho: PriorInformation,
    eta: PriorInformation,
    oracle: PriorInformation,
    metric: BinomialExperiment,
}

fn approx_beta(beta_a: f64, beta_b: f64, n_approx: u64) -> Approximation {
    let n = n_approx as f64;
    let lower = seq!(0, n_approx - 1, 1).fmap(|x| x / n);
    let upper = seq!(1, n_approx, 1).fmap(|x| x / n);
    let mid = seq!(0, n_approx - 1, 1).fmap(|x| (x + 0.5f64) / n);
    let beta = Beta::new(beta_a, beta_b).unwrap();

    let probas = zip_with(|u, l| beta.cdf(u) - beta.cdf(l), &upper, &lower);

    Approximation {
        n_approx,
        values: mid,
        probas,
    }
}

fn approx_uniform(n_approx: u64) -> Approximation {
    let n = n_approx as f64;
    let mid = seq!(0, n_approx - 1, 1).fmap(|x| (x + 0.5f64) / n);
    let ps = zeros!(n_approx as usize).fmap(|_| 1f64 / n);

    Approximation {
        n_approx,
        values: mid,
        probas: ps,
    }
}

fn approximate(prior: &PriorInformation) -> Approximation {
    if prior.binom.tot == 0 {
        approx_uniform(prior.n_approx)
    } else {
        approx_beta(
            (prior.binom.pos + 1) as f64,
            (prior.binom.tot - prior.binom.pos + 1) as f64,
            prior.n_approx,
        )
    }
}

fn metric_proba(
    m_obs: u64,
    m_tot: u64,
    alpha: f64,
    rho: &Approximation,
    eta: &Approximation,
) -> f64 {
    let mut acc = 0.0f64;
    for (r, pr) in rho.values.iter().zip(rho.probas.iter()) {
        for (e, pe) in eta.values.iter().zip(eta.probas.iter()) {
            let p = (alpha * (r + e - 1f64)) + (1f64 - e);
            let binom = Binomial::new(p, m_tot).unwrap();
            acc += binom.pmf(m_obs) * pe * pr;
        }
    }
    acc
}

fn metric_posterior(
    m_obs: u64,
    m_tot: u64,
    rho: &Approximation,
    eta: &Approximation,
    alpha: &Approximation,
) -> Approximation {
    let vs = alpha.values.par_iter();
    let ps = alpha.probas.par_iter();

    let res = vs
        .zip(ps)
        .map(|(a, pa)| pa * metric_proba(m_obs, m_tot, *a, rho, eta))
        .collect();

    Approximation {
        n_approx: alpha.n_approx,
        values: alpha.values.clone(),
        probas: Normed::normalize(&res, Norm::L1),
    }
}

fn posterior(experiment: &Experiment) -> Approximation {
    let alpha = approximate(&experiment.oracle);
    if experiment.metric.tot == 0 {
        alpha
    } else {
        let rho = approximate(&experiment.rho);
        let eta = approximate(&experiment.eta);
        metric_posterior(
            experiment.metric.pos,
            experiment.metric.tot,
            &rho,
            &eta,
            &alpha,
        )
    }
}

fn main() {
    // read experiment data as json from stdin
    let mut json_in = String::new();
    io::stdin().read_to_string(&mut json_in).unwrap();
    let exp: Experiment = serde_json::from_str(&json_in).unwrap();

    // compute posterior of alpha based on experiment data
    let post = posterior(&exp);

    // print posterior as json to stdout
    let serialized = serde_json::to_string_pretty(&post).unwrap();
    println!("{:#}", serialized);
}
