/*
 * Decompiled with CFR 0.152.
 */
package fig.prob;

import fig.basic.Fmt;
import fig.prob.Distrib;
import fig.prob.MultinomialSuffStats;
import fig.prob.SuffStats;
import java.util.Random;

public class Multinomial
implements Distrib<Integer> {
    private double[] probs;

    public Multinomial(double[] probs) {
        this.probs = probs;
    }

    public static double logProb(double[] probs, int x) {
        return Math.log(probs[x]);
    }

    public double logProb(int x) {
        return Multinomial.logProb(this.probs, x);
    }

    @Override
    public double logProb(SuffStats stats) {
        double sum = 0.0;
        int i = 0;
        while (i < this.probs.length) {
            sum += ((MultinomialSuffStats)stats).getCount(i) * this.logProb(i);
            ++i;
        }
        return sum;
    }

    @Override
    public double logProbObject(Integer x) {
        return this.logProb(x);
    }

    public static int sample(Random random, double[] probs) {
        double v = random.nextDouble();
        double sum = 0.0;
        int i = 0;
        while (i < probs.length) {
            if (v < (sum += probs[i])) {
                return i;
            }
            ++i;
        }
        throw new RuntimeException(String.valueOf(sum) + " < " + v);
    }

    public int sample(Random random) {
        return Multinomial.sample(random, this.probs);
    }

    @Override
    public Integer sampleObject(Random random) {
        return this.sample(random);
    }

    @Override
    public double crossEntropy(Distrib<Integer> _that) {
        Multinomial that = (Multinomial)_that;
        double sum = 0.0;
        int i = 0;
        while (i < this.probs.length) {
            sum += this.probs[i] * Math.log(that.probs[i]);
            ++i;
        }
        return sum;
    }

    public double[] getProbs() {
        return this.probs;
    }

    public String toString() {
        return String.format("Multinomial(%s)", Fmt.D(this.probs));
    }
}

