/*
 * Decompiled with CFR 0.152.
 */
package com.aliasi.crf;

import com.aliasi.corpus.Corpus;
import com.aliasi.corpus.ObjectHandler;
import com.aliasi.crf.ChainCrfFeatureExtractor;
import com.aliasi.crf.ChainCrfFeatures;
import com.aliasi.crf.ForwardBackwardTagLattice;
import com.aliasi.features.Features;
import com.aliasi.io.Reporter;
import com.aliasi.io.Reporters;
import com.aliasi.matrix.DenseVector;
import com.aliasi.matrix.Matrices;
import com.aliasi.matrix.Vector;
import com.aliasi.stats.AnnealingSchedule;
import com.aliasi.stats.RegressionPrior;
import com.aliasi.symbol.MapSymbolTable;
import com.aliasi.symbol.SymbolTable;
import com.aliasi.tag.MarginalTagger;
import com.aliasi.tag.NBestTagger;
import com.aliasi.tag.ScoredTagging;
import com.aliasi.tag.TagLattice;
import com.aliasi.tag.Tagger;
import com.aliasi.tag.Tagging;
import com.aliasi.util.AbstractExternalizable;
import com.aliasi.util.BoundedPriorityQueue;
import com.aliasi.util.Iterators;
import com.aliasi.util.ObjectToCounterMap;
import com.aliasi.util.Scored;
import com.aliasi.util.ScoredObject;
import com.aliasi.util.Strings;
import java.io.IOException;
import java.io.ObjectInput;
import java.io.ObjectOutput;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.Formatter;
import java.util.IllegalFormatException;
import java.util.Iterator;
import java.util.List;
import java.util.Locale;

/*
 * This class specifies class file version 49.0 but uses Java 6 signatures.  Assumed Java 6.
 */
public class ChainCrf<E>
implements Tagger<E>,
NBestTagger<E>,
MarginalTagger<E>,
Serializable {
    static final long serialVersionUID = -4868542587460878290L;
    private final List<String> mTagList;
    private final boolean[] mLegalTagStarts;
    private final boolean[] mLegalTagEnds;
    private final boolean[][] mLegalTagTransitions;
    private final Vector[] mCoefficients;
    private final SymbolTable mFeatureSymbolTable;
    private final ChainCrfFeatureExtractor<E> mFeatureExtractor;
    private final boolean mAddInterceptFeature;
    private final int mNumDimensions;
    static final String INTERCEPT_FEATURE_NAME = "*&^INTERCEPT%$^&**";
    static final double[][] EMPTY_DOUBLE_2D_ARRAY = new double[0][];
    static final double[][][] EMPTY_DOUBLE_3D_ARRAY = new double[0][][];

    public ChainCrf(String[] tags, Vector[] coefficients, SymbolTable featureSymbolTable, ChainCrfFeatureExtractor<E> featureExtractor, boolean addInterceptFeature) {
        this(tags, ChainCrf.trueArray(tags.length), ChainCrf.trueArray(tags.length), ChainCrf.trueArray(tags.length, tags.length), coefficients, featureSymbolTable, featureExtractor, addInterceptFeature);
    }

    public ChainCrf(String[] tags, boolean[] legalTagStarts, boolean[] legalTagEnds, boolean[][] legalTagTransitions, Vector[] coefficients, SymbolTable featureSymbolTable, ChainCrfFeatureExtractor<E> featureExtractor, boolean addInterceptFeature) {
        String msg;
        if (tags.length < 1) {
            msg = "Require at least one tag.";
        }
        if (tags.length != coefficients.length) {
            msg = "Require tags and coefficients to be same length. Found tags.length=" + tags.length + " coefficients.length=" + coefficients.length;
            throw new IllegalArgumentException(msg);
        }
        if (tags.length != legalTagStarts.length) {
            msg = "Tags and starts must be same length. Found tags.length=" + tags.length + " legalTagStarts.length=" + legalTagStarts.length;
            throw new IllegalArgumentException(msg);
        }
        if (tags.length != legalTagEnds.length) {
            msg = "Tags and starts must be same length. Found tags.length=" + tags.length + " legalTagStarts.length=" + legalTagStarts.length;
            throw new IllegalArgumentException(msg);
        }
        if (tags.length != legalTagTransitions.length) {
            msg = "Tags and transitions must be same length. Found tags.length=" + tags.length + " legalTagTransitions.length=" + legalTagTransitions.length;
            throw new IllegalArgumentException(msg);
        }
        for (int i = 0; i < legalTagTransitions.length; ++i) {
            if (tags.length == legalTagTransitions[i].length) continue;
            String msg2 = "Tags and transition rows must be same length. Found tags.length=" + tags.length + " legalTagTransitions[" + i + "].length=" + legalTagTransitions[i].length;
            throw new IllegalArgumentException(msg2);
        }
        for (int k = 1; k < coefficients.length; ++k) {
            if (coefficients[0].numDimensions() == coefficients[k].numDimensions()) continue;
            String msg3 = "All coefficients must be same length. Found coefficents[0].numDimensions()=" + coefficients[0].numDimensions() + " coefficients[" + k + "].numDimensions()=" + coefficients[k].numDimensions();
            throw new IllegalArgumentException(msg3);
        }
        this.mTagList = Arrays.asList(tags);
        this.mLegalTagStarts = legalTagStarts;
        this.mLegalTagEnds = legalTagEnds;
        this.mLegalTagTransitions = legalTagTransitions;
        this.mCoefficients = coefficients;
        this.mNumDimensions = coefficients[0].numDimensions();
        this.mFeatureSymbolTable = featureSymbolTable;
        this.mFeatureExtractor = featureExtractor;
        this.mAddInterceptFeature = addInterceptFeature;
    }

    public List<String> tags() {
        return Collections.unmodifiableList(this.mTagList);
    }

    public String tag(int k) {
        return this.mTagList.get(k);
    }

    public Vector[] coefficients() {
        Vector[] result = new Vector[this.mCoefficients.length];
        for (int k = 0; k < result.length; ++k) {
            result[k] = Matrices.unmodifiableVector(this.mCoefficients[k]);
        }
        return result;
    }

    public SymbolTable featureSymbolTable() {
        return MapSymbolTable.unmodifiableView(this.mFeatureSymbolTable);
    }

    public ChainCrfFeatureExtractor<E> featureExtractor() {
        return this.mFeatureExtractor;
    }

    public boolean addInterceptFeature() {
        return this.mAddInterceptFeature;
    }

    @Override
    public Tagging<E> tag(List<E> tokens) {
        int numTokens = tokens.size();
        if (numTokens == 0) {
            return new Tagging<E>(tokens, Collections.<String>emptyList());
        }
        int numTags = this.mTagList.size();
        int numDimensions = this.mFeatureSymbolTable.numSymbols();
        double[][] bestScores = new double[numTokens][numTags];
        int[][] backPointers = new int[numTokens - 1][numTags];
        ChainCrfFeatures<E> features = this.mFeatureExtractor.extract(tokens, this.mTagList);
        Vector nodeVector0 = this.nodeFeatures(0, features);
        for (int k = 0; k < numTags; ++k) {
            bestScores[0][k] = this.mLegalTagStarts[k] ? nodeVector0.dotProduct(this.mCoefficients[k]) : Double.NEGATIVE_INFINITY;
        }
        Vector[] edgeVectors = new Vector[numTags];
        for (int n = 1; n < numTokens; ++n) {
            Vector nodeVector = this.nodeFeatures(n, features);
            for (int kMinus1 = 0; kMinus1 < numTags; ++kMinus1) {
                edgeVectors[kMinus1] = this.edgeFeatures(n, kMinus1, features);
            }
            for (int k = 0; k < numTags; ++k) {
                if (n == numTokens - 1 && !this.mLegalTagEnds[k]) {
                    bestScores[n][k] = Double.NEGATIVE_INFINITY;
                    backPointers[n - 1][k] = -1;
                    continue;
                }
                double bestScore = Double.NEGATIVE_INFINITY;
                int backPtr = -1;
                double nodeScore = nodeVector.dotProduct(this.mCoefficients[k]);
                for (int kMinus1 = 0; kMinus1 < numTags; ++kMinus1) {
                    double score;
                    if (!this.mLegalTagTransitions[kMinus1][k] || !((score = nodeScore + edgeVectors[kMinus1].dotProduct(this.mCoefficients[k]) + bestScores[n - 1][kMinus1]) > bestScore)) continue;
                    bestScore = score;
                    backPtr = kMinus1;
                }
                bestScores[n][k] = bestScore;
                backPointers[n - 1][k] = backPtr;
            }
        }
        double bestScore = Double.NEGATIVE_INFINITY;
        int bestFinalTag = -1;
        for (int k = 0; k < numTags; ++k) {
            if (!(bestScores[numTokens - 1][k] > bestScore)) continue;
            bestScore = bestScores[numTokens - 1][k];
            bestFinalTag = k;
        }
        ArrayList<String> tags = new ArrayList<String>(numTokens);
        int bestPreviousTag = bestFinalTag;
        tags.add(this.mTagList.get(bestFinalTag));
        int n = numTokens - 1;
        while (--n >= 0) {
            bestPreviousTag = backPointers[n][bestPreviousTag];
            tags.add(this.mTagList.get(bestPreviousTag));
        }
        Collections.reverse(tags);
        return new Tagging<E>(tokens, tags);
    }

    @Override
    public Iterator<ScoredTagging<E>> tagNBest(List<E> tokens, int maxResults) {
        if (tokens.size() == 0) {
            ScoredTagging<E> scoredTagging = new ScoredTagging<E>(tokens, Collections.<String>emptyList(), 0.0);
            return Iterators.singleton(scoredTagging);
        }
        return new NBestIterator(tokens, false, maxResults);
    }

    @Override
    public Iterator<ScoredTagging<E>> tagNBestConditional(List<E> tokens, int maxResults) {
        if (tokens.size() == 0) {
            ScoredTagging<E> scoredTagging = new ScoredTagging<E>(tokens, Collections.<String>emptyList(), 0.0);
            return Iterators.singleton(scoredTagging);
        }
        return new NBestIterator(tokens, true, maxResults);
    }

    @Override
    public TagLattice<E> tagMarginal(List<E> tokens) {
        if (tokens.size() == 0) {
            return new ForwardBackwardTagLattice<E>(tokens, this.mTagList, EMPTY_DOUBLE_2D_ARRAY, EMPTY_DOUBLE_2D_ARRAY, EMPTY_DOUBLE_3D_ARRAY, 0.0);
        }
        FeatureVectors features = this.features(tokens);
        TagLattice<E> lattice = this.forwardBackward(tokens, features);
        return lattice;
    }

    public String toString() {
        StringBuilder sb = new StringBuilder();
        sb.append("Feature Extractor=" + this.featureExtractor());
        sb.append("\n");
        sb.append("Add intercept=" + this.addInterceptFeature());
        sb.append("\n");
        List<String> tags = this.tags();
        sb.append("Tags=" + tags);
        sb.append("\n");
        Vector[] coeffs = this.coefficients();
        SymbolTable symTab = this.featureSymbolTable();
        sb.append("Coefficients=\n");
        for (int i = 0; i < coeffs.length; ++i) {
            sb.append(tags.get(i));
            sb.append("  ");
            int[] nzDims = coeffs[i].nonZeroDimensions();
            for (int k = 0; k < nzDims.length; ++k) {
                if (k > 0) {
                    sb.append(", ");
                }
                int d = nzDims[k];
                sb.append(symTab.idToSymbol(d));
                sb.append("=");
                sb.append(coeffs[i].value(d));
            }
            sb.append("\n");
        }
        return sb.toString();
    }

    Object writeReplace() {
        return new Serializer(this);
    }

    private Vector nodeFeatures(int position, ChainCrfFeatures<E> features) {
        return Features.toVector(features.nodeFeatures(position), this.mFeatureSymbolTable, this.mNumDimensions, this.mAddInterceptFeature);
    }

    private Vector edgeFeatures(int position, int lastTagIndex, ChainCrfFeatures<E> features) {
        return Features.toVector(features.edgeFeatures(position, lastTagIndex), this.mFeatureSymbolTable, this.mNumDimensions, this.mAddInterceptFeature);
    }

    private FeatureVectors features(List<E> tokens) {
        int numTags = this.mTagList.size();
        int numDimensions = this.mFeatureSymbolTable.numSymbols();
        if (tokens.size() == 0) {
            return null;
        }
        ChainCrfFeatures<E> features = this.mFeatureExtractor.extract(tokens, this.mTagList);
        Vector[] nodeFeatureVectors = new Vector[tokens.size()];
        for (int n = 0; n < tokens.size(); ++n) {
            nodeFeatureVectors[n] = this.nodeFeatures(n, features);
        }
        Vector[][] edgeFeatureVectorss = new Vector[tokens.size() - 1][this.mTagList.size()];
        for (int n = 1; n < tokens.size(); ++n) {
            for (int k = 0; k < numTags; ++k) {
                edgeFeatureVectorss[n - 1][k] = this.edgeFeatures(n, k, features);
            }
        }
        return new FeatureVectors(nodeFeatureVectors, edgeFeatureVectorss);
    }

    TagLattice<E> forwardBackward(List<E> tokens, FeatureVectors features) {
        double[][][] logPotentials;
        int numTokens = tokens.size();
        int numTags = this.mTagList.size();
        double[] logPotentials0Begin = new double[numTags];
        for (int kTo = 0; kTo < numTags; ++kTo) {
            logPotentials0Begin[kTo] = this.mLegalTagStarts[kTo] ? features.mNodeFeatureVectors[0].dotProduct(this.mCoefficients[kTo]) : Double.NEGATIVE_INFINITY;
        }
        double[][][] arr$ = logPotentials = new double[numTokens - 1][numTags][numTags];
        int len$ = arr$.length;
        for (int i$ = 0; i$ < len$; ++i$) {
            double[][] logPotentials2;
            for (double[] logPotentials3 : logPotentials2 = arr$[i$]) {
                Arrays.fill(logPotentials3, Double.NEGATIVE_INFINITY);
            }
        }
        for (int nTo = 1; nTo < numTokens; ++nTo) {
            for (int kTo = 0; kTo < numTags; ++kTo) {
                if (nTo == numTokens - 1 && !this.mLegalTagEnds[kTo]) continue;
                double nodePotentialKTo = features.mNodeFeatureVectors[nTo].dotProduct(this.mCoefficients[kTo]);
                for (int kFrom = 0; kFrom < numTags; ++kFrom) {
                    if (!this.mLegalTagTransitions[kFrom][kTo]) continue;
                    logPotentials[nTo - 1][kFrom][kTo] = features.mEdgeFeatureVectorss[nTo - 1][kFrom].dotProduct(this.mCoefficients[kTo]) + nodePotentialKTo;
                }
            }
        }
        double[] buf = new double[numTags];
        double[][] logForwards = new double[numTokens][numTags];
        for (int kTo = 0; kTo < numTags; ++kTo) {
            logForwards[0][kTo] = logPotentials0Begin[kTo];
        }
        for (int nTo = 1; nTo < numTokens; ++nTo) {
            for (int kTo = 0; kTo < numTags; ++kTo) {
                for (int kFrom = 0; kFrom < numTags; ++kFrom) {
                    buf[kFrom] = logForwards[nTo - 1][kFrom] + logPotentials[nTo - 1][kFrom][kTo];
                }
                logForwards[nTo][kTo] = com.aliasi.util.Math.logSumOfExponentials(buf);
            }
        }
        double[][] logBackwards = new double[numTokens][numTags];
        int nFrom = numTokens - 1;
        while (--nFrom >= 0) {
            for (int kFrom = 0; kFrom < numTags; ++kFrom) {
                for (int kTo = 0; kTo < numTags; ++kTo) {
                    buf[kTo] = logBackwards[nFrom + 1][kTo] + logPotentials[nFrom][kFrom][kTo];
                }
                logBackwards[nFrom][kFrom] = com.aliasi.util.Math.logSumOfExponentials(buf);
            }
        }
        double logZ = com.aliasi.util.Math.logSumOfExponentials(logForwards[numTokens - 1]);
        return new ForwardBackwardTagLattice<E>(tokens, this.mTagList, logForwards, logBackwards, logPotentials, logZ, false);
    }

    static boolean[] legalStarts(int[][] tagIdss, int numTags) {
        boolean[] legalStarts = new boolean[numTags];
        for (int[] tagIds : tagIdss) {
            if (tagIds.length <= 0) continue;
            legalStarts[tagIds[0]] = true;
        }
        return legalStarts;
    }

    static boolean[] legalEnds(int[][] tagIdss, int numTags) {
        boolean[] legalEnds = new boolean[numTags];
        for (int[] tagIds : tagIdss) {
            if (tagIds.length <= 0) continue;
            legalEnds[tagIds[tagIds.length - 1]] = true;
        }
        return legalEnds;
    }

    static boolean[][] legalTransitions(int[][] tagIdss, int numTags) {
        boolean[][] legalTransitions = new boolean[numTags][numTags];
        for (int[] tagIds : tagIdss) {
            for (int i = 1; i < tagIds.length; ++i) {
                legalTransitions[tagIds[i - 1]][tagIds[i]] = true;
            }
        }
        return legalTransitions;
    }

    static boolean[] trueArray(int m) {
        boolean[] result = new boolean[m];
        Arrays.fill(result, true);
        return result;
    }

    static boolean[][] trueArray(int m, int n) {
        boolean[][] result;
        for (boolean[] row : result = new boolean[m][n]) {
            Arrays.fill(row, true);
        }
        return result;
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    public static <F> ChainCrf<F> estimate(Corpus<ObjectHandler<Tagging<F>>> corpus, ChainCrfFeatureExtractor<F> featureExtractor, boolean addInterceptFeature, int minFeatureCount, boolean cacheFeatureVectors, boolean allowUnseenTransitions, RegressionPrior prior, int priorBlockSize, AnnealingSchedule annealingSchedule, double minImprovement, int minEpochs, int maxEpochs, Reporter reporter) throws IOException {
        FeatureVectors[] featureVectorsCache;
        if (reporter == null) {
            reporter = Reporters.silent();
        }
        reporter.info("ChainCrf.estimate Parameters");
        reporter.info("featureExtractor=" + featureExtractor);
        reporter.info("addInterceptFeature=" + addInterceptFeature);
        reporter.info("minFeatureCount=" + minFeatureCount);
        reporter.info("cacheFeatureVectors=" + cacheFeatureVectors);
        reporter.info("allowUnseenTransitions=" + allowUnseenTransitions);
        reporter.info("prior=" + prior);
        reporter.info("annealingSchedule=" + annealingSchedule);
        reporter.info("minImprovement=" + minImprovement);
        reporter.info("minEpochs=" + minEpochs);
        reporter.info("maxEpochs=" + maxEpochs);
        reporter.info("priorBlockSize=" + priorBlockSize);
        reporter.info("Computing corpus tokens and features");
        List<List<F>> tokenss = ChainCrf.corpusTokens(corpus);
        String[][] tagss = ChainCrf.corpusTags(corpus);
        int numTrainingInstances = tagss.length;
        int longestInput = ChainCrf.longestInput(tagss);
        long numTrainingTokens = 0L;
        for (String[] tags : tagss) {
            numTrainingTokens += (long)tags.length;
        }
        int[][] tagIdss = new int[tagss.length][];
        MapSymbolTable tagSymbolTable = ChainCrf.tagSymbolTable(tagss, tagIdss);
        MapSymbolTable featureSymbolTable = ChainCrf.featureSymbolTable(tagss, tokenss, addInterceptFeature, featureExtractor, minFeatureCount);
        int numTags = tagSymbolTable.numSymbols();
        String[] allTags = new String[numTags];
        for (int n = 0; n < numTags; ++n) {
            allTags[n] = tagSymbolTable.idToSymbol(n);
        }
        boolean[] legalTagStarts = allowUnseenTransitions ? ChainCrf.trueArray(numTags) : ChainCrf.legalStarts(tagIdss, numTags);
        boolean[] legalTagEnds = allowUnseenTransitions ? ChainCrf.trueArray(numTags) : ChainCrf.legalEnds(tagIdss, numTags);
        boolean[][] legalTagTransitions = allowUnseenTransitions ? ChainCrf.trueArray(numTags, numTags) : ChainCrf.legalTransitions(tagIdss, numTags);
        int numDimensions = featureSymbolTable.numSymbols();
        Vector[] weightVectors = new DenseVector[numTags];
        for (int i = 0; i < weightVectors.length; ++i) {
            weightVectors[i] = new DenseVector(numDimensions);
        }
        reporter.info("Corpus Statistics");
        reporter.info("Num Training Instances=" + numTrainingInstances);
        reporter.info("Num Training Tokens=" + numTrainingTokens);
        reporter.info("Num Dimensions After Pruning=" + numDimensions);
        reporter.info("Tags=" + tagSymbolTable);
        ChainCrf<F> crf = new ChainCrf<F>(allTags, legalTagStarts, legalTagEnds, legalTagTransitions, weightVectors, featureSymbolTable, featureExtractor, addInterceptFeature);
        FeatureVectors[] featureVectorsArray = featureVectorsCache = cacheFeatureVectors ? new FeatureVectors[numTrainingInstances] : null;
        if (cacheFeatureVectors) {
            reporter.info("Caching Feature Vectors");
            for (int j = 0; j < numTrainingInstances; ++j) {
                featureVectorsCache[j] = super.features(tokenss.get(j));
            }
        }
        double lastLog2LikelihoodAndPrior = -8.988465674311579E307;
        double rollingAverageRelativeDiff = 1.0;
        double bestLog2LikelihoodAndPrior = Double.NEGATIVE_INFINITY;
        long cumFeatureExtractionMs = 0L;
        long cumForwardBackwardMs = 0L;
        long cumUpdateMs = 0L;
        long cumLossMs = 0L;
        long cumPriorUpdateMs = 0L;
        for (int epoch = 0; epoch < maxEpochs; ++epoch) {
            int instancesSinceLastPriorUpdate = 0;
            double learningRate = annealingSchedule.learningRate(epoch);
            double learningRatePerTrainingInstance = learningRate / (double)numTrainingInstances;
            for (int j = 0; j < numTrainingInstances; ++j) {
                int nTo;
                int[] tagIds = tagIdss[j];
                List<F> tokens = tokenss.get(j);
                int numTokens = tokens.size();
                if (numTokens < 1) continue;
                long startMs = System.currentTimeMillis();
                FeatureVectors features = cacheFeatureVectors ? featureVectorsCache[j] : super.features(tokens);
                long featsMs = System.currentTimeMillis();
                cumFeatureExtractionMs += featsMs - startMs;
                TagLattice<F> lattice = crf.forwardBackward(tokens, features);
                long fwdBkMs = System.currentTimeMillis();
                cumForwardBackwardMs += fwdBkMs - featsMs;
                for (nTo = 0; nTo < numTokens; ++nTo) {
                    ((DenseVector)weightVectors[tagIds[nTo]]).increment(learningRate, features.mNodeFeatureVectors[nTo]);
                }
                for (nTo = 1; nTo < numTokens; ++nTo) {
                    ((DenseVector)weightVectors[tagIds[nTo]]).increment(learningRate, features.mEdgeFeatureVectorss[nTo - 1][tagIds[nTo - 1]]);
                }
                for (nTo = 0; nTo < numTokens; ++nTo) {
                    for (int kTo = 0; kTo < numTags; ++kTo) {
                        double logP = lattice.logProbability(nTo, kTo);
                        if (logP < -400.0) continue;
                        double p = Math.exp(logP);
                        ((DenseVector)weightVectors[kTo]).increment(-p * learningRate, features.mNodeFeatureVectors[nTo]);
                    }
                }
                for (nTo = 1; nTo < numTokens; ++nTo) {
                    for (int kFrom = 0; kFrom < numTags; ++kFrom) {
                        for (int kTo = 0; kTo < numTags; ++kTo) {
                            double logP = lattice.logProbability(nTo, kFrom, kTo);
                            if (logP < -400.0) continue;
                            double p = Math.exp(logP);
                            ((DenseVector)weightVectors[kTo]).increment(-p * learningRate, features.mEdgeFeatureVectorss[nTo - 1][kFrom]);
                        }
                    }
                }
                long updateMs = System.currentTimeMillis();
                cumUpdateMs += updateMs - fwdBkMs;
                if (++instancesSinceLastPriorUpdate == priorBlockSize) {
                    ChainCrf.adjustWeightsWithPrior((DenseVector[])weightVectors, prior, (double)instancesSinceLastPriorUpdate * learningRatePerTrainingInstance);
                    instancesSinceLastPriorUpdate = 0;
                }
                long priorMs = System.currentTimeMillis();
                cumPriorUpdateMs += priorMs - updateMs;
            }
            long finalPriorStartMs = System.currentTimeMillis();
            ChainCrf.adjustWeightsWithPrior((DenseVector[])weightVectors, prior, (double)instancesSinceLastPriorUpdate * learningRatePerTrainingInstance);
            long finalPriorEndMs = System.currentTimeMillis();
            cumPriorUpdateMs += finalPriorEndMs - finalPriorStartMs;
            long lossStartMs = System.currentTimeMillis();
            double log2Likelihood = 0.0;
            for (int j = 0; j < numTrainingInstances; ++j) {
                if (tokenss.get(j).size() < 1) continue;
                FeatureVectors features = cacheFeatureVectors ? featureVectorsCache[j] : super.features(tokenss.get(j));
                TagLattice<F> lattice = crf.forwardBackward(tokenss.get(j), features);
                log2Likelihood += lattice.logProbability(0, tagIdss[j]);
            }
            double log2Prior = prior == null ? 0.0 : prior.log2Prior(weightVectors);
            double log2LikelihoodAndPrior = log2Likelihood + log2Prior;
            double relativeDiff = com.aliasi.util.Math.relativeAbsoluteDifference(lastLog2LikelihoodAndPrior, log2LikelihoodAndPrior);
            rollingAverageRelativeDiff = (9.0 * rollingAverageRelativeDiff + relativeDiff) / 10.0;
            lastLog2LikelihoodAndPrior = log2LikelihoodAndPrior;
            if (log2LikelihoodAndPrior > bestLog2LikelihoodAndPrior) {
                bestLog2LikelihoodAndPrior = log2LikelihoodAndPrior;
            }
            long lossMs = System.currentTimeMillis();
            cumLossMs += lossMs - lossStartMs;
            if (reporter.isDebugEnabled()) {
                Formatter formatter = null;
                try {
                    formatter = new Formatter(Locale.ENGLISH);
                    formatter.format("epoch=%5d lr=%11.9f ll=%11.4f lp=%11.4f llp=%11.4f llp*=%11.4f", epoch, learningRate, log2Likelihood, log2Prior, log2LikelihoodAndPrior, bestLog2LikelihoodAndPrior);
                    reporter.debug(formatter.toString());
                }
                catch (IllegalFormatException e) {
                    reporter.warn("Illegal format in Logistic Regression");
                }
                finally {
                    if (formatter != null) {
                        formatter.close();
                    }
                }
            }
            if (!(rollingAverageRelativeDiff < minImprovement)) continue;
            reporter.info("Converged with rollingAverageRelativeDiff=" + rollingAverageRelativeDiff);
            break;
        }
        reporter.info("Feat Extraction Time=" + Strings.msToString(cumFeatureExtractionMs));
        reporter.info("Forward Backward Time=" + Strings.msToString(cumForwardBackwardMs));
        reporter.info("Update Time=" + Strings.msToString(cumUpdateMs));
        reporter.info("Prior Update Time=" + Strings.msToString(cumPriorUpdateMs));
        reporter.info("Loss Time=" + Strings.msToString(cumLossMs));
        return crf;
    }

    static void adjustWeightsWithPrior(DenseVector[] weightVectors, RegressionPrior prior, double learningRateDividedByNumTrainingInstances) {
        if (prior.isUniform()) {
            return;
        }
        for (DenseVector weightVectorsK : weightVectors) {
            for (int dim = 0; dim < weightVectorsK.numDimensions(); ++dim) {
                double weightVectorsKDim = weightVectorsK.value(dim);
                if (weightVectorsKDim == 0.0) continue;
                double priorGradient = prior.gradient(weightVectorsKDim, dim);
                double delta = priorGradient * learningRateDividedByNumTrainingInstances;
                double newVal = weightVectorsKDim > 0.0 ? Math.max(0.0, weightVectorsKDim - delta) : Math.min(0.0, weightVectorsKDim - delta);
                weightVectorsK.setValue(dim, newVal);
            }
        }
    }

    static MapSymbolTable tagSymbolTable(String[][] tagss, int[][] tagIdss) {
        MapSymbolTable tagSymbolTable = new MapSymbolTable();
        for (int j = 0; j < tagss.length; ++j) {
            tagIdss[j] = new int[tagss[j].length];
            for (int n = 0; n < tagIdss[j].length; ++n) {
                tagIdss[j][n] = tagSymbolTable.getOrAddSymbol(tagss[j][n]);
            }
        }
        return tagSymbolTable;
    }

    static <F> MapSymbolTable featureSymbolTable(String[][] tagss, List<List<F>> tokenss, boolean addInterceptFeature, ChainCrfFeatureExtractor<F> featureExtractor, int minFeatureCount) {
        ObjectToCounterMap<String> featureCounter = new ObjectToCounterMap<String>();
        for (int j = 0; j < tagss.length; ++j) {
            String[] tags = tagss[j];
            List<String> tagList = Arrays.asList(tags);
            List<F> tokens = tokenss.get(j);
            ChainCrfFeatures<F> features = featureExtractor.extract(tokens, tagList);
            for (int n = 0; n < tags.length; ++n) {
                for (String feature : features.nodeFeatures(n).keySet()) {
                    featureCounter.increment(feature);
                }
            }
            for (int k = 1; k < tags.length; ++k) {
                for (String feature : features.edgeFeatures(k, k - 1).keySet()) {
                    featureCounter.increment(feature);
                }
            }
        }
        featureCounter.prune(minFeatureCount);
        MapSymbolTable featureSymbolTable = new MapSymbolTable();
        if (addInterceptFeature) {
            featureSymbolTable.getOrAddSymbol(INTERCEPT_FEATURE_NAME);
        }
        for (String feature : featureCounter.keySet()) {
            featureSymbolTable.getOrAddSymbol(feature);
        }
        return featureSymbolTable;
    }

    static <F> List<List<F>> corpusTokens(Corpus<ObjectHandler<Tagging<F>>> corpus) throws IOException {
        final ArrayList<List<F>> corpusTokenList = new ArrayList<List<F>>();
        corpus.visitTrain(new ObjectHandler<Tagging<F>>(){

            @Override
            public void handle(Tagging<F> tagging) {
                corpusTokenList.add(tagging.tokens());
            }
        });
        return corpusTokenList;
    }

    static <F> String[][] corpusTags(Corpus<ObjectHandler<Tagging<F>>> corpus) throws IOException {
        final ArrayList corpusTagList = new ArrayList(1024);
        corpus.visitTrain(new ObjectHandler<Tagging<F>>(){

            @Override
            public void handle(Tagging<F> tagging) {
                corpusTagList.add(tagging.tags().toArray(Strings.EMPTY_STRING_ARRAY));
            }
        });
        return (String[][])corpusTagList.toArray((T[])Strings.EMPTY_STRING_2D_ARRAY);
    }

    static DenseVector[] copy(DenseVector[] xs) {
        DenseVector[] result = new DenseVector[xs.length];
        for (int k = 0; k < xs.length; ++k) {
            result[k] = new DenseVector(xs[k]);
        }
        return result;
    }

    static int longestInput(String[][] tagss) {
        int longest = 0;
        for (String[] tags : tagss) {
            if (tags.length <= longest) continue;
            longest = tags.length;
        }
        return longest;
    }

    static class FeatureVectors {
        final Vector[] mNodeFeatureVectors;
        final Vector[][] mEdgeFeatureVectorss;

        FeatureVectors(Vector[] nodeFeatureVectors, Vector[][] edgeFeatureVectorss) {
            this.mNodeFeatureVectors = nodeFeatureVectors;
            this.mEdgeFeatureVectorss = edgeFeatureVectorss;
        }
    }

    static class NBestState
    implements Scored {
        final double mScore;
        final ForwardPointer mForwardPointer;
        final int mN;
        final int mK;

        NBestState(double score, ForwardPointer forwardPointer, int n, int k) {
            this.mScore = score;
            this.mForwardPointer = forwardPointer;
            this.mN = n;
            this.mK = k;
        }

        public double score() {
            return this.mForwardPointer != null ? this.mScore + this.mForwardPointer.mScore : this.mScore;
        }
    }

    static class ForwardPointer {
        final int mK;
        final ForwardPointer mPointer;
        final double mScore;

        ForwardPointer(int k, ForwardPointer pointer, double score) {
            this.mK = k;
            this.mPointer = pointer;
            this.mScore = score;
        }
    }

    /*
     * This class specifies class file version 49.0 but uses Java 6 signatures.  Assumed Java 6.
     */
    static class Serializer<F>
    extends AbstractExternalizable {
        static final long serialVersionUID = -4140295941325870709L;
        final ChainCrf<F> mCrf;

        public Serializer(ChainCrf<F> crf) {
            this.mCrf = crf;
        }

        public Serializer() {
            this(null);
        }

        @Override
        public void writeExternal(ObjectOutput out) throws IOException {
            int i;
            int numTags = ((ChainCrf)this.mCrf).mTagList.size();
            out.writeInt(numTags);
            for (String tag : ((ChainCrf)this.mCrf).mTagList) {
                out.writeUTF(tag);
            }
            for (i = 0; i < numTags; ++i) {
                out.writeBoolean(((ChainCrf)this.mCrf).mLegalTagStarts[i]);
            }
            for (i = 0; i < numTags; ++i) {
                out.writeBoolean(((ChainCrf)this.mCrf).mLegalTagEnds[i]);
            }
            for (i = 0; i < numTags; ++i) {
                for (int j = 0; j < numTags; ++j) {
                    out.writeBoolean(((ChainCrf)this.mCrf).mLegalTagTransitions[i][j]);
                }
            }
            for (Vector v : ((ChainCrf)this.mCrf).mCoefficients) {
                out.writeObject(v);
            }
            out.writeObject(((ChainCrf)this.mCrf).mFeatureSymbolTable);
            out.writeObject(((ChainCrf)this.mCrf).mFeatureExtractor);
            out.writeBoolean(((ChainCrf)this.mCrf).mAddInterceptFeature);
        }

        @Override
        public Object read(ObjectInput in) throws ClassNotFoundException, IOException {
            int numTags = in.readInt();
            String[] tags = new String[numTags];
            for (int i = 0; i < tags.length; ++i) {
                tags[i] = in.readUTF();
            }
            boolean[] legalTagStarts = new boolean[numTags];
            for (int i = 0; i < numTags; ++i) {
                legalTagStarts[i] = in.readBoolean();
            }
            boolean[] legalTagEnds = new boolean[numTags];
            for (int i = 0; i < numTags; ++i) {
                legalTagEnds[i] = in.readBoolean();
            }
            boolean[][] legalTagTransitions = new boolean[numTags][numTags];
            for (int i = 0; i < numTags; ++i) {
                for (int j = 0; j < numTags; ++j) {
                    legalTagTransitions[i][j] = in.readBoolean();
                }
            }
            Vector[] coefficients = new Vector[numTags];
            for (int i = 0; i < tags.length; ++i) {
                coefficients[i] = (Vector)in.readObject();
            }
            SymbolTable featureSymbolTable = (SymbolTable)in.readObject();
            ChainCrfFeatureExtractor featureExtractor = (ChainCrfFeatureExtractor)in.readObject();
            boolean addInterceptFeature = in.readBoolean();
            return new ChainCrf(tags, legalTagStarts, legalTagEnds, legalTagTransitions, coefficients, featureSymbolTable, featureExtractor, addInterceptFeature);
        }
    }

    /*
     * This class specifies class file version 49.0 but uses Java 6 signatures.  Assumed Java 6.
     */
    class NBestIterator
    extends Iterators.Buffered<ScoredTagging<E>> {
        final List<E> mTokens;
        final double mLogZ;
        final double[][][] mTransitionScores;
        final double[][] mViterbiScores;
        final int[][] mBackPointers;
        final BoundedPriorityQueue<NBestState> mPriorityQueue;

        NBestIterator(List<E> tokens, boolean normToConditional, int maxResults) {
            int k;
            this.mPriorityQueue = new BoundedPriorityQueue(ScoredObject.comparator(), maxResults);
            this.mTokens = tokens;
            int numTokens = tokens.size();
            int numTags = ChainCrf.this.mTagList.size();
            double[][][] dArray = this.mTransitionScores = new double[numTokens - 1][numTags][numTags];
            int len$ = dArray.length;
            for (int i$ = 0; i$ < len$; ++i$) {
                double[][] xss;
                for (double[] xs : xss = dArray[i$]) {
                    Arrays.fill(xs, Double.NEGATIVE_INFINITY);
                }
            }
            for (double[] xs : this.mViterbiScores = new double[numTokens][numTags]) {
                Arrays.fill(xs, Double.NEGATIVE_INFINITY);
            }
            for (int[] ptrs : this.mBackPointers = new int[numTokens - 1][numTags]) {
                Arrays.fill(ptrs, -1);
            }
            Vector[] vectorArray = new Vector[numTags];
            ChainCrfFeatures features = ChainCrf.this.mFeatureExtractor.extract(tokens, ChainCrf.this.mTagList);
            for (int n = 1; n < numTokens; ++n) {
                Vector nodeVector = ChainCrf.this.nodeFeatures(n, features);
                for (int kMinus1 = 0; kMinus1 < numTags; ++kMinus1) {
                    if (n == 1 && !ChainCrf.this.mLegalTagStarts[kMinus1]) continue;
                    vectorArray[kMinus1] = ChainCrf.this.edgeFeatures(n, kMinus1, features);
                }
                for (int k2 = 0; k2 < numTags; ++k2) {
                    if (n == numTokens - 1 && !ChainCrf.this.mLegalTagEnds[k2]) continue;
                    double nodeScore = nodeVector.dotProduct(ChainCrf.this.mCoefficients[k2]);
                    for (int kMinus1 = 0; kMinus1 < numTags; ++kMinus1) {
                        if (!ChainCrf.this.mLegalTagTransitions[kMinus1][k2] || n == 1 && !ChainCrf.this.mLegalTagStarts[kMinus1]) continue;
                        this.mTransitionScores[n - 1][kMinus1][k2] = nodeScore + vectorArray[kMinus1].dotProduct(ChainCrf.this.mCoefficients[k2]);
                    }
                }
            }
            Vector nodeVector0 = ChainCrf.this.nodeFeatures(0, features);
            for (k = 0; k < numTags; ++k) {
                if (!ChainCrf.this.mLegalTagStarts[k]) continue;
                this.mViterbiScores[0][k] = nodeVector0.dotProduct(ChainCrf.this.mCoefficients[k]);
            }
            for (int n = 1; n < numTokens; ++n) {
                for (int k3 = 0; k3 < numTags; ++k3) {
                    if (n == numTokens - 1 && !ChainCrf.this.mLegalTagEnds[k3]) continue;
                    double bestScore = Double.NEGATIVE_INFINITY;
                    int backPtr = -1;
                    for (int kMinus1 = 0; kMinus1 < numTags; ++kMinus1) {
                        double score;
                        if (!ChainCrf.this.mLegalTagTransitions[kMinus1][k3] || !((score = this.mViterbiScores[n - 1][kMinus1] + this.mTransitionScores[n - 1][kMinus1][k3]) > bestScore)) continue;
                        bestScore = score;
                        backPtr = kMinus1;
                    }
                    this.mViterbiScores[n][k3] = bestScore;
                    this.mBackPointers[n - 1][k3] = backPtr;
                }
            }
            this.mLogZ = normToConditional ? this.logZ() : 0.0;
            for (k = 0; k < numTags; ++k) {
                this.offer(this.mViterbiScores[numTokens - 1][k], null, numTokens - 1, k);
            }
        }

        double logZ() {
            double[] forwards = (double[])this.mViterbiScores[0].clone();
            int numTags = forwards.length;
            double[] previousForwards = new double[numTags];
            double[] exps = new double[numTags];
            for (int n = 0; n < this.mTransitionScores.length; ++n) {
                double[] temp = previousForwards;
                previousForwards = forwards;
                forwards = temp;
                for (int k = 0; k < numTags; ++k) {
                    for (int kMinus1 = 0; kMinus1 < numTags; ++kMinus1) {
                        exps[kMinus1] = previousForwards[kMinus1] + this.mTransitionScores[n][kMinus1][k];
                    }
                    forwards[k] = com.aliasi.util.Math.logSumOfExponentials(exps);
                }
            }
            double logZ = com.aliasi.util.Math.logSumOfExponentials(forwards);
            return logZ;
        }

        void offer(double score, ForwardPointer pointer, int n, int k) {
            if (score == Double.NEGATIVE_INFINITY) {
                return;
            }
            if (pointer != null && pointer.mScore == Double.NEGATIVE_INFINITY) {
                return;
            }
            NBestState state = new NBestState(score, pointer, n, k);
            this.mPriorityQueue.offer(state);
        }

        @Override
        public ScoredTagging<E> bufferNext() {
            NBestState resultState = this.mPriorityQueue.poll();
            if (resultState == null) {
                return null;
            }
            int k = resultState.mK;
            ForwardPointer fwdPointer = resultState.mForwardPointer;
            for (int n = resultState.mN - 1; n >= 0; --n) {
                this.addAlternatives(n, k, fwdPointer);
                int kMinus1 = this.mBackPointers[n][k];
                double fwdScore = this.mTransitionScores[n][kMinus1][k];
                if (fwdPointer != null) {
                    fwdScore += fwdPointer.mScore;
                }
                fwdPointer = new ForwardPointer(k, fwdPointer, fwdScore);
                k = kMinus1;
            }
            ScoredTagging scoredTagging = this.toScoredTagging(resultState);
            return scoredTagging;
        }

        void addAlternatives(int n, int k, ForwardPointer fwdPointer) {
            int numTags = ChainCrf.this.mTagList.size();
            for (int kMinus1 = 0; kMinus1 < numTags; ++kMinus1) {
                if (kMinus1 == this.mBackPointers[n][k]) continue;
                double score = this.mViterbiScores[n][kMinus1];
                double fwdScore = this.mTransitionScores[n][kMinus1][k];
                if (fwdPointer != null) {
                    fwdScore += fwdPointer.mScore;
                }
                ForwardPointer pointer = new ForwardPointer(k, fwdPointer, fwdScore);
                this.offer(score, pointer, n, kMinus1);
            }
        }

        public ScoredTagging<E> toScoredTagging(NBestState state) {
            ArrayList<String> tags = new ArrayList<String>(this.mTokens.size());
            int k = state.mK;
            tags.add((String)ChainCrf.this.mTagList.get(k));
            for (int n = state.mN; n > 0; --n) {
                k = this.mBackPointers[n - 1][k];
                tags.add((String)ChainCrf.this.mTagList.get(k));
            }
            Collections.reverse(tags);
            ForwardPointer pointer = state.mForwardPointer;
            while (pointer != null) {
                tags.add((String)ChainCrf.this.mTagList.get(pointer.mK));
                pointer = pointer.mPointer;
            }
            return new ScoredTagging(this.mTokens, tags, state.score() - this.mLogZ);
        }
    }
}

