/*
 * Decompiled with CFR 0.152.
 */
package com.aliasi.classify;

import com.aliasi.classify.Classification;
import com.aliasi.classify.ClassificationHandlerCorpusAdapter2;
import com.aliasi.classify.Classified;
import com.aliasi.classify.Classifier;
import com.aliasi.classify.ConditionalClassification;
import com.aliasi.classify.JointClassification;
import com.aliasi.classify.JointClassifier;
import com.aliasi.corpus.ClassificationHandler;
import com.aliasi.corpus.Corpus;
import com.aliasi.corpus.ObjectHandler;
import com.aliasi.corpus.TextHandler;
import com.aliasi.io.Reporter;
import com.aliasi.io.Reporters;
import com.aliasi.stats.Statistics;
import com.aliasi.tokenizer.Tokenizer;
import com.aliasi.tokenizer.TokenizerFactory;
import com.aliasi.util.AbstractExternalizable;
import com.aliasi.util.Compilable;
import com.aliasi.util.Counter;
import com.aliasi.util.Exceptions;
import com.aliasi.util.Factory;
import com.aliasi.util.Iterators;
import com.aliasi.util.ObjectToCounterMap;
import com.aliasi.util.Strings;
import java.io.IOException;
import java.io.ObjectInput;
import java.io.ObjectOutput;
import java.io.ObjectStreamException;
import java.io.Serializable;
import java.util.Arrays;
import java.util.Collections;
import java.util.Formatter;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.Map;
import java.util.Set;

/*
 * This class specifies class file version 49.0 but uses Java 6 signatures.  Assumed Java 6.
 */
public class TradNaiveBayesClassifier
implements ClassificationHandler<CharSequence, Classification>,
Classifier<CharSequence, JointClassification>,
JointClassifier<CharSequence>,
ObjectHandler<Classified<CharSequence>>,
Serializable,
Compilable {
    static final long serialVersionUID = -300327951207213311L;
    private final Set<String> mCategorySet;
    private final String[] mCategories;
    private final TokenizerFactory mTokenizerFactory;
    private final double mCategoryPrior;
    private final double mTokenInCategoryPrior;
    private Map<String, double[]> mTokenToCountsMap;
    private double[] mTotalCountsPerCategory;
    private double[] mCaseCounts;
    private double mTotalCaseCount;
    private double mLengthNorm;

    public String toString() {
        StringBuilder sb = new StringBuilder();
        sb.append("categories=" + Arrays.asList(this.mCategories) + "\n");
        sb.append("category Prior=" + this.mCategoryPrior + "\n");
        sb.append("token in category prior=" + this.mTokenInCategoryPrior + "\n");
        sb.append("total case count=" + this.mTotalCaseCount + "\n");
        for (int i = 0; i < this.mCategories.length; ++i) {
            sb.append("category count(" + this.mCategories[i] + ")=" + this.mCaseCounts[i] + "\n");
        }
        for (String token : this.mTokenToCountsMap.keySet()) {
            sb.append("token=" + token + "\n");
            double[] counts = this.mTokenToCountsMap.get(token);
            for (int i = 0; i < this.mCategories.length; ++i) {
                sb.append("  tokenCount(" + this.mCategories[i] + "," + token + ")=" + counts[i] + "\n");
            }
        }
        return sb.toString();
    }

    private TradNaiveBayesClassifier(String[] categories, TokenizerFactory tokenizerFactory, double categoryPrior, double tokenInCategoryPrior, Map<String, double[]> tokenToCountsMap, double[] totalCountsPerCategory, double[] caseCounts, double totalCaseCount, double lengthNorm) {
        this.mCategories = categories;
        this.mCategorySet = new HashSet<String>(Arrays.asList(categories));
        this.mTokenizerFactory = tokenizerFactory;
        this.mCategoryPrior = categoryPrior;
        this.mTokenInCategoryPrior = tokenInCategoryPrior;
        this.mTokenToCountsMap = tokenToCountsMap;
        this.mTotalCountsPerCategory = totalCountsPerCategory;
        this.mCaseCounts = caseCounts;
        this.mTotalCaseCount = totalCaseCount;
        this.mLengthNorm = lengthNorm;
    }

    public TradNaiveBayesClassifier(Set<String> categorySet, TokenizerFactory tokenizerFactory) {
        this(categorySet, tokenizerFactory, 0.5, 0.5, Double.NaN);
    }

    public TradNaiveBayesClassifier(Set<String> categorySet, TokenizerFactory tokenizerFactory, double categoryPrior, double tokenInCategoryPrior, double lengthNorm) {
        if (categorySet.size() < 2) {
            String msg = "Require at least two categorySet. Found categorySet.size()=" + categorySet.size();
            throw new IllegalArgumentException(msg);
        }
        Exceptions.finiteNonNegative("categoryPrior", categoryPrior);
        Exceptions.finiteNonNegative("tokenInCategoryPrior", tokenInCategoryPrior);
        this.setLengthNorm(lengthNorm);
        this.mTotalCaseCount = 0.0;
        this.mCategorySet = new HashSet<String>(categorySet);
        this.mCategories = this.mCategorySet.toArray(Strings.EMPTY_STRING_ARRAY);
        Arrays.sort(this.mCategories);
        this.mTokenizerFactory = tokenizerFactory;
        this.mCategoryPrior = categoryPrior;
        this.mTokenInCategoryPrior = tokenInCategoryPrior;
        this.mTokenToCountsMap = new HashMap<String, double[]>();
        this.mTotalCountsPerCategory = new double[this.mCategories.length];
        this.mCaseCounts = new double[this.mCategories.length];
    }

    public Set<String> categorySet() {
        return Collections.unmodifiableSet(this.mCategorySet);
    }

    public void setLengthNorm(double lengthNorm) {
        if (lengthNorm <= 0.0 || Double.isInfinite(lengthNorm)) {
            String msg = "Length norm must be finite and positive, or Double.NaN. Found lengthNorm=" + lengthNorm;
            throw new IllegalArgumentException(msg);
        }
        this.mLengthNorm = lengthNorm;
    }

    @Override
    public JointClassification classify(CharSequence in) {
        double[] logps = new double[this.mCategories.length];
        char[] cs = Strings.toCharArray(in);
        Tokenizer tokenizer = this.mTokenizerFactory.tokenizer(cs, 0, cs.length);
        int tokenCount = 0;
        for (String token : tokenizer) {
            double[] tokenCounts = this.mTokenToCountsMap.get(token);
            ++tokenCount;
            if (tokenCounts == null) continue;
            for (int i = 0; i < this.mCategories.length; ++i) {
                int n = i;
                logps[n] = logps[n] + com.aliasi.util.Math.log2(this.probTokenByIndexArray(i, tokenCounts));
            }
        }
        if (!Double.isNaN(this.mLengthNorm) && tokenCount > 0) {
            int i = 0;
            while (i < logps.length) {
                int n = i++;
                logps[n] = logps[n] * (this.mLengthNorm / (double)tokenCount);
            }
        }
        for (int i = 0; i < logps.length; ++i) {
            int n = i;
            logps[n] = logps[n] + com.aliasi.util.Math.log2(this.probCatByIndex(i));
        }
        return JointClassification.create(this.mCategories, logps);
    }

    public double lengthNorm() {
        return this.mLengthNorm;
    }

    public boolean isKnownToken(String token) {
        return this.mTokenToCountsMap.containsKey(token);
    }

    public Set<String> knownTokenSet() {
        return Collections.unmodifiableSet(this.mTokenToCountsMap.keySet());
    }

    public double probToken(String token, String cat) {
        int catIndex = this.getIndex(cat);
        double[] tokenCounts = this.mTokenToCountsMap.get(token);
        if (tokenCounts == null) {
            String msg = "Requires known token. Found token=" + token;
            throw new IllegalArgumentException(msg);
        }
        return this.probTokenByIndexArray(catIndex, tokenCounts);
    }

    @Override
    public void compileTo(ObjectOutput out) throws IOException {
        out.writeObject(new Compiler(this));
    }

    public double probCat(String cat) {
        int catIndex = this.getIndex(cat);
        return this.probCatByIndex(catIndex);
    }

    @Override
    public void handle(Classified<CharSequence> classifiedObject) {
        this.handle(classifiedObject.getObject(), classifiedObject.getClassification());
    }

    @Override
    @Deprecated
    public void handle(CharSequence cSeq, Classification classification) {
        this.train(cSeq, classification, 1.0);
    }

    public void trainConditional(CharSequence cSeq, ConditionalClassification classification, double countMultiplier, double minCount) {
        int numCats;
        if (countMultiplier < 0.0 || Double.isNaN(countMultiplier) || Double.isInfinite(countMultiplier)) {
            String msg = "Count multipliers must be finite and non-negative. Found countMultiplier=" + countMultiplier;
            throw new IllegalArgumentException(msg);
        }
        if (minCount < 0.0 || Double.isNaN(minCount) || Double.isInfinite(minCount)) {
            String msg = "Minimum count must be finite non-negative. Found minCount=" + minCount;
            throw new IllegalArgumentException(msg);
        }
        for (numCats = 0; numCats < classification.size() && classification.conditionalProbability(numCats) * countMultiplier >= minCount; ++numCats) {
        }
        ObjectToCounterMap<String> tokenCountMap = this.tokenCountMap(cSeq);
        double lengthMultiplier = this.lengthMultiplier(tokenCountMap);
        double[] lengthNormCatMultipliers = new double[numCats];
        int[] catIndexes = new int[numCats];
        for (int j = 0; j < numCats; ++j) {
            catIndexes[j] = this.getIndex(classification.category(j));
            double count = countMultiplier * classification.conditionalProbability(j);
            this.mTotalCaseCount += count;
            int n = catIndexes[j];
            this.mCaseCounts[n] = this.mCaseCounts[n] + count;
            lengthNormCatMultipliers[j] = lengthMultiplier * count;
        }
        for (Map.Entry entry : tokenCountMap.entrySet()) {
            String token = (String)entry.getKey();
            double tokenCount = ((Counter)entry.getValue()).doubleValue();
            double[] tokenCounts = this.mTokenToCountsMap.get(token);
            if (tokenCounts == null) {
                tokenCounts = new double[this.mCategories.length];
                this.mTokenToCountsMap.put(token, tokenCounts);
            }
            for (int j = 0; j < numCats; ++j) {
                double addend = tokenCount * lengthNormCatMultipliers[j];
                int n = catIndexes[j];
                tokenCounts[n] = tokenCounts[n] + addend;
                int n2 = catIndexes[j];
                this.mTotalCountsPerCategory[n2] = this.mTotalCountsPerCategory[n2] + addend;
            }
        }
    }

    public void train(CharSequence cSeq, Classification classification, double count) {
        if (count == 0.0) {
            return;
        }
        String cat = classification.bestCategory();
        int catIndex = this.getIndex(cat);
        if (this.mCaseCounts[catIndex] < -count) {
            String msg = "Decrement caused negative token count.Revert to previous state. cSeq=" + cSeq + " classification=" + cat + " count=" + count;
            throw new IllegalArgumentException(msg);
        }
        int n = catIndex;
        this.mCaseCounts[n] = this.mCaseCounts[n] + count;
        this.mTotalCaseCount += count;
        ObjectToCounterMap<String> tokenCountMap = this.tokenCountMap(cSeq);
        double lengthMultiplier = this.lengthMultiplier(tokenCountMap);
        double lengthNormCount = lengthMultiplier * count;
        char[] cs = Strings.toCharArray(cSeq);
        Tokenizer tokenizer = this.mTokenizerFactory.tokenizer(cs, 0, cs.length);
        int pos = 0;
        for (String token : tokenizer) {
            double[] tokenCounts = this.mTokenToCountsMap.get(token);
            if (lengthNormCount < 0.0 && (tokenCounts == null || tokenCounts[catIndex] < -lengthNormCount)) {
                int n2 = catIndex;
                this.mCaseCounts[n2] = this.mCaseCounts[n2] - count;
                this.mTotalCaseCount -= count;
                Tokenizer tokenizer2 = this.mTokenizerFactory.tokenizer(cs, 0, cs.length);
                int fixPos = 0;
                for (String token2 : tokenizer2) {
                    if (fixPos >= pos) break;
                    ++fixPos;
                    double[] tokenCounts2 = this.mTokenToCountsMap.get(token2);
                    int n3 = catIndex;
                    tokenCounts2[n3] = tokenCounts2[n3] - lengthNormCount;
                    int n4 = catIndex;
                    this.mTotalCountsPerCategory[n4] = this.mTotalCountsPerCategory[n4] - lengthNormCount;
                }
                String msg = "Decrement caused negative token count.Revert to previous state. cSeq=" + cSeq + " classification=" + cat + " count=" + count;
                throw new IllegalArgumentException(msg);
            }
            ++pos;
            if (tokenCounts == null) {
                tokenCounts = new double[this.mCategories.length];
                this.mTokenToCountsMap.put(token, tokenCounts);
            }
            int n5 = catIndex;
            tokenCounts[n5] = tokenCounts[n5] + lengthNormCount;
            int n6 = catIndex;
            this.mTotalCountsPerCategory[n6] = this.mTotalCountsPerCategory[n6] + lengthNormCount;
        }
    }

    public double log2CaseProb(CharSequence input) {
        JointClassification c = this.classify(input);
        double maxJointLog2P = Double.NEGATIVE_INFINITY;
        for (int rank = 0; rank < c.size(); ++rank) {
            double jointLog2P = c.jointLog2Probability(rank);
            if (!(jointLog2P > maxJointLog2P)) continue;
            maxJointLog2P = jointLog2P;
        }
        double sum = 0.0;
        for (int rank = 0; rank < c.size(); ++rank) {
            sum += Math.pow(2.0, c.jointLog2Probability(rank) - maxJointLog2P);
        }
        return maxJointLog2P + com.aliasi.util.Math.log2(sum);
    }

    public double log2ModelProb() {
        double[] catProbs = new double[this.mCategories.length];
        for (int i = 0; i < this.mCategories.length; ++i) {
            catProbs[i] = this.probCatByIndex(i);
        }
        double sum = Statistics.dirichletLog2Prob(this.mCategoryPrior, catProbs);
        double[] wordProbs = new double[this.mTokenToCountsMap.size()];
        for (int catIndex = 0; catIndex < this.mCategories.length; ++catIndex) {
            int j = 0;
            for (double[] counts : this.mTokenToCountsMap.values()) {
                double totalCountForCat = this.mTotalCountsPerCategory[catIndex];
                wordProbs[j++] = (counts[catIndex] + this.mTokenInCategoryPrior) / (totalCountForCat + (double)this.mCaseCounts.length * this.mTokenInCategoryPrior);
            }
            sum += Statistics.dirichletLog2Prob(this.mTokenInCategoryPrior, wordProbs);
        }
        return sum;
    }

    private Object writeReplace() throws ObjectStreamException {
        return new Serializer(this);
    }

    private double probTokenByIndexArray(int catIndex, double[] tokenCounts) {
        double tokenCatCount = tokenCounts[catIndex];
        double totalCatCount = this.mTotalCountsPerCategory[catIndex];
        return (tokenCatCount + this.mTokenInCategoryPrior) / (totalCatCount + (double)this.mTokenToCountsMap.size() * this.mTokenInCategoryPrior);
    }

    private double probCatByIndex(int catIndex) {
        double caseCountCat = this.mCaseCounts[catIndex];
        return (caseCountCat + this.mCategoryPrior) / (this.mTotalCaseCount + (double)this.mCategories.length * this.mCategoryPrior);
    }

    private ObjectToCounterMap<String> tokenCountMap(CharSequence cSeq) {
        ObjectToCounterMap<String> tokenCountMap = new ObjectToCounterMap<String>();
        char[] cs = Strings.toCharArray(cSeq);
        Tokenizer tokenizer = this.mTokenizerFactory.tokenizer(cs, 0, cs.length);
        for (String token : tokenizer) {
            tokenCountMap.increment(token);
        }
        return tokenCountMap;
    }

    private double lengthMultiplier(ObjectToCounterMap<String> tokenCountMap) {
        if (Double.isNaN(this.mLengthNorm)) {
            return 1.0;
        }
        int length = 0;
        for (Counter counter : tokenCountMap.values()) {
            length += counter.intValue();
        }
        return (double)length != 0.0 ? this.mLengthNorm / (double)length : 1.0;
    }

    private int getIndex(String cat) {
        int catIndex = Arrays.binarySearch(this.mCategories, cat);
        if (catIndex < 0) {
            String msg = "Unknown category.  Require category in category set. Found category=" + cat + " category set=" + this.mCategorySet;
            throw new IllegalArgumentException(msg);
        }
        return catIndex;
    }

    @Deprecated
    public static Iterator<TradNaiveBayesClassifier> em(TradNaiveBayesClassifier initialClassifier, Factory<TradNaiveBayesClassifier> classifierFactory, Corpus<ClassificationHandler<CharSequence, Classification>> labeledData, Corpus<TextHandler> unlabeledData, double minTokenCount) throws IOException {
        return new EmIterator(initialClassifier, classifierFactory, new ClassificationHandlerCorpusAdapter2<CharSequence>(labeledData), new TextHandlerCorpusAdapter2(unlabeledData), minTokenCount);
    }

    public static Iterator<TradNaiveBayesClassifier> emIterator(TradNaiveBayesClassifier initialClassifier, Factory<TradNaiveBayesClassifier> classifierFactory, Corpus<ObjectHandler<Classified<CharSequence>>> labeledData, Corpus<ObjectHandler<CharSequence>> unlabeledData, double minTokenCount) throws IOException {
        return new EmIterator(initialClassifier, classifierFactory, labeledData, unlabeledData, minTokenCount);
    }

    @Deprecated
    public static TradNaiveBayesClassifier em(TradNaiveBayesClassifier initialClassifier, Factory<TradNaiveBayesClassifier> classifierFactory, Corpus<ClassificationHandler<CharSequence, Classification>> labeledData, Corpus<TextHandler> unlabeledData, double minTokenCount, int maxEpochs, double minImprovement, Reporter reporter) throws IOException {
        return TradNaiveBayesClassifier.emTrain(initialClassifier, classifierFactory, new ClassificationHandlerCorpusAdapter2<CharSequence>(labeledData), new TextHandlerCorpusAdapter2(unlabeledData), minTokenCount, maxEpochs, minImprovement, reporter);
    }

    public static TradNaiveBayesClassifier emTrain(TradNaiveBayesClassifier initialClassifier, Factory<TradNaiveBayesClassifier> classifierFactory, Corpus<ObjectHandler<Classified<CharSequence>>> labeledData, Corpus<ObjectHandler<CharSequence>> unlabeledData, double minTokenCount, int maxEpochs, double minImprovement, Reporter reporter) throws IOException {
        if (reporter == null) {
            reporter = Reporters.silent();
        }
        long startTime = System.currentTimeMillis();
        double lastLogProb = Double.NEGATIVE_INFINITY;
        Iterator<TradNaiveBayesClassifier> it = TradNaiveBayesClassifier.emIterator(initialClassifier, classifierFactory, labeledData, unlabeledData, minTokenCount);
        TradNaiveBayesClassifier classifier = null;
        for (int epoch = 0; it.hasNext() && epoch < maxEpochs; ++epoch) {
            classifier = it.next();
            double modelLogProb = classifier.log2ModelProb();
            double dataLogProb = TradNaiveBayesClassifier.dataProb(classifier, labeledData, unlabeledData);
            double logProb = modelLogProb + dataLogProb;
            double relativeDiff = TradNaiveBayesClassifier.relativeDiff(lastLogProb, logProb);
            if (reporter.isDebugEnabled()) {
                Formatter formatter = new Formatter();
                formatter.format("epoch=%4d   dataLogProb=%15.2f   modelLogProb=%15.2f   logProb=%15.2f   diff=%15.12f", epoch, dataLogProb, modelLogProb, logProb, relativeDiff);
                String msg = formatter.toString();
                reporter.debug(msg);
            }
            if (!Double.isNaN(lastLogProb) && relativeDiff < minImprovement) {
                reporter.info("Converged");
                return classifier;
            }
            lastLogProb = logProb;
        }
        return classifier;
    }

    static double dataProb(TradNaiveBayesClassifier classifier, Corpus<ObjectHandler<Classified<CharSequence>>> labeledData, Corpus<ObjectHandler<CharSequence>> unlabeledData) throws IOException {
        CaseProbAccumulator accum = new CaseProbAccumulator(classifier);
        labeledData.visitTrain(accum.supHandler());
        unlabeledData.visitTrain(accum);
        return accum.mCaseProb;
    }

    static double relativeDiff(double x, double y) {
        return 2.0 * Math.abs(x - y) / (Math.abs(x) + Math.abs(y));
    }

    /*
     * This class specifies class file version 49.0 but uses Java 6 signatures.  Assumed Java 6.
     */
    @Deprecated
    static class TextHandlerCorpusAdapter2
    extends Corpus<ObjectHandler<CharSequence>> {
        @Deprecated
        private final Corpus<TextHandler> mCorpus;

        @Deprecated
        public TextHandlerCorpusAdapter2(Corpus<TextHandler> corpus) {
            this.mCorpus = corpus;
        }

        @Override
        public void visitTrain(ObjectHandler<CharSequence> handler) throws IOException {
            this.mCorpus.visitTrain(new HandlerAdapter(handler));
        }

        @Override
        public void visitTest(ObjectHandler<CharSequence> handler) throws IOException {
            this.mCorpus.visitTest(new HandlerAdapter(handler));
        }

        /*
         * This class specifies class file version 49.0 but uses Java 6 signatures.  Assumed Java 6.
         */
        @Deprecated
        static class HandlerAdapter
        implements TextHandler {
            private final ObjectHandler<CharSequence> mHandler;

            HandlerAdapter(ObjectHandler<CharSequence> handler) {
                this.mHandler = handler;
            }

            @Override
            public void handle(char[] cs, int start, int len) {
                this.mHandler.handle(new String(cs, start, len));
            }
        }
    }

    /*
     * This class specifies class file version 49.0 but uses Java 6 signatures.  Assumed Java 6.
     */
    private static class CompiledTradNaiveBayesClassifier
    implements Classifier<CharSequence, JointClassification>,
    JointClassifier<CharSequence> {
        private final TokenizerFactory mTokenizerFactory;
        private final String[] mCategories;
        private final Map<String, double[]> mTokenToLog2ProbsInCats;
        private final double[] mLog2CatProbs;
        private final double mLengthNorm;

        CompiledTradNaiveBayesClassifier(String[] categories, TokenizerFactory tokenizerFactory, Map<String, double[]> tokenToLog2ProbsInCats, double[] log2CatProbs, double lengthNorm) {
            this.mCategories = categories;
            this.mTokenizerFactory = tokenizerFactory;
            this.mTokenToLog2ProbsInCats = tokenToLog2ProbsInCats;
            this.mLog2CatProbs = log2CatProbs;
            this.mLengthNorm = lengthNorm;
        }

        @Override
        public JointClassification classify(CharSequence in) {
            double[] logps = new double[this.mCategories.length];
            char[] cs = Strings.toCharArray(in);
            Tokenizer tokenizer = this.mTokenizerFactory.tokenizer(cs, 0, cs.length);
            int tokenCount = 0;
            for (String token : tokenizer) {
                double[] tokenLog2Probs = this.mTokenToLog2ProbsInCats.get(token);
                ++tokenCount;
                if (tokenLog2Probs == null) continue;
                for (int i = 0; i < logps.length; ++i) {
                    int n = i;
                    logps[n] = logps[n] + tokenLog2Probs[i];
                }
            }
            if (!Double.isNaN(this.mLengthNorm) && tokenCount > 0) {
                int i = 0;
                while (i < logps.length) {
                    int n = i++;
                    logps[n] = logps[n] * (this.mLengthNorm / (double)tokenCount);
                }
            }
            for (int i = 0; i < logps.length; ++i) {
                int n = i;
                logps[n] = logps[n] + this.mLog2CatProbs[i];
            }
            return JointClassification.create(this.mCategories, logps);
        }
    }

    /*
     * This class specifies class file version 49.0 but uses Java 6 signatures.  Assumed Java 6.
     */
    private static class CompiledBinaryTradNaiveBayesClassifier
    implements Classifier<CharSequence, JointClassification>,
    JointClassifier<CharSequence> {
        private final TokenizerFactory mTokenizerFactory;
        private final Map<String, Double> mTokenToLog2ProbDiff;
        private final double mLog2CatProbDiff;
        private final double mLengthNorm;
        private final String[] mCats01;
        private final String[] mCats10;

        CompiledBinaryTradNaiveBayesClassifier(String[] categories, TokenizerFactory tokenizerFactory, Map<String, double[]> tokenToLog2ProbsInCats, double[] log2CatProbs, double lengthNorm) {
            this.mTokenizerFactory = tokenizerFactory;
            this.mTokenToLog2ProbDiff = new HashMap<String, Double>();
            for (Map.Entry<String, double[]> entry : tokenToLog2ProbsInCats.entrySet()) {
                String token = entry.getKey();
                double[] log2Probs = entry.getValue();
                double log2ProbDiff = (log2Probs[0] - log2Probs[1]) / com.aliasi.util.Math.LOG2_E;
                this.mTokenToLog2ProbDiff.put(token, log2ProbDiff);
            }
            this.mLog2CatProbDiff = (log2CatProbs[0] - log2CatProbs[1]) / com.aliasi.util.Math.LOG2_E;
            this.mLengthNorm = lengthNorm;
            this.mCats01 = new String[]{categories[0], categories[1]};
            this.mCats10 = new String[]{categories[1], categories[0]};
        }

        @Override
        public JointClassification classify(CharSequence in) {
            double logDiff = 0.0;
            char[] cs = Strings.toCharArray(in);
            Tokenizer tokenizer = this.mTokenizerFactory.tokenizer(cs, 0, cs.length);
            int tokenCount = 0;
            for (String token : tokenizer) {
                Double tokLogDiff = this.mTokenToLog2ProbDiff.get(token);
                ++tokenCount;
                if (tokLogDiff == null) continue;
                logDiff += tokLogDiff.doubleValue();
            }
            if (!Double.isNaN(this.mLengthNorm) && tokenCount > 0) {
                logDiff *= this.mLengthNorm / (double)tokenCount;
            }
            return this.classification(logDiff + this.mLog2CatProbDiff);
        }

        JointClassification classification(double logDiff) {
            double expProd = Math.exp(logDiff);
            double p0 = expProd / (1.0 + expProd);
            double p1 = 1.0 - p0;
            double log2P0 = com.aliasi.util.Math.log2(p0);
            double log2P1 = com.aliasi.util.Math.log2(p1);
            return p0 > p1 ? new JointClassification(this.mCats01, new double[]{log2P0, log2P1}) : new JointClassification(this.mCats10, new double[]{log2P1, log2P0});
        }
    }

    static class Compiler
    extends AbstractExternalizable {
        static final long serialVersionUID = 5689464666886334529L;
        private final TradNaiveBayesClassifier mClassifier;

        public Compiler() {
            this(null);
        }

        public Compiler(TradNaiveBayesClassifier classifier) {
            this.mClassifier = classifier;
        }

        public void writeExternal(ObjectOutput objOut) throws IOException {
            objOut.writeInt(this.mClassifier.mCategories.length);
            for (int i = 0; i < this.mClassifier.mCategories.length; ++i) {
                objOut.writeUTF(this.mClassifier.mCategories[i]);
            }
            AbstractExternalizable.compileOrSerialize(this.mClassifier.mTokenizerFactory, objOut);
            objOut.writeInt(this.mClassifier.mTokenToCountsMap.size());
            for (Map.Entry entry : this.mClassifier.mTokenToCountsMap.entrySet()) {
                objOut.writeUTF((String)entry.getKey());
                double[] tokenCounts = (double[])entry.getValue();
                for (int i = 0; i < this.mClassifier.mCategories.length; ++i) {
                    double log2Prob = com.aliasi.util.Math.log2(this.mClassifier.probTokenByIndexArray(i, tokenCounts));
                    if (log2Prob > 0.0) {
                        String msg = "key=" + (String)entry.getKey() + " i=" + i + " log2Prob=" + log2Prob + " prob=" + this.mClassifier.probTokenByIndexArray(i, tokenCounts) + " token counts[" + i + "]=" + tokenCounts[i] + " totalCatCount=" + this.mClassifier.mTotalCountsPerCategory[i] + " mTokenToCountsMap.size()=" + this.mClassifier.mTokenToCountsMap.size();
                        throw new IllegalArgumentException(msg);
                    }
                    objOut.writeDouble(log2Prob);
                }
            }
            for (int i = 0; i < this.mClassifier.mCategories.length; ++i) {
                objOut.writeDouble(com.aliasi.util.Math.log2(this.mClassifier.probCatByIndex(i)));
            }
            objOut.writeDouble(this.mClassifier.mLengthNorm);
        }

        public Object read(ObjectInput in) throws ClassNotFoundException, IOException {
            int numCategories = in.readInt();
            String[] categories = new String[numCategories];
            for (int i = 0; i < numCategories; ++i) {
                categories[i] = in.readUTF();
            }
            TokenizerFactory tokenizerFactory = (TokenizerFactory)in.readObject();
            int size = in.readInt();
            HashMap<String, double[]> tokenToLog2ProbsInCats = new HashMap<String, double[]>(size * 3 / 2);
            for (int k = 0; k < size; ++k) {
                String token = in.readUTF();
                double[] log2ProbsInCats = new double[numCategories];
                for (int i = 0; i < numCategories; ++i) {
                    log2ProbsInCats[i] = in.readDouble();
                }
                tokenToLog2ProbsInCats.put(token, log2ProbsInCats);
            }
            double[] log2CatProbs = new double[numCategories];
            for (int i = 0; i < numCategories; ++i) {
                log2CatProbs[i] = in.readDouble();
            }
            double lengthNorm = in.readDouble();
            return categories.length == 2 ? new CompiledBinaryTradNaiveBayesClassifier(categories, tokenizerFactory, tokenToLog2ProbsInCats, log2CatProbs, lengthNorm) : new CompiledTradNaiveBayesClassifier(categories, tokenizerFactory, tokenToLog2ProbsInCats, log2CatProbs, lengthNorm);
        }
    }

    static class Serializer
    extends AbstractExternalizable {
        static final long serialVersionUID = -4786039228920809976L;
        private final TradNaiveBayesClassifier mClassifier;

        public Serializer(TradNaiveBayesClassifier classifier) {
            this.mClassifier = classifier;
        }

        public Serializer() {
            this(null);
        }

        public Object read(ObjectInput in) throws ClassNotFoundException, IOException {
            int numCats = in.readInt();
            String[] categories = new String[numCats];
            for (int i = 0; i < numCats; ++i) {
                categories[i] = in.readUTF();
            }
            TokenizerFactory tokenizerFactory = (TokenizerFactory)in.readObject();
            double catPrior = in.readDouble();
            double tokenInCatPrior = in.readDouble();
            int tokenToCountsMapSize = in.readInt();
            HashMap<String, double[]> tokenToCountsMap = new HashMap<String, double[]>(tokenToCountsMapSize * 3 / 2);
            for (int k = 0; k < tokenToCountsMapSize; ++k) {
                String key = in.readUTF();
                double[] vals = new double[categories.length];
                for (int i = 0; i < categories.length; ++i) {
                    vals[i] = in.readDouble();
                }
                tokenToCountsMap.put(key, vals);
            }
            double[] totalCountsPerCategory = new double[categories.length];
            for (int i = 0; i < categories.length; ++i) {
                totalCountsPerCategory[i] = in.readDouble();
            }
            double[] caseCounts = new double[categories.length];
            for (int i = 0; i < categories.length; ++i) {
                caseCounts[i] = in.readDouble();
            }
            double totalCaseCount = in.readDouble();
            double lengthNorm = in.readDouble();
            return new TradNaiveBayesClassifier(categories, tokenizerFactory, catPrior, tokenInCatPrior, tokenToCountsMap, totalCountsPerCategory, caseCounts, totalCaseCount, lengthNorm);
        }

        public void writeExternal(ObjectOutput objOut) throws IOException {
            int i;
            objOut.writeInt(this.mClassifier.mCategories.length);
            for (String category : this.mClassifier.mCategories) {
                objOut.writeUTF(category);
            }
            objOut.writeObject(this.mClassifier.mTokenizerFactory);
            objOut.writeDouble(this.mClassifier.mCategoryPrior);
            objOut.writeDouble(this.mClassifier.mTokenInCategoryPrior);
            objOut.writeInt(this.mClassifier.mTokenToCountsMap.size());
            for (Map.Entry entry : this.mClassifier.mTokenToCountsMap.entrySet()) {
                objOut.writeUTF((String)entry.getKey());
                double[] vals = (double[])entry.getValue();
                for (int i2 = 0; i2 < this.mClassifier.mCategories.length; ++i2) {
                    objOut.writeDouble(vals[i2]);
                }
            }
            for (i = 0; i < this.mClassifier.mCategories.length; ++i) {
                objOut.writeDouble(this.mClassifier.mTotalCountsPerCategory[i]);
            }
            for (i = 0; i < this.mClassifier.mCategories.length; ++i) {
                objOut.writeDouble(this.mClassifier.mCaseCounts[i]);
            }
            objOut.writeDouble(this.mClassifier.mTotalCaseCount);
            objOut.writeDouble(this.mClassifier.mLengthNorm);
        }
    }

    /*
     * This class specifies class file version 49.0 but uses Java 6 signatures.  Assumed Java 6.
     */
    static class EmIterator
    extends Iterators.Buffered<TradNaiveBayesClassifier> {
        private final Factory<TradNaiveBayesClassifier> mClassifierFactory;
        private final Corpus<ObjectHandler<Classified<CharSequence>>> mLabeledData;
        private final Corpus<ObjectHandler<CharSequence>> mUnlabeledData;
        private final double mMinTokenCount;
        private JointClassifier<CharSequence> mLastClassifier;

        EmIterator(TradNaiveBayesClassifier initialClassifier, Factory<TradNaiveBayesClassifier> classifierFactory, Corpus<ObjectHandler<Classified<CharSequence>>> labeledData, Corpus<ObjectHandler<CharSequence>> unlabeledData, double minTokenCount) {
            this.mClassifierFactory = classifierFactory;
            this.mLabeledData = labeledData;
            this.mUnlabeledData = unlabeledData;
            this.mMinTokenCount = minTokenCount;
            this.trainSup(labeledData, initialClassifier);
            this.compile(initialClassifier);
        }

        @Override
        public TradNaiveBayesClassifier bufferNext() {
            TradNaiveBayesClassifier classifier = this.mClassifierFactory.create();
            this.trainSup(this.mLabeledData, classifier);
            this.trainUnsup(this.mUnlabeledData, classifier);
            this.compile(classifier);
            return classifier;
        }

        void trainSup(Corpus<ObjectHandler<Classified<CharSequence>>> labeledData, TradNaiveBayesClassifier classifier) {
            try {
                labeledData.visitTrain(classifier);
            }
            catch (IOException e) {
                throw new IllegalStateException("Error during labeled training", e);
            }
        }

        void trainUnsup(Corpus<ObjectHandler<CharSequence>> unlabeledData, final TradNaiveBayesClassifier classifier) {
            try {
                unlabeledData.visitTrain(new ObjectHandler<CharSequence>(){

                    @Override
                    public void handle(CharSequence cSeq) {
                        JointClassification c = EmIterator.this.mLastClassifier.classify(cSeq);
                        classifier.trainConditional(cSeq, c, 1.0, EmIterator.this.mMinTokenCount);
                    }
                });
            }
            catch (IOException e) {
                throw new IllegalStateException("Error during unlabeled training", e);
            }
        }

        void compile(TradNaiveBayesClassifier classifier) {
            try {
                JointClassifier lastClassifier;
                this.mLastClassifier = lastClassifier = (JointClassifier)AbstractExternalizable.compile(classifier);
            }
            catch (IOException e) {
                this.mLastClassifier = null;
                throw new IllegalStateException("Error during compilation.", e);
            }
            catch (ClassNotFoundException e) {
                this.mLastClassifier = null;
                throw new IllegalStateException("Error during compilation.", e);
            }
        }
    }

    /*
     * This class specifies class file version 49.0 but uses Java 6 signatures.  Assumed Java 6.
     */
    static class CaseProbAccumulator
    implements ObjectHandler<CharSequence> {
        double mCaseProb = 0.0;
        final TradNaiveBayesClassifier mClassifier;

        CaseProbAccumulator(TradNaiveBayesClassifier classifier) {
            this.mClassifier = classifier;
        }

        @Override
        public void handle(CharSequence cSeq) {
            this.mCaseProb += this.mClassifier.log2CaseProb(cSeq);
        }

        public ObjectHandler<Classified<CharSequence>> supHandler() {
            final CaseProbAccumulator cSeqHandler = this;
            return new ObjectHandler<Classified<CharSequence>>(){

                @Override
                public void handle(Classified<CharSequence> classified) {
                    cSeqHandler.handle(classified.getObject());
                }
            };
        }
    }
}

