/*
 * Decompiled with CFR 0.152.
 */
package tsg.parsingExp;

import java.io.File;
import java.io.PrintWriter;
import java.math.BigDecimal;
import java.math.BigInteger;
import java.math.MathContext;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.BitSet;
import java.util.Hashtable;
import java.util.Iterator;
import java.util.Map;
import java.util.Scanner;
import java.util.Vector;
import settings.Parameters;
import tsg.TSNodeLabel;
import tsg.TSNodeLabelIndex;
import tsg.TSNodeLabelStructure;
import util.ArgumentReader;
import util.FileUtil;
import util.PrintProgress;
import util.Utility;

public class RetrieveBonnemaCounts
extends Thread {
    public static int numberOfTreesPerThreads = 10;
    static boolean debug = false;
    File fragmentsFile;
    File outputFile;
    int threads = 1;
    long fragmentReadCounter;
    long fragmentWrittenCounter;
    long currentIndex;
    PrintProgress progress;
    ArrayList<TSNodeLabelIndex> treebank;
    ArrayList<TSNodeLabel> fragmentList;
    Hashtable<TSNodeLabel, double[]> finalFragmentsCount;
    Iterator<TSNodeLabelIndex> treeIterator;
    BigInteger minDerivations;
    BigInteger maxDerivations;

    public RetrieveBonnemaCounts(ArrayList<TSNodeLabelIndex> treebank, File fragmentsFile, File outputFile, int threads) {
        this.treebank = treebank;
        this.fragmentsFile = fragmentsFile;
        this.outputFile = outputFile;
        this.threads = threads;
        this.treeIterator = treebank.iterator();
        this.minDerivations = new BigInteger("9999999999");
        this.maxDerivations = new BigInteger("-1");
    }

    public RetrieveBonnemaCounts(ArrayList<TSNodeLabelIndex> treebank, ArrayList<TSNodeLabel> fragments) {
        this.treebank = treebank;
        this.fragmentList = fragments;
        this.treeIterator = treebank.iterator();
        this.minDerivations = new BigInteger("9999999999");
        this.maxDerivations = new BigInteger("-1");
    }

    @Override
    public void run() {
        try {
            this.getFragmentList();
        }
        catch (Exception e) {
            e.printStackTrace();
            return;
        }
        this.retriveBonnemaCounts();
        this.writeFragmentsToFile();
    }

    private void getFragmentList() throws Exception {
        Parameters.reportLine("Extracting fragments from file: " + this.fragmentsFile);
        this.progress = new PrintProgress("Progress:", 10000, 0);
        Scanner fragmentsScanner = FileUtil.getScanner(this.fragmentsFile);
        this.fragmentList = new ArrayList();
        while (fragmentsScanner.hasNextLine()) {
            this.progress.next();
            String line = fragmentsScanner.nextLine();
            if (line.equals("")) continue;
            String[] lineSplit = line.split("\t");
            TSNodeLabel fragment = new TSNodeLabel(lineSplit[0], false);
            this.fragmentList.add(fragment);
        }
        this.progress.end();
        Parameters.reportLine("Extracted fragments: " + this.fragmentList.size());
    }

    public void retriveBonnemaCounts() {
        this.finalFragmentsCount = new Hashtable();
        Parameters.reportLineFlush("Retrieving Goodman Counts");
        this.progress = new PrintProgress("Extracting from tree:", 100, 0);
        try {
            if (this.threads == 1) {
                this.updateTableWithFragmCounts(this.finalFragmentsCount, this.treebank);
            } else {
                this.startMultiThreads();
            }
        }
        catch (Exception e) {
            e.printStackTrace();
            return;
        }
        this.progress.end();
    }

    private void updateTableWithFragmCounts(Hashtable<TSNodeLabel, double[]> tableToUpdate, ArrayList<TSNodeLabelIndex> trees) {
        int treeindex = 0;
        for (TSNodeLabelIndex tree : trees) {
            ++treeindex;
            TSNodeLabelStructure treeStructure = new TSNodeLabelStructure(tree);
            int length = treeStructure.length();
            Vector<Hashtable<TSNodeLabel, BitSet>> nodeStartingFragments = new Vector<Hashtable<TSNodeLabel, BitSet>>(length);
            int i = 0;
            while (i < length) {
                nodeStartingFragments.add(new Hashtable());
                ++i;
            }
            for (TSNodeLabel fragment : this.fragmentList) {
                ArrayList<BitSet> rootFrontiers = this.getRootFrontierIndexes(treeStructure, fragment);
                if (rootFrontiers == null) continue;
                for (BitSet rf : rootFrontiers) {
                    int root = rf.nextSetBit(0);
                    rf.clear(root);
                    ((Hashtable)nodeStartingFragments.get(root)).put(fragment, rf);
                }
            }
            BigInteger[] nodeStartingDerivations = new BigInteger[length];
            BigInteger[] nodeArrivingDerivations = new BigInteger[length];
            int i2 = 0;
            while (i2 < length) {
                nodeStartingDerivations[i2] = BigInteger.ZERO;
                nodeArrivingDerivations[i2] = BigInteger.ZERO;
                ++i2;
            }
            nodeArrivingDerivations[0] = BigInteger.ONE;
            BitSet lexIndex = new BitSet();
            ArrayList<TSNodeLabel> lexArray = tree.collectLexicalItems();
            for (TSNodeLabel l : lexArray) {
                lexIndex.set(((TSNodeLabelIndex)l).index);
            }
            int[] nodesFromLeavesToTop = Utility.countDownArray(length);
            int[] nodesFromTopToLeaves = Utility.countUpArray(length);
            RetrieveBonnemaCounts.updateDerivationsStarting(nodeStartingFragments, nodeStartingDerivations, lexIndex, nodesFromLeavesToTop);
            RetrieveBonnemaCounts.updateDerivationsArriving(nodeStartingFragments, nodeArrivingDerivations, lexIndex, nodesFromTopToLeaves);
            BigInteger totalDerivationsInTree = nodeStartingDerivations[0];
            if (totalDerivationsInTree.compareTo(BigInteger.ZERO) <= 0) {
                System.err.println("Tot der = " + totalDerivationsInTree + ", lex length = " + lexIndex.cardinality() + " tree index: " + treeindex);
            }
            if (totalDerivationsInTree.compareTo(this.minDerivations) < 0) {
                this.minDerivations = totalDerivationsInTree;
            }
            if (totalDerivationsInTree.compareTo(this.maxDerivations) > 0) {
                this.maxDerivations = totalDerivationsInTree;
            }
            int rootIndex = 0;
            for (Hashtable<TSNodeLabel, BitSet> startingFragments : nodeStartingFragments) {
                for (Map.Entry<TSNodeLabel, BitSet> e : startingFragments.entrySet()) {
                    BigInteger derivationsWithFragment = nodeArrivingDerivations[rootIndex];
                    TSNodeLabel fragment = e.getKey();
                    BitSet frontiers = e.getValue();
                    int frontierIndex = frontiers.nextSetBit(0);
                    do {
                        if (!lexIndex.get(frontierIndex)) {
                            derivationsWithFragment = derivationsWithFragment.multiply(nodeStartingDerivations[frontierIndex]);
                        }
                        ++frontierIndex;
                    } while ((frontierIndex = frontiers.nextSetBit(frontierIndex)) != -1);
                    double ratio = new BigDecimal(derivationsWithFragment).divide(new BigDecimal(totalDerivationsInTree), MathContext.DECIMAL128).doubleValue();
                    Utility.increaseInTableDoubleArray(tableToUpdate, fragment, ratio);
                }
                ++rootIndex;
            }
        }
    }

    private static void updateDerivationsStarting(Vector<Hashtable<TSNodeLabel, BitSet>> nodeStartingFragments, BigInteger[] nodeStartingDerivations, BitSet lexIndex, int[] nodesFromLeavesToTop) {
        int[] nArray = nodesFromLeavesToTop;
        int n = nodesFromLeavesToTop.length;
        int n2 = 0;
        while (n2 < n) {
            int nodeIndex = nArray[n2];
            if (!lexIndex.get(nodeIndex)) {
                Hashtable<TSNodeLabel, BitSet> startingFragments = nodeStartingFragments.get(nodeIndex);
                for (Map.Entry<TSNodeLabel, BitSet> e : startingFragments.entrySet()) {
                    BitSet frontiers = e.getValue();
                    BigInteger startingDerivations = BigInteger.ONE;
                    int frontierIndex = frontiers.nextSetBit(0);
                    do {
                        if (!lexIndex.get(frontierIndex)) {
                            startingDerivations = startingDerivations.multiply(nodeStartingDerivations[frontierIndex]);
                        }
                        ++frontierIndex;
                    } while ((frontierIndex = frontiers.nextSetBit(frontierIndex)) != -1);
                    nodeStartingDerivations[nodeIndex] = nodeStartingDerivations[nodeIndex].add(startingDerivations);
                }
            }
            ++n2;
        }
    }

    private static void updateDerivationsArriving(Vector<Hashtable<TSNodeLabel, BitSet>> nodeStartingFragments, BigInteger[] nodeArrivingDerivations, BitSet lexIndex, int[] nodesFromTopToLeaves) {
        int[] nArray = nodesFromTopToLeaves;
        int n = nodesFromTopToLeaves.length;
        int n2 = 0;
        while (n2 < n) {
            int nodeIndex = nArray[n2];
            if (!lexIndex.get(nodeIndex)) {
                BigInteger arrivingDerivationsRoot = nodeArrivingDerivations[nodeIndex];
                Hashtable<TSNodeLabel, BitSet> startingFragments = nodeStartingFragments.get(nodeIndex);
                for (Map.Entry<TSNodeLabel, BitSet> e : startingFragments.entrySet()) {
                    BitSet frontiers = e.getValue();
                    int frontierIndex = frontiers.nextSetBit(0);
                    do {
                        if (!lexIndex.get(frontierIndex)) {
                            nodeArrivingDerivations[frontierIndex] = nodeArrivingDerivations[frontierIndex].add(arrivingDerivationsRoot);
                        }
                        ++frontierIndex;
                    } while ((frontierIndex = frontiers.nextSetBit(frontierIndex)) != -1);
                }
            }
            ++n2;
        }
    }

    private ArrayList<BitSet> getRootFrontierIndexes(TSNodeLabelStructure treeStructure, TSNodeLabel fragment) {
        ArrayList<BitSet> result = new ArrayList<BitSet>();
        TSNodeLabelIndex[] tSNodeLabelIndexArray = treeStructure.structure;
        int n = treeStructure.structure.length;
        int n2 = 0;
        while (n2 < n) {
            TSNodeLabelIndex treeIndex = tSNodeLabelIndexArray[n2];
            if (fragment.sameLabelAndDaughersLabels(treeIndex)) {
                BitSet rf = new BitSet();
                rf.set(treeIndex.index);
                boolean present = true;
                int i = 0;
                while (i < treeIndex.prole()) {
                    if (!this.addFrontierIndexes((TSNodeLabelIndex)treeIndex.daughters[i], fragment.daughters[i], rf)) {
                        present = false;
                        break;
                    }
                    ++i;
                }
                if (present && rf.cardinality() > 1) {
                    result.add(rf);
                }
            }
            ++n2;
        }
        return result;
    }

    private boolean addFrontierIndexes(TSNodeLabelIndex treeIndex, TSNodeLabel fragment, BitSet rf) {
        boolean terminal = fragment.isTerminal();
        if (!terminal) {
            if (!fragment.sameLabelAndDaughersLabels(treeIndex)) {
                return false;
            }
        } else if (fragment.sameLabel(treeIndex)) {
            rf.set(treeIndex.index);
            return true;
        }
        int i = 0;
        while (i < treeIndex.prole()) {
            if (!this.addFrontierIndexes((TSNodeLabelIndex)treeIndex.daughters[i], fragment.daughters[i], rf)) {
                return false;
            }
            ++i;
        }
        return true;
    }

    private void startMultiThreads() throws Exception {
        CountFragmentsThread[] threadsArray = new CountFragmentsThread[this.threads];
        int lastThreadIndex = this.threads - 1;
        int t = 0;
        while (t < this.threads) {
            CountFragmentsThread newCounterThread;
            threadsArray[t] = newCounterThread = new CountFragmentsThread();
            if (t == lastThreadIndex) {
                newCounterThread.run();
            } else {
                newCounterThread.start();
            }
            ++t;
        }
        int i = 0;
        while (i < lastThreadIndex) {
            try {
                threadsArray[i].join();
            }
            catch (InterruptedException e) {
                e.printStackTrace();
            }
            ++i;
        }
    }

    private void writeFragmentsToFile() {
        Parameters.reportLine("Printing fragments with new counts to file: " + this.outputFile);
        PrintWriter pw = FileUtil.getPrintWriter(this.outputFile);
        for (Map.Entry<TSNodeLabel, double[]> e : this.finalFragmentsCount.entrySet()) {
            TSNodeLabel fragment = e.getKey();
            double count = e.getValue()[0];
            pw.println(String.valueOf(fragment.toString(false, true)) + "\t" + count);
        }
        pw.close();
    }

    private synchronized ArrayList<TSNodeLabelIndex> getNextTreeLoad() {
        if (!this.treeIterator.hasNext()) {
            return null;
        }
        ArrayList<TSNodeLabelIndex> treesForThread = new ArrayList<TSNodeLabelIndex>(numberOfTreesPerThreads);
        int i = 0;
        while (this.treeIterator.hasNext()) {
            if (i == numberOfTreesPerThreads) break;
            treesForThread.add(this.treeIterator.next());
            ++i;
        }
        this.progress.next(i);
        return treesForThread;
    }

    private synchronized void addFragmentsToFinalTable(Hashtable<TSNodeLabel, double[]> threadFragmentCount) {
        for (Map.Entry<TSNodeLabel, double[]> e : threadFragmentCount.entrySet()) {
            TSNodeLabel fragment = e.getKey();
            Utility.increaseInTableDoubleArray(this.finalFragmentsCount, fragment, e.getValue()[0]);
        }
    }

    public static void main2(String[] args) throws Exception {
        debug = true;
        TSNodeLabelIndex tree = new TSNodeLabelIndex("(A ( B (D h i) (E l m)) (C (F n o) (G p q)))");
        ArrayList<TSNodeLabelIndex> treebank = new ArrayList<TSNodeLabelIndex>();
        treebank.add(tree);
        ArrayList<TSNodeLabel> fragmentsList = new ArrayList<TSNodeLabel>(Arrays.asList(new TSNodeLabel("(A B C)", false), new TSNodeLabel("(B D E)", false), new TSNodeLabel("(C F G)", false), new TSNodeLabel("(D \"h\" \"i\")", false), new TSNodeLabel("(E \"l\" \"m\")", false), new TSNodeLabel("(F \"n\" \"o\")", false), new TSNodeLabel("(G \"p\" \"q\")", false), new TSNodeLabel("(A B (C F G))", false), new TSNodeLabel("(B (D \"h\" \"i\") (E \"l\" \"m\"))", false)));
        RetrieveBonnemaCounts RCF = new RetrieveBonnemaCounts(treebank, fragmentsList);
        RCF.retriveBonnemaCounts();
    }

    public static void main1(String[] args) throws Exception {
        RetrieveBonnemaCounts.main1(new String[]{"tmp/Bonnema/trainingTreebank_UK_first100.mrg", "tmp/Bonnema/fragmentsAndCfgRules.txt", "tmp/Bonnema/fragmentsAndCfgRules_bonnema.txt"});
    }

    public static void main(String[] args) throws Exception {
        long time = System.currentTimeMillis();
        String usage = "USAGE: java RetrieveGoodamnCounts [-threads:1] treebankFile fragmentsFile outputFile";
        String threadsOption = "-threads:";
        int threads = 1;
        int length = args.length;
        if (length < 3 || length > 4) {
            System.err.println("Incorrect number of arguments");
            System.err.println(usage);
            return;
        }
        File treebankFile = null;
        File inputFile = null;
        File outputFile = null;
        String[] stringArray = args;
        int n = args.length;
        int n2 = 0;
        while (n2 < n) {
            String option = stringArray[n2];
            if (option.startsWith(threadsOption)) {
                threads = ArgumentReader.readIntOption(option);
            } else if (treebankFile == null) {
                treebankFile = new File(option);
            } else if (inputFile == null) {
                inputFile = new File(option);
            } else {
                outputFile = new File(option);
            }
            ++n2;
        }
        if (!treebankFile.exists()) {
            System.err.println("Treebank File does not exist");
            System.err.println(usage);
            return;
        }
        if (!inputFile.exists()) {
            System.err.println("Input File does not exist");
            System.err.println(usage);
            return;
        }
        ArrayList<TSNodeLabel> treebank = TSNodeLabel.getTreebank(treebankFile);
        ArrayList<TSNodeLabelIndex> treebankIndex = new ArrayList<TSNodeLabelIndex>();
        for (TSNodeLabel t : treebank) {
            treebankIndex.add(new TSNodeLabelIndex(t));
        }
        RetrieveBonnemaCounts RCF = new RetrieveBonnemaCounts(treebankIndex, inputFile, outputFile, threads);
        RCF.run();
        System.out.println("Min derivations: " + RCF.minDerivations);
        System.out.println("Max derivations: " + RCF.maxDerivations);
        System.out.println("Took: " + (System.currentTimeMillis() - time) / 1000L + "seconds.");
    }

    private class CountFragmentsThread
    extends Thread {
        private CountFragmentsThread() {
        }

        @Override
        public void run() {
            ArrayList treeLoad = null;
            while ((treeLoad = RetrieveBonnemaCounts.this.getNextTreeLoad()) != null) {
                Hashtable threadFragmentCount = new Hashtable();
                RetrieveBonnemaCounts.this.updateTableWithFragmCounts(threadFragmentCount, treeLoad);
                RetrieveBonnemaCounts.this.addFragmentsToFinalTable(threadFragmentCount);
            }
        }
    }
}

