/*
 * Decompiled with CFR 0.152.
 */
package fig.prob;

import fig.basic.NumUtils;
import fig.basic.Pair;
import fig.basic.StatFig;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Iterator;
import java.util.List;
import java.util.Random;

public class SampleUtils {
    private static double oldm = -1.0;
    private static double g;
    private static double sq;
    private static double alxm;
    private static int nold;
    private static double pold;
    private static double pc;
    private static double plog;
    private static double pclog;
    private static double en;
    private static double oldg;

    static {
        nold = -1;
        pold = -1.0;
    }

    public static int[] samplePermutation(Random random, int n) {
        int[] perm = new int[n];
        int i = 0;
        while (i < n) {
            perm[i] = i;
            ++i;
        }
        i = 0;
        while (i < n - 1) {
            int j = i + random.nextInt(n - i);
            int tmp = perm[i];
            perm[i] = perm[j];
            perm[j] = tmp;
            ++i;
        }
        return perm;
    }

    public static <T> Pair<List<T>, List<T>> samplePartition(Random random, List<T> items, int k) {
        if (k > items.size()) {
            throw new RuntimeException("Tried to select " + k + " (too many) of " + items.size() + " items");
        }
        ArrayList<Integer> perm = new ArrayList<Integer>(items.size());
        int i = 0;
        while (i < items.size()) {
            perm.add(i);
            ++i;
        }
        i = 0;
        while (i < k) {
            int j = i + random.nextInt(items.size() - i);
            int tmp = (Integer)perm.get(i);
            perm.set(i, (Integer)perm.get(j));
            perm.set(j, tmp);
            ++i;
        }
        List perm1 = perm.subList(0, k);
        List perm2 = perm.subList(k, items.size());
        Collections.sort(perm1);
        Collections.sort(perm2);
        ArrayList<T> result1 = new ArrayList<T>();
        Iterator iterator = perm1.iterator();
        while (iterator.hasNext()) {
            int i2 = (Integer)iterator.next();
            result1.add(items.get(i2));
        }
        ArrayList<T> result2 = new ArrayList<T>();
        Iterator iterator2 = perm2.iterator();
        while (iterator2.hasNext()) {
            int i3 = (Integer)iterator2.next();
            result2.add(items.get(i3));
        }
        return new Pair<List<T>, List<T>>(result1, result2);
    }

    public static int sampleMultinomial(Random random, double[] probs) {
        double v = random.nextDouble();
        double sum = 0.0;
        int i = 0;
        while (i < probs.length) {
            if (v < (sum += probs[i])) {
                return i;
            }
            ++i;
        }
        throw new RuntimeException(String.valueOf(sum) + " < " + v);
    }

    public static double[] sampleUnitVector(Random random, int n) {
        double[] x = new double[n];
        int i = 0;
        while (i < n) {
            x[i] = random.nextDouble() - 0.5;
            ++i;
        }
        double norm = NumUtils.l2Norm(x);
        int i2 = 0;
        while (i2 < n) {
            int n2 = i2++;
            x[n2] = x[n2] / norm;
        }
        return x;
    }

    public static double sampleGamma(Random random, double a, double rate) {
        double v;
        double boost;
        if (a < 1.0) {
            boost = Math.exp(Math.log(random.nextDouble()) / a);
            a += 1.0;
        } else {
            boost = 1.0;
        }
        double d = a - 0.3333333333333333;
        double c = 1.0 / Math.sqrt(9.0 * d);
        while (true) {
            double x;
            if ((v = 1.0 + c * (x = SampleUtils.sampleGaussian(random))) <= 0.0) {
                continue;
            }
            v = v * v * v;
            x *= x;
            double u = random.nextDouble();
            if (u < 1.0 - 0.0331 * x * x || Math.log(u) < 0.5 * x + d * (1.0 - v + Math.log(v))) break;
        }
        return boost * d * v / rate;
    }

    public static double sampleErlang(Random random, int ia, double rate) {
        double x;
        assert (ia >= 1);
        if (ia < 6) {
            x = 1.0;
            int j = 1;
            while (j <= ia) {
                x *= random.nextDouble();
                ++j;
            }
            x = -Math.log(x);
        } else {
            while (true) {
                double v2;
                double v1;
                if ((v1 = 2.0 * random.nextDouble() - 1.0) * v1 + (v2 = 2.0 * random.nextDouble() - 1.0) * v2 > 1.0) {
                    continue;
                }
                double y = v2 / v1;
                double am = ia - 1;
                double s = Math.sqrt(2.0 * am + 1.0);
                x = s * y + am;
                if (x <= 0.0) continue;
                double e = (1.0 + y * y) * Math.exp(am * Math.log(x / am) - s * y);
                if (!(random.nextDouble() > e)) break;
            }
        }
        return x / rate;
    }

    public static double sampleGaussian(Random random) {
        double x1 = random.nextDouble();
        double x2 = random.nextDouble();
        double z = Math.sqrt(-2.0 * Math.log(x1)) * Math.cos(Math.PI * 2 * x2);
        return z;
    }

    public static double samplePoisson(Random random, double rate) {
        double em;
        double xm = rate;
        if (xm < 12.0) {
            if (xm != oldm) {
                oldm = xm;
                g = Math.exp(-xm);
            }
            em = -1.0;
            double t = 1.0;
            do {
                em += 1.0;
            } while ((t *= random.nextDouble()) > g);
        } else {
            if (xm != oldm) {
                oldm = xm;
                sq = Math.sqrt(2.0 * xm);
                alxm = Math.log(xm);
                g = xm * alxm - NumUtils.logGamma(xm + 1.0);
            }
            while (true) {
                double y;
                if ((em = sq * (y = Math.tan(Math.PI * random.nextDouble())) + xm) < 0.0) {
                    continue;
                }
                em = Math.floor(em);
                double t = 0.9 * (1.0 + y * y) * Math.exp(em * alxm - NumUtils.logGamma(em + 1.0) - g);
                if (!(random.nextDouble() > t)) break;
            }
        }
        return (int)em;
    }

    public static int sampleBinomial(Random random, int n, double pp) {
        double bnl;
        double p = pp <= 0.5 ? pp : 1.0 - pp;
        double am = (double)n * p;
        if (n < 25) {
            bnl = 0.0;
            int j = 1;
            while (j <= n) {
                if (random.nextDouble() < p) {
                    bnl += 1.0;
                }
                ++j;
            }
        } else if (am < 1.0) {
            double g = Math.exp(-am);
            double t = 1.0;
            int j = 0;
            while (j <= n) {
                if ((t *= random.nextDouble()) < g) break;
                ++j;
            }
            bnl = j <= n ? j : n;
        } else {
            double em;
            if (n != nold) {
                en = n;
                oldg = NumUtils.logGamma(en + 1.0);
                nold = n;
            }
            if (p != pold) {
                pc = 1.0 - p;
                plog = Math.log(p);
                pclog = Math.log(pc);
                pold = p;
            }
            double sq = Math.sqrt(2.0 * am * pc);
            while (true) {
                double angle;
                double y;
                if ((em = sq * (y = Math.tan(angle = Math.PI * random.nextDouble())) + am) < 0.0 || em >= en + 1.0) {
                    continue;
                }
                em = Math.floor(em);
                double t = 1.2 * sq * (1.0 + y * y) * Math.exp(oldg - NumUtils.logGamma(em + 1.0) - NumUtils.logGamma(en - em + 1.0) + em * plog + (en - em) * pclog);
                if (!(random.nextDouble() > t)) break;
            }
            bnl = em;
        }
        if (p != pp) {
            bnl = (double)n - bnl;
        }
        return (int)bnl;
    }

    public static int[] sampleMultinomialNaive(Random random, int n, double[] probs) {
        int K = probs.length;
        int[] counts = new int[K];
        double massRemaining = 1.0;
        int i = 0;
        while (i < K - 1) {
            counts[i] = SampleUtils.sampleBinomial(random, n, probs[i] / massRemaining);
            n -= counts[i];
            massRemaining -= probs[i];
            ++i;
        }
        counts[K - 1] = n;
        return counts;
    }

    public static int[] sampleMultinomial(Random random, int n, double[] probs) {
        int K = probs.length;
        MultinomialSampler sampler = new MultinomialSampler();
        sampler.random = random;
        sampler.accumProbs = new double[K + 1];
        int i = 0;
        while (i < K) {
            sampler.accumProbs[i + 1] = sampler.accumProbs[i] + probs[i];
            ++i;
        }
        sampler.counts = new int[K];
        sampler.sample(n, 0, K);
        return sampler.counts;
    }

    public static void main(String[] args) {
        Random random = new Random(1L);
        double[] probs = new double[]{0.2, 0.7, 0.08, 0.02};
        int K = probs.length;
        StatFig[] figs = new StatFig[K];
        int a = 0;
        while (a < K) {
            figs[a] = new StatFig();
            ++a;
        }
        int i = 0;
        while (i < 100000) {
            int[] counts = SampleUtils.sampleMultinomial(random, 1000, probs);
            int a2 = 0;
            while (a2 < K) {
                figs[a2].add(counts[a2]);
                ++a2;
            }
            ++i;
        }
        a = 0;
        while (a < K) {
            System.out.println(String.valueOf(a) + " (" + probs[a] + "): " + figs[a]);
            ++a;
        }
    }

    static class MultinomialSampler {
        Random random;
        double[] accumProbs;
        int[] counts;

        MultinomialSampler() {
        }

        private void sample(int n, int i, int j) {
            assert (i < j) : String.valueOf(i) + " " + j;
            if (i + 1 == j) {
                this.counts[i] = n;
                return;
            }
            int k = (i + j) / 2;
            double prob = (this.accumProbs[k] - this.accumProbs[i]) / (this.accumProbs[j] - this.accumProbs[i]);
            int m = SampleUtils.sampleBinomial(this.random, n, prob);
            this.sample(m, i, k);
            this.sample(n - m, k, j);
        }
    }
}

