/*
 * Decompiled with CFR 0.152.
 */
package fig.prob;

import Jama.CholeskyDecomposition;
import Jama.Matrix;
import fig.prob.Distrib;
import fig.prob.Gaussian;
import fig.prob.MultGaussianSuffStats;
import fig.prob.SampleUtils;
import fig.prob.SuffStats;
import java.util.Arrays;
import java.util.Random;

public class MultGaussian
implements Distrib<double[]> {
    private Matrix mean;
    private Matrix covar;
    private CholeskyDecomposition chol = null;
    private static MultGaussian stdNormal = null;
    private static double[] zeroVector;
    private static double[][] identityMtx;

    public MultGaussian(double[] mean, double[][] covar) {
        this.mean = new Matrix(mean, mean.length);
        this.covar = new Matrix(covar);
    }

    @Override
    public double logProb(SuffStats _stats) {
        MultGaussianSuffStats stats = (MultGaussianSuffStats)_stats;
        double normalizer = 0.5 * Gaussian.LOG_INV_SQRT_2_PI * (double)stats.dim() - this.covar.det() * 0.5;
        Matrix inv = this.covar.inverse();
        double t1 = MultGaussian.aggregatePtwiseProduct(stats.getMtxOuterProduct(), inv);
        double t2 = this.mean.transpose().times(inv).times(this.mean).get(0, 0) * (double)stats.numPoints();
        double t3 = -stats.getMtxSum().transpose().times(inv).times(this.mean).get(0, 0);
        double t4 = -this.mean.transpose().times(inv).times(stats.getMtxSum()).get(0, 0);
        return normalizer - 0.5 * (t1 + t2 + t3 + t4);
    }

    @Override
    public double logProbObject(double[] x) {
        return this.logProb(new MultGaussianSuffStats(x));
    }

    private CholeskyDecomposition getChol() {
        if (this.chol != null) {
            return this.chol;
        }
        this.chol = this.covar.chol();
        return this.chol;
    }

    public double[] sample(Random random) {
        Matrix L = this.getChol().getL();
        Matrix stdNormal = new Matrix(this.dim(), 1);
        int i = 0;
        while (i < this.dim()) {
            stdNormal.set(i, 0, SampleUtils.sampleGaussian(random));
            ++i;
        }
        Matrix result = L.times(stdNormal);
        result.plusEquals(this.mean);
        return result.getColumnPackedCopy();
    }

    @Override
    public double[] sampleObject(Random random) {
        return this.sample(random);
    }

    @Override
    public double crossEntropy(Distrib<double[]> _that) {
        throw new RuntimeException("unsupported");
    }

    public static void main(String[] args) {
        double[] mean = new double[]{1.0, 2.0};
        double[][] covar = new double[2][2];
        covar[0][0] = 1.0;
        covar[1][1] = 4.0;
        covar[0][1] = 1.0;
        covar[1][0] = 1.0;
        MultGaussian g = new MultGaussian(mean, covar);
        Random random = new Random();
        int i = 0;
        while (i < 10000) {
            System.out.println(Arrays.toString(g.sample(random)));
            ++i;
        }
    }

    public static double aggregatePtwiseProduct(Matrix m1, Matrix m2) {
        assert (m1.getRowDimension() == m2.getRowDimension());
        assert (m1.getColumnDimension() == m2.getColumnDimension());
        double sum = 0.0;
        int i = 0;
        while (i < m1.getRowDimension()) {
            int j = 0;
            while (j < m1.getColumnDimension()) {
                sum += m1.get(i, j) * m2.get(i, j);
                ++j;
            }
            ++i;
        }
        return sum;
    }

    public int dim() {
        return this.covar.getRowDimension();
    }

    public static MultGaussian getStdNormal(int n) {
        if (stdNormal != null && stdNormal.dim() == n) {
            return stdNormal;
        }
        stdNormal = new MultGaussian(MultGaussian.getZeroVector(n), MultGaussian.getIdentityMtx(n));
        return stdNormal;
    }

    public static double[] getZeroVector(int n) {
        if (zeroVector != null && zeroVector.length == n) {
            return zeroVector;
        }
        zeroVector = new double[n];
        int i = 0;
        while (i < n) {
            MultGaussian.zeroVector[i] = 0.0;
            ++i;
        }
        return zeroVector;
    }

    public static double[][] getIdentityMtx(int n) {
        if (identityMtx != null && identityMtx.length == n) {
            return identityMtx;
        }
        identityMtx = new double[n][n];
        int i = 0;
        while (i < n) {
            int j = 0;
            while (j < n) {
                MultGaussian.identityMtx[i][j] = i != j ? 0.0 : 1.0;
                ++j;
            }
            ++i;
        }
        return identityMtx;
    }

    public double[] getMean() {
        return this.mean.getArray()[0];
    }

    public double[][] getCovar() {
        return this.covar.getArray();
    }

    public Matrix getCovarMatrix() {
        return this.covar;
    }
}

