/*
 * Decompiled with CFR 0.152.
 */
package fig.prob;

import fig.basic.NumUtils;
import fig.basic.TDoubleMap;
import fig.prob.DegenerateSparseDirichlet;
import fig.prob.DirichletUtils;
import fig.prob.Distrib;
import fig.prob.MargDistrib;
import fig.prob.SparseDirichlet;
import fig.prob.SparseDirichletInterface;
import fig.prob.SparseMultinomial;
import fig.prob.SparseMultinomialSuffStats;
import fig.prob.SuffStats;
import java.util.Random;

public class MargSparseMultinomial
implements MargDistrib<SparseMultinomial> {
    private SparseDirichletInterface prior;

    public MargSparseMultinomial(SparseDirichletInterface prior) {
        this.prior = prior;
    }

    @Override
    public double margLogLikelihood(SuffStats _stats) {
        if (this.prior instanceof DegenerateSparseDirichlet) {
            return this.expectedLogLikelihood(_stats);
        }
        return this.predLogLikelihood(SparseMultinomialSuffStats.emptyStats, _stats);
    }

    @Override
    public double predLogLikelihood(SuffStats _condStats, SuffStats _predStats) {
        SparseMultinomialSuffStats condStats = (SparseMultinomialSuffStats)_condStats;
        SparseMultinomialSuffStats predStats = (SparseMultinomialSuffStats)_predStats;
        SparseDirichlet dprior = (SparseDirichlet)this.prior;
        double sum = 0.0;
        for (TDoubleMap.Entry e : predStats) {
            sum += DirichletUtils.logGammaRatio(dprior.getConcentration(e.getKey()) + condStats.getCount(e.getKey()), e.getValue());
        }
        return sum -= DirichletUtils.logGammaRatio(dprior.totalCount() + condStats.totalCount(), predStats.totalCount());
    }

    @Override
    public MargDistrib getPosterior(SuffStats suffStats) {
        SparseDirichlet dprior = (SparseDirichlet)this.prior;
        return new MargSparseMultinomial(dprior.withExtraCounts((SparseMultinomialSuffStats)suffStats));
    }

    @Override
    public double logProb(SuffStats stats) {
        return this.prior.logProb(stats);
    }

    @Override
    public double logProbObject(SparseMultinomial x) {
        return this.prior.logProbObject(x.getProbs());
    }

    @Override
    public SparseMultinomial sampleObject(Random random) {
        return new SparseMultinomial((TDoubleMap)this.prior.sampleObject(random));
    }

    public double expectedLog(Object key) {
        return this.prior.expectedLog(key);
    }

    @Override
    public double crossEntropy(Distrib<SparseMultinomial> distrib) {
        return this.prior.crossEntropy(((MargSparseMultinomial)distrib).prior);
    }

    @Override
    public double expectedLogLikelihood(SuffStats _stats) {
        SparseMultinomialSuffStats stats = (SparseMultinomialSuffStats)_stats;
        double sum = 0.0;
        for (TDoubleMap.Entry e : stats) {
            sum += e.getValue() * this.prior.expectedLog(e.getKey());
        }
        NumUtils.assertIsFinite(sum);
        return sum;
    }

    public MargSparseMultinomial modeSpike() {
        return new MargSparseMultinomial(this.prior.modeSpike());
    }

    public String toString() {
        return this.prior.toString();
    }
}

