/*
 * Decompiled with CFR 0.152.
 */
package com.aliasi.stats;

import com.aliasi.stats.OnlineNormalEstimator;

public class PotentialScaleReduction {
    private final OnlineNormalEstimator mGlobalEstimator;
    private final OnlineNormalEstimator[] mChainEstimators;

    public PotentialScaleReduction(int numChains) {
        if (numChains < 2) {
            String msg = "Need at least two chains. Found numChains=" + numChains;
            throw new IllegalStateException(msg);
        }
        this.mChainEstimators = new OnlineNormalEstimator[numChains];
        for (int m = 0; m < numChains; ++m) {
            this.mChainEstimators[m] = new OnlineNormalEstimator();
        }
        this.mGlobalEstimator = new OnlineNormalEstimator();
    }

    public PotentialScaleReduction(double[][] yss) {
        this(yss.length);
        for (int m = 0; m < yss.length; ++m) {
            for (int n = 0; n < yss[m].length; ++n) {
                this.update(m, yss[m][n]);
            }
        }
    }

    public int numChains() {
        return this.mChainEstimators.length;
    }

    public OnlineNormalEstimator estimator(int chain) {
        return this.mChainEstimators[chain];
    }

    public OnlineNormalEstimator globalEstimator() {
        return this.mGlobalEstimator;
    }

    public void update(int chain, double y) {
        this.mChainEstimators[chain].handle(y);
        this.mGlobalEstimator.handle(y);
    }

    public double rHat() {
        long minSamples = Long.MAX_VALUE;
        for (OnlineNormalEstimator estimator : this.mChainEstimators) {
            if (minSamples <= estimator.numSamples()) continue;
            minSamples = estimator.numSamples();
        }
        double w = 0.0;
        for (OnlineNormalEstimator estimator : this.mChainEstimators) {
            w += estimator.varianceUnbiased();
        }
        w /= (double)this.numChains();
        double crossChainMean = 0.0;
        for (OnlineNormalEstimator estimator : this.mChainEstimators) {
            crossChainMean += estimator.mean();
        }
        crossChainMean /= (double)this.numChains();
        double b = 0.0;
        for (OnlineNormalEstimator estimator : this.mChainEstimators) {
            b += PotentialScaleReduction.square(estimator.mean() - crossChainMean);
        }
        double varPlus = (double)(minSamples - 1L) * w / (double)minSamples + (b /= (double)this.numChains() - 1.0);
        return Math.sqrt(varPlus / w);
    }

    static double square(double x) {
        return x * x;
    }
}

