/*
 * Decompiled with CFR 0.152.
 */
package edu.pku.coli.dualdecomp;

import edu.pku.coli.io.CCGPARGWriter;
import edu.pku.coli.io.DAGSentenceReader;
import edu.pku.coli.pear.dag.Evaluator;
import edu.pku.coli.pear.dag.PredicateArgumentAdjunctDAG;
import edu.pku.coli.pear.dag.SentenceForDAGParsing;
import edu.pku.coli.treeapprox.DAG2TreeBFS;
import edu.pku.coli.treeapprox.DAG2TreeDFS;
import edu.pku.coli.treeapprox.DAG2TreeDFS_ForwardEdgeFirst;
import edu.pku.coli.treeapprox.DAG2TreeNonProjectiveSimple;
import edu.pku.coli.treeapprox.DAG2TreeProjectiveSimple;
import edu.pku.coli.treeapprox.DAG2TreeTransformer;
import edu.pku.coli.treeapprox.TreeApproxParser;
import edu.pku.coli.treeapprox.TreeApproxPipe;
import fig.basic.Pair;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import matetools.is2.data.Cluster;
import matetools.is2.data.DataFES;
import matetools.is2.data.FV;
import matetools.is2.data.Instances;
import matetools.is2.data.Long2Int;
import matetools.is2.data.Parse;
import matetools.is2.data.PipeGen;
import matetools.is2.data.SentenceData09;
import matetools.is2.io.CONLLReader09;
import matetools.is2.modification.morefeatures.MoreFeaturesInterface;
import matetools.is2.parser.Decoder;
import matetools.is2.parser.Extractor;
import matetools.is2.parser.MFO;
import matetools.is2.parser.Options;
import matetools.is2.parser.ParametersFloat;
import matetools.is2.parser.Parser;
import matetools.is2.parser.Pipe;
import matetools.is2.util.DB;
import matetools.is2.util.OptionsSuper;

public class TreeApproxDecoder
extends Parser {
    private static DAG2TreeTransformer dag2tree;
    private static String dataFormat;
    private static String transformer;
    public static String[] transformerArgs;
    private static OptionsSuper statOps;
    static String[] types;
    private SentenceData09 curInstance;

    public TreeApproxDecoder() {
    }

    public TreeApproxDecoder(OptionsSuper options) {
        super(options);
    }

    public static void initializeTransformer() {
        boolean changeNullEdge = transformerArgs[0].equalsIgnoreCase("true");
        if (transformer.equalsIgnoreCase("dfs")) {
            dag2tree = new DAG2TreeDFS(changeNullEdge, transformerArgs[1], transformerArgs[2]);
        } else if (transformer.equalsIgnoreCase("iteration")) {
            dag2tree = new DAG2TreeDFS_ForwardEdgeFirst(changeNullEdge, transformerArgs[1], transformerArgs[2]);
        } else if (transformer.equalsIgnoreCase("bfs")) {
            dag2tree = new DAG2TreeBFS(changeNullEdge, transformerArgs[1], transformerArgs[2]);
        } else if (transformer.equalsIgnoreCase("projective")) {
            dag2tree = new DAG2TreeProjectiveSimple(changeNullEdge);
        } else if (transformer.equalsIgnoreCase("nonprojective")) {
            dag2tree = new DAG2TreeNonProjectiveSimple(changeNullEdge);
        } else {
            System.err.println("Usage: <data format> <dag2tree algo> \n\t<dag2tree arg1: changeNullEdge> <arg2: attach strategy> <arg3: importance strategy>\n\t<mate args...>");
        }
    }

    public static OptionsSuper initialize(String[] args, boolean usePath) throws Exception {
        System.out.println(Arrays.toString(args));
        dataFormat = args[0];
        transformer = args[1];
        transformerArgs = new String[]{args[2], args[3], args[4]};
        TreeApproxDecoder.initializeTransformer();
        String[] parserArgs = Arrays.copyOfRange(args, 5, args.length);
        Options options = new Options(parserArgs);
        statOps = options;
        Runtime runtime = Runtime.getRuntime();
        THREADS = runtime.availableProcessors();
        if (options.cores < THREADS && options.cores > 0) {
            THREADS = options.cores;
        }
        DB.println("Found " + runtime.availableProcessors() + " cores use " + THREADS);
        if (usePath) {
            MoreFeaturesInterface.initialize(options.moreFeatures);
        } else {
            MoreFeaturesInterface.initialize(null);
        }
        return options;
    }

    public static void main(String[] args) throws Exception {
        TreeApproxParser p;
        System.out.println(Arrays.toString(args));
        dataFormat = args[0];
        transformer = args[1];
        transformerArgs = new String[]{args[2], args[3], args[4]};
        TreeApproxDecoder.initializeTransformer();
        String[] parserArgs = Arrays.copyOfRange(args, 5, args.length);
        long start = System.currentTimeMillis();
        Options options = new Options(parserArgs);
        Runtime runtime = Runtime.getRuntime();
        THREADS = runtime.availableProcessors();
        if (options.cores < THREADS && options.cores > 0) {
            THREADS = options.cores;
        }
        DB.println("Found " + runtime.availableProcessors() + " cores use " + THREADS);
        MoreFeaturesInterface.initialize(options.moreFeatures);
        if (options.train) {
            p = new TreeApproxParser();
            p.options = options;
            MoreFeaturesInterface.l2i = p.l2i = new Long2Int(options.hsize);
            p.pipe = new TreeApproxPipe(options);
            Instances is = new Instances();
            Extractor.initFeatures();
            p.pipe.extractor = new Extractor[THREADS];
            DB.println("hsize " + options.hsize);
            DB.println("Use " + (options.featureCreation == 1 ? "multiplication" : "shift") + "-based feature creation function");
            int t = 0;
            while (t < THREADS) {
                p.pipe.extractor[t] = new Extractor(p.l2i, options.stack, options.featureCreation);
                ++t;
            }
            DB.println("Stacking " + options.stack);
            List<SentenceData09> sents = TreeApproxDecoder.readSentences(options.trainfile);
            ((TreeApproxPipe)p.pipe).createInstances(sents, is);
            p.params = new ParametersFloat(p.l2i.size());
            ((Parser)p).train(options, p.pipe, p.params, is, p.pipe.cl);
            p.writeModell(options, p.params, null, p.pipe.cl);
        }
        if (options.test) {
            p = new TreeApproxParser(options);
            MoreFeaturesInterface.l2i = p.l2i;
            DB.println("label only? " + options.label);
            ((Parser)p).out(options, p.pipe, p.params, false, options.label);
        }
        System.out.println();
        if (options.eval) {
            System.out.println("\nEVALUATION PERFORMANCE:");
            Evaluator e = new Evaluator();
            e.dataFormat = dataFormat;
            e.goldCoNLL = options.goldfile;
            e.sysCoNLL = options.outfile;
            e.run();
        }
        long end = System.currentTimeMillis();
        System.out.println("used time " + (float)((end - start) / 100L) / 10.0f);
        Decoder.executerService.shutdown();
        Pipe.executerService.shutdown();
        System.out.println("end.");
    }

    @Override
    public void train(OptionsSuper options, Pipe pipe, ParametersFloat params, Instances is, Cluster cluster) throws IOException, InterruptedException, ClassNotFoundException {
        DB.println("\nTraining Information ");
        DB.println("-------------------- ");
        Decoder.NON_PROJECTIVITY_THRESHOLD = (float)options.decodeTH;
        if (options.decodeProjective) {
            System.out.println("Decoding: " + (options.decodeProjective ? "projective" : "non-projective"));
        } else {
            System.out.println(Decoder.getInfo());
        }
        int numInstances = is.size();
        int maxLenInstances = 0;
        int i = 0;
        while (i < numInstances) {
            if (maxLenInstances < is.length(i)) {
                maxLenInstances = is.length(i);
            }
            ++i;
        }
        DataFES data = new DataFES(maxLenInstances, pipe.mf.getFeatureCounter().get("REL").shortValue());
        int iter = 0;
        int del = 0;
        float error = 0.0f;
        float f1 = 0.0f;
        FV pred = new FV();
        FV act = new FV();
        double upd = (double)(numInstances * options.numIters) + 1.0;
        while (iter < options.numIters) {
            System.out.print("Iteration " + iter + ": ");
            long start = System.currentTimeMillis();
            long last = System.currentTimeMillis();
            error = 0.0f;
            f1 = 0.0f;
            int n = 0;
            while (n < numInstances) {
                upd -= 1.0;
                if (is.labels[n].length <= options.maxLen) {
                    String info = " td " + (float)Decoder.timeDecotder / 1000000.0f + " tr " + (float)Decoder.timeRearrange / 1000000.0f + " te " + (float)Pipe.timeExtract / 1000000.0f;
                    if ((n + 1) % 500 == 0) {
                        del = PipeGen.outValueErr(n + 1, error, f1 / (float)n, del, last, upd, info);
                    }
                    short[] pos = is.pposs[n];
                    data = pipe.fillVector(params.getFV(), is, n, data, cluster);
                    Parse d = Decoder.decode(pos, data, options.decodeProjective, true);
                    double e = pipe.errors(is, n, d);
                    if (d.f1 > 0.0) {
                        f1 = (float)((double)f1 + d.f1);
                    }
                    if (!(e <= 0.0)) {
                        pred.clear();
                        pipe.extractor[0].encodeCat(is, n, pos, is.forms[n], is.plemmas[n], d.heads, d.labels, is.feats[n], pipe.cl, pred);
                        error = (float)((double)error + e);
                        params.getFV();
                        act.clear();
                        pipe.extractor[0].encodeCat(is, n, pos, is.forms[n], is.plemmas[n], is.heads[n], is.labels[n], is.feats[n], pipe.cl, act);
                        params.update(act, pred, is, n, d, upd, e);
                    }
                }
                ++n;
            }
            String info = " td " + (float)Decoder.timeDecotder / 1000000.0f + " tr " + (float)Decoder.timeRearrange / 1000000.0f + " te " + (float)Pipe.timeExtract / 1000000.0f + " nz " + params.countNZ();
            PipeGen.outValueErr(numInstances, error, f1 / (float)numInstances, del, last, upd, info);
            del = 0;
            long end = System.currentTimeMillis();
            System.out.println(" time:" + (end - start));
            ParametersFloat pf = params.average2((iter + 1) * is.size());
            try {
                if (options.testfile != null && options.goldfile != null) {
                    this.out(options, pipe, pf, false, false);
                    Evaluator e = new Evaluator();
                    e.dataFormat = dataFormat;
                    e.goldCoNLL = options.goldfile;
                    e.sysCoNLL = options.outfile;
                    e.run();
                }
            }
            catch (Exception e) {
                e.printStackTrace();
            }
            if (error == 0.0f) {
                DB.println("stopped because learned all lessons");
                break;
            }
            Decoder.timeDecotder = 0L;
            Decoder.timeRearrange = 0L;
            Pipe.timeExtract = 0L;
            ++iter;
        }
        if (options.average) {
            params.average(iter * is.size());
        }
    }

    @Override
    public void out(OptionsSuper options, Pipe pipe, ParametersFloat params, boolean maxInfo, boolean labelOnly) throws Exception {
        long start = System.currentTimeMillis();
        List<SentenceData09> testSents = TreeApproxDecoder.readSentences(options.testfile);
        ArrayList<SentenceData09> outSents = new ArrayList<SentenceData09>();
        int cnt = 0;
        int del = 0;
        long last = System.currentTimeMillis();
        if (maxInfo) {
            System.out.println("\nParsing Information ");
        }
        if (maxInfo) {
            System.out.println("------------------- ");
        }
        if (maxInfo && !options.decodeProjective) {
            System.out.println(Decoder.getInfo());
        }
        System.out.print("Processing Sentence: ");
        MoreFeaturesInterface.startReadingAllTest();
        for (SentenceData09 instance : testSents) {
            SentenceData09 i09 = this.parse(instance, params, labelOnly, options);
            outSents.add(i09);
            del = PipeGen.outValue(++cnt, del, last);
        }
        this.writeAllSentences(outSents, options.outfile);
        long end = System.currentTimeMillis();
        if (maxInfo) {
            System.out.println("Used time " + (end - start));
        }
        if (maxInfo) {
            System.out.println("forms count " + Instances.m_count + " unkown " + Instances.m_unkown);
        }
    }

    private void writeAllSentences(List<SentenceData09> outSents, String outfile) {
        CCGPARGWriter writer = new CCGPARGWriter(outfile);
        int failed = 0;
        int all = 0;
        for (SentenceData09 sent : outSents) {
            ++all;
            PredicateArgumentAdjunctDAG tree = null;
            try {
                tree = PredicateArgumentAdjunctDAG.buildTreeFromHeadsAndLabels(sent.pheads, sent.plabels);
                PredicateArgumentAdjunctDAG dag = dag2tree.toDAG(tree);
                this.writeOneSentence(dag, sent, writer);
            }
            catch (Exception e) {
                e.printStackTrace();
                ++failed;
                this.writeOneSentence(new PredicateArgumentAdjunctDAG(sent.forms.length), sent, writer);
            }
        }
        writer.close();
        System.out.println(String.valueOf(failed) + "/" + all + " failed.");
    }

    private void writeOneSentence(PredicateArgumentAdjunctDAG dag, SentenceData09 sent, CCGPARGWriter writer) {
        if (dataFormat.equalsIgnoreCase("sdp")) {
            writer.printOneSentenceSDP(dag, Arrays.asList(sent.forms), Arrays.asList(sent.lemmas), Arrays.asList(sent.gpos), "#00000000");
        } else if (dataFormat.equalsIgnoreCase("conll08")) {
            writer.printOneSentenceCoNLL08(dag, Arrays.asList(sent.forms), Arrays.asList(sent.gpos));
        } else if (dataFormat.equals("sdp15")) {
            writer.printOneSentenceSDP15(dag, Arrays.asList(sent.forms), Arrays.asList(sent.lemmas), Arrays.asList(sent.gpos), null, "#00000000");
        }
    }

    private static List<SentenceData09> readSentences(String trainfile) {
        DAGSentenceReader reader = null;
        if (dataFormat.equalsIgnoreCase("conll08")) {
            reader = DAGSentenceReader.dagsReaderFromCoNLL08(trainfile, false);
        } else if (dataFormat.equalsIgnoreCase("sdp")) {
            reader = DAGSentenceReader.dagsReaderFromSDP(trainfile);
        } else if (dataFormat.equalsIgnoreCase("sdp15")) {
            reader = DAGSentenceReader.dagsReaderFromSDP15(trainfile);
        }
        ArrayList<SentenceData09> sents = new ArrayList<SentenceData09>();
        for (SentenceForDAGParsing sent : reader) {
            PredicateArgumentAdjunctDAG tree = dag2tree.toTree(sent.getGoldDAG());
            List<String> lemmas = sent.lemmas() == null ? null : Arrays.asList(sent.lemmas());
            String[][] conll09 = CCGPARGWriter.toCoNLL(null, tree, Arrays.asList(sent.words()), lemmas, Arrays.asList(sent.tags()), false, CCGPARGWriter.CoNLLFormat.CONLL09);
            sents.add(TreeApproxDecoder.buildSentenceData09FromCoNLL09(conll09));
        }
        return sents;
    }

    private static SentenceData09 dagToTreeToSentenceData09(SentenceForDAGParsing sent) {
        PredicateArgumentAdjunctDAG tree = dag2tree.toTree(sent.getGoldDAG());
        List<String> lemmas = sent.lemmas() == null ? null : Arrays.asList(sent.lemmas());
        String[][] conll09 = CCGPARGWriter.toCoNLL(null, tree, Arrays.asList(sent.words()), lemmas, Arrays.asList(sent.tags()), false, CCGPARGWriter.CoNLLFormat.CONLL09);
        return TreeApproxDecoder.buildSentenceData09FromCoNLL09(conll09);
    }

    private static SentenceData09 buildSentenceData09FromCoNLL09(String[][] conll09) {
        if (conll09 == null) {
            return null;
        }
        int length = conll09[0].length;
        SentenceData09 it = new SentenceData09();
        it.forms = new String[length + 1];
        it.plemmas = new String[length + 1];
        it.gpos = new String[length + 1];
        it.labels = new String[length + 1];
        it.heads = new int[length + 1];
        it.pheads = new int[length + 1];
        it.plabels = new String[length + 1];
        it.ppos = new String[length + 1];
        it.lemmas = new String[length + 1];
        it.fillp = new String[length + 1];
        it.feats = new String[length + 1][];
        it.ofeats = new String[length + 1];
        it.pfeats = new String[length + 1];
        it.id = new String[length + 1];
        it.forms[0] = "<root>";
        it.plemmas[0] = "<root-LEMMA>";
        it.fillp[0] = "N";
        it.lemmas[0] = "<root-LEMMA>";
        it.gpos[0] = "<root-POS>";
        it.ppos[0] = "<root-POS>";
        it.labels[0] = "<no-type>";
        it.heads[0] = -1;
        it.plabels[0] = "<no-type>";
        it.pheads[0] = -1;
        it.ofeats[0] = "<no-type>";
        it.id[0] = "0";
        int i = 1;
        while (i <= length) {
            String[] info = new String[conll09.length];
            int j = 0;
            while (j < conll09.length) {
                info[j] = conll09[j][i - 1];
                ++j;
            }
            it.id[i] = info[0];
            it.forms[i] = info[1];
            if (info.length >= 3) {
                it.lemmas[i] = info[2];
                it.plemmas[i] = info[3];
                it.gpos[i] = info[4];
                if (info.length >= 5) {
                    it.ppos[i] = info[5];
                    String string = it.ofeats[i] = info[6].equals("_") ? "_" : info[6];
                    if (info[7].equals("_")) {
                        it.feats[i] = null;
                    } else {
                        it.feats[i] = info[7].split("\\|");
                        it.pfeats[i] = info[7];
                    }
                    if (info[8].equals("_")) {
                        it.heads[i] = -1;
                        System.err.println("Not a tree!");
                        System.out.println(Arrays.toString(conll09[1]));
                        System.exit(-1);
                    } else {
                        it.heads[i] = Integer.parseInt(info[8]);
                    }
                    it.pheads[i] = info[9].equals("_") ? (it.pheads[i] = -1) : Integer.parseInt(info[9]);
                    it.labels[i] = info[10];
                    it.plabels[i] = info[11];
                    it.fillp[i] = info[12];
                    if (info.length > 13) {
                        if (!info[13].equals("_")) {
                            it.addPredicate(i, info[13]);
                        }
                        int k = 14;
                        while (k < info.length) {
                            it.addArgument(i, k - 14, info[k]);
                            ++k;
                        }
                    }
                }
            }
            ++i;
        }
        return it;
    }

    public SentenceForDAGParsing toDAGSentence(SentenceForDAGParsing sent, int[][] var) {
        int sentLen = sent.numOfWords();
        PredicateArgumentAdjunctDAG sys = new PredicateArgumentAdjunctDAG(sentLen);
        int i = 0;
        while (i <= sentLen) {
            int j = 0;
            while (j <= sentLen) {
                if (var[i][j] == 1) {
                    sys.addArc(i, j, "X");
                }
                ++j;
            }
            ++i;
        }
        sent.setPredictedDAG(sys);
        return sent;
    }

    public int[][] toVariable(SentenceForDAGParsing sent) {
        int sentLen = sent.numOfWords();
        int[][] var = new int[sentLen + 1][sentLen + 1];
        PredicateArgumentAdjunctDAG dag = sent.getGoldDAG();
        for (Pair<Integer, Integer> arc : dag.toUnlabeledPairs()) {
            var[arc.getFirst().intValue()][arc.getSecond().intValue()] = 1;
        }
        return var;
    }

    public void scoreFeats() {
        try {
            this.d2 = this.pipe.fillVector(this.params.getFV(), this.is, 0, null, this.pipe.cl);
        }
        catch (InterruptedException e) {
            throw new RuntimeException(e);
        }
    }

    public int[][] decodeAfterScoring(double[][] additionalWeight) {
        try {
            this.d = Decoder.decode(this.is.pposs[0], this.d2, TreeApproxDecoder.statOps.decodeProjective, false, additionalWeight);
        }
        catch (InterruptedException e) {
            throw new RuntimeException(e);
        }
        SentenceData09 i09 = new SentenceData09(this.curInstance);
        int j = 0;
        while (j < this.curInstance.forms.length - 1) {
            i09.plabels[j] = types[this.d.labels[j + 1]];
            i09.pheads[j] = this.d.heads[j + 1];
            ++j;
        }
        PredicateArgumentAdjunctDAG tree = PredicateArgumentAdjunctDAG.buildTreeFromHeadsAndLabels(i09.pheads, i09.plabels);
        PredicateArgumentAdjunctDAG dag = dag2tree.toDAG(tree);
        int sentLen = dag.sentenceLength();
        int[][] var = new int[sentLen + 1][sentLen + 1];
        for (Pair<Integer, Integer> arc : dag.toUnlabeledPairs()) {
            var[arc.getFirst().intValue()][arc.getSecond().intValue()] = 1;
        }
        return var;
    }

    public SentenceForDAGParsing decode(SentenceForDAGParsing s) {
        this.curInstance = TreeApproxDecoder.dagToTreeToSentenceData09(s);
        this.is = new Instances();
        this.is.init(1, new MFO(), TreeApproxDecoder.statOps.formatTask);
        new CONLLReader09().insert(this.is, this.curInstance);
        MoreFeaturesInterface.createMoreFeatures(this.is, this.curInstance);
        this.scoreFeats();
        int n = s.numOfWords();
        int[][] var = this.decodeAfterScoring(new double[n + 1][n + 1]);
        return this.toDAGSentence(s, var);
    }
}

