/*
 * Decompiled with CFR 0.152.
 */
package edu.stanford.nlp.classify;

import edu.stanford.nlp.classify.BiasedLogisticObjectiveFunction;
import edu.stanford.nlp.classify.ClassifierFactory;
import edu.stanford.nlp.classify.Dataset;
import edu.stanford.nlp.classify.GeneralDataset;
import edu.stanford.nlp.classify.LogPrior;
import edu.stanford.nlp.classify.LogisticClassifier;
import edu.stanford.nlp.classify.LogisticObjectiveFunction;
import edu.stanford.nlp.classify.RVFDataset;
import edu.stanford.nlp.ling.RVFDatum;
import edu.stanford.nlp.optimization.DiffFunction;
import edu.stanford.nlp.optimization.Minimizer;
import edu.stanford.nlp.optimization.QNMinimizer;
import edu.stanford.nlp.util.ErasureUtils;
import edu.stanford.nlp.util.Index;
import edu.stanford.nlp.util.ReflectionLoading;
import java.util.List;

public class LogisticClassifierFactory<L, F>
implements ClassifierFactory<L, F, LogisticClassifier<L, F>> {
    private static final long serialVersionUID = 1L;
    private double[] weights;
    private Index<F> featureIndex;
    private L[] classes = ErasureUtils.mkTArray(Object.class, 2);

    public LogisticClassifier<L, F> trainWeightedData(GeneralDataset<L, F> data, float[] dataWeights) {
        if (data instanceof RVFDataset) {
            ((RVFDataset)data).ensureRealValues();
        }
        if (data.labelIndex.size() != 2) {
            throw new RuntimeException("LogisticClassifier is only for binary classification!");
        }
        LogisticObjectiveFunction lof = null;
        if (data instanceof Dataset) {
            lof = new LogisticObjectiveFunction(data.numFeatureTypes(), data.getDataArray(), data.getLabelsArray(), new LogPrior(LogPrior.LogPriorType.QUADRATIC), dataWeights);
        } else if (data instanceof RVFDataset) {
            lof = new LogisticObjectiveFunction(data.numFeatureTypes(), data.getDataArray(), data.getValuesArray(), data.getLabelsArray(), new LogPrior(LogPrior.LogPriorType.QUADRATIC), dataWeights);
        }
        QNMinimizer minim = new QNMinimizer(lof);
        this.weights = minim.minimize(lof, 1.0E-4, new double[data.numFeatureTypes()]);
        this.featureIndex = data.featureIndex;
        this.classes[0] = data.labelIndex.get(0);
        this.classes[1] = data.labelIndex.get(1);
        return new LogisticClassifier<L, F>(this.weights, this.featureIndex, this.classes);
    }

    @Override
    public LogisticClassifier<L, F> trainClassifier(GeneralDataset<L, F> data) {
        return this.trainClassifier(data, 0.0);
    }

    public LogisticClassifier<L, F> trainClassifier(GeneralDataset<L, F> data, LogPrior prior, boolean biased) {
        return this.trainClassifier(data, 0.0, 1.0E-4, prior, biased);
    }

    public LogisticClassifier<L, F> trainClassifier(GeneralDataset<L, F> data, double l1reg) {
        return this.trainClassifier(data, l1reg, 1.0E-4);
    }

    public LogisticClassifier<L, F> trainClassifier(GeneralDataset<L, F> data, double l1reg, double tol) {
        return this.trainClassifier(data, l1reg, tol, new LogPrior(LogPrior.LogPriorType.QUADRATIC), false);
    }

    public LogisticClassifier<L, F> trainClassifier(GeneralDataset<L, F> data, double l1reg, double tol, LogPrior prior) {
        return this.trainClassifier(data, l1reg, tol, prior, false);
    }

    public LogisticClassifier<L, F> trainClassifier(GeneralDataset<L, F> data, double l1reg, double tol, boolean biased) {
        return this.trainClassifier(data, l1reg, tol, new LogPrior(LogPrior.LogPriorType.QUADRATIC), biased);
    }

    public LogisticClassifier<L, F> trainClassifier(GeneralDataset<L, F> data, double l1reg, double tol, LogPrior prior, boolean biased) {
        if (data instanceof RVFDataset) {
            ((RVFDataset)data).ensureRealValues();
        }
        if (data.labelIndex.size() != 2) {
            throw new RuntimeException("LogisticClassifier is only for binary classification!");
        }
        if (!biased) {
            LogisticObjectiveFunction lof = null;
            if (data instanceof Dataset) {
                lof = new LogisticObjectiveFunction(data.numFeatureTypes(), data.getDataArray(), data.getLabelsArray(), prior);
            } else if (data instanceof RVFDataset) {
                lof = new LogisticObjectiveFunction(data.numFeatureTypes(), data.getDataArray(), data.getValuesArray(), data.getLabelsArray(), prior);
            }
            Minimizer<DiffFunction> minim = l1reg > 0.0 ? (Minimizer)ReflectionLoading.loadByReflection("edu.stanford.nlp.optimization.OWLQNMinimizer", l1reg) : new QNMinimizer(lof);
            this.weights = minim.minimize(lof, tol, new double[data.numFeatureTypes()]);
        } else {
            BiasedLogisticObjectiveFunction lof = new BiasedLogisticObjectiveFunction(data.numFeatureTypes(), data.getDataArray(), data.getLabelsArray(), prior);
            Minimizer<DiffFunction> minim = l1reg > 0.0 ? (Minimizer)ReflectionLoading.loadByReflection("edu.stanford.nlp.optimization.OWLQNMinimizer", l1reg) : new QNMinimizer(lof);
            this.weights = minim.minimize(lof, tol, new double[data.numFeatureTypes()]);
        }
        this.featureIndex = data.featureIndex;
        this.classes[0] = data.labelIndex.get(0);
        this.classes[1] = data.labelIndex.get(1);
        return new LogisticClassifier<L, F>(this.weights, this.featureIndex, this.classes);
    }

    @Override
    @Deprecated
    public LogisticClassifier<L, F> trainClassifier(List<RVFDatum<L, F>> examples) {
        return null;
    }
}

