/*
 * Decompiled with CFR 0.152.
 */
package fig.prob;

import Jama.Matrix;
import fig.prob.SuffStats;

public class MultGaussianSuffStats
implements SuffStats {
    private Matrix sum;
    private Matrix outerproducts;
    private int n;

    public MultGaussianSuffStats(int numDim) {
        double[] sum = new double[numDim];
        double[][] outerproducts = new double[numDim][numDim];
        this.n = 0;
        this.sum = new Matrix(sum, sum.length);
        this.outerproducts = new Matrix(outerproducts);
    }

    public MultGaussianSuffStats(double[] x) {
        this.sum = new Matrix(x, x.length);
        this.outerproducts = this.sum.times(this.sum.transpose());
        this.n = 1;
    }

    public MultGaussianSuffStats(MultGaussianSuffStats stats) {
        this.sum = stats.sum.copy();
        this.outerproducts = stats.outerproducts.copy();
    }

    public void add(double[] _x) {
        Matrix x = new Matrix(_x, _x.length);
        this.sum = this.sum.plus(x);
        this.outerproducts = this.outerproducts.plus(x.times(x.transpose()));
        ++this.n;
    }

    @Override
    public void add(SuffStats _stats) {
        MultGaussianSuffStats stats = (MultGaussianSuffStats)_stats;
        this.sum = this.sum.plus(stats.sum);
        this.outerproducts = this.outerproducts.plus(stats.outerproducts);
        this.n += stats.n;
    }

    public void sub(double[] _x) {
        Matrix x = new Matrix(_x, _x.length);
        this.sum = this.sum.minus(x);
        this.outerproducts = this.outerproducts.minus(x.times(x.transpose()));
        --this.n;
    }

    @Override
    public void sub(SuffStats _stats) {
        MultGaussianSuffStats stats = (MultGaussianSuffStats)_stats;
        this.sum = this.sum.minus(stats.sum);
        this.outerproducts = this.outerproducts.minus(stats.outerproducts);
        this.n += stats.n;
    }

    @Override
    public SuffStats reweight(double scale) {
        throw new RuntimeException("unsupported");
    }

    public double[] getSum() {
        double[] result = new double[this.dim()];
        int i = 0;
        while (i < this.dim()) {
            result[i] = this.getSum(i);
            ++i;
        }
        return result;
    }

    public double[][] getOuterProduct() {
        return this.outerproducts.getArray();
    }

    public Matrix getMtxOuterProduct() {
        return this.outerproducts;
    }

    public double getSum(int i) {
        return this.sum.get(i, 0);
    }

    public Matrix getMtxSum() {
        return this.sum;
    }

    public double getOuterProduct(int i, int j) {
        return this.outerproducts.get(i, j);
    }

    public int numPoints() {
        return this.n;
    }

    public int dim() {
        return this.outerproducts.getRowDimension();
    }
}

