/*
 * Decompiled with CFR 0.152.
 */
package edu.pku.coli.dualdecomp;

import edu.pku.coli.pear.dag.Evaluator;
import edu.pku.coli.pear.dag.SentenceForDAGParsing;
import fig.basic.Indexer;
import java.io.File;
import java.io.FileInputStream;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Iterator;
import java.util.List;
import java.util.Random;
import java.util.zip.GZIPInputStream;
import java.util.zip.GZIPOutputStream;

public abstract class AbstractDecoder
implements Serializable {
    private static final long serialVersionUID = -3268422412938065337L;
    public static Random rnd = new Random(42L);
    public int threadNum = 8;
    public boolean prune = true;
    public int _maxSibling = Integer.MAX_VALUE;
    public boolean usePathFeat = true;
    int paramLength = 179669557;
    protected transient float[] curParam = new float[this.paramLength];
    protected float[] avgParam = new float[this.paramLength];
    int maxIter = 5;
    protected FeatureExtractor featureExtractor;
    protected transient SentenceForDAGParsing curSent;

    public void train(List<SentenceForDAGParsing> trainset, List<SentenceForDAGParsing> devset) {
        System.out.println("Using decoder: " + this.getClass().getSimpleName());
        ArrayList<SentenceForDAGParsing> train = new ArrayList<SentenceForDAGParsing>();
        for (SentenceForDAGParsing s : trainset) {
            boolean exceedMaxSib = false;
            int i = 0;
            while (i < s.numOfWords()) {
                int nsib = s.getGoldDAG().getAdjacencyLists()[i].getOutArcs().size();
                if (nsib > this._maxSibling) {
                    exceedMaxSib = true;
                    break;
                }
                ++i;
            }
            if (exceedMaxSib) continue;
            train.add(s);
        }
        int iter = 0;
        while (iter < this.maxIter) {
            Collections.shuffle(train, rnd);
            int completeMatch = 0;
            int i = 0;
            while (i < train.size()) {
                SentenceForDAGParsing sent = (SentenceForDAGParsing)train.get(i);
                this.decodeCur(sent);
                List<Integer> gold = this.featureExtractor.extract(sent, false);
                List<Integer> sys = this.featureExtractor.extract(sent, true);
                if (gold.equals(sys)) {
                    ++completeMatch;
                } else {
                    this.updateParam(gold, sys, 1.0);
                }
                if (i % 500 == 0 && i != 0) {
                    System.out.print(".");
                }
                if (i % 5000 == 0 && i != 0) {
                    this.avarageParam();
                    System.out.print("+");
                }
                ++i;
            }
            this.avarageParam();
            System.out.print("+");
            System.out.println("\nIteration " + iter + ", complete match in trainset:" + completeMatch + "/" + train.size());
            if (completeMatch == train.size()) break;
            Evaluator e = new Evaluator();
            for (SentenceForDAGParsing s : devset) {
                this.decode(s);
                e.registry(s.getGoldDAG(), s.getPredictedDAG());
            }
            if (!devset.isEmpty()) {
                System.out.println(e);
            }
            int none0 = 0;
            int i2 = 0;
            while (i2 < this.paramLength) {
                if (this.curParam[i2] != 0.0f || this.avgParam[i2] != 0.0f) {
                    ++none0;
                }
                ++i2;
            }
            System.out.println("Param length: " + this.paramLength);
            System.out.println("None zero param: " + none0);
            ++iter;
        }
    }

    protected void avarageParam() {
        int i = 0;
        while (i < this.paramLength) {
            int n = i;
            this.avgParam[n] = this.avgParam[n] + this.curParam[i];
            ++i;
        }
    }

    protected double score(List<Integer> feats, float[] param) {
        double score = 0.0;
        for (int feat : feats) {
            score += (double)param[feat];
        }
        return score;
    }

    protected int hashInts(int x) {
        int result = this.hashInt(x);
        if ((result %= this.paramLength) < 0) {
            result += this.paramLength;
        }
        return result;
    }

    protected int hashInts(int x1, int x2) {
        int result = 31 + this.hashInt(x1);
        result = 31 * result + this.hashInt(x2);
        if ((result %= this.paramLength) < 0) {
            result += this.paramLength;
        }
        return result;
    }

    protected int hashInts(int x1, int x2, int x3) {
        int result = 31 + this.hashInt(x1);
        result = 31 * result + this.hashInt(x2);
        result = 31 * result + this.hashInt(x3);
        if ((result %= this.paramLength) < 0) {
            result += this.paramLength;
        }
        return result;
    }

    protected int hashInts(int x1, int x2, int x3, int x4) {
        int result = 31 + this.hashInt(x1);
        result = 31 * result + this.hashInt(x2);
        result = 31 * result + this.hashInt(x3);
        result = 31 * result + this.hashInt(x4);
        if ((result %= this.paramLength) < 0) {
            result += this.paramLength;
        }
        return result;
    }

    protected int hashInts(int x1, int x2, int x3, int x4, int x5) {
        int result = 31 + this.hashInt(x1);
        result = 31 * result + this.hashInt(x2);
        result = 31 * result + this.hashInt(x3);
        result = 31 * result + this.hashInt(x4);
        result = 31 * result + this.hashInt(x5);
        if ((result %= this.paramLength) < 0) {
            result += this.paramLength;
        }
        return result;
    }

    protected int hashInts(int x1, int x2, int x3, int x4, int x5, int x6) {
        int result = 31 + this.hashInt(x1);
        result = 31 * result + this.hashInt(x2);
        result = 31 * result + this.hashInt(x3);
        result = 31 * result + this.hashInt(x4);
        result = 31 * result + this.hashInt(x5);
        result = 31 * result + this.hashInt(x6);
        if ((result %= this.paramLength) < 0) {
            result += this.paramLength;
        }
        return result;
    }

    protected int hashInts(int x1, int x2, int x3, int x4, int x5, int x6, int x7) {
        int result = 31 + this.hashInt(x1);
        result = 31 * result + this.hashInt(x2);
        result = 31 * result + this.hashInt(x3);
        result = 31 * result + this.hashInt(x4);
        result = 31 * result + this.hashInt(x5);
        result = 31 * result + this.hashInt(x6);
        result = 31 * result + this.hashInt(x7);
        if ((result %= this.paramLength) < 0) {
            result += this.paramLength;
        }
        return result;
    }

    protected int hashInts(int x1, int x2, int x3, int x4, int x5, int x6, int x7, int x8) {
        int result = 31 + this.hashInt(x1);
        result = 31 * result + this.hashInt(x2);
        result = 31 * result + this.hashInt(x3);
        result = 31 * result + this.hashInt(x4);
        result = 31 * result + this.hashInt(x5);
        result = 31 * result + this.hashInt(x6);
        result = 31 * result + this.hashInt(x7);
        result = 31 * result + this.hashInt(x8);
        if ((result %= this.paramLength) < 0) {
            result += this.paramLength;
        }
        return result;
    }

    protected int hashInts(int ... ints) {
        if (ints == null) {
            return 0;
        }
        int result = 1;
        int[] nArray = ints;
        int n = ints.length;
        int n2 = 0;
        while (n2 < n) {
            int element = nArray[n2];
            result = 31 * result + this.hashInt(element);
            ++n2;
        }
        if ((result %= this.paramLength) < 0) {
            result += this.paramLength;
        }
        return result;
    }

    protected int hashInt(int i) {
        return i;
    }

    protected int distMap(int dist) {
        if (dist <= 5 && dist >= -5) {
            return dist;
        }
        if (dist > 0) {
            return 6;
        }
        return -6;
    }

    protected void updateParam(List<Integer> gold, List<Integer> sys, double weight) {
        int feat;
        Iterator<Integer> iterator = gold.iterator();
        while (iterator.hasNext()) {
            int n = feat = iterator.next().intValue();
            this.curParam[n] = (float)((double)this.curParam[n] + weight);
        }
        iterator = sys.iterator();
        while (iterator.hasNext()) {
            int n = feat = iterator.next().intValue();
            this.curParam[n] = (float)((double)this.curParam[n] - weight);
        }
    }

    public abstract SentenceForDAGParsing toDAGSentence(SentenceForDAGParsing var1, int[][] var2);

    public abstract int[][] toVariable(SentenceForDAGParsing var1);

    public abstract int[][] decodeAfterScoring(double[][] var1);

    abstract void scoreFeats(float[] var1);

    public void setSentAndScoreFeats(SentenceForDAGParsing sent, float[] param) {
        this.curSent = sent;
        this.scoreFeats(param);
    }

    public int[][] decodeVariable(SentenceForDAGParsing sent, float[] param, double[][] additionalWeight) {
        this.setSentAndScoreFeats(sent, param);
        return this.decodeAfterScoring(additionalWeight);
    }

    public SentenceForDAGParsing decode(SentenceForDAGParsing sent, float[] param, double[][] additionalWeight) {
        int[][] var = this.decodeVariable(sent, param, additionalWeight);
        return this.toDAGSentence(sent, var);
    }

    public SentenceForDAGParsing decodeCur(SentenceForDAGParsing sent) {
        int sentLen = sent.numOfWords();
        return this.decode(sent, this.curParam, new double[sentLen + 1][sentLen + 1]);
    }

    public SentenceForDAGParsing decode(SentenceForDAGParsing sent) {
        int sentLen = sent.numOfWords();
        return this.decode(sent, this.avgParam, new double[sentLen + 1][sentLen + 1]);
    }

    public void dump(String file) {
        try {
            this.dump(new File(file));
        }
        catch (IOException e) {
            throw new RuntimeException(e);
        }
    }

    public static AbstractDecoder load(String file) {
        try {
            return AbstractDecoder.load(new File(file));
        }
        catch (IOException e) {
            throw new RuntimeException(e);
        }
    }

    public void dump(File file) throws IOException {
        FileOutputStream fos = new FileOutputStream(file);
        GZIPOutputStream gos = new GZIPOutputStream(fos);
        ObjectOutputStream oos = new ObjectOutputStream(gos);
        oos.writeObject(this);
        oos.writeObject(SentenceForDAGParsing.wordIndexer);
        oos.writeObject(SentenceForDAGParsing.posIndexer);
        oos.close();
        gos.close();
        fos.close();
    }

    public static AbstractDecoder load(File file) throws IOException {
        Object p;
        Object w;
        Object o;
        FileInputStream fis = new FileInputStream(file);
        GZIPInputStream gis = new GZIPInputStream(fis);
        ObjectInputStream ois = new ObjectInputStream(gis);
        try {
            o = ois.readObject();
            w = ois.readObject();
            p = ois.readObject();
        }
        catch (ClassNotFoundException e) {
            throw new RuntimeException(e);
        }
        AbstractDecoder ret = (AbstractDecoder)o;
        SentenceForDAGParsing.wordIndexer = (Indexer)w;
        SentenceForDAGParsing.posIndexer = (Indexer)p;
        ois.close();
        gis.close();
        fis.close();
        return ret;
    }

    protected static abstract class FeatureExtractor
    implements Serializable {
        private static final long serialVersionUID = 3309455104556150453L;

        protected FeatureExtractor() {
        }

        abstract List<Integer> extract(SentenceForDAGParsing var1, boolean var2);
    }
}

