/*
 * Decompiled with CFR 0.152.
 */
package edu.pku.coli.dualdecomp;

import edu.pku.coli.dualdecomp.AbstractDecoder;
import edu.pku.coli.dualdecomp.ArgumentCentricDecoder;
import edu.pku.coli.dualdecomp.GeneralizedCoordination;
import edu.pku.coli.dualdecomp.GeneralizedCoordinationDecoder;
import edu.pku.coli.dualdecomp.GeneralizedCoordinationReader;
import edu.pku.coli.dualdecomp.PredicateCentricDecoder;
import edu.pku.coli.dualdecomp.TreeApproxDecoder;
import edu.pku.coli.io.CCGPARGWriter;
import edu.pku.coli.io.DAGSentenceReader;
import edu.pku.coli.pear.dag.Evaluator;
import edu.pku.coli.pear.dag.SentenceForDAGParsing;
import fig.basic.Option;
import fig.exec.Execution;
import java.io.FileNotFoundException;
import java.io.PrintWriter;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import matetools.is2.modification.morefeatures.MoreFeaturesInterface;
import matetools.is2.parser.Decoder;
import matetools.is2.parser.MFO;
import matetools.is2.util.OptionsSuper;

public class DualDecompositionTester
implements Runnable {
    ArgumentCentricDecoder acd;
    PredicateCentricDecoder pcd;
    GeneralizedCoordinationDecoder gcd;
    TreeApproxDecoder tad;
    @Option
    public double gcdWeight = 1.0;
    @Option
    public double tadWeight = 1.0;
    @Option
    public String acdmod;
    @Option
    public String pcdmod;
    @Option
    public String gcdmod;
    @Option
    public String tadmod;
    @Option
    public boolean usePath = true;
    @Option
    public String tadargs = "conll08 projective X false X";
    @Option
    public String test;
    @Option
    public String testFormat = "conll08";
    @Option
    public String testcoords;
    @Option
    public String moreTreeFilesForTest;
    @Option
    public String out;
    @Option
    public double baseStep = 200.0;
    @Option
    public int maxIter = 200;
    @Option
    public String iterlog;
    double[] step;
    int all = 0;
    int exact = 0;
    PrintWriter pw = null;

    public SentenceForDAGParsing decode(SentenceForDAGParsing s) {
        ++this.all;
        int sentLen = s.numOfWords();
        double[][] pcdu = new double[sentLen + 1][sentLen + 1];
        double[][] acdu = new double[sentLen + 1][sentLen + 1];
        double[][] gcdu = new double[sentLen + 1][sentLen + 1];
        double[][] tadu = new double[sentLen + 1][sentLen + 1];
        int[][] acdv = null;
        int[][] pcdv = null;
        int[][] gcdv = null;
        int[][] tadv = null;
        int iter = 0;
        while (iter < this.maxIter) {
            if (this.pcd != null) {
                pcdv = this.pcd.decodeAfterScoring(pcdu);
            }
            if (this.acd != null) {
                acdv = this.acd.decodeAfterScoring(acdu);
            }
            if (this.gcd != null) {
                gcdv = this.gcd.decodeAfterScoring(gcdu);
            }
            if (this.tad != null) {
                tadv = this.tad.decodeAfterScoring(tadu);
            }
            boolean acpcEquals = true;
            boolean gcpcEquals = true;
            boolean tapcEquals = true;
            if (this.acd != null && this.pcd != null) {
                int j = 0;
                block1: while (j <= sentLen) {
                    int k = 0;
                    while (k <= sentLen) {
                        if (acdv[j][k] != pcdv[j][k]) {
                            acpcEquals = false;
                            break block1;
                        }
                        ++k;
                    }
                    ++j;
                }
            }
            if (acpcEquals && this.gcd != null) {
                block3: for (GeneralizedCoordination gc : s.getGeneralCoordinations()) {
                    if (gc.direction == GeneralizedCoordination.Direction.WORD2COORD) {
                        for (int arg : gc.getCoordPositions()) {
                            if (pcdv[gc.wordPosition][arg] == gcdv[gc.wordPosition][arg]) continue;
                            gcpcEquals = false;
                            break block3;
                        }
                        continue;
                    }
                    for (int pred : gc.getCoordPositions()) {
                        if (pcdv[pred][gc.wordPosition] == gcdv[pred][gc.wordPosition]) continue;
                        gcpcEquals = false;
                        break block3;
                    }
                }
            }
            if (acpcEquals && gcpcEquals && this.tad != null) {
                int j = 0;
                block6: while (j <= sentLen) {
                    int k = 0;
                    while (k <= sentLen) {
                        boolean equals;
                        if (this.pcd != null) {
                            equals = tadv[j][k] == pcdv[j][k];
                        } else if (this.acd != null) {
                            equals = tadv[j][k] == acdv[j][k];
                        } else {
                            throw new RuntimeException("One of pcd and acd is required to exist!!");
                        }
                        if (!equals && tadv[j][k] == 1) {
                            tapcEquals = false;
                            break block6;
                        }
                        ++k;
                    }
                    ++j;
                }
            }
            if (acpcEquals && gcpcEquals && tapcEquals) {
                ++this.exact;
                if (this.pw != null) {
                    this.pw.println(iter);
                    this.pw.flush();
                }
                if (this.pcd != null) {
                    return this.pcd.toDAGSentence(s, pcdv);
                }
                return this.acd.toDAGSentence(s, acdv);
            }
            if (this.acd != null && this.pcd != null) {
                int j = 0;
                while (j <= sentLen) {
                    int k = 0;
                    while (k <= sentLen) {
                        double[] dArray = pcdu[j];
                        int n = k;
                        dArray[n] = dArray[n] - this.step[iter] * (double)(pcdv[j][k] - acdv[j][k]);
                        double[] dArray2 = acdu[j];
                        int n2 = k;
                        dArray2[n2] = dArray2[n2] + this.step[iter] * (double)(pcdv[j][k] - acdv[j][k]);
                        ++k;
                    }
                    ++j;
                }
            }
            if (this.gcd != null) {
                for (GeneralizedCoordination gc : s.getGeneralCoordinations()) {
                    boolean allArc = true;
                    if (gc.direction == GeneralizedCoordination.Direction.WORD2COORD) {
                        for (int arg : gc.getCoordPositions()) {
                            if (gcdv[gc.getWordPosition()][arg] == 1) continue;
                            allArc = false;
                            break;
                        }
                        if (!allArc) continue;
                        for (int arg : gc.getCoordPositions()) {
                            double[] dArray = pcdu[gc.wordPosition];
                            int n = arg;
                            dArray[n] = dArray[n] - this.step[iter] * (double)(pcdv[gc.wordPosition][arg] - gcdv[gc.wordPosition][arg]);
                            double[] dArray3 = gcdu[gc.wordPosition];
                            int n3 = arg;
                            dArray3[n3] = dArray3[n3] + this.step[iter] / this.gcdWeight * (double)(pcdv[gc.wordPosition][arg] - gcdv[gc.wordPosition][arg]);
                            double[] dArray4 = acdu[gc.wordPosition];
                            int n4 = arg;
                            dArray4[n4] = dArray4[n4] - this.step[iter] * (double)(acdv[gc.wordPosition][arg] - gcdv[gc.wordPosition][arg]);
                            double[] dArray5 = gcdu[gc.wordPosition];
                            int n5 = arg;
                            dArray5[n5] = dArray5[n5] + this.step[iter] / this.gcdWeight * (double)(acdv[gc.wordPosition][arg] - gcdv[gc.wordPosition][arg]);
                        }
                        continue;
                    }
                    for (int pred : gc.getCoordPositions()) {
                        if (gcdv[pred][gc.getWordPosition()] == 1) continue;
                        allArc = false;
                        break;
                    }
                    if (!allArc) continue;
                    for (int pred : gc.getCoordPositions()) {
                        double[] dArray = pcdu[pred];
                        int n = gc.wordPosition;
                        dArray[n] = dArray[n] - this.step[iter] * (double)(pcdv[pred][gc.wordPosition] - gcdv[pred][gc.wordPosition]);
                        double[] dArray6 = gcdu[pred];
                        int n6 = gc.wordPosition;
                        dArray6[n6] = dArray6[n6] + this.step[iter] / this.gcdWeight * (double)(pcdv[pred][gc.wordPosition] - gcdv[pred][gc.wordPosition]);
                        double[] dArray7 = acdu[pred];
                        int n7 = gc.wordPosition;
                        dArray7[n7] = dArray7[n7] - this.step[iter] * (double)(acdv[pred][gc.wordPosition] - gcdv[pred][gc.wordPosition]);
                        double[] dArray8 = gcdu[pred];
                        int n8 = gc.wordPosition;
                        dArray8[n8] = dArray8[n8] + this.step[iter] / this.gcdWeight * (double)(acdv[pred][gc.wordPosition] - gcdv[pred][gc.wordPosition]);
                    }
                }
            }
            if (this.tad != null) {
                int j = 0;
                while (j <= sentLen) {
                    int k = 0;
                    while (k <= sentLen) {
                        if (tadv[j][k] == 1) {
                            if (this.pcd != null) {
                                double[] dArray = pcdu[j];
                                int n = k;
                                dArray[n] = dArray[n] - this.step[iter] * (double)(pcdv[j][k] - tadv[j][k]);
                                double[] dArray9 = tadu[j];
                                int n9 = k;
                                dArray9[n9] = dArray9[n9] + this.step[iter] / this.tadWeight * (double)(pcdv[j][k] - tadv[j][k]);
                            }
                            if (this.acd != null) {
                                double[] dArray = acdu[j];
                                int n = k;
                                dArray[n] = dArray[n] - this.step[iter] * (double)(acdv[j][k] - tadv[j][k]);
                                double[] dArray10 = tadu[j];
                                int n10 = k;
                                dArray10[n10] = dArray10[n10] + this.step[iter] / this.tadWeight * (double)(acdv[j][k] - tadv[j][k]);
                            }
                        }
                        ++k;
                    }
                    ++j;
                }
            }
            ++iter;
        }
        if (this.pw != null) {
            this.pw.println(this.maxIter);
            this.pw.flush();
        }
        if (this.pcd != null) {
            return this.pcd.toDAGSentence(s, pcdv);
        }
        return this.acd.toDAGSentence(s, acdv);
    }

    @Override
    public void run() {
        if (this.iterlog != null) {
            try {
                this.pw = new PrintWriter(this.iterlog);
            }
            catch (FileNotFoundException e) {
                throw new RuntimeException(e);
            }
        }
        CCGPARGWriter writer = null;
        if (this.out != null) {
            writer = new CCGPARGWriter(this.out);
        }
        this.step = new double[this.maxIter];
        int i = 0;
        while (i < this.maxIter) {
            this.step[i] = (double)(this.maxIter - i) * this.baseStep / (double)this.maxIter;
            ++i;
        }
        if (this.pcdmod != null) {
            this.pcd = (PredicateCentricDecoder)AbstractDecoder.load(this.pcdmod);
        }
        if (this.acdmod != null) {
            this.acd = (ArgumentCentricDecoder)AbstractDecoder.load(this.acdmod);
        }
        if (this.gcdmod != null) {
            this.gcd = (GeneralizedCoordinationDecoder)AbstractDecoder.load(this.gcdmod);
        }
        if (this.tadmod != null) {
            this.tad = new TreeApproxDecoder();
            String[] a = this.tadargs.split("\\s+");
            String[] args = new String[a.length + 6];
            System.arraycopy(a, 0, args, 0, a.length);
            args[a.length] = "-more_features";
            args[a.length + 1] = "-testDepDir " + this.test;
            args[a.length + 2] = "-model";
            args[a.length + 3] = this.tadmod;
            args[a.length + 4] = "-decode";
            args[a.length + 5] = "proj";
            try {
                OptionsSuper opt = TreeApproxDecoder.initialize(args, this.usePath);
                this.tad = new TreeApproxDecoder(opt);
            }
            catch (Exception e) {
                throw new RuntimeException(e);
            }
            MoreFeaturesInterface.l2i = this.tad.l2i;
            MoreFeaturesInterface.startReadingAllTest();
            MFO mf = new MFO();
            TreeApproxDecoder.types = new String[mf.getFeatureCounter().get("REL").intValue()];
            for (Map.Entry<String, Integer> e : MFO.getFeatureSet().get("REL").entrySet()) {
                TreeApproxDecoder.types[((Integer)e.getValue()).intValue()] = (String)e.getKey();
            }
            Decoder.labReverseXindexs = new ArrayList<Integer>();
            Decoder.labXindexs = new ArrayList<Integer>();
            for (Map.Entry<String, Integer> e : MFO.getFeatureSet().get("REL").entrySet()) {
                if (e.getKey().startsWith("X~R")) {
                    Decoder.labReverseXindexs.add(e.getValue());
                    continue;
                }
                if (!e.getKey().startsWith("X")) continue;
                Decoder.labXindexs.add(e.getValue());
            }
        }
        DAGSentenceReader reader = null;
        if ("conll08".equalsIgnoreCase(this.testFormat)) {
            reader = DAGSentenceReader.dagsReaderFromCoNLL08(this.test, true);
        } else if ("sdp".equalsIgnoreCase(this.testFormat)) {
            reader = DAGSentenceReader.dagsReaderFromSDP(this.test);
        } else if ("sdp15".equalsIgnoreCase(this.testFormat)) {
            reader = DAGSentenceReader.dagsReaderFromSDP15(this.test);
        } else {
            throw new RuntimeException("Unknown format");
        }
        if (this.moreTreeFilesForTest != null) {
            reader = DAGSentenceReader.DAGSentenceReaderWithMoreTrees(reader, this.moreTreeFilesForTest);
        }
        ArrayList<SentenceForDAGParsing> sents = new ArrayList<SentenceForDAGParsing>();
        for (SentenceForDAGParsing s : reader) {
            sents.add(s);
        }
        if (this.testcoords != null) {
            GeneralizedCoordinationReader gcr = new GeneralizedCoordinationReader(this.testcoords);
            List<List<GeneralizedCoordination>> coords = gcr.readAll();
            int i2 = 0;
            while (i2 < sents.size()) {
                ((SentenceForDAGParsing)sents.get(i2)).setGeneralCoordinations(coords.get(i2));
                ++i2;
            }
        }
        Evaluator eacd = new Evaluator();
        Evaluator epcd = new Evaluator();
        Evaluator egcd = new Evaluator();
        Evaluator etad = new Evaluator();
        Evaluator edd = new Evaluator();
        int i3 = 0;
        while (i3 < sents.size()) {
            SentenceForDAGParsing s = (SentenceForDAGParsing)sents.get(i3);
            if (this.pcd != null) {
                this.pcd.decode(s);
                epcd.registry(s.getGoldDAG(), s.getPredictedDAG());
            }
            if (this.acd != null) {
                this.acd.decode(s);
                eacd.registry(s.getGoldDAG(), s.getPredictedDAG());
            }
            if (this.gcd != null) {
                this.gcd.decode(s);
                egcd.registry(s.getGoldDAG(), s.getPredictedDAG());
            }
            if (this.tad != null) {
                this.tad.decode(s);
                etad.registry(s.getGoldDAG(), s.getPredictedDAG());
            }
            this.decode(s);
            edd.registry(s.getGoldDAG(), s.getPredictedDAG());
            if (writer != null) {
                writer.printOneSentence(s, "conll08", true);
            }
            if (i3 % 5 == 0 || i3 == sents.size() - 1) {
                System.out.println("##############" + (i3 + 1) + "/" + sents.size() + "############");
                System.out.println("=============pcd============\n" + epcd);
                System.out.println("=============acd============\n" + eacd);
                if (this.gcd != null) {
                    System.out.println("=============gcd============\n" + egcd);
                }
                if (this.tad != null) {
                    System.out.println("=============tad============\n" + etad);
                }
                System.out.println("=============dd============\n" + edd);
            }
            ++i3;
        }
        if (writer != null) {
            writer.close();
        }
        System.out.println("tadWeight:" + this.tadWeight);
        System.out.println("exact / all = " + this.exact + " / " + this.all);
    }

    public static void main(String[] args) {
        Execution.run(args, new DualDecompositionTester());
    }
}

