/*
 * Decompiled with CFR 0.152.
 */
package edu.stanford.nlp.parser.shiftreduce;

import edu.stanford.nlp.io.IOUtils;
import edu.stanford.nlp.io.RuntimeIOException;
import edu.stanford.nlp.ling.CoreLabel;
import edu.stanford.nlp.ling.HasTag;
import edu.stanford.nlp.ling.HasWord;
import edu.stanford.nlp.ling.Label;
import edu.stanford.nlp.ling.TaggedWord;
import edu.stanford.nlp.ling.Word;
import edu.stanford.nlp.parser.common.ArgUtils;
import edu.stanford.nlp.parser.common.ParserConstraint;
import edu.stanford.nlp.parser.common.ParserGrammar;
import edu.stanford.nlp.parser.common.ParserQuery;
import edu.stanford.nlp.parser.common.ParserUtils;
import edu.stanford.nlp.parser.lexparser.BinaryHeadFinder;
import edu.stanford.nlp.parser.lexparser.EvaluateTreebank;
import edu.stanford.nlp.parser.lexparser.Options;
import edu.stanford.nlp.parser.lexparser.TreeBinarizer;
import edu.stanford.nlp.parser.lexparser.TreebankLangParserParams;
import edu.stanford.nlp.parser.metrics.Eval;
import edu.stanford.nlp.parser.metrics.ParserQueryEval;
import edu.stanford.nlp.parser.shiftreduce.BinaryTransition;
import edu.stanford.nlp.parser.shiftreduce.CombinationFeatureFactory;
import edu.stanford.nlp.parser.shiftreduce.CompoundUnaryTransition;
import edu.stanford.nlp.parser.shiftreduce.CreateTransitionSequence;
import edu.stanford.nlp.parser.shiftreduce.FeatureFactory;
import edu.stanford.nlp.parser.shiftreduce.Oracle;
import edu.stanford.nlp.parser.shiftreduce.OracleTransition;
import edu.stanford.nlp.parser.shiftreduce.ReorderingOracle;
import edu.stanford.nlp.parser.shiftreduce.ShiftReduceOptions;
import edu.stanford.nlp.parser.shiftreduce.ShiftReduceParserQuery;
import edu.stanford.nlp.parser.shiftreduce.ShiftReduceTrainOptions;
import edu.stanford.nlp.parser.shiftreduce.ShiftReduceUtils;
import edu.stanford.nlp.parser.shiftreduce.State;
import edu.stanford.nlp.parser.shiftreduce.Transition;
import edu.stanford.nlp.parser.shiftreduce.TreeRecorder;
import edu.stanford.nlp.parser.shiftreduce.UnaryTransition;
import edu.stanford.nlp.parser.shiftreduce.Weight;
import edu.stanford.nlp.stats.IntCounter;
import edu.stanford.nlp.tagger.common.Tagger;
import edu.stanford.nlp.trees.BasicCategoryTreeTransformer;
import edu.stanford.nlp.trees.CompositeTreeTransformer;
import edu.stanford.nlp.trees.LabeledScoredTreeNode;
import edu.stanford.nlp.trees.MemoryTreebank;
import edu.stanford.nlp.trees.Tree;
import edu.stanford.nlp.trees.TreeCoreAnnotations;
import edu.stanford.nlp.trees.Treebank;
import edu.stanford.nlp.trees.TreebankLanguagePack;
import edu.stanford.nlp.trees.Trees;
import edu.stanford.nlp.util.ArrayUtils;
import edu.stanford.nlp.util.CollectionUtils;
import edu.stanford.nlp.util.Function;
import edu.stanford.nlp.util.Generics;
import edu.stanford.nlp.util.HashIndex;
import edu.stanford.nlp.util.Index;
import edu.stanford.nlp.util.Pair;
import edu.stanford.nlp.util.ReflectionLoading;
import edu.stanford.nlp.util.Scored;
import edu.stanford.nlp.util.ScoredComparator;
import edu.stanford.nlp.util.ScoredObject;
import edu.stanford.nlp.util.StringUtils;
import edu.stanford.nlp.util.Timing;
import edu.stanford.nlp.util.Triple;
import edu.stanford.nlp.util.concurrent.MulticoreWrapper;
import edu.stanford.nlp.util.concurrent.ThreadsafeProcessor;
import java.io.FileFilter;
import java.io.IOException;
import java.io.Serializable;
import java.text.DecimalFormat;
import java.text.NumberFormat;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.PriorityQueue;
import java.util.Random;
import java.util.Set;

public class ShiftReduceParser
extends ParserGrammar
implements Serializable {
    Index<Transition> transitionIndex = new HashIndex<Transition>();
    Map<String, Weight> featureWeights = Generics.newHashMap();
    ShiftReduceOptions op;
    FeatureFactory featureFactory;
    Set<String> knownStates;
    Set<String> rootStates;
    Set<String> rootOnlyStates;
    private static final String[] BEAM_FLAGS = new String[]{"-beamSize", "4"};
    private static final NumberFormat NF = new DecimalFormat("0.00");
    private static final NumberFormat FILENAME = new DecimalFormat("0000");
    static final String[] FORCE_TAGS = new String[]{"-forceTags"};
    private static final long serialVersionUID = 1L;

    public ShiftReduceParser(ShiftReduceOptions op) {
        this.op = op;
        String[] classes = op.featureFactoryClass.split(";");
        if (classes.length == 1) {
            this.featureFactory = (FeatureFactory)ReflectionLoading.loadByReflection(classes[0], new Object[0]);
        } else {
            FeatureFactory[] factories = new FeatureFactory[classes.length];
            for (int i = 0; i < classes.length; ++i) {
                int paren = classes[i].indexOf("(");
                if (paren >= 0) {
                    String arg = classes[i].substring(paren + 1, classes[i].length() - 1);
                    factories[i] = (FeatureFactory)ReflectionLoading.loadByReflection(classes[i].substring(0, paren), arg);
                    continue;
                }
                factories[i] = (FeatureFactory)ReflectionLoading.loadByReflection(classes[i], new Object[0]);
            }
            this.featureFactory = new CombinationFeatureFactory(factories);
        }
    }

    private ShiftReduceParser(ShiftReduceOptions op, FeatureFactory factory) {
        this.op = op;
        this.featureFactory = factory;
    }

    @Override
    public Options getOp() {
        return this.op;
    }

    @Override
    public TreebankLangParserParams getTLPParams() {
        return this.op.tlpParams;
    }

    @Override
    public TreebankLanguagePack treebankLanguagePack() {
        return this.getTLPParams().treebankLanguagePack();
    }

    @Override
    public String[] defaultCoreNLPFlags() {
        if (this.op.trainOptions().beamSize > 1) {
            return ArrayUtils.concatenate(this.getTLPParams().defaultCoreNLPFlags(), BEAM_FLAGS);
        }
        return this.getTLPParams().defaultCoreNLPFlags();
    }

    @Override
    public boolean requiresTags() {
        return true;
    }

    private ShiftReduceParser deepCopy() {
        ShiftReduceParser copy = new ShiftReduceParser(this.op, this.featureFactory);
        copy.copyWeights(this);
        return copy;
    }

    public void copyWeights(ShiftReduceParser other) {
        this.transitionIndex.clear();
        for (Transition transition : other.transitionIndex) {
            this.transitionIndex.add(transition);
        }
        this.knownStates = Collections.unmodifiableSet(Generics.newHashSet(other.knownStates));
        this.rootStates = Collections.unmodifiableSet(Generics.newHashSet(other.rootStates));
        this.rootOnlyStates = Collections.unmodifiableSet(Generics.newHashSet(other.rootOnlyStates));
        this.featureWeights.clear();
        for (String feature : other.featureWeights.keySet()) {
            this.featureWeights.put(feature, new Weight(other.featureWeights.get(feature)));
        }
    }

    public static ShiftReduceParser averageScoredModels(Collection<ScoredObject<ShiftReduceParser>> scoredModels) {
        if (scoredModels.size() == 0) {
            throw new IllegalArgumentException("Cannot average empty models");
        }
        System.err.print("Averaging " + scoredModels.size() + " models with scores");
        for (ScoredObject<ShiftReduceParser> model : scoredModels) {
            System.err.print(" " + NF.format(model.score()));
        }
        System.err.println();
        List<ShiftReduceParser> models = CollectionUtils.transformAsList(scoredModels, new Function<ScoredObject<ShiftReduceParser>, ShiftReduceParser>(){

            @Override
            public ShiftReduceParser apply(ScoredObject<ShiftReduceParser> object) {
                return object.object();
            }
        });
        return ShiftReduceParser.averageModels(models);
    }

    public static ShiftReduceParser averageModels(Collection<ShiftReduceParser> models) {
        ShiftReduceParser firstModel = models.iterator().next();
        ShiftReduceParser copy = new ShiftReduceParser(firstModel.op, firstModel.featureFactory);
        for (Transition transition : firstModel.transitionIndex) {
            copy.transitionIndex.add(transition);
        }
        copy.knownStates = Collections.unmodifiableSet(Generics.newHashSet(firstModel.knownStates));
        copy.rootStates = Collections.unmodifiableSet(Generics.newHashSet(firstModel.rootStates));
        copy.rootOnlyStates = Collections.unmodifiableSet(Generics.newHashSet(firstModel.rootOnlyStates));
        for (ShiftReduceParser model : models) {
            if (model.transitionIndex.equals(copy.transitionIndex)) continue;
            throw new IllegalArgumentException("Can only average models with the same transition index");
        }
        Set<String> features = Generics.newHashSet();
        for (ShiftReduceParser model : models) {
            for (String feature : model.featureWeights.keySet()) {
                features.add(feature);
            }
        }
        for (String feature : features) {
            copy.featureWeights.put(feature, new Weight());
        }
        int numModels = models.size();
        for (String feature : features) {
            for (ShiftReduceParser model : models) {
                if (!model.featureWeights.containsKey(feature)) continue;
                copy.featureWeights.get(feature).addScaled(model.featureWeights.get(feature), 1.0f / (float)numModels);
            }
        }
        return copy;
    }

    @Override
    public ParserQuery parserQuery() {
        return new ShiftReduceParserQuery(this);
    }

    @Override
    public Tree apply(List<? extends HasWord> sentence) {
        ShiftReduceParserQuery pq = new ShiftReduceParserQuery(this);
        if (pq.parse(sentence)) {
            return pq.getBestParse();
        }
        return ParserUtils.xTree(sentence);
    }

    public void condenseFeatures() {
        Iterator<String> featureIt = this.featureWeights.keySet().iterator();
        while (featureIt.hasNext()) {
            String feature = featureIt.next();
            Weight weights = this.featureWeights.get(feature);
            weights.condense();
            if (weights.size() != 0) continue;
            featureIt.remove();
        }
    }

    public void filterFeatures(Set<String> keep) {
        Iterator<String> featureIt = this.featureWeights.keySet().iterator();
        while (featureIt.hasNext()) {
            if (keep.contains(featureIt.next())) continue;
            featureIt.remove();
        }
    }

    public void outputStats() {
        System.err.println("Number of known features: " + this.featureWeights.size());
        int numWeights = 0;
        for (String feature : this.featureWeights.keySet()) {
            numWeights += this.featureWeights.get(feature).size();
        }
        System.err.println("Number of non-zero weights: " + numWeights);
        int wordLength = 0;
        for (String feature : this.featureWeights.keySet()) {
            wordLength += feature.length();
        }
        System.err.println("Total word length: " + wordLength);
        System.err.println("Number of transitions: " + this.transitionIndex.size());
    }

    @Override
    public List<Eval> getExtraEvals() {
        return Collections.emptyList();
    }

    @Override
    public List<ParserQueryEval> getParserQueryEvals() {
        if (this.op.testOptions().recordBinarized == null && this.op.testOptions().recordDebinarized == null) {
            return Collections.emptyList();
        }
        ArrayList<ParserQueryEval> evals = Generics.newArrayList();
        if (this.op.testOptions().recordBinarized != null) {
            evals.add(new TreeRecorder(TreeRecorder.Mode.BINARIZED, this.op.testOptions().recordBinarized));
        }
        if (this.op.testOptions().recordDebinarized != null) {
            evals.add(new TreeRecorder(TreeRecorder.Mode.DEBINARIZED, this.op.testOptions().recordDebinarized));
        }
        return evals;
    }

    public Transition findEmergencyTransition(State state, List<ParserConstraint> constraints) {
        if (state.stack.size() == 0) {
            return null;
        }
        if (constraints != null) {
            Tree top = state.stack.peek();
            for (ParserConstraint constraint : constraints) {
                if (ShiftReduceUtils.leftIndex(top) != constraint.start || ShiftReduceUtils.rightIndex(top) != constraint.end - 1 || ShiftReduceUtils.constraintMatchesTreeTop(top, constraint)) continue;
                for (String label : this.knownStates) {
                    if (!constraint.state.matcher(label).matches()) continue;
                    return this.op.compoundUnaries ? new CompoundUnaryTransition(Collections.singletonList(label), false) : new UnaryTransition(label, false);
                }
            }
        }
        if (ShiftReduceUtils.isTemporary(state.stack.peek()) && (state.stack.size() == 1 || ShiftReduceUtils.isTemporary(state.stack.pop().peek()))) {
            return this.op.compoundUnaries ? new CompoundUnaryTransition(Collections.singletonList(state.stack.peek().value().substring(1)), false) : new UnaryTransition(state.stack.peek().value().substring(1), false);
        }
        if (state.stack.size() == 1 && state.tokenPosition >= state.sentence.size() && !this.rootStates.contains(state.stack.peek().value())) {
            String root = this.rootStates.iterator().next();
            return this.op.compoundUnaries ? new CompoundUnaryTransition(Collections.singletonList(root), false) : new UnaryTransition(root, false);
        }
        if (state.stack.size() == 1) {
            return null;
        }
        if (ShiftReduceUtils.isTemporary(state.stack.peek())) {
            return new BinaryTransition(state.stack.peek().value().substring(1), BinaryTransition.Side.RIGHT);
        }
        if (ShiftReduceUtils.isTemporary(state.stack.pop().peek())) {
            return new BinaryTransition(state.stack.pop().peek().value().substring(1), BinaryTransition.Side.LEFT);
        }
        return null;
    }

    public ScoredObject<Integer> findHighestScoringTransition(State state, List<String> features, boolean requireLegal) {
        Collection<ScoredObject<Integer>> transitions = this.findHighestScoringTransitions(state, features, requireLegal, 1, null);
        if (transitions.size() == 0) {
            return null;
        }
        return transitions.iterator().next();
    }

    public Collection<ScoredObject<Integer>> findHighestScoringTransitions(State state, List<String> features, boolean requireLegal, int numTransitions, List<ParserConstraint> constraints) {
        float[] scores = new float[this.transitionIndex.size()];
        for (String feature : features) {
            Weight weight = this.featureWeights.get(feature);
            if (weight == null) continue;
            weight.score(scores);
        }
        PriorityQueue<Scored> queue = new PriorityQueue<Scored>(numTransitions + 1, ScoredComparator.ASCENDING_COMPARATOR);
        for (int i = 0; i < scores.length; ++i) {
            if (requireLegal && !this.transitionIndex.get(i).isLegal(state, constraints)) continue;
            queue.add(new ScoredObject<Integer>(i, scores[i]));
            if (queue.size() <= numTransitions) continue;
            queue.poll();
        }
        return queue;
    }

    public static State initialStateFromGoldTagTree(Tree tree) {
        return ShiftReduceParser.initialStateFromTaggedSentence(tree.taggedYield());
    }

    public static State initialStateFromTaggedSentence(List<? extends HasWord> words) {
        ArrayList<Tree> preterminals = Generics.newArrayList();
        for (int index = 0; index < words.size(); ++index) {
            String tag;
            CoreLabel wordLabel;
            HasWord hw = words.get(index);
            if (hw instanceof CoreLabel) {
                wordLabel = (CoreLabel)hw;
                tag = wordLabel.tag();
            } else {
                wordLabel = new CoreLabel();
                wordLabel.setValue(hw.word());
                wordLabel.setWord(hw.word());
                if (!(hw instanceof HasTag)) {
                    throw new IllegalArgumentException("Expected tagged words");
                }
                tag = ((HasTag)((Object)hw)).tag();
                wordLabel.setTag(tag);
            }
            if (tag == null) {
                throw new IllegalArgumentException("Input word not tagged");
            }
            CoreLabel tagLabel = new CoreLabel();
            tagLabel.setValue(tag);
            wordLabel.setIndex(index + 1);
            tagLabel.setIndex(index + 1);
            LabeledScoredTreeNode wordNode = new LabeledScoredTreeNode(wordLabel);
            LabeledScoredTreeNode tagNode = new LabeledScoredTreeNode(tagLabel);
            tagNode.addChild(wordNode);
            wordLabel.set(TreeCoreAnnotations.HeadWordAnnotation.class, wordNode);
            wordLabel.set(TreeCoreAnnotations.HeadTagAnnotation.class, tagNode);
            tagLabel.set(TreeCoreAnnotations.HeadWordAnnotation.class, wordNode);
            tagLabel.set(TreeCoreAnnotations.HeadTagAnnotation.class, tagNode);
            preterminals.add(tagNode);
        }
        return new State(preterminals);
    }

    public static ShiftReduceOptions buildTrainingOptions(String tlppClass, String[] args) {
        ShiftReduceOptions op = new ShiftReduceOptions();
        op.setOptions("-forceTags", "-debugOutputFrequency", "1", "-quietEvaluation");
        if (tlppClass != null) {
            op.tlpParams = (TreebankLangParserParams)ReflectionLoading.loadByReflection(tlppClass, new Object[0]);
        }
        op.setOptions(args);
        if (op.trainOptions.randomSeed == 0L) {
            op.trainOptions.randomSeed = new Random().nextLong();
            System.err.println("Random seed not set by options, using " + op.trainOptions.randomSeed);
        }
        return op;
    }

    public Treebank readTreebank(String treebankPath, FileFilter treebankFilter) {
        System.err.println("Loading trees from " + treebankPath);
        MemoryTreebank treebank = this.op.tlpParams.memoryTreebank();
        treebank.loadPath(treebankPath, treebankFilter);
        System.err.println("Read in " + ((Treebank)treebank).size() + " trees from " + treebankPath);
        return treebank;
    }

    public List<Tree> readBinarizedTreebank(String treebankPath, FileFilter treebankFilter) {
        Treebank treebank = this.readTreebank(treebankPath, treebankFilter);
        List<Tree> binarized = ShiftReduceParser.binarizeTreebank(treebank, this.op);
        System.err.println("Converted trees to binarized format");
        return binarized;
    }

    public static List<Tree> binarizeTreebank(Treebank treebank, Options op) {
        TreeBinarizer binarizer = new TreeBinarizer(op.tlpParams.headFinder(), op.tlpParams.treebankLanguagePack(), false, false, 0, false, false, 0.0, false, true, true);
        BasicCategoryTreeTransformer basicTransformer = new BasicCategoryTreeTransformer(op.langpack());
        CompositeTreeTransformer transformer = new CompositeTreeTransformer();
        transformer.addTransformer(binarizer);
        transformer.addTransformer(basicTransformer);
        treebank = treebank.transform(transformer);
        BinaryHeadFinder binaryHeadFinder = new BinaryHeadFinder(op.tlpParams.headFinder());
        ArrayList<Tree> binarizedTrees = Generics.newArrayList();
        for (Tree tree : treebank) {
            Trees.convertToCoreLabels(tree);
            tree.percolateHeadAnnotations(binaryHeadFinder);
            tree.indexLeaves(1, true);
            binarizedTrees.add(tree);
        }
        return binarizedTrees;
    }

    public static Set<String> findKnownStates(List<Tree> binarizedTrees) {
        Set<String> knownStates = Generics.newHashSet();
        for (Tree tree : binarizedTrees) {
            ShiftReduceParser.findKnownStates(tree, knownStates);
        }
        return Collections.unmodifiableSet(knownStates);
    }

    public static void findKnownStates(Tree tree, Set<String> knownStates) {
        if (tree.isLeaf() || tree.isPreTerminal()) {
            return;
        }
        if (!ShiftReduceUtils.isTemporary(tree)) {
            knownStates.add(tree.value());
        }
        for (Tree child : tree.children()) {
            ShiftReduceParser.findKnownStates(child, knownStates);
        }
    }

    public static void redoTags(Tree tree, Tagger tagger) {
        ArrayList<Word> words = tree.yieldWords();
        List<TaggedWord> tagged = tagger.apply((List<? extends HasWord>)words);
        List<Label> tags = tree.preTerminalYield();
        if (tags.size() != tagged.size()) {
            throw new AssertionError((Object)"Tags are not the same size");
        }
        for (int i = 0; i < tags.size(); ++i) {
            tags.get(i).setValue(tagged.get(i).tag());
        }
    }

    public static void redoTags(List<Tree> trees, Tagger tagger, int nThreads) {
        if (nThreads == 1) {
            for (Tree tree : trees) {
                ShiftReduceParser.redoTags(tree, tagger);
            }
        } else {
            MulticoreWrapper<Tree, Tree> wrapper = new MulticoreWrapper<Tree, Tree>(nThreads, new RetagProcessor(tagger));
            for (Tree tree : trees) {
                wrapper.put(tree);
            }
            wrapper.join();
        }
    }

    private static boolean findStateOnAgenda(Collection<State> agenda, State state) {
        for (State other : agenda) {
            if (!other.areTransitionsEqual(state)) continue;
            return true;
        }
        return false;
    }

    private Pair<Integer, Integer> trainTree(int index, List<Tree> binarizedTrees, List<List<Transition>> transitionLists, List<Update> updates, Oracle oracle) {
        int numWrong;
        int numCorrect;
        block28: {
            ReorderingOracle reorderer;
            Tree tree;
            block30: {
                block29: {
                    numCorrect = 0;
                    numWrong = 0;
                    tree = binarizedTrees.get(index);
                    reorderer = null;
                    if (this.op.trainOptions().trainingMethod == ShiftReduceTrainOptions.TrainingMethod.REORDER_ORACLE || this.op.trainOptions().trainingMethod == ShiftReduceTrainOptions.TrainingMethod.REORDER_BEAM) {
                        reorderer = new ReorderingOracle(this.op);
                    }
                    if (this.op.trainOptions().trainingMethod != ShiftReduceTrainOptions.TrainingMethod.ORACLE) break block29;
                    State state = ShiftReduceParser.initialStateFromGoldTagTree(tree);
                    while (!state.isFinished()) {
                        int transitionNum;
                        List<String> features = this.featureFactory.featurize(state);
                        ScoredObject<Integer> prediction = this.findHighestScoringTransition(state, features, true);
                        if (prediction == null) {
                            throw new AssertionError((Object)"Did not find a legal transition");
                        }
                        int predictedNum = prediction.object();
                        Transition predicted = this.transitionIndex.get(predictedNum);
                        OracleTransition gold = oracle.goldTransition(index, state);
                        if (gold.isCorrect(predicted)) {
                            ++numCorrect;
                            if (gold.transition != null && !gold.transition.equals(predicted)) {
                                transitionNum = this.transitionIndex.indexOf(gold.transition);
                                if (transitionNum < 0) continue;
                                updates.add(new Update(features, transitionNum, -1, 1.0f));
                            }
                        } else {
                            ++numWrong;
                            transitionNum = -1;
                            if (gold.transition != null) {
                                transitionNum = this.transitionIndex.indexOf(gold.transition);
                            }
                            updates.add(new Update(features, transitionNum, predictedNum, 1.0f));
                        }
                        state = predicted.apply(state);
                    }
                    break block28;
                }
                if (this.op.trainOptions().trainingMethod != ShiftReduceTrainOptions.TrainingMethod.BEAM && this.op.trainOptions().trainingMethod != ShiftReduceTrainOptions.TrainingMethod.REORDER_BEAM) break block30;
                if (this.op.trainOptions().beamSize <= 0) {
                    throw new IllegalArgumentException("Illegal beam size " + this.op.trainOptions().beamSize);
                }
                LinkedList<Transition> transitions = Generics.newLinkedList((Collection)transitionLists.get(index));
                PriorityQueue<Scored> agenda = new PriorityQueue<Scored>(this.op.trainOptions().beamSize + 1, ScoredComparator.ASCENDING_COMPARATOR);
                State goldState = ShiftReduceParser.initialStateFromGoldTagTree(tree);
                agenda.add(goldState);
                boolean transitionCount = false;
                while (transitions.size() > 0) {
                    Transition goldTransition = (Transition)transitions.get(0);
                    Transition highestScoringTransitionFromGoldState = null;
                    double highestScoreFromGoldState = 0.0;
                    PriorityQueue<Scored> newAgenda = new PriorityQueue<Scored>(this.op.trainOptions().beamSize + 1, ScoredComparator.ASCENDING_COMPARATOR);
                    State highestScoringState = null;
                    State highestCurrentState = null;
                    for (State state : agenda) {
                        boolean isGoldState = this.op.trainOptions().trainingMethod == ShiftReduceTrainOptions.TrainingMethod.REORDER_BEAM && goldState.areTransitionsEqual(state);
                        List<String> features = this.featureFactory.featurize(state);
                        Collection<ScoredObject<Integer>> stateTransitions = this.findHighestScoringTransitions(state, features, true, this.op.trainOptions().beamSize, null);
                        for (ScoredObject<Integer> transition : stateTransitions) {
                            State newState = this.transitionIndex.get(transition.object()).apply(state, transition.score());
                            newAgenda.add(newState);
                            if (newAgenda.size() > this.op.trainOptions().beamSize) {
                                newAgenda.poll();
                            }
                            if (highestScoringState == null || highestScoringState.score() < newState.score()) {
                                highestScoringState = newState;
                                highestCurrentState = state;
                            }
                            if (!isGoldState || highestScoringTransitionFromGoldState != null && !(transition.score() > highestScoreFromGoldState)) continue;
                            highestScoringTransitionFromGoldState = this.transitionIndex.get(transition.object());
                            highestScoreFromGoldState = transition.score();
                        }
                    }
                    if (this.op.trainOptions().trainingMethod == ShiftReduceTrainOptions.TrainingMethod.REORDER_BEAM && highestScoringTransitionFromGoldState == null) break block28;
                    State newGoldState = goldTransition.apply(goldState, 0.0);
                    if (!newGoldState.areTransitionsEqual(highestScoringState)) {
                        ++numWrong;
                        List<String> list = this.featureFactory.featurize(goldState);
                        int lastTransition = this.transitionIndex.indexOf(highestScoringState.transitions.peek());
                        updates.add(new Update(this.featureFactory.featurize(highestCurrentState), -1, lastTransition, 1.0f));
                        updates.add(new Update(list, this.transitionIndex.indexOf(goldTransition), -1, 1.0f));
                        if (this.op.trainOptions().trainingMethod == ShiftReduceTrainOptions.TrainingMethod.BEAM) {
                            if (!ShiftReduceParser.findStateOnAgenda(newAgenda, newGoldState)) break block28;
                            transitions.remove(0);
                        } else if (this.op.trainOptions().trainingMethod == ShiftReduceTrainOptions.TrainingMethod.REORDER_BEAM) {
                            if (!ShiftReduceParser.findStateOnAgenda(newAgenda, newGoldState)) {
                                if (!reorderer.reorder(goldState, highestScoringTransitionFromGoldState, transitions) || !ShiftReduceParser.findStateOnAgenda(newAgenda, newGoldState = highestScoringTransitionFromGoldState.apply(goldState))) {
                                    break block28;
                                }
                            } else {
                                transitions.remove(0);
                            }
                        }
                    } else {
                        ++numCorrect;
                        transitions.remove(0);
                    }
                    goldState = newGoldState;
                    agenda = newAgenda;
                }
                break block28;
            }
            if (this.op.trainOptions().trainingMethod != ShiftReduceTrainOptions.TrainingMethod.REORDER_ORACLE && this.op.trainOptions().trainingMethod != ShiftReduceTrainOptions.TrainingMethod.EARLY_TERMINATION && this.op.trainOptions().trainingMethod != ShiftReduceTrainOptions.TrainingMethod.GOLD) break block28;
            State state = ShiftReduceParser.initialStateFromGoldTagTree(tree);
            List<Transition> transitions = transitionLists.get(index);
            transitions = Generics.newLinkedList(transitions);
            boolean keepGoing = true;
            block9: while (transitions.size() > 0 && keepGoing) {
                Transition transition = transitions.get(0);
                int transitionNum = this.transitionIndex.indexOf(transition);
                List<String> features = this.featureFactory.featurize(state);
                int predictedNum = this.findHighestScoringTransition(state, features, false).object();
                Transition predicted = this.transitionIndex.get(predictedNum);
                if (transitionNum == predictedNum) {
                    transitions.remove(0);
                    state = transition.apply(state);
                    ++numCorrect;
                    continue;
                }
                ++numWrong;
                updates.add(new Update(features, transitionNum, predictedNum, 1.0f));
                switch (this.op.trainOptions().trainingMethod) {
                    case EARLY_TERMINATION: {
                        keepGoing = false;
                        continue block9;
                    }
                    case GOLD: {
                        transitions.remove(0);
                        state = transition.apply(state);
                        continue block9;
                    }
                    case REORDER_ORACLE: {
                        keepGoing = reorderer.reorder(state, predicted, transitions);
                        if (!keepGoing) continue block9;
                        state = predicted.apply(state);
                        continue block9;
                    }
                }
                throw new IllegalArgumentException("Unexpected method " + (Object)((Object)this.op.trainOptions().trainingMethod));
            }
        }
        return Pair.makePair(numCorrect, numWrong);
    }

    private Triple<List<Update>, Integer, Integer> trainBatch(List<Integer> indices, List<Tree> binarizedTrees, List<List<Transition>> transitionLists, List<Update> updates, Oracle oracle, MulticoreWrapper<Integer, Pair<Integer, Integer>> wrapper) {
        int numCorrect = 0;
        int numWrong = 0;
        if (this.op.trainOptions.trainingThreads == 1) {
            for (Integer index : indices) {
                Pair<Integer, Integer> count = this.trainTree(index, binarizedTrees, transitionLists, updates, oracle);
                numCorrect += ((Integer)count.first).intValue();
                numWrong += ((Integer)count.second).intValue();
            }
        } else {
            for (Integer index : indices) {
                wrapper.put(index);
            }
            wrapper.join(false);
            while (wrapper.peek()) {
                Pair<Integer, Integer> result = wrapper.poll();
                numCorrect += ((Integer)result.first).intValue();
                numWrong += ((Integer)result.second).intValue();
            }
        }
        return new Triple<List<Update>, Integer, Integer>(updates, numCorrect, numWrong);
    }

    private static Set<String> findRootStates(List<Tree> trees) {
        Set<String> roots = Generics.newHashSet();
        for (Tree tree : trees) {
            roots.add(tree.value());
        }
        return Collections.unmodifiableSet(roots);
    }

    private static Set<String> findRootOnlyStates(List<Tree> trees, Set<String> rootStates) {
        Set<String> rootOnlyStates = Generics.newHashSet(rootStates);
        for (Tree tree : trees) {
            for (Tree child : tree.children()) {
                ShiftReduceParser.findRootOnlyStatesHelper(child, rootStates, rootOnlyStates);
            }
        }
        return Collections.unmodifiableSet(rootOnlyStates);
    }

    private static void findRootOnlyStatesHelper(Tree tree, Set<String> rootStates, Set<String> rootOnlyStates) {
        rootOnlyStates.remove(tree.value());
        for (Tree child : tree.children()) {
            ShiftReduceParser.findRootOnlyStatesHelper(child, rootStates, rootOnlyStates);
        }
    }

    private void train(List<Pair<String, FileFilter>> trainTreebankPath, Pair<String, FileFilter> devTreebankPath, String serializedPath, Set<String> allowedFeatures) {
        EvaluateTreebank evaluator;
        System.err.println("Training method: " + (Object)((Object)this.op.trainOptions().trainingMethod));
        ArrayList<Tree> binarizedTrees = Generics.newArrayList();
        for (Pair<String, FileFilter> treebank : trainTreebankPath) {
            binarizedTrees.addAll(this.readBinarizedTreebank(treebank.first(), treebank.second()));
        }
        int nThreads = this.op.trainOptions.trainingThreads;
        nThreads = nThreads <= 0 ? Runtime.getRuntime().availableProcessors() : nThreads;
        Tagger tagger = null;
        if (this.op.testOptions.preTag) {
            Timing retagTimer = new Timing();
            tagger = Tagger.loadModel(this.op.testOptions.taggerSerializedFile);
            ShiftReduceParser.redoTags(binarizedTrees, tagger, nThreads);
            retagTimer.done("Retagging");
        }
        this.knownStates = ShiftReduceParser.findKnownStates(binarizedTrees);
        this.rootStates = ShiftReduceParser.findRootStates(binarizedTrees);
        this.rootOnlyStates = ShiftReduceParser.findRootOnlyStates(binarizedTrees, this.rootStates);
        System.err.println("Known states: " + this.knownStates);
        System.err.println("States which occur at the root: " + this.rootStates);
        System.err.println("States which only occur at the root: " + this.rootStates);
        Timing transitionTimer = new Timing();
        List<List<Transition>> transitionLists = CreateTransitionSequence.createTransitionSequences(binarizedTrees, this.op.compoundUnaries, this.rootStates, this.rootOnlyStates);
        for (List<Transition> transitions : transitionLists) {
            this.transitionIndex.addAll(transitions);
        }
        transitionTimer.done("Converting trees into transition lists");
        System.err.println("Number of transitions: " + this.transitionIndex.size());
        Random random = new Random(this.op.trainOptions.randomSeed);
        Treebank devTreebank = null;
        if (devTreebankPath != null) {
            devTreebank = this.readTreebank(devTreebankPath.first(), devTreebankPath.second());
        }
        double bestScore = 0.0;
        int bestIteration = 0;
        PriorityQueue<Scored> bestModels = null;
        if (this.op.trainOptions().averagedModels > 0) {
            bestModels = new PriorityQueue<Scored>(this.op.trainOptions().averagedModels + 1, ScoredComparator.ASCENDING_COMPARATOR);
        }
        ArrayList<Integer> indices = Generics.newArrayList();
        for (int i = 0; i < binarizedTrees.size(); ++i) {
            indices.add(i);
        }
        Oracle oracle = null;
        if (this.op.trainOptions().trainingMethod == ShiftReduceTrainOptions.TrainingMethod.ORACLE) {
            oracle = new Oracle(binarizedTrees, this.op.compoundUnaries, this.rootStates);
        }
        ArrayList<Update> updates = Generics.newArrayList();
        MulticoreWrapper<Integer, Pair<Integer, Integer>> wrapper = null;
        if (nThreads != 1) {
            updates = Collections.synchronizedList(updates);
            wrapper = new MulticoreWrapper<Integer, Pair<Integer, Integer>>(this.op.trainOptions.trainingThreads, new TrainTreeProcessor(binarizedTrees, transitionLists, updates, oracle));
        }
        IntCounter<String> featureFrequencies = null;
        if (this.op.trainOptions().featureFrequencyCutoff > 1) {
            featureFrequencies = new IntCounter<String>();
        }
        for (int iteration = 1; iteration <= this.op.trainOptions.trainingIterations; ++iteration) {
            Timing trainingTimer = new Timing();
            int numCorrect = 0;
            int numWrong = 0;
            Collections.shuffle(indices, random);
            for (int start = 0; start < indices.size(); start += this.op.trainOptions.batchSize) {
                int end = Math.min(start + this.op.trainOptions.batchSize, indices.size());
                Triple<List<Update>, Integer, Integer> result = this.trainBatch(indices.subList(start, end), binarizedTrees, transitionLists, updates, oracle, wrapper);
                numCorrect += ((Integer)result.second).intValue();
                numWrong += ((Integer)result.third).intValue();
                for (Update update : (List)result.first) {
                    for (String feature : update.features) {
                        if (allowedFeatures != null && !allowedFeatures.contains(feature)) continue;
                        Weight weights = this.featureWeights.get(feature);
                        if (weights == null) {
                            weights = new Weight();
                            this.featureWeights.put(feature, weights);
                        }
                        weights.updateWeight(update.goldTransition, update.delta);
                        weights.updateWeight(update.predictedTransition, -update.delta);
                        if (featureFrequencies == null) continue;
                        featureFrequencies.incrementCount(feature, update.goldTransition >= 0 && update.predictedTransition >= 0 ? 2 : 1);
                    }
                }
                updates.clear();
            }
            trainingTimer.done("Iteration " + iteration);
            System.err.println("While training, got " + numCorrect + " transitions correct and " + numWrong + " transitions wrong");
            this.outputStats();
            double labelF1 = 0.0;
            if (devTreebank != null) {
                evaluator = new EvaluateTreebank(this.op, null, this, tagger);
                evaluator.testOnTreebank(devTreebank);
                labelF1 = evaluator.getLBScore();
                System.err.println("Label F1 after " + iteration + " iterations: " + labelF1);
                if (labelF1 > bestScore) {
                    System.err.println("New best dev score (previous best " + bestScore + ")");
                    bestScore = labelF1;
                    bestIteration = iteration;
                } else {
                    System.err.println("Failed to improve for " + (iteration - bestIteration) + " iteration(s) on previous best score of " + bestScore);
                    if (this.op.trainOptions.stalledIterationLimit > 0 && iteration - bestIteration >= this.op.trainOptions.stalledIterationLimit) {
                        System.err.println("Failed to improve for too long, stopping training");
                        break;
                    }
                }
                System.err.println();
                if (bestModels != null) {
                    bestModels.add(new ScoredObject<ShiftReduceParser>(this.deepCopy(), labelF1));
                    if (bestModels.size() > this.op.trainOptions().averagedModels) {
                        bestModels.poll();
                    }
                }
            }
            if (!this.op.trainOptions().saveIntermediateModels || serializedPath == null || this.op.trainOptions.debugOutputFrequency <= 0) continue;
            String tempName = serializedPath.substring(0, serializedPath.length() - 7) + "-" + FILENAME.format(iteration) + "-" + NF.format(labelF1) + ".ser.gz";
            this.saveModel(tempName);
        }
        if (wrapper != null) {
            wrapper.join();
        }
        if (bestModels != null) {
            if (this.op.trainOptions().cvAveragedModels && devTreebank != null) {
                ArrayList<Scored> models = Generics.newArrayList();
                while (bestModels.size() > 0) {
                    models.add(bestModels.poll());
                }
                Collections.reverse(models);
                double bestF1 = 0.0;
                int bestSize = 0;
                for (int i = 1; i <= models.size(); ++i) {
                    System.err.println("Testing with " + i + " models averaged together");
                    ShiftReduceParser parser = ShiftReduceParser.averageScoredModels(models.subList(0, i));
                    evaluator = new EvaluateTreebank(parser.op, null, parser, tagger);
                    evaluator.testOnTreebank(devTreebank);
                    double labelF1 = evaluator.getLBScore();
                    System.err.println("Label F1 for " + i + " models: " + labelF1);
                    if (!(labelF1 > bestF1)) continue;
                    bestF1 = labelF1;
                    bestSize = i;
                }
                this.copyWeights(ShiftReduceParser.averageScoredModels(models.subList(0, bestSize)));
            } else {
                this.copyWeights(ShiftReduceParser.averageScoredModels(bestModels));
            }
        }
        if (featureFrequencies != null) {
            this.filterFeatures(featureFrequencies.keysAbove(this.op.trainOptions().featureFrequencyCutoff));
        }
        this.condenseFeatures();
    }

    @Override
    public void setOptionFlags(String ... flags) {
        this.op.setOptions(flags);
    }

    public static ShiftReduceParser loadModel(String path, String ... extraFlags) {
        ShiftReduceParser parser = null;
        try {
            Timing timing = new Timing();
            System.err.print("Loading parser from serialized file " + path + " ...");
            parser = (ShiftReduceParser)IOUtils.readObjectFromURLOrClasspathOrFileSystem(path);
            timing.done();
        }
        catch (IOException e) {
            throw new RuntimeIOException(e);
        }
        catch (ClassNotFoundException e) {
            throw new RuntimeIOException(e);
        }
        if (extraFlags.length > 0) {
            parser.setOptionFlags(extraFlags);
        }
        return parser;
    }

    public void saveModel(String path) {
        try {
            IOUtils.writeObjectToFile((Object)this, path);
        }
        catch (IOException e) {
            throw new RuntimeIOException(e);
        }
    }

    public static void main(String[] args) {
        ArrayList<String> remainingArgs = Generics.newArrayList();
        ArrayList<Pair<String, FileFilter>> trainTreebankPath = null;
        Pair<String, FileFilter> testTreebankPath = null;
        Pair<String, FileFilter> devTreebankPath = null;
        String serializedPath = null;
        String tlppClass = null;
        String continueTraining = null;
        int argIndex = 0;
        while (argIndex < args.length) {
            if (args[argIndex].equalsIgnoreCase("-trainTreebank")) {
                if (trainTreebankPath == null) {
                    trainTreebankPath = Generics.newArrayList();
                }
                trainTreebankPath.add(ArgUtils.getTreebankDescription(args, argIndex, "-trainTreebank"));
                argIndex = argIndex + ArgUtils.numSubArgs(args, argIndex) + 1;
                continue;
            }
            if (args[argIndex].equalsIgnoreCase("-testTreebank")) {
                testTreebankPath = ArgUtils.getTreebankDescription(args, argIndex, "-testTreebank");
                argIndex = argIndex + ArgUtils.numSubArgs(args, argIndex) + 1;
                continue;
            }
            if (args[argIndex].equalsIgnoreCase("-devTreebank")) {
                devTreebankPath = ArgUtils.getTreebankDescription(args, argIndex, "-devTreebank");
                argIndex = argIndex + ArgUtils.numSubArgs(args, argIndex) + 1;
                continue;
            }
            if (args[argIndex].equalsIgnoreCase("-serializedPath") || args[argIndex].equalsIgnoreCase("-model")) {
                serializedPath = args[argIndex + 1];
                argIndex += 2;
                continue;
            }
            if (args[argIndex].equalsIgnoreCase("-tlpp")) {
                tlppClass = args[argIndex + 1];
                argIndex += 2;
                continue;
            }
            if (args[argIndex].equalsIgnoreCase("-continueTraining")) {
                continueTraining = args[argIndex + 1];
                argIndex += 2;
                continue;
            }
            remainingArgs.add(args[argIndex]);
            ++argIndex;
        }
        String[] newArgs = new String[remainingArgs.size()];
        newArgs = remainingArgs.toArray(newArgs);
        if (trainTreebankPath == null && serializedPath == null) {
            throw new IllegalArgumentException("Must specify a treebank to train from with -trainTreebank or a parser to load with -serializedPath");
        }
        ShiftReduceParser parser = null;
        if (trainTreebankPath != null) {
            ShiftReduceOptions op;
            System.err.println("Training ShiftReduceParser");
            System.err.println("Initial arguments:");
            System.err.println("   " + StringUtils.join(args));
            if (continueTraining != null) {
                parser = ShiftReduceParser.loadModel(continueTraining, ArrayUtils.concatenate(FORCE_TAGS, newArgs));
            } else {
                op = ShiftReduceParser.buildTrainingOptions(tlppClass, newArgs);
                parser = new ShiftReduceParser(op);
            }
            op = parser.op;
            if (op.trainOptions().retrainAfterCutoff && op.trainOptions().featureFrequencyCutoff > 0) {
                String tempName = serializedPath.substring(0, serializedPath.length() - 7) + "-" + "temp.ser.gz";
                parser.train(trainTreebankPath, devTreebankPath, tempName, null);
                parser.saveModel(tempName);
                Set<String> features = parser.featureWeights.keySet();
                parser = new ShiftReduceParser(op);
                parser.train(trainTreebankPath, devTreebankPath, serializedPath, features);
            } else {
                parser.train(trainTreebankPath, devTreebankPath, serializedPath, null);
            }
            parser.saveModel(serializedPath);
        }
        if (serializedPath != null && parser == null) {
            parser = ShiftReduceParser.loadModel(serializedPath, ArrayUtils.concatenate(FORCE_TAGS, newArgs));
        }
        if (testTreebankPath != null) {
            System.err.println("Loading test trees from " + testTreebankPath.first());
            MemoryTreebank testTreebank = parser.op.tlpParams.memoryTreebank();
            testTreebank.loadPath(testTreebankPath.first(), testTreebankPath.second());
            System.err.println("Loaded " + ((Treebank)testTreebank).size() + " trees");
            EvaluateTreebank evaluator = new EvaluateTreebank(parser.op, null, parser);
            evaluator.testOnTreebank(testTreebank);
        }
    }

    private class TrainTreeProcessor
    implements ThreadsafeProcessor<Integer, Pair<Integer, Integer>> {
        List<Tree> binarizedTrees;
        List<List<Transition>> transitionLists;
        List<Update> updates;
        Oracle oracle;

        public TrainTreeProcessor(List<Tree> binarizedTrees, List<List<Transition>> transitionLists, List<Update> updates, Oracle oracle) {
            this.binarizedTrees = binarizedTrees;
            this.transitionLists = transitionLists;
            this.updates = updates;
            this.oracle = oracle;
        }

        @Override
        public Pair<Integer, Integer> process(Integer index) {
            return ShiftReduceParser.this.trainTree(index, this.binarizedTrees, this.transitionLists, this.updates, this.oracle);
        }

        public TrainTreeProcessor newInstance() {
            return this;
        }
    }

    private static class Update {
        final List<String> features;
        final int goldTransition;
        final int predictedTransition;
        final float delta;

        Update(List<String> features, int goldTransition, int predictedTransition, float delta) {
            this.features = features;
            this.goldTransition = goldTransition;
            this.predictedTransition = predictedTransition;
            this.delta = delta;
        }
    }

    private static class RetagProcessor
    implements ThreadsafeProcessor<Tree, Tree> {
        Tagger tagger;

        public RetagProcessor(Tagger tagger) {
            this.tagger = tagger;
        }

        @Override
        public Tree process(Tree tree) {
            ShiftReduceParser.redoTags(tree, this.tagger);
            return tree;
        }

        public RetagProcessor newInstance() {
            return this;
        }
    }
}

