/*
 * Decompiled with CFR 0.152.
 */
package fig.prob;

import fig.basic.Fmt;
import fig.basic.ListUtils;
import fig.prob.DiagMultGaussianSuffStats;
import fig.prob.Distrib;
import fig.prob.Gaussian;
import fig.prob.SuffStats;
import java.util.Random;

public class DiagMultGaussian
implements Distrib<double[]> {
    private double[] mean;
    private double[] var;

    public DiagMultGaussian(double[] mean, double var) {
        this.mean = mean;
        this.var = ListUtils.newDouble(mean.length, var);
    }

    public DiagMultGaussian(double[] mean, double[] var) {
        this.mean = mean;
        this.var = var;
    }

    public DiagMultGaussian(Gaussian[] distrib) {
        this.mean = new double[distrib.length];
        this.var = new double[distrib.length];
        int i = 0;
        while (i < this.dim()) {
            this.mean[i] = distrib[i].getMean();
            this.var[i] = distrib[i].getVar();
            ++i;
        }
    }

    public DiagMultGaussian(int numDim, Gaussian distrib) {
        this.mean = ListUtils.newDouble(numDim, distrib.getMean());
        this.var = ListUtils.newDouble(numDim, distrib.getVar());
    }

    public double logProb(double[] x) {
        double sum = 0.0;
        int i = 0;
        while (i < this.dim()) {
            sum += Gaussian.logProb(this.mean[i], this.var[i], x[i]);
            ++i;
        }
        return sum;
    }

    @Override
    public double logProb(SuffStats _stats) {
        DiagMultGaussianSuffStats stats = (DiagMultGaussianSuffStats)_stats;
        double sum = 0.0;
        int i = 0;
        while (i < this.dim()) {
            sum += Gaussian.logProb(this.mean[i], this.var[i], stats.getSum(i), stats.getSumSq(i), stats.numPoints());
            ++i;
        }
        return sum;
    }

    @Override
    public double logProbObject(double[] x) {
        return this.logProb(x);
    }

    public double[] sample(Random random) {
        double[] x = new double[this.mean.length];
        int i = 0;
        while (i < this.dim()) {
            x[i] = Gaussian.sample(random, this.mean[i], this.var[i]);
            ++i;
        }
        return x;
    }

    @Override
    public double[] sampleObject(Random random) {
        return this.sample(random);
    }

    @Override
    public double crossEntropy(Distrib<double[]> _that) {
        DiagMultGaussian that = (DiagMultGaussian)_that;
        double sum = 0.0;
        int i = 0;
        while (i < this.dim()) {
            sum += this.getComponent(i).crossEntropy(that.getComponent(i));
            ++i;
        }
        return sum;
    }

    public Gaussian getComponent(int i) {
        return new Gaussian(this.mean[i], this.var[i]);
    }

    public double[] getMean() {
        return this.mean;
    }

    public double[] getVar() {
        return this.var;
    }

    public double getMean(int i) {
        return this.mean[i];
    }

    public double getVar(int i) {
        return this.var[i];
    }

    public int dim() {
        return this.mean.length;
    }

    public String toString() {
        return String.format("mean(%s),var(%s)", Fmt.D(this.mean), Fmt.D(this.var));
    }
}

