/*
 * Decompiled with CFR 0.152.
 */
package edu.stanford.nlp.classify;

import edu.stanford.nlp.classify.AbstractLinearClassifierFactory;
import edu.stanford.nlp.classify.GeneralDataset;
import edu.stanford.nlp.classify.LinearClassifier;
import edu.stanford.nlp.ling.BasicDatum;
import edu.stanford.nlp.optimization.GoldenSectionLineSearch;
import edu.stanford.nlp.util.Function;

public class NBLinearClassifierFactory<L, F>
extends AbstractLinearClassifierFactory<L, F> {
    private static final boolean VERBOSE = false;
    private double sigma;
    private final boolean interpretAlwaysOnFeatureAsPrior;
    private static final double epsilon = 1.0E-30;
    private boolean tuneSigma = false;
    private int folds;
    private static final long serialVersionUID = 1L;

    @Override
    protected double[][] trainWeights(GeneralDataset<L, F> data) {
        return this.trainWeights(data.getDataArray(), data.getLabelsArray());
    }

    double[][] trainWeights(int[][] data, int[] labels) {
        if (this.tuneSigma) {
            this.tuneSigma(data, labels);
        }
        int numFeatures = this.numFeatures();
        int numClasses = this.numClasses();
        double[][] weights = new double[numFeatures][numClasses];
        int num = 0;
        double[] numc = new double[numClasses];
        double n = 0.0;
        double[] n_c = new double[numClasses];
        double[] n_f = new double[numFeatures];
        double[][] n_fc = new double[numFeatures][numClasses];
        for (int d = 0; d < data.length; ++d) {
            ++num;
            int n2 = labels[d];
            numc[n2] = numc[n2] + 1.0;
            for (int i = 0; i < data[d].length; ++i) {
                n += 1.0;
                int n3 = labels[d];
                n_c[n3] = n_c[n3] + 1.0;
                int n4 = data[d][i];
                n_f[n4] = n_f[n4] + 1.0;
                double[] dArray = n_fc[data[d][i]];
                int n5 = labels[d];
                dArray[n5] = dArray[n5] + 1.0;
            }
        }
        for (int c = 0; c < numClasses; ++c) {
            for (int f = 0; f < numFeatures; ++f) {
                if (this.interpretAlwaysOnFeatureAsPrior && n_f[f] == (double)data.length) {
                    weights[f][c] = Math.log(numc[c] / (double)num);
                    continue;
                }
                double p_c = (n_c[c] + 1.0E-30) / (n + (double)numClasses * 1.0E-30);
                double p_c_f = (n_fc[f][c] + this.sigma) / (n_f[f] + this.sigma * (double)numClasses);
                weights[f][c] = Math.log(p_c_f / p_c);
            }
        }
        return weights;
    }

    double[][] weights(int[][] data, int[] labels, int testMin, int testMax, double trialSigma, int foldSize) {
        int numFeatures = this.numFeatures();
        int numClasses = this.numClasses();
        double[][] weights = new double[numFeatures][numClasses];
        int num = 0;
        double[] numc = new double[numClasses];
        double n = 0.0;
        double[] n_c = new double[numClasses];
        double[] n_f = new double[numFeatures];
        double[][] n_fc = new double[numFeatures][numClasses];
        for (int d = 0; d < data.length; ++d) {
            if (d == testMin) {
                d = testMax - 1;
                continue;
            }
            ++num;
            int n2 = labels[d];
            numc[n2] = numc[n2] + 1.0;
            for (int i = 0; i < data[d].length; ++i) {
                if (i == testMin) {
                    i = testMax - 1;
                    continue;
                }
                n += 1.0;
                int n3 = labels[d];
                n_c[n3] = n_c[n3] + 1.0;
                int n4 = data[d][i];
                n_f[n4] = n_f[n4] + 1.0;
                double[] dArray = n_fc[data[d][i]];
                int n5 = labels[d];
                dArray[n5] = dArray[n5] + 1.0;
            }
        }
        for (int c = 0; c < numClasses; ++c) {
            for (int f = 0; f < numFeatures; ++f) {
                if (this.interpretAlwaysOnFeatureAsPrior && n_f[f] == (double)(data.length - foldSize)) {
                    weights[f][c] = Math.log(numc[c] / (double)num);
                    continue;
                }
                double p_c = (n_c[c] + 1.0E-30) / (n + (double)numClasses * 1.0E-30);
                double p_c_f = (n_fc[f][c] + trialSigma) / (n_f[f] + trialSigma * (double)numClasses);
                weights[f][c] = Math.log(p_c_f / p_c);
            }
        }
        return weights;
    }

    private void tuneSigma(final int[][] data, final int[] labels) {
        Function<Double, Double> CVSigmaToPerplexity = new Function<Double, Double>(){

            @Override
            public Double apply(Double trialSigma) {
                int nbCV;
                int foldSize;
                double score = 0.0;
                double sumScore = 0.0;
                System.err.println("Trying sigma = " + trialSigma);
                if (data.length >= NBLinearClassifierFactory.this.folds) {
                    foldSize = data.length / NBLinearClassifierFactory.this.folds;
                    nbCV = NBLinearClassifierFactory.this.folds;
                } else {
                    foldSize = 1;
                    nbCV = data.length;
                }
                for (int j = 0; j < nbCV; ++j) {
                    int testMin = j * foldSize;
                    int testMax = testMin + foldSize;
                    LinearClassifier c = new LinearClassifier(NBLinearClassifierFactory.this.weights(data, labels, testMin, testMax, trialSigma, foldSize), NBLinearClassifierFactory.this.featureIndex, NBLinearClassifierFactory.this.labelIndex);
                    for (int i = testMin; i < testMax; ++i) {
                        score -= c.logProbabilityOf(new BasicDatum(NBLinearClassifierFactory.this.featureIndex.objects(data[i]))).getCount(NBLinearClassifierFactory.this.labelIndex.get(labels[i]));
                    }
                    sumScore += score;
                }
                System.err.printf(": %8g%n", sumScore);
                return sumScore;
            }
        };
        GoldenSectionLineSearch gsls = new GoldenSectionLineSearch(true);
        this.sigma = gsls.minimize(CVSigmaToPerplexity, 0.01, 1.0E-4, 2.0);
        System.out.println("Sigma used: " + this.sigma);
    }

    public NBLinearClassifierFactory() {
        this(1.0);
    }

    public NBLinearClassifierFactory(double sigma) {
        this(sigma, false);
    }

    public NBLinearClassifierFactory(double sigma, boolean interpretAlwaysOnFeatureAsPrior) {
        this.sigma = sigma;
        this.interpretAlwaysOnFeatureAsPrior = interpretAlwaysOnFeatureAsPrior;
    }

    public void setTuneSigmaCV(int folds) {
        this.tuneSigma = true;
        this.folds = folds;
    }
}

