/*
 * Decompiled with CFR 0.152.
 */
package edu.pku.coli.dualdecomp;

import edu.pku.coli.dualdecomp.SimpleClassificationDecoder;
import edu.pku.coli.lexer.BeamClassifier;
import edu.pku.coli.lexer.IndexedTokenSequence;
import edu.pku.coli.pear.dag.PredicateArgumentAdjunctDAG;
import edu.pku.coli.pear.dag.SentenceForDAGParsing;
import fig.basic.Indexer;
import java.util.ArrayList;
import java.util.List;

public class SequenceLabelingDecoder {
    static boolean parentsSequence = false;
    Indexer<SimpleClassificationDecoder.IntsWrapper> featsIndexer = new Indexer();
    Indexer<SimpleClassificationDecoder.IntsWrapper> tokenIndexer = new Indexer();
    BeamClassifier bc;
    private List<IndexedTokenSequence> trainSequences;
    private int beam = 4;
    private boolean noDP = true;
    private int maxIter = 20;
    private int order = 2;
    private List<SentenceForDAGParsing> trainset;

    public void train(List<SentenceForDAGParsing> trainset) {
        this.trainset = trainset;
        this.extractSequenceLabelingSamples();
        this.bc = new BeamClassifier.Builder().trainingData(this.trainSequences).beam(this.beam).DP(!this.noDP).maxIter(this.maxIter).order(this.order).build();
        this.bc.train();
    }

    public SentenceForDAGParsing decode(SentenceForDAGParsing s) {
        int sentLen = s.numOfWords();
        PredicateArgumentAdjunctDAG dag = new PredicateArgumentAdjunctDAG(sentLen);
        int i = 0;
        while (i <= sentLen) {
            int j;
            int[] tags;
            IndexedTokenSequence sample;
            if (!parentsSequence) {
                sample = this.extractChildrenSequenceLabelingSample(s, i, true);
                tags = this.bc.predict(sample);
                j = 0;
                while (j < tags.length) {
                    if (tags[j] == 1) {
                        dag.addArc(i, j + 1, "X");
                    }
                    ++j;
                }
            } else if (i != 0) {
                sample = this.extractParentsSequenceLabelingSample(s, i, true);
                tags = this.bc.predict(sample);
                j = 0;
                while (j < tags.length) {
                    if (tags[j] == 1) {
                        dag.addArc(j, i, "X");
                    }
                    ++j;
                }
            }
            ++i;
        }
        s.setPredictedDAG(dag);
        return s;
    }

    private void extractSequenceLabelingSamples() {
        this.trainSequences = new ArrayList<IndexedTokenSequence>();
        int num = 0;
        for (SentenceForDAGParsing s : this.trainset) {
            if (++num % 500 == 0) {
                if (num % 5000 == 0) {
                    System.out.print("+");
                } else {
                    System.out.print(".");
                }
            }
            int i = 0;
            while (i <= s.numOfWords()) {
                if (!parentsSequence) {
                    this.trainSequences.add(this.extractChildrenSequenceLabelingSample(s, i, false));
                } else if (i != 0) {
                    this.trainSequences.add(this.extractParentsSequenceLabelingSample(s, i, false));
                }
                ++i;
            }
        }
    }

    private IndexedTokenSequence extractParentsSequenceLabelingSample(SentenceForDAGParsing s, int dependent, boolean forTest) {
        int sentLen = s.numOfWords();
        int[] tokens = new int[sentLen + 1];
        int[] tags = new int[sentLen + 1];
        int[][] feats = new int[sentLen + 1][];
        int i = 0;
        while (i <= sentLen) {
            tokens[i] = s.getKthWordIndex(i);
            ++i;
        }
        if (!forTest) {
            i = 0;
            while (i <= sentLen) {
                if (s.getGoldDAG().containsArc(i, dependent)) {
                    tags[i] = 1;
                }
                ++i;
            }
        }
        i = 0;
        while (i <= sentLen) {
            List<SimpleClassificationDecoder.IntsWrapper> featsi = this.extractEmissionFeats(s, i, dependent);
            feats[i] = new int[featsi.size()];
            int j = 0;
            while (j < feats[i].length) {
                feats[i][j] = this.featsIndexer.getIndex(featsi.get(j));
                ++j;
            }
            ++i;
        }
        return new IndexedTokenSequence(tokens, tags, (int[][])feats);
    }

    private IndexedTokenSequence extractChildrenSequenceLabelingSample(SentenceForDAGParsing s, int head, boolean forTest) {
        int sentLen = s.numOfWords();
        int[] tokens = new int[sentLen];
        int[] tags = new int[sentLen];
        int[][] feats = new int[sentLen][];
        int i = 0;
        while (i < sentLen) {
            tokens[i] = s.getKthWordIndex(i + 1);
            ++i;
        }
        if (!forTest) {
            i = 0;
            while (i < sentLen) {
                if (s.getGoldDAG().containsArc(head, i + 1)) {
                    tags[i] = 1;
                }
                ++i;
            }
        }
        i = 0;
        while (i < sentLen) {
            List<SimpleClassificationDecoder.IntsWrapper> featsi = this.extractEmissionFeats(s, head, i + 1);
            feats[i] = new int[featsi.size()];
            int j = 0;
            while (j < feats[i].length) {
                feats[i][j] = this.featsIndexer.getIndex(featsi.get(j));
                ++j;
            }
            ++i;
        }
        return new IndexedTokenSequence(tokens, tags, (int[][])feats);
    }

    private List<SimpleClassificationDecoder.IntsWrapper> extractEmissionFeats(SentenceForDAGParsing s, int from, int to) {
        ArrayList<SimpleClassificationDecoder.IntsWrapper> intsFeats = new ArrayList<SimpleClassificationDecoder.IntsWrapper>();
        int wordFrom = s.getKthWordIndex(from);
        int posFrom = s.getKthPosIndex(from);
        int preWordFrom = s.getKthWordIndex(from - 1);
        int prePosFrom = s.getKthPosIndex(from - 1);
        int postWordFrom = s.getKthWordIndex(from + 1);
        int postPosFrom = s.getKthPosIndex(from + 1);
        int wordTo = s.getKthWordIndex(to);
        int posTo = s.getKthPosIndex(to);
        int preWordTo = s.getKthWordIndex(to - 1);
        int prePosTo = s.getKthPosIndex(to - 1);
        int postWordTo = s.getKthWordIndex(to + 1);
        int postPosTo = s.getKthPosIndex(to + 1);
        int dist = from - to;
        int direction = dist > 0 ? -1 : 1;
        int n = 42;
        intsFeats.add(new SimpleClassificationDecoder.IntsWrapper(new int[]{n++, wordFrom, dist}));
        intsFeats.add(new SimpleClassificationDecoder.IntsWrapper(new int[]{n++, posFrom, dist}));
        intsFeats.add(new SimpleClassificationDecoder.IntsWrapper(new int[]{n++, wordFrom, posFrom, dist}));
        intsFeats.add(new SimpleClassificationDecoder.IntsWrapper(new int[]{n++, wordTo, dist}));
        intsFeats.add(new SimpleClassificationDecoder.IntsWrapper(new int[]{n++, posTo, dist}));
        intsFeats.add(new SimpleClassificationDecoder.IntsWrapper(new int[]{n++, wordTo, posTo, dist}));
        intsFeats.add(new SimpleClassificationDecoder.IntsWrapper(new int[]{n++, preWordFrom, dist}));
        intsFeats.add(new SimpleClassificationDecoder.IntsWrapper(new int[]{n++, prePosFrom, dist}));
        intsFeats.add(new SimpleClassificationDecoder.IntsWrapper(new int[]{n++, preWordFrom, prePosFrom, dist}));
        intsFeats.add(new SimpleClassificationDecoder.IntsWrapper(new int[]{n++, postWordFrom, dist}));
        intsFeats.add(new SimpleClassificationDecoder.IntsWrapper(new int[]{n++, postPosFrom, dist}));
        intsFeats.add(new SimpleClassificationDecoder.IntsWrapper(new int[]{n++, postWordFrom, postPosFrom, dist}));
        intsFeats.add(new SimpleClassificationDecoder.IntsWrapper(new int[]{n++, preWordTo, dist}));
        intsFeats.add(new SimpleClassificationDecoder.IntsWrapper(new int[]{n++, prePosTo, dist}));
        intsFeats.add(new SimpleClassificationDecoder.IntsWrapper(new int[]{n++, preWordTo, prePosTo, dist}));
        intsFeats.add(new SimpleClassificationDecoder.IntsWrapper(new int[]{n++, postWordTo, dist}));
        intsFeats.add(new SimpleClassificationDecoder.IntsWrapper(new int[]{n++, postPosTo, dist}));
        intsFeats.add(new SimpleClassificationDecoder.IntsWrapper(new int[]{n++, postWordTo, postPosTo, dist}));
        intsFeats.add(new SimpleClassificationDecoder.IntsWrapper(new int[]{n++, wordFrom, posFrom, wordTo, posTo, dist}));
        intsFeats.add(new SimpleClassificationDecoder.IntsWrapper(new int[]{n++, wordFrom, posFrom, wordTo, dist}));
        intsFeats.add(new SimpleClassificationDecoder.IntsWrapper(new int[]{n++, wordFrom, posFrom, posTo, dist}));
        intsFeats.add(new SimpleClassificationDecoder.IntsWrapper(new int[]{n++, wordFrom, wordTo, posTo, dist}));
        intsFeats.add(new SimpleClassificationDecoder.IntsWrapper(new int[]{n++, posFrom, wordTo, posTo, dist}));
        intsFeats.add(new SimpleClassificationDecoder.IntsWrapper(new int[]{n++, wordFrom, wordTo, dist}));
        intsFeats.add(new SimpleClassificationDecoder.IntsWrapper(new int[]{n++, posFrom, posTo, dist}));
        intsFeats.add(new SimpleClassificationDecoder.IntsWrapper(new int[]{n++, wordFrom, posFrom, prePosFrom, dist}));
        intsFeats.add(new SimpleClassificationDecoder.IntsWrapper(new int[]{n++, posFrom, preWordFrom, prePosFrom, dist}));
        intsFeats.add(new SimpleClassificationDecoder.IntsWrapper(new int[]{n++, posFrom, prePosFrom, dist}));
        intsFeats.add(new SimpleClassificationDecoder.IntsWrapper(new int[]{n++, wordFrom, posFrom, postPosFrom, dist}));
        intsFeats.add(new SimpleClassificationDecoder.IntsWrapper(new int[]{n++, posFrom, postWordFrom, postPosFrom, dist}));
        intsFeats.add(new SimpleClassificationDecoder.IntsWrapper(new int[]{n++, posFrom, postPosFrom, dist}));
        intsFeats.add(new SimpleClassificationDecoder.IntsWrapper(new int[]{n++, wordTo, posTo, prePosTo, dist}));
        intsFeats.add(new SimpleClassificationDecoder.IntsWrapper(new int[]{n++, posTo, preWordTo, prePosTo, dist}));
        intsFeats.add(new SimpleClassificationDecoder.IntsWrapper(new int[]{n++, posTo, prePosTo, dist}));
        intsFeats.add(new SimpleClassificationDecoder.IntsWrapper(new int[]{n++, wordTo, posTo, postPosTo, dist}));
        intsFeats.add(new SimpleClassificationDecoder.IntsWrapper(new int[]{n++, posTo, postWordTo, postPosTo, dist}));
        intsFeats.add(new SimpleClassificationDecoder.IntsWrapper(new int[]{n++, posTo, postPosTo, dist}));
        intsFeats.add(new SimpleClassificationDecoder.IntsWrapper(new int[]{n++, wordTo, posTo, prePosFrom, dist}));
        intsFeats.add(new SimpleClassificationDecoder.IntsWrapper(new int[]{n++, posTo, preWordFrom, prePosFrom, dist}));
        intsFeats.add(new SimpleClassificationDecoder.IntsWrapper(new int[]{n++, posTo, prePosFrom, dist}));
        intsFeats.add(new SimpleClassificationDecoder.IntsWrapper(new int[]{n++, wordTo, posTo, postWordFrom, dist}));
        intsFeats.add(new SimpleClassificationDecoder.IntsWrapper(new int[]{n++, posTo, postWordFrom, postPosFrom, dist}));
        intsFeats.add(new SimpleClassificationDecoder.IntsWrapper(new int[]{n++, posTo, postPosFrom, dist}));
        intsFeats.add(new SimpleClassificationDecoder.IntsWrapper(new int[]{n++, wordFrom, posFrom, prePosTo, dist}));
        intsFeats.add(new SimpleClassificationDecoder.IntsWrapper(new int[]{n++, posFrom, preWordTo, prePosTo, dist}));
        intsFeats.add(new SimpleClassificationDecoder.IntsWrapper(new int[]{n++, posFrom, prePosTo, dist}));
        intsFeats.add(new SimpleClassificationDecoder.IntsWrapper(new int[]{n++, wordFrom, posFrom, postPosTo, dist}));
        intsFeats.add(new SimpleClassificationDecoder.IntsWrapper(new int[]{n++, posFrom, postWordTo, postPosTo, dist}));
        intsFeats.add(new SimpleClassificationDecoder.IntsWrapper(new int[]{n++, posFrom, postPosTo, dist}));
        return intsFeats;
    }
}

