/*
 * Decompiled with CFR 0.152.
 */
package fig.prob;

import fig.basic.LogInfo;
import fig.basic.StatFig;
import fig.basic.TDoubleMap;
import fig.prob.DiagMultGaussian;
import fig.prob.DiagMultGaussianSuffStats;
import fig.prob.Dirichlet;
import fig.prob.DirichletInterface;
import fig.prob.Distrib;
import fig.prob.Gamma;
import fig.prob.Gaussian;
import fig.prob.GaussianSuffStats;
import fig.prob.MargDistrib;
import fig.prob.MargMeanDiagMultGaussian;
import fig.prob.MargMeanGaussian;
import fig.prob.MargMultinomial;
import fig.prob.MargSparseMultinomial;
import fig.prob.MultinomialSuffStats;
import fig.prob.SparseDirichlet;
import fig.prob.SparseMultinomialSuffStats;
import fig.prob.SuffStats;
import java.util.Random;

public class DistribUtils {
    public static final double margin = 1.0E-8;

    public static double predLogLikelihood(MargDistrib margDistrib, SuffStats condSuffStats, SuffStats predSuffStats) {
        double oldLogProb = margDistrib.margLogLikelihood(predSuffStats);
        predSuffStats.add(condSuffStats);
        double newLogProb = margDistrib.margLogLikelihood(predSuffStats);
        predSuffStats.sub(condSuffStats);
        return newLogProb - oldLogProb;
    }

    public static double KL(Distrib d1, Distrib d2) {
        return d1.crossEntropy(d1) - d1.crossEntropy(d2);
    }

    public static <T> void verifyCrossEntropy(Distrib<T> d1, Distrib<T> d2) {
        Random random = new Random();
        StatFig fig = new StatFig();
        int i = 0;
        while (i < 100000) {
            fig.add(d2.logProbObject(d1.sampleObject(random)));
            ++i;
        }
        double A = d1.crossEntropy(d2);
        double B = fig.mean();
        System.out.println(String.valueOf(A) + " " + B + " " + (A - B));
    }

    public static void verifyExpectedLogLikelihood(MargDistrib prior, SuffStats stats) {
        DistribUtils.verifyExpectedLogLikelihood(prior, stats, 100000);
    }

    public static void verifyExpectedLogLikelihood(MargDistrib prior, SuffStats stats, int numSamples) {
        Random random = new Random();
        StatFig fig = new StatFig();
        int i = 0;
        while (i < numSamples) {
            Distrib param = (Distrib)prior.sampleObject(random);
            fig.add(param.logProb(stats));
            ++i;
        }
        double A = prior.expectedLogLikelihood(stats);
        double B = fig.mean();
        System.out.println(String.valueOf(A) + " " + B + " " + (A - B));
    }

    public static void verifyPassed() {
        DistribUtils.verifyCrossEntropy(new Gaussian(2.0, 0.3), new Gaussian(8.0, 1.7));
        DistribUtils.verifyCrossEntropy(new Gamma(2.0, 0.3), new Gamma(8.0, 1.7));
        DistribUtils.verifyCrossEntropy(new Dirichlet(10, 0.3), new Dirichlet(10, 1.7));
        TDoubleMap<String> m1 = new TDoubleMap<String>();
        m1.put("A", 3.0);
        m1.put("B", 8.0);
        m1.put("C", 0.0);
        TDoubleMap<String> m2 = new TDoubleMap<String>();
        m2.put("A", 3.0);
        m1.put("B", 0.0);
        m2.put("C", 1.0);
        SparseDirichlet d1 = new SparseDirichlet(10, 0.3, m1);
        SparseDirichlet d2 = new SparseDirichlet(10, 1.7, m2);
        DistribUtils.verifyCrossEntropy(d1, d2);
        DistribUtils.verifyExpectedLogLikelihood(new MargMultinomial(new Dirichlet(5, 1.3)), new MultinomialSuffStats(new double[]{4.0, 21.0, 0.3, 2.0, 4.0}));
        TDoubleMap<String> m = new TDoubleMap<String>();
        m.put("A", 3.0);
        m.put("B", 8.0);
        m.put("C", 4.0);
        DistribUtils.verifyExpectedLogLikelihood(new MargSparseMultinomial(new SparseDirichlet(10, 1.3, m)), new SparseMultinomialSuffStats(m));
        DistribUtils.verifyExpectedLogLikelihood(new MargMeanGaussian(new Gaussian(0.0, 1.0), 1.0), new GaussianSuffStats(0.0, 0.0, 1.0));
        DistribUtils.verifyExpectedLogLikelihood(new MargMeanGaussian(new Gaussian(2.0, 10.0), 0.7), new GaussianSuffStats(2.0, 10.0, 0.3), 1000000);
        DistribUtils.verifyExpectedLogLikelihood(new MargMeanDiagMultGaussian(new DiagMultGaussian(new double[]{3.0, 4.0, -2.0}, new double[]{0.7, 1.5, 4.4}), new double[]{1.7, 3.5, 0.4}), new DiagMultGaussianSuffStats(new double[]{1.0, 1.0, 0.0}, new double[]{8.0, 4.0, 3.0}, 0.37), 1000000);
        Random random = new Random();
        Gamma g = new Gamma(2.0, 0.3);
        StatFig fig = new StatFig();
        int i = 0;
        while (i < 10000) {
            fig.add(Math.log(g.sample(random)));
            ++i;
        }
        System.out.println(String.valueOf(g.expectedLog()) + " " + fig.mean());
    }

    public static void main(String[] args) {
        LogInfo.init();
        LogInfo.msPerLine = 0;
        DirichletInterface d1 = new Dirichlet(new double[]{10001.0, 1.0, 2.0, 1.0});
        DirichletInterface d2 = new Dirichlet(new double[]{50001.0, 1.0, 6.0, 1.0});
        LogInfo.logs(String.valueOf(Math.exp(d1.expectedLog(2))) + " " + Math.exp(d1.expectedLog(3)));
        LogInfo.logs(String.valueOf(Math.exp(d2.expectedLog(2))) + " " + Math.exp(d2.expectedLog(3)));
        d1 = d1.modeSpike();
        d2 = d2.modeSpike();
        LogInfo.logs(String.valueOf(Math.exp(d1.expectedLog(2))) + " " + Math.exp(d1.expectedLog(3)));
        LogInfo.logs(String.valueOf(Math.exp(d2.expectedLog(2))) + " " + Math.exp(d2.expectedLog(3)));
    }
}

