/*
 * Decompiled with CFR 0.152.
 */
package tsg.parser.petrov;

import java.io.File;
import java.io.PrintWriter;
import java.util.ArrayList;
import java.util.Hashtable;
import java.util.Iterator;
import java.util.Map;
import java.util.Scanner;
import tsg.TSNodeLabel;
import util.FileUtil;
import util.Utility;

public class ComparePetroBitPar
extends Thread {
    static double minProbRule;
    static boolean normalize;
    static String outputPath;
    static File petrovGrammarFile;
    static File petrovLexiconFile;
    static File petrovParsedFile;
    static File bitparParsedFile;
    Hashtable<String, Double> grammarRulesLogProb;
    Hashtable<String, Double> lexRulesLogProb;
    static String regexConvertLabel;
    static String regexCleanProb;

    static {
        regexConvertLabel = "\\^g";
        regexCleanProb = "[\\[\\]\\,]";
    }

    @Override
    public void run() {
        this.convertGrammar();
        this.convertLexicon();
        if (normalize) {
            this.normalizeAndMakeLog();
        } else {
            this.makeLog();
        }
        try {
            this.makeComparison();
        }
        catch (Exception e) {
            e.printStackTrace();
        }
    }

    private void normalizeAndMakeLog() {
        int i = 0;
        while (i < 2) {
            Hashtable<String, Double> rulesProb = i == 0 ? this.grammarRulesLogProb : this.lexRulesLogProb;
            Hashtable rootProb = new Hashtable();
            for (Map.Entry<String, Double> e : rulesProb.entrySet()) {
                String rule = e.getKey();
                double prob = e.getValue();
                String lhs = rule.split("\\s")[0];
                Utility.increaseInTableDoubleArray(rootProb, lhs, prob);
            }
            Hashtable<String, Double> newRulesLogProb = new Hashtable<String, Double>();
            for (Map.Entry<String, Double> e : rulesProb.entrySet()) {
                String rule = e.getKey();
                double prob = e.getValue();
                String lhs = rule.split("\\s")[0];
                double lhsProb = ((double[])rootProb.get(lhs))[0];
                double logProb = Math.log(prob / lhsProb);
                newRulesLogProb.put(rule, logProb);
            }
            if (i == 0) {
                this.grammarRulesLogProb = newRulesLogProb;
            } else {
                this.lexRulesLogProb = newRulesLogProb;
            }
            ++i;
        }
    }

    private void makeLog() {
        int i = 0;
        while (i < 2) {
            Hashtable<String, Double> rulesProb = i == 0 ? this.grammarRulesLogProb : this.lexRulesLogProb;
            Hashtable<String, Double> newRulesLogProb = new Hashtable<String, Double>();
            for (Map.Entry<String, Double> e : rulesProb.entrySet()) {
                String rule = e.getKey();
                double prob = e.getValue();
                String lhs = rule.split("\\s")[0];
                double logProb = Math.log(prob);
                newRulesLogProb.put(rule, logProb);
            }
            if (i == 0) {
                this.grammarRulesLogProb = newRulesLogProb;
            } else {
                this.lexRulesLogProb = newRulesLogProb;
            }
            ++i;
        }
    }

    private void convertGrammar() {
        this.grammarRulesLogProb = new Hashtable();
        Scanner grammarScan = FileUtil.getScanner(petrovGrammarFile);
        int smallProbRuleSkipped = 0;
        int totalAcceptedRules = 0;
        while (grammarScan.hasNextLine()) {
            boolean unaryRule;
            int length;
            String line = grammarScan.nextLine();
            String[] lineSplit = line.split("\\s");
            double prob = Double.parseDouble(lineSplit[(length = lineSplit.length) - 1]);
            if (prob < minProbRule) {
                ++smallProbRuleSkipped;
                continue;
            }
            ++totalAcceptedRules;
            String lhs = ComparePetroBitPar.convertLable(lineSplit[0]);
            boolean bl = unaryRule = length == 4;
            if (unaryRule) {
                String rhsChild = ComparePetroBitPar.convertLable(lineSplit[2]);
                String rule = String.valueOf(lhs) + " " + rhsChild;
                this.grammarRulesLogProb.put(rule, prob);
                continue;
            }
            String rule = lhs;
            int i = 2;
            while (i < length - 1) {
                String rhsChild = ComparePetroBitPar.convertLable(lineSplit[i]);
                rule = String.valueOf(rule) + " " + rhsChild;
                ++i;
            }
            this.grammarRulesLogProb.put(rule, prob);
        }
        System.out.println("Skipped small prob internal rules: " + smallProbRuleSkipped);
        System.out.println("Total accepted internal rules: " + totalAcceptedRules);
    }

    private static String convertLable(String label) {
        return label.replace('_', '-').replaceFirst(regexConvertLabel, "");
    }

    private void convertLexicon() {
        this.lexRulesLogProb = new Hashtable();
        int smallProbRuleSkipped = 0;
        int totalAcceptedRules = 0;
        Scanner lexiconScan = FileUtil.getScanner(petrovLexiconFile);
        while (lexiconScan.hasNextLine()) {
            String line = lexiconScan.nextLine();
            line = line.replaceAll("\\\\", "");
            String[] lineSplit = line.split("\\s");
            int length = lineSplit.length;
            String pos = lineSplit[0];
            String lex = lineSplit[1];
            int index = 0;
            int i = 2;
            while (i < length) {
                double prob = ComparePetroBitPar.cleanProb(lineSplit[i]);
                if (prob > minProbRule) {
                    ++totalAcceptedRules;
                    String refinedPos = String.valueOf(pos) + "-" + index;
                    String rule = String.valueOf(refinedPos) + " " + lex;
                    this.lexRulesLogProb.put(rule, prob);
                } else {
                    ++smallProbRuleSkipped;
                }
                ++index;
                ++i;
            }
        }
        System.out.println("Skipped small prob lex rules: " + smallProbRuleSkipped);
        System.out.println("Total accepted lex rules: " + totalAcceptedRules);
    }

    private static double cleanProb(String p) {
        return Double.parseDouble(p.replaceAll(regexCleanProb, ""));
    }

    private void makeComparison() throws Exception {
        ArrayList<TSNodeLabel> petrovTreebank = TSNodeLabel.getTreebank(petrovParsedFile);
        ArrayList<TSNodeLabel> bitparTreebank = TSNodeLabel.getTreebank(bitparParsedFile);
        if (petrovTreebank.size() != bitparTreebank.size()) {
            System.err.println("Sizes differ");
        }
        File logFile = new File(String.valueOf(outputPath) + "compareReport.log");
        PrintWriter pw = FileUtil.getPrintWriter(logFile);
        Iterator<TSNodeLabel> petrovIter = petrovTreebank.iterator();
        Iterator<TSNodeLabel> bitparIter = bitparTreebank.iterator();
        int differInLex = 0;
        int equalCounter = 0;
        int totalEqualLex = 0;
        int[] petrovBitparEqual = new int[3];
        int index = 0;
        while (petrovIter.hasNext()) {
            int winner;
            TSNodeLabel bitparTree;
            ++index;
            TSNodeLabel petrovTree = petrovIter.next();
            if (!petrovTree.sameLexLabels(bitparTree = bitparIter.next())) {
                ++differInLex;
                continue;
            }
            ++totalEqualLex;
            if (petrovTree.equals(bitparTree)) {
                ++equalCounter;
                continue;
            }
            int n = winner = this.compareProbTrees(petrovTree, bitparTree, pw, index);
            petrovBitparEqual[n] = petrovBitparEqual[n] + 1;
        }
        pw.close();
        System.out.println("Total differ in lex: " + differInLex);
        System.out.println("Total equal lex: " + totalEqualLex);
        System.out.println("Equal trees: " + equalCounter);
        System.out.println("Petrov wins: " + petrovBitparEqual[0]);
        System.out.println("Bitpar wins: " + petrovBitparEqual[1]);
        System.out.println("Equal wins: " + petrovBitparEqual[2]);
    }

    private int compareProbTrees(TSNodeLabel petrovTree, TSNodeLabel bitparTree, PrintWriter pw, int index) {
        pw.println("Index: " + index);
        int length = petrovTree.countLexicalNodes();
        pw.println("Length: " + length);
        pw.println(petrovTree.toFlatSentence());
        pw.println("Petrov: ");
        double petrovLogProb = this.getLogProb(petrovTree, pw, true);
        pw.println("BitPar: ");
        double bitparLogProb = this.getLogProb(bitparTree, pw, false);
        int result = petrovLogProb > bitparLogProb ? 0 : (petrovLogProb < bitparLogProb ? 1 : 2);
        double diff = petrovLogProb - bitparLogProb;
        String lineReport = String.valueOf(index) + " " + length + " " + "Petrov log prob: " + petrovLogProb + " BitPar log prob: " + bitparLogProb + " diff: " + diff;
        System.out.println(lineReport);
        pw.println(lineReport);
        pw.println();
        return result;
    }

    private double getLogProb(TSNodeLabel tree, PrintWriter pw, boolean petrov) {
        double result = 0.0;
        ArrayList<TSNodeLabel> nodes = tree.collectAllNodes();
        for (TSNodeLabel n : nodes) {
            Double prob;
            if (n.isLexical) continue;
            String rule = n.cfgRuleNoQuotes();
            Double d = prob = n.isPreLexical() ? this.lexRulesLogProb.get(rule) : this.grammarRulesLogProb.get(rule);
            if (prob == null) {
                System.out.println("\tNot found rule: (" + (petrov ? "petrov" : "bitpar") + ") " + rule);
                return 0.0;
            }
            result += prob.doubleValue();
            pw.println("\t" + rule + " " + prob);
        }
        pw.println("\tTotalProb: " + result);
        return result;
    }

    public static void main(String[] args) throws Exception {
        minProbRule = 0.0;
        outputPath = "tmp/compare/";
        petrovGrammarFile = new File("tmp/compare/eng_sm6_readable.gr.grammar");
        petrovLexiconFile = new File("tmp/compare/eng_sm6_readable.gr.lexicon");
        petrovParsedFile = new File("tmp/compare/wsj-24_eng_sm6_viterbi_sub_40.mrg");
        bitparParsedFile = new File("tmp/compare/BITPAR_MPD_RAW.mrg");
        new ComparePetroBitPar().run();
    }
}

