/*
 * Decompiled with CFR 0.152.
 */
package edu.stanford.nlp.optimization;

import edu.stanford.nlp.math.ArrayMath;
import edu.stanford.nlp.optimization.AbstractCachingDiffFunction;
import edu.stanford.nlp.optimization.AbstractStochasticCachingDiffFunction;
import edu.stanford.nlp.optimization.AbstractStochasticCachingDiffUpdateFunction;
import edu.stanford.nlp.optimization.DiffFunction;
import edu.stanford.nlp.optimization.Evaluator;
import edu.stanford.nlp.optimization.HasEvaluators;
import edu.stanford.nlp.optimization.HasFeatureGrouping;
import edu.stanford.nlp.optimization.HasRegularizerParamRange;
import edu.stanford.nlp.optimization.Minimizer;
import edu.stanford.nlp.util.Timing;
import java.text.DecimalFormat;
import java.text.NumberFormat;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Random;
import java.util.Set;

public class SGDWithAdaGradAndFOBOS<T extends DiffFunction>
implements Minimizer<T>,
HasEvaluators {
    protected double[] x;
    protected double initRate;
    protected double lambda;
    protected double alpha = 1.0;
    protected boolean quiet = false;
    private static final int DEFAULT_NUM_PASSES = 50;
    protected final int numPasses;
    protected int bSize = 1;
    private static final int DEFAULT_TUNING_SAMPLES = Integer.MAX_VALUE;
    private static final int DEFAULT_BATCH_SIZE = 1000;
    private double eps = 0.001;
    private double TOL = 1.0E-4;
    public List<double[]> yList = null;
    public List<double[]> sList = null;
    public double[] diag;
    private int hessSampleSize = -1;
    private double[] s;
    private double[] y = null;
    protected Random gen = new Random(1L);
    protected long maxTime = Long.MAX_VALUE;
    private int evaluateIters = 0;
    private Evaluator[] evaluators;
    private Prior prior = Prior.LASSO;
    private boolean useEvalImprovement = false;
    private boolean useAvgImprovement = false;
    private boolean suppressTestPrompt = false;
    private int terminateOnEvalImprovementNumOfEpoch = 1;
    private double bestEvalSoFar = Double.NEGATIVE_INFINITY;
    private double[] xBest;
    private int noImproveItrCount = 0;
    private boolean useAdaDelta = false;
    private boolean useAdaDiff = false;
    private double rho = 0.95;
    private double[] sumGradSquare;
    private double[] prevGrad;
    private double[] prevDeltaX;
    private double[] sumDeltaXSquare;
    private static final NumberFormat nf = new DecimalFormat("0.000E0");

    public void setHessSampleSize(int hessSize) {
        this.hessSampleSize = hessSize;
    }

    public void terminateOnEvalImprovement(boolean toTerminate) {
        this.useEvalImprovement = toTerminate;
    }

    public void terminateOnAvgImprovement(boolean toTerminate, double tolerance) {
        this.useAvgImprovement = toTerminate;
        this.TOL = tolerance;
    }

    public void suppressTestPrompt(boolean suppressTestPrompt) {
        this.suppressTestPrompt = suppressTestPrompt;
    }

    public void setTerminateOnEvalImprovementNumOfEpoch(int terminateOnEvalImprovementNumOfEpoch) {
        this.terminateOnEvalImprovementNumOfEpoch = terminateOnEvalImprovementNumOfEpoch;
    }

    public boolean toContinue(double[] x, double currEval) {
        if (currEval >= this.bestEvalSoFar) {
            this.bestEvalSoFar = currEval;
            this.noImproveItrCount = 0;
            if (this.xBest == null) {
                this.xBest = Arrays.copyOf(x, x.length);
            } else {
                System.arraycopy(x, 0, this.xBest, 0, x.length);
            }
            return true;
        }
        ++this.noImproveItrCount;
        return this.noImproveItrCount <= this.terminateOnEvalImprovementNumOfEpoch;
    }

    private static Prior getPrior(String priorType) {
        if (priorType.equals("none")) {
            return Prior.NONE;
        }
        if (priorType.equals("lasso")) {
            return Prior.LASSO;
        }
        if (priorType.equals("ridge")) {
            return Prior.RIDGE;
        }
        if (priorType.equals("gaussian")) {
            return Prior.GAUSSIAN;
        }
        if (priorType.equals("ae-lasso")) {
            return Prior.aeLASSO;
        }
        if (priorType.equals("g-lasso")) {
            return Prior.gLASSO;
        }
        if (priorType.equals("sg-lasso")) {
            return Prior.sgLASSO;
        }
        throw new IllegalArgumentException("prior type " + priorType + " not recognized; supported priors " + "are: lasso, ridge, gaussian, ae-lasso, g-lasso, and sg-lasso");
    }

    public SGDWithAdaGradAndFOBOS(double initRate, double lambda, int numPasses) {
        this(initRate, lambda, numPasses, -1);
    }

    public SGDWithAdaGradAndFOBOS(double initRate, double lambda, int numPasses, int batchSize) {
        this(initRate, lambda, numPasses, batchSize, "lasso", 1.0, false, false, 0.001, 0.95);
    }

    public SGDWithAdaGradAndFOBOS(double initRate, double lambda, int numPasses, int batchSize, String priorType, double alpha, boolean useAdaDelta, boolean useAdaDiff, double adaGradEps, double adaDeltaRho) {
        this.initRate = initRate;
        this.prior = SGDWithAdaGradAndFOBOS.getPrior(priorType);
        this.bSize = batchSize;
        this.lambda = lambda;
        this.eps = adaGradEps;
        this.rho = adaDeltaRho;
        this.useAdaDelta = useAdaDelta;
        this.useAdaDiff = useAdaDiff;
        this.alpha = alpha;
        if (numPasses >= 0) {
            this.numPasses = numPasses;
        } else {
            this.numPasses = 50;
            this.sayln("  SGDWithAdaGradAndFOBOS: numPasses=" + numPasses + ", defaulting to " + this.numPasses);
        }
    }

    public void shutUp() {
        this.quiet = true;
    }

    protected String getName() {
        return "SGDWithAdaGradAndFOBOS" + this.bSize + "_lambda" + nf.format(this.lambda) + "_alpha" + nf.format(this.alpha);
    }

    @Override
    public void setEvaluators(int iters, Evaluator[] evaluators) {
        this.evaluateIters = iters;
        this.evaluators = evaluators;
    }

    private static double getNorm(double[] w) {
        double norm = 0.0;
        for (int i = 0; i < w.length; ++i) {
            norm += w[i] * w[i];
        }
        return Math.sqrt(norm);
    }

    private double doEvaluation(double[] x) {
        if (this.evaluators == null) {
            return Double.NEGATIVE_INFINITY;
        }
        double score = Double.NEGATIVE_INFINITY;
        for (Evaluator eval : this.evaluators) {
            double aScore;
            if (!this.suppressTestPrompt) {
                this.sayln("  Evaluating: " + eval.toString());
            }
            if ((aScore = eval.evaluate(x)) == Double.NEGATIVE_INFINITY) continue;
            score = aScore;
        }
        return score;
    }

    private static double pospart(double number) {
        return number > 0.0 ? number : 0.0;
    }

    private double computeLearningRate(int index, double grad) {
        double currentRate = Double.NEGATIVE_INFINITY;
        double prevG = this.prevGrad[index];
        double gradDiff = grad - prevG;
        if (this.useAdaDelta) {
            double deltaXt = this.prevDeltaX[index];
            this.sumDeltaXSquare[index] = this.sumDeltaXSquare[index] * this.rho + (1.0 - this.rho) * deltaXt * deltaXt;
            this.sumGradSquare[index] = this.useAdaDiff ? this.sumGradSquare[index] * this.rho + (1.0 - this.rho) * gradDiff * gradDiff : this.sumGradSquare[index] * this.rho + (1.0 - this.rho) * grad * grad;
            currentRate = Math.sqrt(this.sumDeltaXSquare[index] + this.eps) / Math.sqrt(this.sumGradSquare[index] + this.eps);
        } else {
            if (this.useAdaDiff) {
                int n = index;
                this.sumGradSquare[n] = this.sumGradSquare[n] + gradDiff * gradDiff;
            } else {
                int n = index;
                this.sumGradSquare[n] = this.sumGradSquare[n] + grad * grad;
            }
            currentRate = this.initRate / Math.sqrt(this.sumGradSquare[index] + this.eps);
        }
        return currentRate;
    }

    private void updateX(double[] x, int index, double realUpdate) {
        this.prevDeltaX[index] = realUpdate - x[index];
        x[index] = realUpdate;
    }

    @Override
    public double[] minimize(DiffFunction function, double functionTolerance, double[] initial) {
        return this.minimize(function, functionTolerance, initial, -1);
    }

    @Override
    public double[] minimize(DiffFunction f, double functionTolerance, double[] initial, int maxIterations) {
        boolean have_max;
        int totalSamples = 0;
        this.sayln("Using lambda=" + this.lambda);
        if (f instanceof AbstractStochasticCachingDiffUpdateFunction) {
            AbstractStochasticCachingDiffUpdateFunction func = (AbstractStochasticCachingDiffUpdateFunction)f;
            func.sampleMethod = AbstractStochasticCachingDiffFunction.SamplingMethod.Shuffled;
            totalSamples = func.dataDimension();
            if (this.bSize > totalSamples) {
                System.err.println("WARNING: Total number of samples=" + totalSamples + " is smaller than requested batch size=" + this.bSize + "!!!");
                this.bSize = totalSamples;
                this.sayln("Using batch size=" + this.bSize);
            }
            if (this.bSize <= 0) {
                System.err.println("WARNING: Requested batch size=" + this.bSize + " <= 0 !!!");
                this.bSize = totalSamples;
                this.sayln("Using batch size=" + this.bSize);
            }
        }
        this.x = new double[initial.length];
        double[] testUpdateCache = null;
        double[] currentRateCache = null;
        double[] bCache = null;
        this.sumGradSquare = new double[initial.length];
        this.prevGrad = new double[initial.length];
        this.prevDeltaX = new double[initial.length];
        if (this.useAdaDelta) {
            this.sumDeltaXSquare = new double[initial.length];
            if (this.prior != Prior.NONE && this.prior != Prior.GAUSSIAN) {
                throw new UnsupportedOperationException("useAdaDelta is currently only supported for Prior.NONE or Prior.GAUSSIAN");
            }
        }
        int[][] featureGrouping = null;
        if (this.prior != Prior.LASSO && this.prior != Prior.NONE) {
            testUpdateCache = new double[initial.length];
            currentRateCache = new double[initial.length];
        }
        if (this.prior != Prior.LASSO && this.prior != Prior.RIDGE && this.prior != Prior.GAUSSIAN) {
            if (!(f instanceof HasFeatureGrouping)) {
                throw new UnsupportedOperationException("prior is specified to be ae-lasso or g-lasso, but function does not support feature grouping");
            }
            featureGrouping = ((HasFeatureGrouping)((Object)f)).getFeatureGrouping();
        }
        if (this.prior == Prior.sgLASSO) {
            bCache = new double[initial.length];
        }
        System.arraycopy(initial, 0, this.x, 0, this.x.length);
        int numBatches = 1;
        if (f instanceof AbstractStochasticCachingDiffUpdateFunction && totalSamples > 0) {
            numBatches = totalSamples / this.bSize;
        }
        boolean bl = have_max = maxIterations > 0 || this.numPasses > 0;
        if (!have_max) {
            throw new UnsupportedOperationException("No maximum number of iterations has been specified.");
        }
        maxIterations = Math.max(maxIterations, this.numPasses * numBatches);
        this.sayln("       Batch size of: " + this.bSize);
        this.sayln("       Data dimension of: " + totalSamples);
        this.sayln("       Batches per pass through data:  " + numBatches);
        this.sayln("       Number of passes is = " + this.numPasses);
        this.sayln("       Max iterations is = " + maxIterations);
        Timing total = new Timing();
        Timing current = new Timing();
        total.start();
        current.start();
        int iters = 0;
        double gValue = 0.0;
        double wValue = 0.0;
        double currentRate = 0.0;
        double testUpdate = 0.0;
        double realUpdate = 0.0;
        ArrayList<Double> values = null;
        double oldObjVal = 0.0;
        for (int pass = 0; pass < this.numPasses; ++pass) {
            int size;
            double previousVal;
            double averageImprovement;
            boolean doEval = pass > 0 && this.evaluateIters > 0 && pass % this.evaluateIters == 0;
            double evalScore = Double.NEGATIVE_INFINITY;
            if (doEval) {
                evalScore = this.doEvaluation(this.x);
                if (this.useEvalImprovement && !this.toContinue(this.x, evalScore)) break;
            }
            double objVal = Double.NEGATIVE_INFINITY;
            double objDelta = Double.NEGATIVE_INFINITY;
            this.say("Iter: " + iters + " pass " + pass + " batch 1 ... ");
            int numOfNonZero = 0;
            int numOfNonZeroGroup = 0;
            String gSizeStr = "";
            for (int batch = 0; batch < numBatches; ++batch) {
                ++iters;
                double[] gradients = null;
                if (f instanceof AbstractStochasticCachingDiffUpdateFunction) {
                    AbstractStochasticCachingDiffUpdateFunction func = (AbstractStochasticCachingDiffUpdateFunction)f;
                    if (this.bSize == totalSamples) {
                        objVal = func.valueAt(this.x);
                        gradients = func.getDerivative();
                        objDelta = objVal - oldObjVal;
                        oldObjVal = objVal;
                        if (values == null) {
                            values = new ArrayList<Double>();
                        }
                        values.add(objVal);
                    } else {
                        func.calculateStochasticGradient(this.x, this.bSize);
                        gradients = func.getDerivative();
                    }
                } else if (f instanceof AbstractCachingDiffFunction) {
                    AbstractCachingDiffFunction func = (AbstractCachingDiffFunction)f;
                    gradients = func.derivativeAt(this.x);
                }
                if (this.prior == Prior.NONE || this.prior == Prior.GAUSSIAN) {
                    for (int index = 0; index < this.x.length; ++index) {
                        gValue = gradients[index];
                        currentRate = this.computeLearningRate(index, gValue);
                        wValue = this.x[index];
                        realUpdate = testUpdate = wValue - currentRate * gValue;
                        this.updateX(this.x, index, realUpdate);
                    }
                } else if (this.prior == Prior.LASSO || this.prior == Prior.RIDGE) {
                    double testUpdateSquaredSum = 0.0;
                    Set<Object> paramRange = null;
                    if (f instanceof HasRegularizerParamRange) {
                        paramRange = ((HasRegularizerParamRange)((Object)f)).getRegularizerParamRange(this.x);
                    } else {
                        paramRange = new HashSet();
                        for (int i = 0; i < this.x.length; ++i) {
                            paramRange.add(i);
                        }
                    }
                    Iterator<Object> i$ = paramRange.iterator();
                    while (i$.hasNext()) {
                        int index = (Integer)i$.next();
                        gValue = gradients[index];
                        currentRate = this.computeLearningRate(index, gValue);
                        wValue = this.x[index];
                        testUpdate = wValue - currentRate * gValue;
                        double currentLambda = currentRate * this.lambda;
                        if (this.prior == Prior.LASSO) {
                            realUpdate = Math.signum(testUpdate) * SGDWithAdaGradAndFOBOS.pospart(Math.abs(testUpdate) - currentLambda);
                            this.updateX(this.x, index, realUpdate);
                            if (realUpdate == 0.0) continue;
                            ++numOfNonZero;
                            continue;
                        }
                        if (this.prior != Prior.RIDGE) continue;
                        testUpdateSquaredSum += testUpdate * testUpdate;
                        testUpdateCache[index] = testUpdate;
                        currentRateCache[index] = currentRate;
                    }
                    if (this.prior == Prior.RIDGE) {
                        double testUpdateNorm = Math.sqrt(testUpdateSquaredSum);
                        for (int index = 0; index < testUpdateCache.length; ++index) {
                            realUpdate = testUpdateCache[index] * SGDWithAdaGradAndFOBOS.pospart(1.0 - currentRateCache[index] * this.lambda / testUpdateNorm);
                            this.updateX(this.x, index, realUpdate);
                            if (realUpdate == 0.0) continue;
                            ++numOfNonZero;
                        }
                    }
                } else {
                    for (int gIndex = 0; gIndex < featureGrouping.length; ++gIndex) {
                        int[] gFeatureIndices = featureGrouping[gIndex];
                        double testUpdateSquaredSum = 0.0;
                        double testUpdateAbsSum = 0.0;
                        double M = gFeatureIndices.length;
                        double dm = Math.log(M);
                        for (int index : gFeatureIndices) {
                            gValue = gradients[index];
                            currentRate = this.computeLearningRate(index, gValue);
                            wValue = this.x[index];
                            testUpdate = wValue - currentRate * gValue;
                            testUpdateSquaredSum += testUpdate * testUpdate;
                            testUpdateAbsSum += Math.abs(testUpdate);
                            testUpdateCache[index] = testUpdate;
                            currentRateCache[index] = currentRate;
                        }
                        if (this.prior == Prior.gLASSO) {
                            double testUpdateNorm = Math.sqrt(testUpdateSquaredSum);
                            boolean groupHasNonZero = false;
                            for (int index : gFeatureIndices) {
                                realUpdate = testUpdateCache[index] * SGDWithAdaGradAndFOBOS.pospart(1.0 - currentRateCache[index] * this.lambda * dm / testUpdateNorm);
                                this.updateX(this.x, index, realUpdate);
                                if (realUpdate == 0.0) continue;
                                ++numOfNonZero;
                                groupHasNonZero = true;
                            }
                            if (!groupHasNonZero) continue;
                            ++numOfNonZeroGroup;
                            continue;
                        }
                        if (this.prior == Prior.aeLASSO) {
                            int nonZeroCount = 0;
                            boolean groupHasNonZero = false;
                            for (int index : gFeatureIndices) {
                                double tau = currentRateCache[index] * this.lambda / (1.0 + currentRateCache[index] * this.lambda * M) * testUpdateAbsSum;
                                realUpdate = Math.signum(testUpdateCache[index]) * SGDWithAdaGradAndFOBOS.pospart(Math.abs(testUpdateCache[index]) - tau);
                                this.updateX(this.x, index, realUpdate);
                                if (realUpdate == 0.0) continue;
                                ++numOfNonZero;
                                ++nonZeroCount;
                                groupHasNonZero = true;
                            }
                            if (!groupHasNonZero) continue;
                            ++numOfNonZeroGroup;
                            continue;
                        }
                        if (this.prior != Prior.sgLASSO) continue;
                        double bSquaredSum = 0.0;
                        double b = 0.0;
                        for (int index : gFeatureIndices) {
                            bCache[index] = b = Math.signum(testUpdateCache[index]) * SGDWithAdaGradAndFOBOS.pospart(Math.abs(testUpdateCache[index]) - currentRateCache[index] * this.alpha * this.lambda);
                            bSquaredSum += b * b;
                        }
                        double bNorm = Math.sqrt(bSquaredSum);
                        int nonZeroCount = 0;
                        boolean groupHasNonZero = false;
                        for (int index : gFeatureIndices) {
                            realUpdate = bCache[index] * SGDWithAdaGradAndFOBOS.pospart(1.0 - currentRateCache[index] * (1.0 - this.alpha) * this.lambda * dm / bNorm);
                            this.updateX(this.x, index, realUpdate);
                            if (realUpdate == 0.0) continue;
                            ++numOfNonZero;
                            ++nonZeroCount;
                            groupHasNonZero = true;
                        }
                        if (!groupHasNonZero) continue;
                        ++numOfNonZeroGroup;
                    }
                }
                for (int index = 0; index < this.x.length; ++index) {
                    this.prevGrad[index] = gradients[index];
                }
            }
            try {
                ArrayMath.assertFinite(this.x, "x");
            }
            catch (ArrayMath.InvalidElementException e) {
                System.err.println(e.toString());
                for (int i = 0; i < this.x.length; ++i) {
                    this.x[i] = Double.NaN;
                }
                break;
            }
            this.sayln(String.valueOf(numBatches) + ", n0-fCount:" + numOfNonZero + (this.prior != Prior.LASSO && this.prior != Prior.RIDGE ? ", n0-gCount:" + numOfNonZeroGroup : "") + (evalScore != Double.NEGATIVE_INFINITY ? ", evalScore:" + evalScore : "") + (objVal != Double.NEGATIVE_INFINITY ? ", obj_val:" + nf.format(objVal) + ", obj_delta:" + objDelta : ""));
            if (values != null && this.useAvgImprovement && iters > 5 && Math.abs((averageImprovement = ((previousVal = ((size = values.size()) >= 10 ? (Double)values.get(size - 10) : (Double)values.get(0)).doubleValue()) - objVal) / (double)(size >= 10 ? 10 : size)) / objVal) < this.TOL) {
                this.sayln("Online Optmization completed, due to average improvement: | newest_val - previous_val | / |newestVal| < TOL ");
                break;
            }
            if (iters >= maxIterations) {
                this.sayln("Online Optimization complete.  Stopped after max iterations");
                break;
            }
            if (total.report() < this.maxTime) continue;
            this.sayln("Online Optimization complete.  Stopped after max time");
            break;
        }
        if (this.evaluateIters > 0) {
            double evalScore = this.useEvalImprovement ? this.doEvaluation(this.xBest) : this.doEvaluation(this.x);
            this.sayln("final evalScore is: " + evalScore);
        }
        this.sayln("Completed in: " + Timing.toSecondsString(total.report()) + " s");
        return this.useEvalImprovement ? this.xBest : this.x;
    }

    protected void sayln(String s) {
        if (!this.quiet) {
            System.err.println(s);
        }
    }

    protected void say(String s) {
        if (!this.quiet) {
            System.err.print(s);
        }
    }

    public static enum Prior {
        LASSO,
        RIDGE,
        GAUSSIAN,
        aeLASSO,
        gLASSO,
        sgLASSO,
        NONE;

    }
}

