/*
 * Decompiled with CFR 0.152.
 */
package com.aliasi.classify;

import com.aliasi.classify.ScoredClassification;
import com.aliasi.stats.Statistics;
import com.aliasi.util.Math;
import com.aliasi.util.Pair;
import com.aliasi.util.ScoredObject;
import java.util.Arrays;

/*
 * This class specifies class file version 49.0 but uses Java 6 signatures.  Assumed Java 6.
 */
public class ConditionalClassification
extends ScoredClassification {
    private final double[] mConditionalProbs;
    private static final double TOLERANCE = 0.01;

    public ConditionalClassification(String[] categories, double[] conditionalProbs) {
        this(categories, conditionalProbs, conditionalProbs, 0.01);
    }

    public ConditionalClassification(String[] categories, double[] scores, double[] conditionalProbs) {
        this(categories, scores, conditionalProbs, 0.01);
    }

    public ConditionalClassification(String[] categories, double[] conditionalProbs, double tolerance) {
        this(categories, conditionalProbs, conditionalProbs, tolerance);
    }

    public ConditionalClassification(String[] categories, double[] scores, double[] conditionalProbs, double tolerance) {
        super(categories, scores);
        this.mConditionalProbs = conditionalProbs;
        if (tolerance < 0.0 || Double.isNaN(tolerance)) {
            String msg = "Tolerance must be a positive number. Found tolerance=" + tolerance;
            throw new IllegalArgumentException(msg);
        }
        for (int i = 0; i < conditionalProbs.length; ++i) {
            if (!(conditionalProbs[i] < 0.0) && !(conditionalProbs[i] > 1.0)) continue;
            String msg = "Conditional probabilities must be  between 0.0 and 1.0. Found conditionalProbs[" + i + "]=" + conditionalProbs[i];
            throw new IllegalArgumentException(msg);
        }
        double sum = Math.sum(conditionalProbs);
        if (sum < 1.0 - tolerance || sum > 1.0 + tolerance) {
            String msg = "Conditional probabilities must sum to 1.0. Acceptable tolerance=" + tolerance + " Found sum=" + sum;
            throw new IllegalArgumentException(msg);
        }
    }

    public double conditionalProbability(int rank) {
        if (rank < 0 || rank > this.mConditionalProbs.length - 1) {
            String msg = "Require rank in range 0.." + (this.mConditionalProbs.length - 1) + " Found rank=" + rank;
            throw new IllegalArgumentException(msg);
        }
        return this.mConditionalProbs[rank];
    }

    public double conditionalProbability(String category) {
        for (int rank = 0; rank < this.size(); ++rank) {
            if (!this.category(rank).equals(category)) continue;
            return this.conditionalProbability(rank);
        }
        String msg = category + " is not a valid category for this classification.  Valid categories are:";
        for (int rank = 0; rank < this.size(); ++rank) {
            msg = msg + " " + this.category(rank) + ",";
        }
        msg = msg.substring(0, msg.length() - 1);
        throw new IllegalArgumentException(msg);
    }

    @Override
    public String toString() {
        StringBuilder sb = new StringBuilder();
        sb.append("Rank  Category  Score  P(Category|Input)\n");
        for (int i = 0; i < this.size(); ++i) {
            sb.append(i + "=" + this.category(i) + " " + this.score(i) + " " + this.conditionalProbability(i) + '\n');
        }
        return sb.toString();
    }

    public static ConditionalClassification createLogProbs(String[] categories, double[] logProbabilities) {
        ConditionalClassification.verifyLengths(categories, logProbabilities);
        ConditionalClassification.verifyLogProbs(logProbabilities);
        double[] linearProbs = ConditionalClassification.logJointToConditional(logProbabilities);
        Pair<String[], double[]> catsProbs = ConditionalClassification.sort(categories, linearProbs);
        return new ConditionalClassification(catsProbs.a(), catsProbs.b());
    }

    public static ConditionalClassification createProbs(String[] categories, double[] probabilityRatios) {
        for (int i = 0; i < probabilityRatios.length; ++i) {
            if (!(probabilityRatios[i] < 0.0) && !Double.isInfinite(probabilityRatios[i]) && !Double.isNaN(probabilityRatios[i])) continue;
            String msg = "Probability ratios must be non-negative and finite. Found probabilityRatios[" + i + "]=" + probabilityRatios[i];
            throw new IllegalArgumentException(msg);
        }
        if (Math.sum(probabilityRatios) == 0.0) {
            double[] conditionalProbs = new double[probabilityRatios.length];
            Arrays.fill(conditionalProbs, 1.0 / (double)probabilityRatios.length);
            return new ConditionalClassification(categories, conditionalProbs);
        }
        double[] logProbs = new double[probabilityRatios.length];
        for (int i = 0; i < probabilityRatios.length; ++i) {
            logProbs[i] = Math.log2(probabilityRatios[i]);
        }
        return ConditionalClassification.createLogProbs(categories, logProbs);
    }

    static void verifyLogProbs(double[] logProbabilities) {
        for (double x : logProbabilities) {
            if (!Double.isNaN(x) && !(x > 0.0)) continue;
            String msg = "Log probs must be non-positive numbers. Found x=" + x;
            throw new IllegalArgumentException(msg);
        }
    }

    static void verifyLengths(String[] categories, double[] logProbabilities) {
        if (categories.length != logProbabilities.length) {
            String msg = "Arrays must be same length. Found categories.length=" + categories.length + " logProbabilities.length=" + logProbabilities.length;
            throw new IllegalArgumentException(msg);
        }
    }

    static Pair<String[], double[]> sort(String[] categories, double[] vals) {
        ConditionalClassification.verifyLengths(categories, vals);
        ScoredObject[] scoredObjects = new ScoredObject[categories.length];
        for (int i = 0; i < categories.length; ++i) {
            scoredObjects[i] = new ScoredObject<String>(categories[i], vals[i]);
        }
        String[] categoriesSorted = new String[scoredObjects.length];
        double[] valsSorted = new double[categories.length];
        Arrays.sort(scoredObjects, ScoredObject.reverseComparator());
        for (int i = 0; i < scoredObjects.length; ++i) {
            categoriesSorted[i] = (String)scoredObjects[i].getObject();
            valsSorted[i] = scoredObjects[i].score();
        }
        return new Pair<String[], double[]>(categoriesSorted, valsSorted);
    }

    static double[] logJointToConditional(double[] logJointProbs) {
        for (int i = 0; i < logJointProbs.length; ++i) {
            if (logJointProbs[i] > 0.0 && logJointProbs[i] < 1.0E-10) {
                logJointProbs[i] = 0.0;
            }
            if (!(logJointProbs[i] > 0.0) && !Double.isNaN(logJointProbs[i])) continue;
            StringBuilder sb = new StringBuilder();
            sb.append("Joint probs must be zero or negative. Found log2JointProbs[" + i + "]=" + logJointProbs[i]);
            for (int k = 0; k < logJointProbs.length; ++k) {
                sb.append("\nlogJointProbs[" + k + "]=" + logJointProbs[k]);
            }
            throw new IllegalArgumentException(sb.toString());
        }
        double max = Math.maximum(logJointProbs);
        double[] probRatios = new double[logJointProbs.length];
        for (int i = 0; i < logJointProbs.length; ++i) {
            probRatios[i] = java.lang.Math.pow(2.0, logJointProbs[i] - max);
            if (probRatios[i] == Double.POSITIVE_INFINITY) {
                probRatios[i] = 3.4028234663852886E38;
                continue;
            }
            if (probRatios[i] != Double.NEGATIVE_INFINITY && !Double.isNaN(probRatios[i])) continue;
            probRatios[i] = 0.0;
        }
        return Statistics.normalize(probRatios);
    }
}

