/*
 * Decompiled with CFR 0.152.
 */
package fig.prob;

import Jama.Matrix;
import fig.prob.Distrib;
import fig.prob.NormalInverseWishart;
import fig.prob.SuffStats;
import java.util.Random;

public class NormalInverseWishartDistrib
implements Distrib<NormalInverseWishart> {
    private double nu;
    private Matrix delta;
    private Matrix scriptV;
    private double kappa;

    public NormalInverseWishartDistrib(double kappa, Matrix scriptV, double nu, Matrix delta) {
        if (kappa <= 0.0) {
            throw new RuntimeException("kappa " + kappa + " should be > 0");
        }
        if (nu <= (double)(delta.getColumnDimension() + 1)) {
            throw new RuntimeException("nu " + nu + " should be > d + 1, d = " + delta.getColumnDimension());
        }
        this.nu = nu;
        this.delta = delta;
        this.scriptV = scriptV;
        this.kappa = kappa;
    }

    public NormalInverseWishartDistrib(double kappa, double[] scriptV, double nu, double[][] delta) {
        this(kappa, new Matrix(scriptV, scriptV.length), nu, new Matrix(delta));
    }

    public double unNormalizedLogProb(double[] mean, double[][] covar) {
        Matrix mu = new Matrix(mean, mean.length);
        Matrix lambda = new Matrix(covar);
        assert (this.isIdentity(lambda));
        double determinant = 1.0;
        double exponent = -((this.nu + (double)this.dim()) / 2.0 + 1.0);
        double normalizer = exponent * Math.log(determinant);
        Matrix inverse = lambda;
        double t1 = -0.5 * this.delta.times(inverse).trace();
        double t2 = -this.kappa / 2.0 * NormalInverseWishartDistrib.norm(inverse, mu.minus(this.scriptV));
        return normalizer + t1 + t2;
    }

    @Override
    public double logProb(SuffStats stats) {
        throw new RuntimeException("Not implemented");
    }

    @Override
    public double logProbObject(NormalInverseWishart x) {
        throw new RuntimeException("Not supported right now");
    }

    @Override
    public NormalInverseWishart sampleObject(Random random) {
        throw new RuntimeException("Not supported right now");
    }

    @Override
    public double crossEntropy(Distrib<NormalInverseWishart> _that) {
        throw new RuntimeException("Not supported");
    }

    private boolean isIdentity(Matrix lambda) {
        int i = 0;
        while (i < lambda.getRowDimension()) {
            int j = 0;
            while (j < lambda.getColumnDimension()) {
                if (i == j ? lambda.get(i, j) != 1.0 : lambda.get(i, j) != 0.0) {
                    return false;
                }
                ++j;
            }
            ++i;
        }
        return true;
    }

    public static double norm(Matrix kernel, Matrix vector) {
        assert (kernel.getColumnDimension() == kernel.getRowDimension());
        assert (kernel.getColumnDimension() == vector.getRowDimension());
        assert (vector.getColumnDimension() == 1);
        Matrix result = vector.transpose().times(kernel).times(vector);
        assert (result.getColumnDimension() == 1);
        assert (result.getRowDimension() == 1);
        return result.get(0, 0);
    }

    public Matrix expectedVariance() {
        double coefficient = this.nu / (this.nu - (double)this.dim() - 1.0);
        return this.delta.times(coefficient);
    }

    public int dim() {
        return this.delta.getColumnDimension();
    }

    public Matrix getDelta() {
        return this.delta;
    }

    public double getKappa() {
        return this.kappa;
    }

    public double getNu() {
        return this.nu;
    }

    public Matrix getScriptV() {
        return this.scriptV;
    }

    public String toString() {
        return "NIW(nu=" + this.nu + ", kappa=" + this.kappa + ")";
    }
}

