/*
 * Decompiled with CFR 0.152.
 */
package tsg;

import java.io.File;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.BitSet;
import java.util.HashMap;
import java.util.Hashtable;
import java.util.Map;
import java.util.Scanner;
import kernels.NodeSetCollector;
import kernels.NodeSetCollectorSimple;
import tsg.Label;
import tsg.TSNodeLabel;
import tsg.TSNodeLabelIndex;
import tsg.TSNodeLabelStructure;
import util.FileUtil;
import util.Utility;

public class DOP1_reranker {
    public static HashMap<TSNodeLabel, Double> fragmentTableFreq;
    public static Hashtable<Label, double[]> rootTableFreq;
    public static boolean allowUnknownCFG;

    static {
        allowUnknownCFG = true;
    }

    public static ArrayList<TSNodeLabelIndex> nextNBest(int nBest, Scanner s) throws Exception {
        ArrayList<TSNodeLabelIndex> result = new ArrayList<TSNodeLabelIndex>(nBest);
        int count = 0;
        while (s.hasNextLine() && count < nBest) {
            String line = s.nextLine();
            if (line.equals("")) {
                return result;
            }
            TSNodeLabelIndex tree = new TSNodeLabelIndex(line);
            result.add(tree);
            ++count;
        }
        while (s.hasNextLine() && !s.nextLine().equals("")) {
        }
        return result;
    }

    public static void readFragmentsFile(File fragmentFile) throws Exception {
        System.out.println("Reading fragments from: " + fragmentFile.toString());
        fragmentTableFreq = new HashMap();
        Scanner scan = FileUtil.getScanner(fragmentFile);
        int countFragments = 0;
        while (scan.hasNextLine()) {
            String line = scan.nextLine();
            if (line.equals("")) continue;
            ++countFragments;
            String[] fragmentFreq = line.split("\t");
            String fragmentString = fragmentFreq[0];
            Double freq = Double.parseDouble(fragmentFreq[1]);
            fragmentString = fragmentString.replaceAll("\\\\", "");
            TSNodeLabel fragment = new TSNodeLabel(fragmentString, false);
            fragmentTableFreq.put(fragment, freq);
        }
        System.out.println("Read " + countFragments + " fragments");
        scan.close();
    }

    public static void addCFGfragments(File trainingCorpus) throws Exception {
        ArrayList<TSNodeLabel> corpus = TSNodeLabel.getTreebank(trainingCorpus);
        Hashtable ruleTable = new Hashtable();
        for (TSNodeLabel t : corpus) {
            t.addTop();
            ArrayList<TSNodeLabel> nodes = t.collectAllNodes();
            for (TSNodeLabel n : nodes) {
                if (n.isLexical) continue;
                String rule = n.cfgRule();
                Utility.increaseInTableInt(ruleTable, rule);
            }
        }
        System.out.println("Read " + ruleTable.size() + " CFG fragments");
        int kept = 0;
        for (Map.Entry e : ruleTable.entrySet()) {
            TSNodeLabel ruleFragment = new TSNodeLabel("( " + (String)e.getKey() + ")", false);
            if (fragmentTableFreq.containsKey(ruleFragment)) continue;
            double freq = ((int[])e.getValue())[0];
            fragmentTableFreq.put(ruleFragment, freq);
            ++kept;
        }
        System.out.println("Added " + kept + " CFG fragments");
    }

    public static void getRootFreq() {
        rootTableFreq = new Hashtable();
        for (Map.Entry<TSNodeLabel, Double> e : fragmentTableFreq.entrySet()) {
            Label rootLabel = e.getKey().label;
            double freq = e.getValue();
            Utility.increaseInTableDoubleArray(rootTableFreq, rootLabel, freq);
        }
        System.out.println("Built root freq. table: " + rootTableFreq.size() + " entries.");
    }

    private static double getParseTreeProb(TSNodeLabelIndex t) {
        NodeSetCollectorSimple setCollector = new NodeSetCollectorSimple();
        HashMap<BitSet, Double> bitSetFreqTable = new HashMap<BitSet, Double>();
        for (Map.Entry<TSNodeLabel, Double> e : fragmentTableFreq.entrySet()) {
            DOP1_reranker.getCFGSetCoveringFragment(t, e.getKey(), e.getValue(), setCollector, bitSetFreqTable);
        }
        BitSet union = setCollector.uniteSubGraphs();
        ArrayList<TSNodeLabel> nonLexicalNodes = t.collectNonLexicalNodes();
        BitSet preLexNonCovered = new BitSet();
        if (allowUnknownCFG) {
            for (TSNodeLabel nlN : nonLexicalNodes) {
                TSNodeLabelIndex nlNI = (TSNodeLabelIndex)nlN;
                int index = nlNI.index;
                if (union.get(index)) continue;
                BitSet set = new BitSet();
                set.set(index);
                setCollector.add(set);
                bitSetFreqTable.put(set, 1.0);
            }
        } else {
            for (TSNodeLabel nlN : nonLexicalNodes) {
                TSNodeLabelIndex nlNI = (TSNodeLabelIndex)nlN;
                int index = nlNI.index;
                if (nlNI.isPreLexical()) {
                    if (union.get(index)) continue;
                    preLexNonCovered.set(index);
                    continue;
                }
                if (union.get(index)) continue;
                return -1.0;
            }
        }
        TSNodeLabelStructure tStructure = new TSNodeLabelStructure(t);
        ProbChart pc = new ProbChart(setCollector, tStructure, preLexNonCovered, bitSetFreqTable);
        return pc.getProb();
    }

    private static void getCFGSetCoveringFragment(TSNodeLabelIndex t, TSNodeLabel fragment, double fragmentFreq, NodeSetCollector setCollector, HashMap<BitSet, Double> bitSetFreqLogTable) {
        BitSet set;
        if (t.isLexical) {
            return;
        }
        if (t.sameLabel(fragment) && DOP1_reranker.getCFGSetCoveringFragmentNonRecursive(t, fragment, set = new BitSet()) && !set.isEmpty()) {
            setCollector.add(set);
            bitSetFreqLogTable.put(set, fragmentFreq);
        }
        TSNodeLabel[] tSNodeLabelArray = t.daughters;
        int n = t.daughters.length;
        int n2 = 0;
        while (n2 < n) {
            TSNodeLabel d = tSNodeLabelArray[n2];
            TSNodeLabelIndex di = (TSNodeLabelIndex)d;
            DOP1_reranker.getCFGSetCoveringFragment(di, fragment, fragmentFreq, setCollector, bitSetFreqLogTable);
            ++n2;
        }
    }

    private static boolean getCFGSetCoveringFragmentNonRecursive(TSNodeLabelIndex t, TSNodeLabel fragment, BitSet set) {
        if (t.isLexical || fragment.isTerminal()) {
            return true;
        }
        if (!t.sameDaughtersLabel(fragment)) {
            return false;
        }
        int prole = t.prole();
        int i = 0;
        while (i < prole) {
            TSNodeLabel thisDaughter = t.daughters[i];
            TSNodeLabelIndex thisDaughterIndex = (TSNodeLabelIndex)thisDaughter;
            TSNodeLabel otherDaughter = fragment.daughters[i];
            if (!DOP1_reranker.getCFGSetCoveringFragmentNonRecursive(thisDaughterIndex, otherDaughter, set)) {
                return false;
            }
            ++i;
        }
        set.set(t.index);
        return true;
    }

    public static void rerank(int nBest) throws Exception {
        throw new Error("Unresolved compilation problem: \n\tThe method staticEvalF(File, File, File, boolean) is undefined for the type EvalC\n");
    }

    public static void rerankEM(int nBest, int cycle) throws Exception {
        throw new Error("Unresolved compilation problem: \n\tThe method staticEvalF(File, File, File, boolean) is undefined for the type EvalC\n");
    }

    public static void main1(String[] args) throws Exception {
        int[] nBest;
        int[] nArray = nBest = new int[]{5, 10, 100};
        int n = nBest.length;
        int n2 = 0;
        while (n2 < n) {
            int n3 = nArray[n2];
            DOP1_reranker.rerank(n3);
            ++n2;
        }
    }

    public static void main(String[] args) throws Exception {
        allowUnknownCFG = true;
        DOP1_reranker.rerankEM(10, 1);
        DOP1_reranker.rerankEM(100, 1);
    }

    static class ProbChart {
        NodeSetCollectorSimple setCollector;
        TSNodeLabelStructure t;
        int totalNodes;
        BitSet preLexNonCovered;
        double[] probNodes;
        NodeSetCollectorSimple[] nodesCollector;
        HashMap<BitSet, Double> bitSetFreqTable;

        public ProbChart(NodeSetCollectorSimple setCollector, TSNodeLabelStructure t, BitSet preLexNonCovered, HashMap<BitSet, Double> bitSetFreqTable) {
            this.setCollector = setCollector;
            this.t = t;
            this.preLexNonCovered = preLexNonCovered;
            this.totalNodes = t.length;
            this.probNodes = new double[this.totalNodes];
            Arrays.fill(this.probNodes, -1.0);
            this.nodesCollector = new NodeSetCollectorSimple[this.totalNodes];
            this.bitSetFreqTable = bitSetFreqTable;
        }

        public double getProb() {
            for (BitSet bs : this.setCollector.bitSetSet) {
                int firstIndex = bs.nextSetBit(0);
                if (this.nodesCollector[firstIndex] == null) {
                    this.nodesCollector[firstIndex] = new NodeSetCollectorSimple();
                }
                this.nodesCollector[firstIndex].add(bs);
            }
            return this.getProbRecursive(0);
        }

        private double getProbRecursive(int index) {
            if (this.probNodes[index] != -1.0) {
                return this.probNodes[index];
            }
            NodeSetCollectorSimple setCollector = this.nodesCollector[index];
            if (setCollector == null) {
                this.probNodes[index] = 0.0;
                return 0.0;
            }
            TSNodeLabelIndex root = this.t.structure[index];
            double rootFreq = rootTableFreq.get(root.label)[0];
            double prob = 0.0;
            for (BitSet initialSubTree : setCollector.bitSetSet) {
                ArrayList<Integer> subSitesIndexes = new ArrayList<Integer>();
                this.collectSubSites(root, initialSubTree, subSitesIndexes);
                double partialProb = 1.0;
                for (int subSiteIndex : subSitesIndexes) {
                    double subSiteProb = this.getProbRecursive(subSiteIndex);
                    if (subSiteProb == 0.0) {
                        partialProb = 0.0;
                        break;
                    }
                    partialProb *= subSiteProb;
                }
                if (partialProb == 0.0) continue;
                double initialSubTreeFreq = this.bitSetFreqTable.get(initialSubTree);
                prob += (partialProb *= initialSubTreeFreq / rootFreq);
            }
            this.probNodes[index] = prob;
            return this.probNodes[index];
        }

        private void collectSubSites(TSNodeLabelIndex root, BitSet initialSubTree, ArrayList<Integer> subSitesIndexes) {
            TSNodeLabel[] tSNodeLabelArray = root.daughters;
            int n = root.daughters.length;
            int n2 = 0;
            while (n2 < n) {
                TSNodeLabel d = tSNodeLabelArray[n2];
                if (d.isLexical) {
                    return;
                }
                TSNodeLabelIndex di = (TSNodeLabelIndex)d;
                int index = di.index;
                if (!this.preLexNonCovered.get(index)) {
                    if (!initialSubTree.get(index)) {
                        subSitesIndexes.add(index);
                    } else {
                        this.collectSubSites(di, initialSubTree, subSitesIndexes);
                    }
                }
                ++n2;
            }
        }
    }
}

