/*
 * Decompiled with CFR 0.152.
 */
package edu.stanford.nlp.sentiment;

import edu.stanford.nlp.neural.rnn.RNNCoreAnnotations;
import edu.stanford.nlp.neural.rnn.TopNGramRecord;
import edu.stanford.nlp.sentiment.SentimentCostAndGradient;
import edu.stanford.nlp.sentiment.SentimentModel;
import edu.stanford.nlp.sentiment.SentimentUtils;
import edu.stanford.nlp.stats.ClassicCounter;
import edu.stanford.nlp.stats.Counter;
import edu.stanford.nlp.stats.IntCounter;
import edu.stanford.nlp.trees.Tree;
import edu.stanford.nlp.util.ConfusionMatrix;
import edu.stanford.nlp.util.Generics;
import edu.stanford.nlp.util.StringUtils;
import java.text.DecimalFormat;
import java.text.NumberFormat;
import java.util.ArrayList;
import java.util.List;
import java.util.Set;
import java.util.TreeSet;

public class Evaluate {
    final SentimentCostAndGradient cag;
    final SentimentModel model;
    final int[][] equivalenceClasses;
    final String[] equivalenceClassNames;
    int labelsCorrect;
    int labelsIncorrect;
    int[][] labelConfusion;
    int rootLabelsCorrect;
    int rootLabelsIncorrect;
    int[][] rootLabelConfusion;
    IntCounter<Integer> lengthLabelsCorrect;
    IntCounter<Integer> lengthLabelsIncorrect;
    TopNGramRecord ngrams;
    static final int NUM_NGRAMS = 5;
    private static final NumberFormat NF = new DecimalFormat("0.000000");

    public Evaluate(SentimentModel model) {
        this.model = model;
        this.cag = new SentimentCostAndGradient(model, null);
        this.equivalenceClasses = model.op.equivalenceClasses;
        this.equivalenceClassNames = model.op.equivalenceClassNames;
        this.reset();
    }

    public void reset() {
        this.labelsCorrect = 0;
        this.labelsIncorrect = 0;
        this.labelConfusion = new int[this.model.op.numClasses][this.model.op.numClasses];
        this.rootLabelsCorrect = 0;
        this.rootLabelsIncorrect = 0;
        this.rootLabelConfusion = new int[this.model.op.numClasses][this.model.op.numClasses];
        this.lengthLabelsCorrect = new IntCounter();
        this.lengthLabelsIncorrect = new IntCounter();
        this.ngrams = this.model.op.testOptions.ngramRecordSize > 0 ? new TopNGramRecord(this.model.op.numClasses, this.model.op.testOptions.ngramRecordSize, this.model.op.testOptions.ngramRecordMaximumLength) : null;
    }

    public void eval(List<Tree> trees) {
        for (Tree tree : trees) {
            this.eval(tree);
        }
    }

    public void eval(Tree tree) {
        this.cag.forwardPropagateTree(tree);
        this.countTree(tree);
        this.countRoot(tree);
        this.countLengthAccuracy(tree);
        if (this.ngrams != null) {
            this.ngrams.countTree(tree);
        }
    }

    private int countLengthAccuracy(Tree tree) {
        int length;
        if (tree.isLeaf()) {
            return 0;
        }
        Integer gold = RNNCoreAnnotations.getGoldClass(tree);
        Integer predicted = RNNCoreAnnotations.getPredictedClass(tree);
        if (tree.isPreTerminal()) {
            length = 1;
        } else {
            length = 0;
            for (Tree child : tree.children()) {
                length += this.countLengthAccuracy(child);
            }
        }
        if (gold >= 0) {
            if (gold.equals(predicted)) {
                this.lengthLabelsCorrect.incrementCount(length);
            } else {
                this.lengthLabelsIncorrect.incrementCount(length);
            }
        }
        return length;
    }

    private void countTree(Tree tree) {
        if (tree.isLeaf()) {
            return;
        }
        for (Tree child : tree.children()) {
            this.countTree(child);
        }
        Integer gold = RNNCoreAnnotations.getGoldClass(tree);
        Integer predicted = RNNCoreAnnotations.getPredictedClass(tree);
        if (gold >= 0) {
            if (gold.equals(predicted)) {
                ++this.labelsCorrect;
            } else {
                ++this.labelsIncorrect;
            }
            int[] nArray = this.labelConfusion[gold];
            int n = predicted;
            nArray[n] = nArray[n] + 1;
        }
    }

    private void countRoot(Tree tree) {
        Integer gold = RNNCoreAnnotations.getGoldClass(tree);
        Integer predicted = RNNCoreAnnotations.getPredictedClass(tree);
        if (gold >= 0) {
            if (gold.equals(predicted)) {
                ++this.rootLabelsCorrect;
            } else {
                ++this.rootLabelsIncorrect;
            }
            int[] nArray = this.rootLabelConfusion[gold];
            int n = predicted;
            nArray[n] = nArray[n] + 1;
        }
    }

    public double exactNodeAccuracy() {
        return (double)this.labelsCorrect / (double)(this.labelsCorrect + this.labelsIncorrect);
    }

    public double exactRootAccuracy() {
        return (double)this.rootLabelsCorrect / (double)(this.rootLabelsCorrect + this.rootLabelsIncorrect);
    }

    public Counter<Integer> lengthAccuracies() {
        Set<Integer> keys = Generics.newHashSet();
        keys.addAll(this.lengthLabelsCorrect.keySet());
        keys.addAll(this.lengthLabelsIncorrect.keySet());
        ClassicCounter<Integer> results = new ClassicCounter<Integer>();
        for (Integer key : keys) {
            results.setCount(key, this.lengthLabelsCorrect.getCount(key) / (this.lengthLabelsCorrect.getCount(key) + this.lengthLabelsIncorrect.getCount(key)));
        }
        return results;
    }

    public void printLengthAccuracies() {
        Counter<Integer> accuracies = this.lengthAccuracies();
        TreeSet<Integer> keys = Generics.newTreeSet();
        keys.addAll(accuracies.keySet());
        System.err.println("Label accuracy at various lengths:");
        for (Integer key : keys) {
            System.err.println(StringUtils.padLeft(Integer.toString(key), 4) + ": " + NF.format(accuracies.getCount(key)));
        }
    }

    private static void printConfusionMatrix(String name, int[][] confusion) {
        System.err.println(name + " confusion matrix");
        ConfusionMatrix<Integer> confusionMatrix = new ConfusionMatrix<Integer>();
        confusionMatrix.setUseRealLabels(true);
        for (int i = 0; i < confusion.length; ++i) {
            for (int j = 0; j < confusion[i].length; ++j) {
                confusionMatrix.add(j, i, confusion[i][j]);
            }
        }
        System.err.println(confusionMatrix);
    }

    private static double[] approxAccuracy(int[][] confusion, int[][] classes) {
        int[] correct = new int[classes.length];
        int[] total = new int[classes.length];
        double[] results = new double[classes.length];
        for (int i = 0; i < classes.length; ++i) {
            for (int j = 0; j < classes[i].length; ++j) {
                int k;
                for (k = 0; k < classes[i].length; ++k) {
                    int n = i;
                    correct[n] = correct[n] + confusion[classes[i][j]][classes[i][k]];
                }
                for (k = 0; k < confusion[classes[i][j]].length; ++k) {
                    int n = i;
                    total[n] = total[n] + confusion[classes[i][j]][k];
                }
            }
            results[i] = (double)correct[i] / (double)total[i];
        }
        return results;
    }

    private static double approxCombinedAccuracy(int[][] confusion, int[][] classes) {
        int correct = 0;
        int total = 0;
        for (int i = 0; i < classes.length; ++i) {
            for (int j = 0; j < classes[i].length; ++j) {
                int k;
                for (k = 0; k < classes[i].length; ++k) {
                    correct += confusion[classes[i][j]][classes[i][k]];
                }
                for (k = 0; k < confusion[classes[i][j]].length; ++k) {
                    total += confusion[classes[i][j]][k];
                }
            }
        }
        return (double)correct / (double)total;
    }

    public void printSummary() {
        System.err.println("EVALUATION SUMMARY");
        System.err.println("Tested " + (this.labelsCorrect + this.labelsIncorrect) + " labels");
        System.err.println("  " + this.labelsCorrect + " correct");
        System.err.println("  " + this.labelsIncorrect + " incorrect");
        System.err.println("  " + NF.format(this.exactNodeAccuracy()) + " accuracy");
        System.err.println("Tested " + (this.rootLabelsCorrect + this.rootLabelsIncorrect) + " roots");
        System.err.println("  " + this.rootLabelsCorrect + " correct");
        System.err.println("  " + this.rootLabelsIncorrect + " incorrect");
        System.err.println("  " + NF.format(this.exactRootAccuracy()) + " accuracy");
        Evaluate.printConfusionMatrix("Label", this.labelConfusion);
        Evaluate.printConfusionMatrix("Root label", this.rootLabelConfusion);
        if (this.equivalenceClasses != null && this.equivalenceClassNames != null) {
            double[] approxLabelAccuracy = Evaluate.approxAccuracy(this.labelConfusion, this.equivalenceClasses);
            for (int i = 0; i < this.equivalenceClassNames.length; ++i) {
                System.err.println("Approximate " + this.equivalenceClassNames[i] + " label accuracy: " + NF.format(approxLabelAccuracy[i]));
            }
            System.err.println("Combined approximate label accuracy: " + NF.format(Evaluate.approxCombinedAccuracy(this.labelConfusion, this.equivalenceClasses)));
            double[] approxRootLabelAccuracy = Evaluate.approxAccuracy(this.rootLabelConfusion, this.equivalenceClasses);
            for (int i = 0; i < this.equivalenceClassNames.length; ++i) {
                System.err.println("Approximate " + this.equivalenceClassNames[i] + " root label accuracy: " + NF.format(approxRootLabelAccuracy[i]));
            }
            System.err.println("Combined approximate root label accuracy: " + NF.format(Evaluate.approxCombinedAccuracy(this.rootLabelConfusion, this.equivalenceClasses)));
            System.err.println();
        }
        if (this.model.op.testOptions.ngramRecordSize > 0) {
            System.err.println(this.ngrams);
        }
        if (this.model.op.testOptions.printLengthAccuracies) {
            this.printLengthAccuracies();
        }
    }

    public static void main(String[] args) {
        String modelPath = null;
        String treePath = null;
        boolean filterUnknown = false;
        ArrayList<String> remainingArgs = Generics.newArrayList();
        int argIndex = 0;
        while (argIndex < args.length) {
            if (args[argIndex].equalsIgnoreCase("-model")) {
                modelPath = args[argIndex + 1];
                argIndex += 2;
                continue;
            }
            if (args[argIndex].equalsIgnoreCase("-treebank")) {
                treePath = args[argIndex + 1];
                argIndex += 2;
                continue;
            }
            if (args[argIndex].equalsIgnoreCase("-filterUnknown")) {
                filterUnknown = true;
                ++argIndex;
                continue;
            }
            remainingArgs.add(args[argIndex]);
            ++argIndex;
        }
        String[] newArgs = new String[remainingArgs.size()];
        remainingArgs.toArray(newArgs);
        SentimentModel model = SentimentModel.loadSerialized(modelPath);
        int argIndex2 = 0;
        while (argIndex2 < newArgs.length) {
            int newIndex = model.op.setOption(newArgs, argIndex2);
            if (argIndex2 == newIndex) {
                System.err.println("Unknown argument " + newArgs[argIndex2]);
                throw new IllegalArgumentException("Unknown argument " + newArgs[argIndex2]);
            }
            argIndex2 = newIndex;
        }
        List<Tree> trees = SentimentUtils.readTreesWithGoldLabels(treePath);
        if (filterUnknown) {
            trees = SentimentUtils.filterUnknownRoots(trees);
        }
        Evaluate eval = new Evaluate(model);
        eval.eval(trees);
        eval.printSummary();
    }
}

