/*
 * Decompiled with CFR 0.152.
 */
package fig.prob;

import fig.basic.NumUtils;
import fig.prob.Beta;
import fig.prob.BetaInterface;
import fig.prob.Binomial;
import fig.prob.BinomialSuffStats;
import fig.prob.DegenerateBeta;
import fig.prob.DirichletUtils;
import fig.prob.Distrib;
import fig.prob.MargDistrib;
import fig.prob.SuffStats;
import java.util.Random;

public class MargBinomial
implements MargDistrib<Binomial> {
    private BetaInterface prior;

    public MargBinomial(BetaInterface prior) {
        this.prior = prior;
    }

    @Override
    public double margLogLikelihood(SuffStats _stats) {
        if (this.prior instanceof DegenerateBeta) {
            return this.expectedLogLikelihood(_stats);
        }
        return this.predLogLikelihood(null, _stats);
    }

    @Override
    public double predLogLikelihood(SuffStats _condStats, SuffStats _predStats) {
        BinomialSuffStats condStats = (BinomialSuffStats)_condStats;
        BinomialSuffStats predStats = (BinomialSuffStats)_predStats;
        Beta bprior = (Beta)this.prior;
        double sum = 0.0;
        sum += DirichletUtils.logGammaRatio(bprior.getAlpha() + (condStats == null ? 0.0 : condStats.getTrueCount()), predStats.getTrueCount());
        sum += DirichletUtils.logGammaRatio(bprior.getBeta() + (condStats == null ? 0.0 : condStats.getFalseCount()), predStats.getFalseCount());
        return sum -= DirichletUtils.logGammaRatio(bprior.totalCount() + (condStats == null ? 0.0 : condStats.totalCount()), predStats.totalCount());
    }

    @Override
    public double logProb(SuffStats stats) {
        return this.prior.logProb(stats);
    }

    @Override
    public double logProbObject(Binomial bin) {
        return this.prior.logProbObject(bin.getProb());
    }

    @Override
    public double crossEntropy(Distrib<Binomial> _that) {
        MargBinomial that = (MargBinomial)_that;
        return this.prior.crossEntropy(that.prior);
    }

    public double expectedLog(boolean b) {
        return this.prior.expectedLog(b);
    }

    @Override
    public double expectedLogLikelihood(SuffStats _stats) {
        BinomialSuffStats stats = (BinomialSuffStats)_stats;
        double sum = 0.0;
        sum += stats.getTrueCount() * this.prior.expectedLog(true);
        NumUtils.assertIsFinite(sum += stats.getFalseCount() * this.prior.expectedLog(false));
        return sum;
    }

    public BetaInterface getPrior() {
        return this.prior;
    }

    @Override
    public MargBinomial getPosterior(SuffStats _stats) {
        BinomialSuffStats stats = (BinomialSuffStats)_stats;
        return new MargBinomial(new Beta(this.prior.getAlpha() + stats.getTrueCount(), this.prior.getBeta() + stats.getFalseCount()));
    }

    public MargBinomial modeSpike() {
        return new MargBinomial(this.prior.modeSpike());
    }

    public MargBinomial perturb(Random random) {
        Beta dprior = (Beta)this.prior;
        return new MargBinomial(dprior.perturb(random));
    }

    public MargBinomial degeneratePerturb(Random random) {
        Beta dprior = (Beta)this.prior;
        return new MargBinomial(new DegenerateBeta(dprior.sample(random)));
    }

    @Override
    public Binomial sampleObject(Random random) {
        return new Binomial((Double)this.prior.sampleObject(random));
    }

    public String toString() {
        return String.format("MargBinomial(%s)", this.prior);
    }
}

