/*
 * Decompiled with CFR 0.152.
 */
package tsg;

import java.io.File;
import java.io.PrintWriter;
import java.util.ArrayList;
import java.util.BitSet;
import java.util.HashMap;
import java.util.Hashtable;
import java.util.Iterator;
import java.util.Map;
import java.util.Scanner;
import java.util.TreeSet;
import kernels.NodeSetCollector;
import kernels.NodeSetCollectorSimple;
import kernels.NodeSetCollectorStandard;
import tsg.Label;
import tsg.TSNodeLabel;
import tsg.TSNodeLabelIndex;
import tsg.TSNodeLabelStructure;
import tsg.corpora.Wsj;
import util.FileUtil;
import util.PrintProgressStatic;
import util.Utility;

public class DOP_EM {
    public static Hashtable<TSNodeLabel, double[]> fragmentTableFreq;
    public static Hashtable<Label, double[]> rootTableFreq;
    public static ArrayList<TSNodeLabelIndex> trainingCorpus;
    public static int endCycle;
    public static String workingDir;

    static {
        endCycle = 10;
    }

    public static void readFragmentsFile(File fragmentFile) throws Exception {
        fragmentTableFreq = new Hashtable();
        Scanner scan = FileUtil.getScanner(fragmentFile);
        int countFragments = 0;
        while (scan.hasNextLine()) {
            String line = scan.nextLine();
            if (line.equals("")) continue;
            ++countFragments;
            String[] fragmentFreq = line.split("\t");
            String fragmentString = fragmentFreq[0];
            double freq = Integer.parseInt(fragmentFreq[1]);
            TSNodeLabel fragment = new TSNodeLabel(fragmentString, false);
            fragmentTableFreq.put(fragment, new double[]{freq});
        }
        System.out.println("Read " + countFragments + " fragments");
        scan.close();
    }

    private static void readTreeBank(ArrayList<TSNodeLabel> treebank) throws Exception {
        trainingCorpus = new ArrayList();
        for (TSNodeLabel t : treebank) {
            trainingCorpus.add(new TSNodeLabelIndex(t));
        }
    }

    public static void printFragmentFreq(File outputFile) {
        PrintWriter pw = FileUtil.getPrintWriter(outputFile);
        for (Map.Entry<TSNodeLabel, double[]> e : fragmentTableFreq.entrySet()) {
            String fragmentString = e.getKey().toString(false, true);
            double freq = e.getValue()[0];
            pw.println(String.valueOf(fragmentString) + "\t" + freq);
        }
        pw.close();
    }

    /*
     * WARNING - void declaration
     */
    public static void addCFGfragments() throws Exception {
        void var1_4;
        Hashtable ruleTable = new Hashtable();
        for (TSNodeLabel tSNodeLabel : trainingCorpus) {
            ArrayList<TSNodeLabel> nodes = tSNodeLabel.collectAllNodes();
            for (TSNodeLabel n : nodes) {
                if (n.isLexical) continue;
                String rule = n.cfgRule();
                Utility.increaseInTableInt(ruleTable, rule);
            }
        }
        System.out.println("Read " + ruleTable.size() + " CFG fragments");
        boolean bl = false;
        for (Map.Entry e : ruleTable.entrySet()) {
            TSNodeLabel ruleFragment = new TSNodeLabel("( " + (String)e.getKey() + ")", false);
            if (fragmentTableFreq.containsKey(ruleFragment)) continue;
            double freq = ((int[])e.getValue())[0];
            fragmentTableFreq.put(ruleFragment, new double[]{freq});
            ++var1_4;
        }
        System.out.println("Added " + (int)var1_4 + " CFG fragments");
    }

    public static void getRootFreq() {
        rootTableFreq = new Hashtable();
        for (Map.Entry<TSNodeLabel, double[]> e : fragmentTableFreq.entrySet()) {
            Label rootLabel = e.getKey().label;
            double freq = e.getValue()[0];
            Utility.increaseInTableDoubleArray(rootTableFreq, rootLabel, freq);
        }
        System.out.println("Built root freq. table: " + rootTableFreq.size() + " entries.");
    }

    public static void runEM() {
        int cycle = 0;
        double previousLogLikelihood = -1.7976931348623157E308;
        BitSet structuresRemoved = new BitSet();
        DOP_EM.printFragmentFreq(new File(String.valueOf(workingDir) + "kernelsMUB_CFG_freq_EM_cycle_" + cycle + ".txt"));
        do {
            Hashtable<TSNodeLabel, double[]> newFragmentTableFreq = new Hashtable<TSNodeLabel, double[]>();
            double currentLogLikelihood = 1.0;
            PrintProgressStatic.start("Iterating Training Corpus:");
            int index = -1;
            for (TSNodeLabelIndex t : trainingCorpus) {
                PrintProgressStatic.next();
                if (structuresRemoved.get(++index)) continue;
                double prob = DOP_EM.updateNewFragmentTableFreq(t, newFragmentTableFreq);
                if (prob == 0.0) {
                    System.out.println("Zero prob. in sentence index: " + index + " (ignoring it from now on).");
                    structuresRemoved.set(index);
                    continue;
                }
                currentLogLikelihood += Math.log(prob);
            }
            PrintProgressStatic.end();
            System.out.println("EM cyle " + ++cycle + ". Log-Likelihood: " + currentLogLikelihood);
            if (currentLogLikelihood < previousLogLikelihood) break;
            previousLogLikelihood = currentLogLikelihood;
            fragmentTableFreq = newFragmentTableFreq;
            DOP_EM.getRootFreq();
            DOP_EM.printFragmentFreq(new File(String.valueOf(workingDir) + "kernelsMUB_CFG_freq_EM_cycle_" + cycle + ".txt"));
        } while (cycle != endCycle);
        System.out.println("Index sentences removed: " + structuresRemoved.toString());
    }

    private static double updateNewFragmentTableFreq(TSNodeLabelIndex t, Hashtable<TSNodeLabel, double[]> newFragmentTableFreq) {
        NodeSetCollectorSimple setCollector = new NodeSetCollectorSimple();
        HashMap<BitSet, TSNodeLabel_Double> bitSetFreqTable = new HashMap<BitSet, TSNodeLabel_Double>();
        for (Map.Entry<TSNodeLabel, double[]> e : fragmentTableFreq.entrySet()) {
            DOP_EM.getCFGSetCoveringFragment(t, e.getKey(), e.getValue()[0], (NodeSetCollector)setCollector, bitSetFreqTable);
        }
        TSNodeLabelStructure tStructure = new TSNodeLabelStructure(t);
        ProbChart pc = new ProbChart(setCollector, tStructure, bitSetFreqTable, newFragmentTableFreq);
        double prob = pc.getProb();
        if (prob == 0.0) {
            return prob;
        }
        pc.extractNewFragmentFrequencies();
        return prob;
    }

    private static void getCFGSetCoveringFragment(TSNodeLabelIndex t, TSNodeLabel fragment, double fragmentFreq, NodeSetCollector setCollector, HashMap<BitSet, TSNodeLabel_Double> bitSetFreqTable) {
        BitSet set;
        if (t.isLexical) {
            return;
        }
        if (t.sameLabel(fragment) && DOP_EM.getCFGSetCoveringFragmentNonRecursive(t, fragment, set = new BitSet()) && !set.isEmpty()) {
            setCollector.add(set);
            bitSetFreqTable.put(set, new TSNodeLabel_Double(fragment, fragmentFreq));
        }
        TSNodeLabel[] tSNodeLabelArray = t.daughters;
        int n = t.daughters.length;
        int n2 = 0;
        while (n2 < n) {
            TSNodeLabel d = tSNodeLabelArray[n2];
            TSNodeLabelIndex di = (TSNodeLabelIndex)d;
            DOP_EM.getCFGSetCoveringFragment(di, fragment, fragmentFreq, setCollector, bitSetFreqTable);
            ++n2;
        }
    }

    private static boolean getCFGSetCoveringFragmentNonRecursive(TSNodeLabelIndex t, TSNodeLabel fragment, BitSet set) {
        if (t.isLexical || fragment.isTerminal()) {
            return true;
        }
        if (!t.sameDaughtersLabel(fragment)) {
            return false;
        }
        int prole = t.prole();
        int i = 0;
        while (i < prole) {
            TSNodeLabel thisDaughter = t.daughters[i];
            TSNodeLabelIndex thisDaughterIndex = (TSNodeLabelIndex)thisDaughter;
            TSNodeLabel otherDaughter = fragment.daughters[i];
            if (!DOP_EM.getCFGSetCoveringFragmentNonRecursive(thisDaughterIndex, otherDaughter, set)) {
                return false;
            }
            ++i;
        }
        set.set(t.index);
        return true;
    }

    public static void main(String[] args) throws Exception {
        workingDir = new String("/scratch/fsangati/RESULTS/TSG/DOP_EM/");
        System.out.println("Working Dir: " + workingDir);
        String fragmentFileDir = "/scratch/fsangati/RESULTS/TSG/TSGkernels/Wsj/KenelFragments/SemTagOff_Top/all/";
        File fragmentFile = new File(String.valueOf(fragmentFileDir) + "fragments_MUB_freq_all.txt");
        File corpusFile = new File(String.valueOf(Wsj.WsjOriginalCleanedTop) + "wsj-02-21.mrg");
        ArrayList<TSNodeLabel> treebank = TSNodeLabel.getTreebank(corpusFile);
        TSNodeLabel.removeSemanticTagsInTreebank(treebank);
        endCycle = 50;
        DOP_EM.readTreeBank(treebank);
        DOP_EM.readFragmentsFile(fragmentFile);
        DOP_EM.addCFGfragments();
        DOP_EM.getRootFreq();
        DOP_EM.runEM();
    }

    static class DerivationsNode {
        ArrayList<PartialDerivation> partialDerivations = new ArrayList();
        double totalProb = 0.0;
        double newProbMass = 0.0;

        public void addDerivation(TSNodeLabel intialFragment, ArrayList<Integer> subSites, double derivationProb) {
            this.partialDerivations.add(new PartialDerivation(intialFragment, subSites, derivationProb));
            this.totalProb += derivationProb;
        }

        public void addProbMass(double probMass) {
            this.newProbMass += probMass;
        }
    }

    static class PartialDerivation {
        TSNodeLabel intialFragment;
        ArrayList<Integer> subSites;
        double partialDerivProb;

        public PartialDerivation(TSNodeLabel intialFragment, ArrayList<Integer> subSites, double partialDerivProb) {
            this.intialFragment = intialFragment;
            this.subSites = subSites;
            this.partialDerivProb = partialDerivProb;
        }
    }

    static class ProbChart {
        NodeSetCollectorSimple setCollector;
        TSNodeLabelStructure t;
        int totalNodes;
        DerivationsNode[] derivationsNodes;
        NodeSetCollectorStandard[] nodesCollector;
        HashMap<BitSet, TSNodeLabel_Double> bitSetFreqTable;
        Hashtable<TSNodeLabel, double[]> newFragmentTableFreq;

        public ProbChart(NodeSetCollectorSimple setCollector, TSNodeLabelStructure t, HashMap<BitSet, TSNodeLabel_Double> bitSetFreqTable, Hashtable<TSNodeLabel, double[]> newFragmentTableFreq) {
            this.setCollector = setCollector;
            this.t = t;
            this.newFragmentTableFreq = newFragmentTableFreq;
            this.totalNodes = t.length;
            this.derivationsNodes = new DerivationsNode[this.totalNodes];
            this.nodesCollector = new NodeSetCollectorStandard[this.totalNodes];
            this.bitSetFreqTable = bitSetFreqTable;
        }

        public double getProb() {
            for (BitSet bs : this.setCollector.bitSetSet) {
                int firstIndex = bs.nextSetBit(0);
                if (this.nodesCollector[firstIndex] == null) {
                    this.nodesCollector[firstIndex] = new NodeSetCollectorStandard();
                }
                this.nodesCollector[firstIndex].add(bs);
            }
            return this.getProbRecursive(0);
        }

        private double getProbRecursive(int index) {
            if (this.derivationsNodes[index] != null) {
                return this.derivationsNodes[index].totalProb;
            }
            NodeSetCollectorStandard setCollector = this.nodesCollector[index];
            if (setCollector == null) {
                this.derivationsNodes[index] = new DerivationsNode();
                return 0.0;
            }
            TSNodeLabelIndex root = this.t.structure[index];
            double rootFreq = rootTableFreq.get(root.label)[0];
            DerivationsNode derivation = new DerivationsNode();
            for (BitSet initialSubTree : setCollector.bitSetArray) {
                ArrayList<Integer> subSitesIndexes = new ArrayList<Integer>();
                this.collectSubSites(root, initialSubTree, subSitesIndexes);
                double partialProb = 1.0;
                for (int subSiteIndex : subSitesIndexes) {
                    double subSiteProb = this.getProbRecursive(subSiteIndex);
                    if (subSiteProb == 0.0) {
                        partialProb = 0.0;
                        break;
                    }
                    partialProb *= subSiteProb;
                }
                if (partialProb == 0.0) continue;
                TSNodeLabel_Double treeDouble = this.bitSetFreqTable.get(initialSubTree);
                double initialSubTreeFreq = treeDouble.d;
                TSNodeLabel initialFragment = treeDouble.tree;
                derivation.addDerivation(initialFragment, subSitesIndexes, partialProb *= initialSubTreeFreq / rootFreq);
            }
            this.derivationsNodes[index] = derivation;
            return this.derivationsNodes[index].totalProb;
        }

        private void collectSubSites(TSNodeLabelIndex root, BitSet initialSubTree, ArrayList<Integer> subSitesIndexes) {
            TSNodeLabel[] tSNodeLabelArray = root.daughters;
            int n = root.daughters.length;
            int n2 = 0;
            while (n2 < n) {
                TSNodeLabel d = tSNodeLabelArray[n2];
                if (d.isLexical) {
                    return;
                }
                TSNodeLabelIndex di = (TSNodeLabelIndex)d;
                int index = di.index;
                if (!initialSubTree.get(index)) {
                    subSitesIndexes.add(index);
                } else {
                    this.collectSubSites(di, initialSubTree, subSitesIndexes);
                }
                ++n2;
            }
        }

        public void extractNewFragmentFrequencies() {
            this.derivationsNodes[0].newProbMass = 1.0;
            this.extractNewFragmentFrequenciesRecursive(0);
        }

        private void extractNewFragmentFrequenciesRecursive(int index) {
            DerivationsNode derivations = this.derivationsNodes[index];
            double totalMass = derivations.newProbMass;
            double derivationsTotProb = derivations.totalProb;
            TreeSet<Integer> allEncounteredSubSites = new TreeSet<Integer>();
            for (PartialDerivation pd : derivations.partialDerivations) {
                TSNodeLabel initialFragment = pd.intialFragment;
                double pdProb = pd.partialDerivProb;
                double partialMass = pdProb / derivationsTotProb * totalMass;
                Utility.increaseInTableDoubleArray(this.newFragmentTableFreq, initialFragment, partialMass);
                for (int subSite : pd.subSites) {
                    allEncounteredSubSites.add(subSite);
                    DerivationsNode subSiteDerivation = this.derivationsNodes[subSite];
                    subSiteDerivation.addProbMass(partialMass);
                }
            }
            Iterator<PartialDerivation> iterator = allEncounteredSubSites.iterator();
            while (iterator.hasNext()) {
                int subSite = (Integer)((Object)iterator.next());
                this.extractNewFragmentFrequenciesRecursive(subSite);
            }
        }
    }

    static class TSNodeLabel_Double {
        TSNodeLabel tree;
        double d;

        public TSNodeLabel_Double(TSNodeLabel tree, double d) {
            this.tree = tree;
            this.d = d;
        }
    }
}

