/*
 * Decompiled with CFR 0.152.
 */
package com.aliasi.stats;

import com.aliasi.io.LogLevel;
import com.aliasi.io.Reporter;
import com.aliasi.io.Reporters;
import com.aliasi.matrix.DenseVector;
import com.aliasi.matrix.Matrices;
import com.aliasi.matrix.SparseFloatVector;
import com.aliasi.matrix.Vector;
import com.aliasi.stats.AnnealingSchedule;
import com.aliasi.stats.RegressionPrior;
import com.aliasi.util.AbstractExternalizable;
import com.aliasi.util.Compilable;
import com.aliasi.util.Math;
import java.io.IOException;
import java.io.ObjectInput;
import java.io.ObjectOutput;
import java.io.PrintWriter;
import java.io.Serializable;
import java.util.Arrays;
import java.util.Formatter;
import java.util.IllegalFormatException;
import java.util.Locale;

public class LogisticRegression
implements Compilable,
Serializable {
    static final long serialVersionUID = -8585743596322227589L;
    private final Vector[] mWeightVectors;

    public LogisticRegression(Vector[] weightVectors) {
        if (weightVectors.length < 1) {
            String msg = "Require at least one weight vector.";
            throw new IllegalArgumentException(msg);
        }
        int numDimensions = weightVectors[0].numDimensions();
        for (int k = 1; k < weightVectors.length; ++k) {
            if (numDimensions == weightVectors[k].numDimensions()) continue;
            String msg = "All weight vectors must be same dimensionality. Found weightVectors[0].numDimensions()=" + numDimensions + " weightVectors[" + k + "]=" + weightVectors[k].numDimensions();
            throw new IllegalArgumentException(msg);
        }
        this.mWeightVectors = weightVectors;
    }

    public LogisticRegression(Vector weightVector) {
        this.mWeightVectors = new Vector[]{weightVector};
    }

    public int numInputDimensions() {
        return this.mWeightVectors[0].numDimensions();
    }

    public int numOutcomes() {
        return this.mWeightVectors.length + 1;
    }

    public Vector[] weightVectors() {
        Vector[] immutables = new Vector[this.mWeightVectors.length];
        for (int i = 0; i < immutables.length; ++i) {
            immutables[i] = Matrices.unmodifiableVector(this.mWeightVectors[i]);
        }
        return immutables;
    }

    public double[] classify(Vector x) {
        int k;
        if (this.numInputDimensions() != x.numDimensions()) {
            String msg = "Vector and classifer must be of same dimensionality. Regression model this.numInputDimensions()=" + this.numInputDimensions() + " Vector x.numDimensions()=" + x.numDimensions();
            throw new IllegalArgumentException(msg);
        }
        int numOutcomesMinus1 = this.numOutcomes() - 1;
        double[] ysHat = new double[this.numOutcomes()];
        ysHat[numOutcomesMinus1] = 0.0;
        double max = 0.0;
        for (int k2 = 0; k2 < numOutcomesMinus1; ++k2) {
            ysHat[k2] = x.dotProduct(this.mWeightVectors[k2]);
            if (!(ysHat[k2] > max)) continue;
            max = ysHat[k2];
        }
        double z = 0.0;
        for (k = 0; k < ysHat.length; ++k) {
            ysHat[k] = java.lang.Math.exp(ysHat[k] - max);
            z += ysHat[k];
        }
        k = 0;
        while (k < ysHat.length) {
            int n = k++;
            ysHat[n] = ysHat[n] / z;
        }
        return ysHat;
    }

    public void compileTo(ObjectOutput out) throws IOException {
        out.writeObject(new Externalizer(this));
    }

    Object writeReplace() {
        return new Externalizer(this);
    }

    @Deprecated
    public static LogisticRegression estimate(Vector[] xs, int[] cs, RegressionPrior prior, AnnealingSchedule annealingSchedule, double minImprovement, int minEpochs, int maxEpochs, PrintWriter progressWriter) {
        Reporter reporter = progressWriter == null ? null : Reporters.writer(progressWriter).setLevel(LogLevel.DEBUG);
        return LogisticRegression.estimate(xs, cs, prior, annealingSchedule, reporter, minImprovement, minEpochs, maxEpochs);
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    public static LogisticRegression estimate(Vector[] xs, int[] cs, RegressionPrior prior, AnnealingSchedule annealingSchedule, Reporter reporter, double minImprovement, int minEpochs, int maxEpochs) {
        if (reporter == null) {
            reporter = Reporters.silent();
        }
        if (xs.length < 1) {
            String msg = "Require at least one training instance.";
            reporter.fatal(msg);
            throw new IllegalArgumentException(msg);
        }
        if (xs.length != cs.length) {
            String msg = "Require same number of training instances as outcomes. Found xs.length=" + xs.length + " cs.length=" + cs.length;
            reporter.fatal(msg);
            throw new IllegalArgumentException(msg);
        }
        int numTrainingInstances = xs.length;
        int numOutcomesMinus1 = LogisticRegression.max(cs);
        int numOutcomes = numOutcomesMinus1 + 1;
        int numDimensions = xs[0].numDimensions();
        prior.verifyNumberOfDimensions(numDimensions);
        for (int i = 1; i < xs.length; ++i) {
            if (xs[i].numDimensions() == numDimensions) continue;
            String msg = "Number of dimensions must match for all input vectors. Found xs[0].numDimensions()=" + numDimensions + " xs[" + i + "].numDimensions()=" + xs[i].numDimensions();
            reporter.fatal(msg);
            throw new IllegalArgumentException(msg);
        }
        Vector[] weightVectors = new DenseVector[numOutcomesMinus1];
        for (int k = 0; k < numOutcomesMinus1; ++k) {
            weightVectors[k] = new DenseVector(numDimensions);
        }
        boolean hasSparseInputs = LogisticRegression.isSparse(xs);
        boolean hasPrior = prior != null && !prior.isUniform();
        reporter.info("Logistic Regression Progress Report");
        reporter.info("Number of dimensions=" + numDimensions);
        reporter.info("Number of Outcomes=" + numOutcomes);
        reporter.info("Number of Parameters=" + (long)(numOutcomes - 1) * (long)numDimensions);
        reporter.info("Prior:\n" + prior);
        reporter.info("Annealing Schedule=" + annealingSchedule);
        reporter.info("Minimum Epochs=" + minEpochs);
        reporter.info("Maximum Epochs=" + maxEpochs);
        reporter.info("Minimum Improvement Per Period=" + minImprovement);
        reporter.info("Has Sparse Inputs=" + hasSparseInputs);
        reporter.info("Has Informative Prior=" + hasPrior);
        int[] lastRegularizations = hasPrior && hasSparseInputs ? new int[numDimensions] : null;
        double lastLog2LikelihoodAndPrior = -8.988465674311579E307;
        LogisticRegression regression = new LogisticRegression(weightVectors);
        double rollingAverageRelativeDiff = 1.0;
        double bestLog2LikelihoodAndPrior = Double.NEGATIVE_INFINITY;
        for (int epoch = 0; epoch < maxEpochs; ++epoch) {
            boolean acceptUpdate;
            DenseVector[] weightVectorCopies = LogisticRegression.copy((DenseVector[])weightVectors);
            if (hasPrior && hasSparseInputs) {
                Arrays.fill(lastRegularizations, 0);
            }
            double learningRate = annealingSchedule.learningRate(epoch);
            for (int j = 0; j < numTrainingInstances; ++j) {
                Vector xsJ = xs[j];
                int csJ = cs[j];
                if (hasPrior) {
                    if (hasSparseInputs) {
                        LogisticRegression.adjustWeightsWithPrior((DenseVector[])weightVectors, xsJ.nonZeroDimensions(), j, prior, learningRate, numTrainingInstances, lastRegularizations);
                    } else {
                        LogisticRegression.adjustWeightsWithPriorDense((DenseVector[])weightVectors, prior, learningRate, numTrainingInstances);
                    }
                }
                double[] conditionalProbs = regression.classify(xsJ);
                for (int k = 0; k < numOutcomesMinus1; ++k) {
                    LogisticRegression.adjustWeightsWithConditionalProbs((DenseVector)weightVectors[k], conditionalProbs[k], learningRate, xsJ, k, csJ);
                }
            }
            if (hasPrior) {
                if (hasSparseInputs) {
                    LogisticRegression.adjustWeightsWithPriorAll((DenseVector[])weightVectors, prior, learningRate, numTrainingInstances, lastRegularizations);
                } else {
                    LogisticRegression.adjustWeightsWithPriorDense((DenseVector[])weightVectors, prior, learningRate, numTrainingInstances);
                }
            }
            double log2Likelihood = LogisticRegression.log2Likelihood(xs, cs, regression);
            double log2Prior = prior.log2Prior(weightVectors);
            double log2LikelihoodAndPrior = log2Likelihood + prior.log2Prior(weightVectors);
            if (log2LikelihoodAndPrior > bestLog2LikelihoodAndPrior) {
                bestLog2LikelihoodAndPrior = log2LikelihoodAndPrior;
            }
            if (!(acceptUpdate = annealingSchedule.receivedError(epoch, learningRate, -log2LikelihoodAndPrior))) {
                reporter.debug("Annealing rejected update at learningRate=" + learningRate + " error=" + -log2LikelihoodAndPrior);
                weightVectors = weightVectorCopies;
                regression = new LogisticRegression(weightVectors);
            }
            double relativeDiff = Math.relativeAbsoluteDifference(lastLog2LikelihoodAndPrior, log2LikelihoodAndPrior);
            rollingAverageRelativeDiff = (9.0 * rollingAverageRelativeDiff + relativeDiff) / 10.0;
            lastLog2LikelihoodAndPrior = log2LikelihoodAndPrior;
            if (reporter.isDebugEnabled()) {
                Formatter formatter = null;
                try {
                    formatter = new Formatter(Locale.ENGLISH);
                    formatter.format("epoch=%5d lr=%11.9f ll=%11.4f lp=%11.4f llp=%11.4f llp*=%11.4f", epoch, learningRate, log2Likelihood, log2Prior, log2LikelihoodAndPrior, bestLog2LikelihoodAndPrior);
                    reporter.debug(formatter.toString());
                }
                catch (IllegalFormatException e) {
                    reporter.warn("Illegal format in Logistic Regression");
                }
                finally {
                    if (formatter != null) {
                        formatter.close();
                    }
                }
            }
            if (!(rollingAverageRelativeDiff < minImprovement)) continue;
            reporter.info("Converged with rollingAverageRelativeDiff=" + rollingAverageRelativeDiff);
            break;
        }
        return regression;
    }

    public static double log2Likelihood(Vector[] inputs, int[] cats, LogisticRegression regression) {
        if (inputs.length != cats.length) {
            String msg = "Inputs and categories must be same length. Found inputs.length=" + inputs.length + " cats.length=" + cats.length;
            throw new IllegalArgumentException(msg);
        }
        int numTrainingInstances = inputs.length;
        double log2Likelihood = 0.0;
        for (int j = 0; j < numTrainingInstances; ++j) {
            double[] conditionalProbs = regression.classify(inputs[j]);
            log2Likelihood += Math.log2(conditionalProbs[cats[j]]);
        }
        return log2Likelihood;
    }

    private static void adjustWeightsWithPrior(DenseVector[] weightVectors, int[] dimensions, int curInstance, RegressionPrior prior, double learningRate, int numTrainingInstances, int[] lastRegularizations) {
        double learningRateDividedByNumTrainingInstances = learningRate / (double)numTrainingInstances;
        for (DenseVector weightVectorsK : weightVectors) {
            for (int i = 0; i < dimensions.length; ++i) {
                int dim = dimensions[i];
                double weightVectorsKDim = weightVectorsK.value(dim);
                if (weightVectorsKDim == 0.0) continue;
                int skippedDimMultiplier = curInstance - lastRegularizations[dim];
                double priorGradient = prior.gradient(weightVectorsKDim, dim);
                double delta = (double)skippedDimMultiplier * priorGradient * learningRateDividedByNumTrainingInstances;
                double newVal = weightVectorsKDim > 0.0 ? java.lang.Math.max(0.0, weightVectorsKDim - delta) : java.lang.Math.min(0.0, weightVectorsKDim - delta);
                weightVectorsK.setValue(dim, newVal);
            }
        }
        for (int i = 0; i < dimensions.length; ++i) {
            lastRegularizations[dimensions[i]] = curInstance;
        }
    }

    private static void adjustWeightsWithPriorAll(DenseVector[] weightVectors, RegressionPrior prior, double learningRate, int numTrainingInstances, int[] lastRegularizations) {
        double learningRateDividedByNumTrainingInstances = learningRate / (double)numTrainingInstances;
        int numDimensions = weightVectors[0].numDimensions();
        for (DenseVector weightVectorsK : weightVectors) {
            for (int dim = 0; dim < numDimensions; ++dim) {
                double weightVectorsKDim = weightVectorsK.value(dim);
                if (weightVectorsKDim == 0.0) continue;
                int skippedDimMultiplier = numTrainingInstances - lastRegularizations[dim];
                double priorGradient = prior.gradient(weightVectorsKDim, dim);
                double delta = (double)skippedDimMultiplier * priorGradient * learningRateDividedByNumTrainingInstances;
                double newVal = weightVectorsKDim > 0.0 ? java.lang.Math.max(0.0, weightVectorsKDim - delta) : java.lang.Math.min(0.0, weightVectorsKDim - delta);
                weightVectorsK.setValue(dim, newVal);
            }
        }
    }

    private static void adjustWeightsWithPriorDense(DenseVector[] weightVectors, RegressionPrior prior, double learningRate, int numTrainingInstances) {
        double learningRateDividedByNumTrainingInstances = learningRate / (double)numTrainingInstances;
        for (DenseVector weightVectorsK : weightVectors) {
            for (int dim = 0; dim < weightVectorsK.numDimensions(); ++dim) {
                double weightVectorsKDim = weightVectorsK.value(dim);
                if (weightVectorsKDim == 0.0) continue;
                double priorGradient = prior.gradient(weightVectorsKDim, dim);
                double delta = priorGradient * learningRateDividedByNumTrainingInstances;
                double newVal = weightVectorsKDim > 0.0 ? java.lang.Math.max(0.0, weightVectorsKDim - delta) : java.lang.Math.min(0.0, weightVectorsKDim - delta);
                weightVectorsK.setValue(dim, newVal);
            }
        }
    }

    private static void adjustWeightsWithConditionalProbs(DenseVector weightVectorsK, double conditionalProb, double learningRate, Vector xsJ, int k, int csJ) {
        double conditionalProbMinusTruth;
        double d = conditionalProbMinusTruth = k == csJ ? conditionalProb - 1.0 : conditionalProb;
        if (conditionalProbMinusTruth == 0.0) {
            return;
        }
        weightVectorsK.increment(-learningRate * conditionalProbMinusTruth, xsJ);
    }

    private static boolean isSparse(Vector[] xs) {
        int sparseCount = 0;
        for (int i = 0; i < xs.length; ++i) {
            if (!(xs[i] instanceof SparseFloatVector)) continue;
            ++sparseCount;
        }
        return sparseCount >= xs.length / 2;
    }

    private static int max(int[] xs) {
        int max = xs[0];
        for (int i = 1; i < xs.length; ++i) {
            if (xs[i] <= max) continue;
            max = xs[i];
        }
        return max;
    }

    private static DenseVector[] copy(DenseVector[] xs) {
        DenseVector[] result = new DenseVector[xs.length];
        for (int k = 0; k < xs.length; ++k) {
            result[k] = new DenseVector(xs[k]);
        }
        return result;
    }

    static class Externalizer
    extends AbstractExternalizable {
        static final long serialVersionUID = -2256261505231943102L;
        final LogisticRegression mRegression;

        public Externalizer() {
            this(null);
        }

        public Externalizer(LogisticRegression regression) {
            this.mRegression = regression;
        }

        public void writeExternal(ObjectOutput out) throws IOException {
            int numOutcomes = this.mRegression.mWeightVectors.length + 1;
            out.writeInt(numOutcomes);
            int numDimensions = this.mRegression.mWeightVectors[0].numDimensions();
            out.writeInt(numDimensions);
            for (int c = 0; c < numOutcomes - 1; ++c) {
                Vector vC = this.mRegression.mWeightVectors[c];
                for (int i = 0; i < numDimensions; ++i) {
                    out.writeDouble(vC.value(i));
                }
            }
        }

        public Object read(ObjectInput in) throws IOException {
            int numOutcomes = in.readInt();
            int numDimensions = in.readInt();
            Vector[] weightVectors = new Vector[numOutcomes - 1];
            for (int c = 0; c < weightVectors.length; ++c) {
                DenseVector weightVectorsC = new DenseVector(numDimensions);
                weightVectors[c] = weightVectorsC;
                for (int i = 0; i < numDimensions; ++i) {
                    weightVectorsC.setValue(i, in.readDouble());
                }
            }
            return new LogisticRegression(weightVectors);
        }
    }
}

