/*
 * Decompiled with CFR 0.152.
 */
package edu.pku.coli.dualdecomp;

import edu.pku.coli.dualdecomp.AbstractDecoder;
import edu.pku.coli.dualdecomp.ArcClassificationDecoder;
import edu.pku.coli.dualdecomp.ArgumentCentricDecoder;
import edu.pku.coli.dualdecomp.GeneralizedCoordination;
import edu.pku.coli.dualdecomp.GeneralizedCoordinationDecoder;
import edu.pku.coli.dualdecomp.GeneralizedCoordinationReader;
import edu.pku.coli.dualdecomp.PredicateCentricDecoder;
import edu.pku.coli.dualdecomp.ThirdOrderArgumentCentricDecoder;
import edu.pku.coli.io.CCGPARGWriter;
import edu.pku.coli.io.DAGSentenceReader;
import edu.pku.coli.pear.dag.Evaluator;
import edu.pku.coli.pear.dag.SentenceForDAGParsing;
import fig.basic.Option;
import fig.exec.Execution;
import java.util.ArrayList;
import java.util.List;

public class DualDecompositionTrainer
implements Runnable {
    @Option
    public String train;
    @Option
    public String trainCoords;
    @Option
    public String test;
    @Option
    public String testCoords;
    @Option
    public String usemod = "acd";
    @Option
    public String model;
    @Option
    public String out;
    @Option
    public String dataFormat = "conll08";
    @Option
    public int threadNum = 8;
    @Option
    public boolean prune = true;
    @Option
    public boolean usePathFeat = true;
    @Option
    public String moreTreeFilesForTrain;
    @Option
    public String moreTreeFilesForTest;
    List<SentenceForDAGParsing> trainset = new ArrayList<SentenceForDAGParsing>();
    List<SentenceForDAGParsing> testset = new ArrayList<SentenceForDAGParsing>();
    AbstractDecoder abd;

    public static void main(String[] args) {
        Execution.run(args, new DualDecompositionTrainer());
    }

    @Override
    public void run() {
        if (this.train == null && this.model != null) {
            this.abd = AbstractDecoder.load(this.model);
            float max = 0.0f;
            int i = 0;
            while (i < this.abd.paramLength) {
                if (max < this.abd.avgParam[i]) {
                    max = this.abd.avgParam[i];
                }
                ++i;
            }
            System.out.println("max param: " + max);
        }
        if (this.train != null) {
            this.trainset = this.readSentences(this.train, this.moreTreeFilesForTrain);
            if (this.test != null) {
                this.testset = this.readSentences(this.test, this.moreTreeFilesForTest);
            }
            if (this.trainCoords != null) {
                GeneralizedCoordinationReader trainCoordsReader = new GeneralizedCoordinationReader(this.trainCoords);
                List<List<GeneralizedCoordination>> trainCoords = trainCoordsReader.readAll();
                if (trainCoords.size() != this.trainset.size()) {
                    throw new RuntimeException("Trainset Generalized Coordinations inconsistent: " + this.trainset.size() + ":" + trainCoords.size());
                }
                int p = 0;
                int n = 0;
                int i = 0;
                while (i < this.trainset.size()) {
                    this.trainset.get(i).setGeneralCoordinations(trainCoords.get(i));
                    for (GeneralizedCoordination c : trainCoords.get(i)) {
                        if (c.isAllArc()) {
                            ++p;
                            continue;
                        }
                        ++n;
                    }
                    ++i;
                }
                System.out.println("Coordinations in trainset positive : negative = " + p + ":" + n);
            }
            System.out.println("train loaded");
            this.train();
            if (this.model != null) {
                this.abd.dump(this.model);
            }
        }
        if (this.test != null) {
            this.testset = this.readSentences(this.test, this.moreTreeFilesForTest);
            if (this.testCoords != null) {
                GeneralizedCoordinationReader testCoordsReader = new GeneralizedCoordinationReader(this.testCoords);
                List<List<GeneralizedCoordination>> testCoords = testCoordsReader.readAll();
                if (testCoords.size() != this.testset.size()) {
                    throw new RuntimeException("Testset Generalized Coordinations inconsistent: " + this.testset.size() + ":" + testCoords.size());
                }
                int i = 0;
                while (i < this.testset.size()) {
                    this.testset.get(i).setGeneralCoordinations(testCoords.get(i));
                    ++i;
                }
            }
            System.out.println("test loaded");
            this.test();
        }
    }

    public void train() {
        if (this.usemod.equalsIgnoreCase("pcd")) {
            this.abd = new PredicateCentricDecoder();
        } else if (this.usemod.equalsIgnoreCase("acd")) {
            this.abd = new ArgumentCentricDecoder();
        } else if (this.usemod.equalsIgnoreCase("gcd")) {
            this.abd = new GeneralizedCoordinationDecoder();
        } else if (this.usemod.equalsIgnoreCase("arc")) {
            this.abd = new ArcClassificationDecoder();
        } else if (this.usemod.equalsIgnoreCase("acd3")) {
            this.abd = new ThirdOrderArgumentCentricDecoder();
        } else {
            System.out.println("Unkown model:" + this.usemod);
            return;
        }
        this.abd.usePathFeat = this.usePathFeat;
        this.abd.threadNum = this.threadNum;
        this.abd.prune = this.prune;
        this.abd.train(this.trainset, this.testset);
    }

    private List<SentenceForDAGParsing> readSentences(String file, String moreTreeFiles) {
        ArrayList<SentenceForDAGParsing> ret = new ArrayList<SentenceForDAGParsing>();
        DAGSentenceReader reader = null;
        if (this.dataFormat.equalsIgnoreCase("sdp")) {
            reader = DAGSentenceReader.dagsReaderFromSDP(file);
        } else if (this.dataFormat.equalsIgnoreCase("conll08")) {
            reader = DAGSentenceReader.dagsReaderFromCoNLL08(file, true);
        } else if (this.dataFormat.equalsIgnoreCase("sdp15")) {
            reader = DAGSentenceReader.dagsReaderFromSDP15(file);
        } else {
            throw new RuntimeException("Unknown data format: " + this.dataFormat);
        }
        if (moreTreeFiles != null) {
            reader = DAGSentenceReader.DAGSentenceReaderWithMoreTrees(reader, moreTreeFiles);
        }
        for (SentenceForDAGParsing s : reader) {
            ret.add(s);
        }
        return ret;
    }

    public void test() {
        CCGPARGWriter writer = null;
        if (this.out != null) {
            writer = new CCGPARGWriter(this.out);
        }
        Evaluator e = new Evaluator();
        int num = 0;
        for (SentenceForDAGParsing s : this.testset) {
            this.abd.decode(s);
            if (num % 500 == 0) {
                System.out.print(".");
            }
            e.registry(s.getGoldDAG(), s.getPredictedDAG());
            if (writer != null) {
                writer.printOneSentence(s, this.dataFormat, true);
            }
            ++num;
        }
        System.out.println("\n" + e);
        if (writer != null) {
            writer.close();
        }
    }
}

