/*
 * Decompiled with CFR 0.152.
 */
package fig.prob;

import fig.basic.Fmt;
import fig.basic.ListUtils;
import fig.prob.DiagMultGaussian;
import fig.prob.DiagMultGaussianSuffStats;
import fig.prob.Distrib;
import fig.prob.Gaussian;
import fig.prob.MargDistrib;
import fig.prob.MargMeanGaussian;
import fig.prob.SuffStats;
import java.util.Random;

public class MargMeanDiagMultGaussian
implements MargDistrib<DiagMultGaussian> {
    private DiagMultGaussian meanDistrib;
    private double[] varSpikes;

    public MargMeanDiagMultGaussian(DiagMultGaussian meanDistrib, double[] varSpikes) {
        this.meanDistrib = meanDistrib;
        this.varSpikes = varSpikes;
    }

    public MargMeanDiagMultGaussian(int numDim, Gaussian meanDistrib, double varSpike) {
        this.meanDistrib = new DiagMultGaussian(numDim, meanDistrib);
        this.varSpikes = ListUtils.newDouble(numDim, varSpike);
    }

    public MargMeanGaussian getComponent(int i) {
        return new MargMeanGaussian(this.meanDistrib.getComponent(i), this.varSpikes[i]);
    }

    @Override
    public MargDistrib getPosterior(SuffStats _stats) {
        DiagMultGaussianSuffStats stats = (DiagMultGaussianSuffStats)_stats;
        Gaussian[] meanDistribs = new Gaussian[this.dim()];
        double[] varSpikes = new double[this.dim()];
        int i = 0;
        while (i < this.dim()) {
            MargMeanGaussian posterior = this.getComponent(i).getPosterior(stats.getComponent(i));
            meanDistribs[i] = posterior.getMeanDistrib();
            varSpikes[i] = posterior.getVarSpike();
            ++i;
        }
        return new MargMeanDiagMultGaussian(new DiagMultGaussian(meanDistribs), varSpikes);
    }

    @Override
    public double margLogLikelihood(SuffStats stats) {
        double sum = 0.0;
        int i = 0;
        while (i < this.dim()) {
            sum += this.getComponent(i).margLogLikelihood(((DiagMultGaussianSuffStats)stats).getComponent(i));
            ++i;
        }
        return sum;
    }

    @Override
    public double predLogLikelihood(SuffStats condStats, SuffStats predStats) {
        return this.getPosterior((DiagMultGaussianSuffStats)condStats).margLogLikelihood(predStats);
    }

    @Override
    public double logProb(SuffStats stats) {
        return this.meanDistrib.logProb(stats);
    }

    @Override
    public double logProbObject(DiagMultGaussian distrib) {
        return this.meanDistrib.logProbObject(distrib.getMean());
    }

    @Override
    public double crossEntropy(Distrib<DiagMultGaussian> _distrib) {
        MargMeanDiagMultGaussian distrib = (MargMeanDiagMultGaussian)_distrib;
        return this.meanDistrib.crossEntropy(distrib.meanDistrib);
    }

    @Override
    public double expectedLogLikelihood(SuffStats _stats) {
        DiagMultGaussianSuffStats stats = (DiagMultGaussianSuffStats)_stats;
        double sum = 0.0;
        int i = 0;
        while (i < this.dim()) {
            sum += this.getComponent(i).expectedLogLikelihood(stats.getComponent(i));
            ++i;
        }
        return sum;
    }

    @Override
    public DiagMultGaussian sampleObject(Random random) {
        return new DiagMultGaussian(this.meanDistrib.sample(random), this.varSpikes);
    }

    public int dim() {
        return this.varSpikes.length;
    }

    public String toString() {
        return String.format("mean(%s),var(%s)", this.meanDistrib, Fmt.D(this.varSpikes));
    }
}

