/*
 * Decompiled with CFR 0.152.
 */
package jigsaw.metrics;

import java.io.PrintWriter;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import jigsaw.metrics.TaggingEval;
import jigsaw.syntax.Tree;
import jigsaw.util.Pair;

public class TaggingCatEval
extends TaggingEval {
    private double totalToks = 0.0;
    private double totalCorrect = 0.0;
    private HashMap<String, Double> catCorrect = new HashMap();
    private HashMap<String, Double> catCount = new HashMap();
    private HashMap<Pair<String, String>, Double> errorCount = new HashMap();

    public TaggingCatEval(String str, boolean runningAverages) {
        super(str, runningAverages);
    }

    @Override
    public void evaluate(Tree<String> guess, Tree<String> gold, PrintWriter pw, double weight) {
        List<Tree<String>> guesspts = guess.getPreTerminals();
        List<Tree<String>> goldpts = gold.getPreTerminals();
        if (guesspts.size() != goldpts.size()) {
            pw.println("Warning: yield length differs: Guess " + guesspts.size() + " / Gold" + goldpts.size());
        }
        double currCorrect = 0.0;
        int i = 0;
        while (i < goldpts.size()) {
            String guesspos = guesspts.get(i).getLabel();
            String goldpos = goldpts.get(i).getLabel();
            double c = weight + (this.catCount.containsKey(goldpos) ? this.catCount.get(goldpos) : 0.0);
            this.catCount.put(goldpos, c);
            if (guesspos.equals(goldpos)) {
                currCorrect += 1.0;
                c = weight + (this.catCorrect.containsKey(goldpos) ? this.catCorrect.get(goldpos) : 0.0);
                this.catCorrect.put(goldpos, c);
            } else {
                Pair<String, String> p = new Pair<String, String>(goldpos, guesspos);
                c = weight + (this.errorCount.containsKey(p) ? this.errorCount.get(p) : 0.0);
                this.errorCount.put(p, c);
            }
            ++i;
        }
        double currAcc = currCorrect / (double)goldpts.size();
        this.totalCorrect += currCorrect * weight;
        this.totalToks += (double)goldpts.size() * weight;
        pw.format("%s [current] Acc: %.2f", this.str, currAcc * 100.0);
        if (this.runningAverages) {
            pw.format(" - [average] Acc: %.2f", this.totalCorrect / this.totalToks * 100.0);
        }
        pw.println();
    }

    @Override
    public void display(boolean verbose, PrintWriter pw) {
        pw.format("%s [summary] Acc: %.2f T#: %d%n", this.str, this.totalCorrect / this.totalToks * 100.0, (int)this.totalToks);
        if (verbose) {
            HashMap<String, Double> catAccuracy = new HashMap<String, Double>();
            for (String pos : this.catCorrect.keySet()) {
                catAccuracy.put(pos, this.catCorrect.get(pos) / this.catCount.get(pos));
            }
            List catacclist = TaggingCatEval.SortByValue(catAccuracy);
            pw.println("Top-10 low accuracy tags : ");
            int i = 0;
            while (i < 10 && i < catacclist.size()) {
                String pos = (String)catacclist.get(i);
                pw.format(" %s : %.0f / %.0f = %.2f %n", pos, this.catCorrect.get(pos), this.catCount.get(pos), (Double)catAccuracy.get(pos) * 100.0);
                ++i;
            }
            pw.println("Top-10 tagging errors : ");
            List<Pair<String, String>> caterrlist = TaggingCatEval.SortByValue(this.errorCount);
            int i2 = 0;
            while (i2 < 10 && i2 < caterrlist.size()) {
                Pair<String, String> pair = caterrlist.get(caterrlist.size() - i2 - 1);
                pw.format(" %s -> %s %.0f%n", pair.getFirst(), pair.getSecond(), this.errorCount.get(pair));
                ++i2;
            }
        }
    }

    public static <K, V> List<K> SortByValue(final Map<K, V> m) {
        ArrayList<K> keys = new ArrayList<K>();
        keys.addAll(m.keySet());
        Collections.sort(keys, new Comparator<K>(){

            @Override
            public int compare(Object o1, Object o2) {
                Object v1 = m.get(o1);
                Object v2 = m.get(o2);
                if (v1 == null) {
                    return v2 == null ? 0 : 1;
                }
                if (v1 instanceof Comparable) {
                    return ((Comparable)v1).compareTo(v2);
                }
                return 0;
            }
        });
        return keys;
    }
}

