/*
 * Decompiled with CFR 0.152.
 */
package tsg.parser.petrov;

import java.io.File;
import java.util.ArrayList;
import java.util.Hashtable;
import java.util.Map;
import java.util.Scanner;
import tsg.TSNodeLabel;
import util.FileUtil;
import util.Utility;

public class CompareMJCykBitPar {
    static double minProbRule;
    static boolean normalize;
    static String outputPath;
    static File petrovGrammarFile;
    static File petrovLexiconFile;
    static File petrovParsedFile;
    static File bitparParsedFile;
    Hashtable<String, Double> grammarRulesLogProb;
    Hashtable<String, Double> lexRulesLogProb;

    public CompareMJCykBitPar() {
        this.convertGrammar();
        this.convertLexicon();
        if (normalize) {
            this.normalizeAndMakeLog();
        } else {
            this.makeLog();
        }
    }

    private void normalizeAndMakeLog() {
        int i = 0;
        while (i < 2) {
            Hashtable<String, Double> rulesProb = i == 0 ? this.grammarRulesLogProb : this.lexRulesLogProb;
            Hashtable rootProb = new Hashtable();
            for (Map.Entry<String, Double> e : rulesProb.entrySet()) {
                String rule = e.getKey();
                double prob = e.getValue();
                String lhs = rule.split("\\s")[0];
                Utility.increaseInTableDoubleArray(rootProb, lhs, prob);
            }
            Hashtable<String, Double> newRulesLogProb = new Hashtable<String, Double>();
            for (Map.Entry<String, Double> e : rulesProb.entrySet()) {
                String rule = e.getKey();
                double prob = e.getValue();
                String lhs = rule.split("\\s")[0];
                double lhsProb = ((double[])rootProb.get(lhs))[0];
                double logProb = Math.log(prob / lhsProb);
                newRulesLogProb.put(rule, logProb);
            }
            if (i == 0) {
                this.grammarRulesLogProb = newRulesLogProb;
            } else {
                this.lexRulesLogProb = newRulesLogProb;
            }
            ++i;
        }
    }

    private void makeLog() {
        int i = 0;
        while (i < 2) {
            Hashtable<String, Double> rulesProb = i == 0 ? this.grammarRulesLogProb : this.lexRulesLogProb;
            Hashtable<String, Double> newRulesLogProb = new Hashtable<String, Double>();
            for (Map.Entry<String, Double> e : rulesProb.entrySet()) {
                String rule = e.getKey();
                double prob = e.getValue();
                double logProb = Math.log(prob);
                newRulesLogProb.put(rule, logProb);
            }
            if (i == 0) {
                this.grammarRulesLogProb = newRulesLogProb;
            } else {
                this.lexRulesLogProb = newRulesLogProb;
            }
            ++i;
        }
    }

    private void convertGrammar() {
        this.grammarRulesLogProb = new Hashtable();
        Scanner grammarScan = FileUtil.getScanner(petrovGrammarFile);
        int smallProbRuleSkipped = 0;
        int totalAcceptedRules = 0;
        while (grammarScan.hasNextLine()) {
            boolean unaryRule;
            String line = grammarScan.nextLine();
            String[] lineSplit = line.split("\\s");
            int length = lineSplit.length;
            double prob = Double.parseDouble(lineSplit[0]);
            if (prob < minProbRule) {
                System.out.println("Skipped rule: " + line);
                ++smallProbRuleSkipped;
                continue;
            }
            ++totalAcceptedRules;
            String lhs = lineSplit[1];
            boolean bl = unaryRule = length == 3;
            if (unaryRule) {
                String rhsChild = lineSplit[2];
                String rule = String.valueOf(lhs) + " " + rhsChild;
                this.grammarRulesLogProb.put(rule, prob);
                continue;
            }
            String rule = lhs;
            int i = 2;
            while (i < length) {
                String rhsChild = lineSplit[i];
                rule = String.valueOf(rule) + " " + rhsChild;
                ++i;
            }
            this.grammarRulesLogProb.put(rule, prob);
        }
        System.out.println("Skipped small prob internal rules: " + smallProbRuleSkipped);
        System.out.println("Total accepted internal rules: " + totalAcceptedRules);
    }

    private void convertLexicon() {
        this.lexRulesLogProb = new Hashtable();
        int smallProbRuleSkipped = 0;
        int totalAcceptedRules = 0;
        Scanner lexiconScan = FileUtil.getScanner(petrovLexiconFile);
        while (lexiconScan.hasNextLine()) {
            String line = lexiconScan.nextLine();
            String[] lineSplit = line.split("\t");
            int length = lineSplit.length;
            String lex = lineSplit[0];
            int i = 1;
            while (i < length) {
                String posProb = lineSplit[i];
                String[] posProbSplit = posProb.split("\\s");
                String pos = posProbSplit[0];
                double prob = Double.parseDouble(posProbSplit[1]);
                if (prob > minProbRule) {
                    ++totalAcceptedRules;
                    String rule = String.valueOf(pos) + " " + lex;
                    this.lexRulesLogProb.put(rule, prob);
                } else {
                    ++smallProbRuleSkipped;
                }
                ++i;
            }
        }
        System.out.println("Skipped small prob lex rules: " + smallProbRuleSkipped);
        System.out.println("Total accepted lex rules: " + totalAcceptedRules);
    }

    private void makeComparison(TSNodeLabel bitparTree, TSNodeLabel cyk) {
        this.compareProbTrees(cyk, bitparTree, 0);
    }

    private int compareProbTrees(TSNodeLabel petrovTree, TSNodeLabel bitparTree, int index) {
        System.out.println("Index: " + index);
        int length = petrovTree.countLexicalNodes();
        System.out.println("Length: " + length);
        System.out.println(petrovTree.toFlatSentence());
        System.out.println("Petrov: ");
        double petrovLogProb = this.getLogProb(petrovTree, true);
        System.out.println("BitPar: ");
        double bitparLogProb = this.getLogProb(bitparTree, false);
        int result = petrovLogProb > bitparLogProb ? 0 : (petrovLogProb < bitparLogProb ? 1 : 2);
        double diff = petrovLogProb - bitparLogProb;
        String lineReport = String.valueOf(index) + " " + length + " " + "Petrov log prob: " + petrovLogProb + " BitPar log prob: " + bitparLogProb + " diff: " + diff;
        System.out.println(lineReport);
        System.out.println();
        return result;
    }

    private double getLogProb(TSNodeLabel tree, boolean petrov) {
        double result = 0.0;
        ArrayList<TSNodeLabel> nodes = tree.collectAllNodes();
        for (TSNodeLabel n : nodes) {
            Double prob;
            if (n.isLexical) continue;
            String rule = n.cfgRuleNoQuotes();
            Double d = prob = n.isPreLexical() ? this.lexRulesLogProb.get(rule) : this.grammarRulesLogProb.get(rule);
            if (prob == null) {
                System.out.println("\tNot found rule: (" + (petrov ? "petrov" : "bitpar") + ") " + rule);
                return 0.0;
            }
            result += prob.doubleValue();
            System.out.println("\t" + rule + " " + prob);
        }
        System.out.println("\tTotalProb: " + result);
        return result;
    }

    public static void main(String[] args) throws Exception {
        minProbRule = 0.0;
        outputPath = "tmp/compare/newTest/";
        normalize = true;
        petrovGrammarFile = new File("tmp/compare/newTest/bitpar_grammar.txt");
        petrovLexiconFile = new File("tmp/compare/newTest/bitpar_lexicon.txt");
        petrovParsedFile = new File("tmp/compare/newTest/bitpar_parses.mrg");
        bitparParsedFile = new File("tmp/compare/newTest/cyk_parses.mrg");
        CompareMJCykBitPar C = new CompareMJCykBitPar();
        TSNodeLabel bitparTree44 = new TSNodeLabel("(ROOT-0 (FRAG-0 (@FRAG-2 (@FRAG-0 (INTJ-0 (UH-0 Ah)) (,-0 ,)) (PP-25 (IN-22 UNK-LC) (NP-53 (NNP-37 Columbia)))) (.-2 !)))");
        TSNodeLabel cyk44 = new TSNodeLabel("(ROOT-0 (FRAG-0 (@FRAG-2 (@FRAG-0 (INTJ-0 (UH-0 Ah)) (,-0 ,)) (NP-14 (JJ-10 UNK-LC) (NNP-41 Columbia))) (.-2 !)))");
        C.makeComparison(bitparTree44, cyk44);
    }
}

