/*
 * Decompiled with CFR 0.152.
 */
package tsg.LTSG;

import java.util.ArrayList;
import java.util.Hashtable;
import java.util.IdentityHashMap;
import java.util.List;
import java.util.Map;
import settings.Parameters;
import tsg.LTSG.LTSG;
import tsg.LTSG.LexTreeStructure;
import tsg.LTSG.NodeStructure;
import tsg.TSNode;
import tsg.corpora.ConstCorpus;
import util.FileUtil;
import util.PrintProgressStatic;
import util.Utility;

public class LTSG_EM
extends LTSG {
    Hashtable<String, Double> template_prob;
    Hashtable<String, Double> root_prob;
    public static final String initializeUNIFORM = "initializeUNIFORM";
    public static final String initializeDOP = "initializeDOP";

    private void extractRootProb() {
        this.root_prob = new Hashtable();
        for (Map.Entry<String, Double> e : this.template_prob.entrySet()) {
            String root = TSNode.get_unique_root(e.getKey());
            Double count = e.getValue();
            Utility.increaseStringDouble(this.root_prob, root, count);
        }
    }

    private void initializeProb() {
        this.initializeProbExcluding(null);
    }

    private void initializeProbExcluding(TSNode treeToExclude) {
        this.extractAllLexTrees();
        if (treeToExclude != null) {
            this.decreaseElementayTreesFrom(treeToExclude);
        }
        if (Parameters.EM_initialization.equals(initializeDOP)) {
            this.template_prob = new Hashtable();
            for (Map.Entry e : this.template_freq.entrySet()) {
                String tree = (String)e.getKey();
                double count = ((Integer)e.getValue()).intValue();
                this.template_prob.put(tree, count);
            }
        } else {
            this.template_prob = new Hashtable();
            for (Map.Entry e : this.template_freq.entrySet()) {
                String tree = (String)e.getKey();
                this.template_prob.put(tree, 1.0);
            }
        }
        this.normalizeTemplateProb();
        if (treeToExclude != null) {
            this.increaseElementaryTreeesFrom(treeToExclude);
        }
    }

    private void normalizeTemplateProb() {
        this.extractRootProb();
        for (Map.Entry<String, Double> e : this.template_prob.entrySet()) {
            String tree = e.getKey();
            double treeProb = e.getValue();
            String root = TSNode.get_unique_root(tree);
            double rootCount = this.root_prob.get(root);
            double newTreeProb = treeProb / rootCount;
            e.setValue(newTreeProb);
        }
    }

    public void reportMaxLexicalDerivations() {
        long maxDerivation = 0L;
        for (TSNode inputTree : Parameters.trainingCorpus.treeBank) {
            long lexDerivations = inputTree.lexDerivations();
            if (lexDerivations <= maxDerivation) continue;
            maxDerivation = lexDerivations;
        }
        FileUtil.appendReturn("Max number of derivation per tree: " + maxDerivation, Parameters.logFile);
    }

    public void checkEMCoverage() {
        ConstCorpus originalTrainingCorpus = null;
        if (Parameters.posTagConversion) {
            originalTrainingCorpus = Parameters.trainingCorpus.deepClone();
            Parameters.trainingCorpus.makePosTagsLexicon();
        }
        this.extractAllLexTrees();
        int covered = 0;
        for (TSNode inputTree : Parameters.trainingCorpus.treeBank) {
            this.decreaseElementayTreesFrom(inputTree);
            if (this.checkCoverageTree(inputTree)) {
                ++covered;
            }
            this.increaseElementaryTreeesFrom(inputTree);
        }
        if (Parameters.posTagConversion) {
            Parameters.trainingCorpus.unMakePosTagsLexicon(originalTrainingCorpus);
        }
        int treebankSize = Parameters.trainingCorpus.size();
        float ratio = (float)covered / (float)treebankSize;
        System.out.println("Covered tree in training corpus: " + covered + " / " + treebankSize + " (" + ratio + ")");
    }

    private boolean checkCoverageTree(TSNode inputTree) {
        ArrayList<ArrayList<TSNode>> levelsSubTrees = inputTree.getNodesInDepthLevels();
        IdentityHashMap<TSNode, Boolean> coveredSubTrees = new IdentityHashMap<TSNode, Boolean>();
        for (ArrayList arrayList : levelsSubTrees) {
            for (TSNode TN : arrayList) {
                if (TN.isLexical || TN.isUniqueDaughter()) continue;
                boolean covered = false;
                if (TN.isPrelexical() || !TN.hasMoreThanNBranching(1)) {
                    if (this.template_freq.keySet().contains(TN.toString(false, true))) {
                        covered = true;
                    }
                } else {
                    List<TSNode> lexicon = TN.collectLexicalItems();
                    for (TSNode anchor : lexicon) {
                        TN.markHeadPathToAnchor(anchor);
                        TSNode lexTemplate = TN.lexicalizedTreeCopy();
                        lexTemplate.applyAllConversions();
                        List<TSNode> subSites = TN.collectSubstitutionSites();
                        TN.unmarkHeadPathToAnchor(anchor);
                        if (!this.template_freq.keySet().contains(lexTemplate.toString(false, true))) continue;
                        covered = true;
                        for (TSNode SS : subSites) {
                            if (coveredSubTrees.keySet().contains(SS)) continue;
                            covered = false;
                            break;
                        }
                        if (covered) break;
                    }
                }
                if (!covered) continue;
                coveredSubTrees.put(TN, true);
            }
        }
        return coveredSubTrees.keySet().contains(inputTree);
    }

    public void EMHeldOutAlgorithm() {
        ConstCorpus originalTrainingCorpus = null;
        if (Parameters.posTagConversion) {
            originalTrainingCorpus = Parameters.trainingCorpus.deepClone();
            Parameters.trainingCorpus.makePosTagsLexicon();
        }
        this.extractAllLexTrees();
        PrintProgressStatic.start("Estimating EM param. sentence:");
        block0: for (TSNode observedTree : Parameters.trainingCorpus.treeBank) {
            PrintProgressStatic.next();
            this.initializeProbExcluding(observedTree);
            int cycle = 0;
            double previousLikelihood = -1.7976931348623157E308;
            double delta = 0.0;
            do {
                ++cycle;
                observedTree.removeHeadAnnotations();
                Hashtable<String, Double> new_template_prob = new Hashtable<String, Double>();
                Double actualLikelihood = this.getLikelihoodAndBestAnnotation(observedTree, new_template_prob);
                if (actualLikelihood == null) {
                    FileUtil.appendReturn("No coverage for " + observedTree, Parameters.logFile);
                    observedTree.assignRandomHeads();
                    continue block0;
                }
                delta = actualLikelihood - previousLikelihood;
                previousLikelihood = actualLikelihood;
                this.template_prob = new_template_prob;
                this.normalizeTemplateProb();
            } while (delta > 0.0 && delta > Parameters.EM_deltaThreshold && cycle < Parameters.EM_maxCycle);
        }
        PrintProgressStatic.end();
        if (Parameters.posTagConversion) {
            Parameters.trainingCorpus.unMakePosTagsLexicon(originalTrainingCorpus);
        }
    }

    public void EMHeldOutAlgorithm(ConstCorpus heldOutCorpus) {
        ConstCorpus originalTrainingCorpus = null;
        ConstCorpus originalHeldOutCorpus = null;
        if (Parameters.posTagConversion) {
            originalTrainingCorpus = Parameters.trainingCorpus.deepClone();
            originalHeldOutCorpus = heldOutCorpus.deepClone();
            Parameters.trainingCorpus.makePosTagsLexicon();
            heldOutCorpus.makePosTagsLexicon();
        }
        this.initializeProb();
        int cycle = 0;
        double previousLikelihood = -1.7976931348623157E308;
        double delta = 0.0;
        do {
            boolean[] headVariation = new boolean[1];
            double currentLikelihood = this.emStepsHeldOutCorpus(heldOutCorpus);
            delta = currentLikelihood - previousLikelihood;
            previousLikelihood = currentLikelihood;
            String line = "EM cycle: " + ++cycle + "\tCurrent LogLikeLihood: " + currentLikelihood + "\tDelta LogLikeLihood: " + delta + "\tHead Variation: " + headVariation[0];
            FileUtil.appendReturn(line, Parameters.logFile);
        } while (delta > 0.0 && delta > Parameters.EM_deltaThreshold && cycle < Parameters.EM_maxCycle);
        if (Parameters.posTagConversion) {
            Parameters.trainingCorpus.unMakePosTagsLexicon(originalTrainingCorpus);
            heldOutCorpus.unMakePosTagsLexicon(originalHeldOutCorpus);
        }
    }

    public void EMalgorithm() {
        ConstCorpus originalTrainingCorpus = null;
        if (Parameters.posTagConversion) {
            originalTrainingCorpus = Parameters.trainingCorpus.deepClone();
            Parameters.trainingCorpus.makePosTagsLexicon();
        }
        this.initializeProb();
        int cycle = 0;
        double previousLikelihood = -1.7976931348623157E308;
        double delta = 0.0;
        do {
            boolean[] headVariation = new boolean[1];
            double currentLikelihood = this.emSteps(headVariation);
            delta = currentLikelihood - previousLikelihood;
            previousLikelihood = currentLikelihood;
            String line = "EM cycle: " + ++cycle + "\tCurrent LogLikeLihood: " + currentLikelihood + "\tDelta LogLikeLihood: " + delta + "\tHead Variation: " + headVariation[0];
            FileUtil.appendReturn(line, Parameters.logFile);
        } while (delta > 0.0 && delta > Parameters.EM_deltaThreshold && cycle < Parameters.EM_maxCycle);
        if (Parameters.posTagConversion) {
            Parameters.trainingCorpus.unMakePosTagsLexicon(originalTrainingCorpus);
        }
    }

    public void EMalgorithmIntermediateResults() {
        throw new Error("Unresolved compilation problem: \n\tThe constructor Parser(LTSG_EM) is undefined\n");
    }

    private void assignHeadAnnotation(TSNode inputTree, IdentityHashMap<TSNode, NodeStructure> bestDerivationTable) {
        NodeStructure bestDerivation = bestDerivationTable.get(inputTree);
        LexTreeStructure bestAnchorDerivation = bestDerivation.getBestLexTree();
        inputTree.markHeadPathToAnchor(bestAnchorDerivation.anchor);
        List<TSNode> subSites = inputTree.collectSubstitutionSites();
        for (TSNode D : subSites) {
            this.assignHeadAnnotation(D, bestDerivationTable);
        }
    }

    private void calculateInsideProbNodes(TSNode inputTree, IdentityHashMap<TSNode, NodeStructure> bestDerivationTable) {
        ArrayList<ArrayList<TSNode>> levelsSubTrees = inputTree.getNodesInDepthLevels();
        for (ArrayList arrayList : levelsSubTrees) {
            for (TSNode TN : arrayList) {
                TSNode anchor;
                if (TN.isLexical || TN.isUniqueDaughter()) continue;
                NodeStructure nodeInfos = null;
                if (TN.isPrelexical() || !TN.hasMoreThanNBranching(1)) {
                    Double weight = this.template_prob.get(TN.toString(false, true));
                    if (weight == null) continue;
                    double logWeight = Math.log(weight);
                    anchor = TN.getAnchor();
                    TN.markHeadPathToAnchor(anchor);
                    TSNode lexicalTree = TN.lexicalizedTreeCopy();
                    TN.unmarkHeadPathToAnchor(anchor);
                    LexTreeStructure bestAnchorDerivation = new LexTreeStructure(anchor, lexicalTree, logWeight);
                    bestAnchorDerivation.bestDerivationLogProb = logWeight;
                    nodeInfos = new NodeStructure();
                    nodeInfos.insideLogProb = logWeight;
                    nodeInfos.bestRootAnchorLexTree.put(TN.getAnchor(), bestAnchorDerivation);
                } else {
                    TSNode[] lexicalsTN = TN.collectTerminals().toArray(new TSNode[0]);
                    double insideProb = 0.0;
                    TSNode[] tSNodeArray = lexicalsTN;
                    int n = lexicalsTN.length;
                    int n2 = 0;
                    while (n2 < n) {
                        anchor = tSNodeArray[n2];
                        TN.markHeadPathToAnchor(anchor);
                        List<TSNode> subSites = TN.collectSubstitutionSites();
                        TSNode lexicalTree = TN.lexicalizedTreeCopy();
                        TN.unmarkHeadPathToAnchor(anchor);
                        lexicalTree.applyAllConversions();
                        Double weight = this.template_prob.get(lexicalTree.toString(false, true));
                        if (weight != null) {
                            double logWeight = Math.log(weight);
                            LexTreeStructure bestAnchorDerivation = new LexTreeStructure(anchor, lexicalTree, logWeight, subSites);
                            bestAnchorDerivation.bestDerivationLogProb = logWeight;
                            Double insideLogProbTree = bestAnchorDerivation.getInsideLogProbTree(bestDerivationTable);
                            if (insideLogProbTree != null) {
                                if (nodeInfos == null) {
                                    nodeInfos = new NodeStructure();
                                }
                                insideProb += Math.exp(insideLogProbTree);
                                nodeInfos.bestRootAnchorLexTree.put(anchor, bestAnchorDerivation);
                            }
                        }
                        ++n2;
                    }
                    if (nodeInfos != null) {
                        nodeInfos.insideLogProb = Math.log(insideProb);
                    }
                }
                if (nodeInfos == null) continue;
                bestDerivationTable.put(TN, nodeInfos);
            }
        }
    }

    private void calculateOutsideProbNodes(TSNode inputTree, IdentityHashMap<TSNode, NodeStructure> bestDerivationTable) {
        List<TSNode> nodeList = inputTree.collectNonLexicalNodes();
        for (TSNode node : nodeList) {
            NodeStructure nodeInfos;
            if (node.isUniqueDaughter() || (nodeInfos = bestDerivationTable.get(node)) == null) continue;
            if (node.isRoot()) {
                nodeInfos.outsideProb = 1.0;
                continue;
            }
            TSNode parentNode = node.parent;
            while (parentNode.isUniqueDaughter()) {
                parentNode = parentNode.parent;
            }
            ArrayList<TSNode> selectedLexicon = new ArrayList<TSNode>(parentNode.collectLexicalItems());
            selectedLexicon.removeAll(node.collectLexicalItems());
            List<TSNode> ancestorList = node.collectAncestorNodes();
            for (TSNode ancestor : ancestorList) {
                NodeStructure ancestorInfos;
                if (ancestor.isUniqueDaughter() || (ancestorInfos = bestDerivationTable.get(ancestor)) == null || ancestorInfos.outsideProb == 0.0) continue;
                double outsideProbAncestorPart = 0.0;
                for (TSNode lex : selectedLexicon) {
                    LexTreeStructure lexTreeInfo = ancestorInfos.bestRootAnchorLexTree.get(lex);
                    if (lexTreeInfo == null) continue;
                    Double insideProbTreeExcludingNode = Math.exp(lexTreeInfo.getInsideLogProbTreeExludingSubSite(bestDerivationTable, node));
                    outsideProbAncestorPart += insideProbTreeExcludingNode.doubleValue();
                }
                nodeInfos.outsideProb += (outsideProbAncestorPart *= ancestorInfos.outsideProb);
            }
        }
    }

    private void updateNewProbTable(TSNode inputTree, double logLikelihood, IdentityHashMap<TSNode, NodeStructure> bestDerivationTable, Hashtable<String, Double> new_template_prob) {
        List<TSNode> internalNodes = inputTree.collectNonLexicalNodes();
        for (TSNode node : internalNodes) {
            NodeStructure nodeInfos;
            if (node.isUniqueDaughter() || (nodeInfos = bestDerivationTable.get(node)) == null || nodeInfos.outsideProb == 0.0) continue;
            double nodeOutsideLogProb = Math.log(nodeInfos.outsideProb);
            for (Map.Entry<TSNode, LexTreeStructure> i : nodeInfos.bestRootAnchorLexTree.entrySet()) {
                LexTreeStructure lexTree = i.getValue();
                double lexTreeInsideLogProb = lexTree.getInsideLogProbTree(bestDerivationTable);
                double lexTreeIncrementCount = Math.exp(lexTreeInsideLogProb + nodeOutsideLogProb - logLikelihood);
                String lexTreeString = lexTree.lexTreeCopy.toString(false, true);
                Utility.increaseStringDouble(new_template_prob, lexTreeString, lexTreeIncrementCount);
            }
        }
    }

    private double emSteps(boolean[] headVariation) {
        double likelihood = 0.0;
        Hashtable<String, Double> new_template_prob = new Hashtable<String, Double>();
        for (TSNode inputTree : Parameters.trainingCorpus.treeBank) {
            TSNode inputTreeCopy = null;
            if (!headVariation[0]) {
                inputTreeCopy = new TSNode(inputTree);
            }
            inputTree.removeHeadAnnotations();
            likelihood += this.getLikelihoodAndBestAnnotation(inputTree, new_template_prob).doubleValue();
            if (!headVariation[0]) {
                boolean bl = headVariation[0] = !inputTree.hasSameHeadAnnotation(inputTreeCopy);
            }
            if (!inputTree.hasWrongHeadAssignment()) continue;
            System.err.println("Wrong Head Assignment: " + inputTree.toString(true, true));
        }
        this.template_prob = new_template_prob;
        this.normalizeTemplateProb();
        return likelihood;
    }

    private double emStepsHeldOutCorpus(ConstCorpus heldOutCorpus) {
        double likelihood = 0.0;
        Hashtable<String, Double> new_template_prob = new Hashtable<String, Double>();
        for (TSNode inputTree : heldOutCorpus.treeBank) {
            inputTree.removeHeadAnnotations();
            Double actualLikelihood = this.getLikelihoodAndBestAnnotation(inputTree, new_template_prob);
            if (actualLikelihood == null) {
                FileUtil.appendReturn("No coverage for " + inputTree, Parameters.logFile);
                continue;
            }
            likelihood += actualLikelihood.doubleValue();
        }
        this.template_prob = new_template_prob;
        this.normalizeTemplateProb();
        return likelihood;
    }

    private Double getLikelihoodAndBestAnnotation(TSNode inputTree, Hashtable<String, Double> new_template_prob) {
        IdentityHashMap<TSNode, NodeStructure> bestDerivationTable = new IdentityHashMap<TSNode, NodeStructure>();
        this.calculateInsideProbNodes(inputTree, bestDerivationTable);
        this.calculateOutsideProbNodes(inputTree, bestDerivationTable);
        if (bestDerivationTable.get(inputTree) == null) {
            return null;
        }
        Double logLikelihood = bestDerivationTable.get((Object)inputTree).insideLogProb;
        if (logLikelihood != null) {
            this.updateNewProbTable(inputTree, logLikelihood, bestDerivationTable, new_template_prob);
            this.assignHeadAnnotation(inputTree, bestDerivationTable);
        }
        return logLikelihood;
    }

    public static void coverage() {
        Parameters.setDefaultParam();
        Parameters.lengthLimitTraining = 40;
        Parameters.posTagConversion = true;
        Parameters.spineConversion = true;
        Parameters.LTSGtype = "EM";
        Parameters.outputPath = "/home/fsangati/PROJECTS/TSG/RESULTS/LTSG/" + Parameters.LTSGtype + "/";
        LTSG_EM Grammar2 = new LTSG_EM();
        Grammar2.checkEMCoverage();
    }

    public static void EmStandard(String[] args) {
        throw new Error("Unresolved compilation problem: \n\tThe constructor Parser(LTSG_EM) is undefined\n");
    }

    public static void EmIntermediateResults() {
        Parameters.setDefaultParam();
        Parameters.writeGlobalResults = false;
        Parameters.lengthLimitTraining = 10;
        Parameters.lengthLimitTest = 10;
        Parameters.smoothing = false;
        Parameters.LTSGtype = "EM";
        Parameters.outputPath = "/home/fsangati/PROJECTS/TSG/RESULTS/LTSG/" + Parameters.LTSGtype + "/";
        Parameters.EM_initialization = initializeUNIFORM;
        Parameters.EM_nBest = -1;
        Parameters.EM_deltaThreshold = 1.0E-4;
        Parameters.EM_maxCycle = 100;
        Parameters.parserName = "bitPar";
        Parameters.nBest = 1;
        Parameters.cachingActive = false;
        Parameters.posTagConversion = true;
        Parameters.spineConversion = true;
        LTSG_EM Grammar2 = new LTSG_EM();
        Grammar2.EMalgorithmIntermediateResults();
    }

    public static void EMHeldOut(String[] args) {
        throw new Error("Unresolved compilation problem: \n\tThe constructor Parser(LTSG_EM) is undefined\n");
    }

    public static void main(String[] args) {
        LTSG_EM.EmStandard(args);
    }
}

