/*
 * Decompiled with CFR 0.152.
 */
package sdp.tools;

import java.util.ArrayList;
import java.util.Collections;
import java.util.HashSet;
import java.util.LinkedList;
import java.util.List;
import java.util.Set;
import sdp.graph.Edge;
import sdp.graph.Graph;
import sdp.graph.Node;
import sdp.io.GraphReader;

public class Scorer {
    private static final String UNLABELED = "-UNLABELED-";
    private static final String VIRTUAL = "-VIRTUAL-";
    private final boolean includeLabels;
    private final boolean includeTopNodes;
    private final boolean includePunctuation;
    private final boolean treatEdgesAsUndirected;
    private int nGraphs;
    private final Set<ScorerEdge> edgesInGoldStandard;
    private final Set<ScorerEdge> edgesInSystemOutput;
    private int nExactMatches;

    public Scorer(boolean includeLabels, boolean includeTopNodes, boolean includePunctuation, boolean treatEdgesAsUndirected) {
        this.includeLabels = includeLabels;
        this.includeTopNodes = includeTopNodes;
        this.edgesInGoldStandard = new HashSet<ScorerEdge>();
        this.edgesInSystemOutput = new HashSet<ScorerEdge>();
        this.includePunctuation = includePunctuation;
        this.treatEdgesAsUndirected = treatEdgesAsUndirected;
    }

    public Scorer() {
        this(true, true, true, false);
    }

    public void update(Graph goldStandard, Graph systemOutput) {
        assert (goldStandard.getNNodes() == systemOutput.getNNodes());
        Set<ScorerEdge> edgesG = this.getEdges(goldStandard);
        Set<ScorerEdge> edgesS = this.getEdges(systemOutput);
        ++this.nGraphs;
        this.nExactMatches += ((Object)edgesG).equals(edgesS) ? 1 : 0;
        this.edgesInGoldStandard.addAll(edgesG);
        this.edgesInSystemOutput.addAll(edgesS);
    }

    private boolean isPunctuation(Node node) {
        return node.pos.equals(".") || node.pos.equals(",") || node.pos.equals(":") || node.pos.equals("(") || node.pos.equals(")");
    }

    private boolean edgeIsAdmissible(Graph graph, int src, int tgt) {
        if (this.includePunctuation) {
            return true;
        }
        return !this.isPunctuation(graph.getNode(src)) && !this.isPunctuation(graph.getNode(tgt));
    }

    private Set<ScorerEdge> getEdges(Graph graph) {
        HashSet<ScorerEdge> edges = new HashSet<ScorerEdge>();
        for (Edge edge : graph.getEdges()) {
            if (!this.edgeIsAdmissible(graph, edge.source, edge.target)) continue;
            String label = this.includeLabels ? edge.label : UNLABELED;
            edges.add(this.makeEdge(this.nGraphs, edge.source, edge.target, label));
        }
        if (this.includeTopNodes) {
            for (Node node : graph.getNodes()) {
                if (!node.isTop || !this.edgeIsAdmissible(graph, 0, node.id)) continue;
                edges.add(this.makeEdge(this.nGraphs, 0, node.id, VIRTUAL));
            }
        }
        return edges;
    }

    public int getNEdgesInGoldStandard() {
        return this.edgesInGoldStandard.size();
    }

    public int getNEdgesInSystemOutput() {
        return this.edgesInSystemOutput.size();
    }

    public double getPrecision() {
        return (double)this.getNEdgesInCommon() / (double)this.getNEdgesInSystemOutput();
    }

    public double getRecall() {
        return (double)this.getNEdgesInCommon() / (double)this.getNEdgesInGoldStandard();
    }

    private Set<ScorerEdge> getEdgesInCommon() {
        HashSet<ScorerEdge> intersection = new HashSet<ScorerEdge>(this.edgesInGoldStandard);
        intersection.retainAll(this.edgesInSystemOutput);
        return intersection;
    }

    public int getNEdgesInCommon() {
        return this.getEdgesInCommon().size();
    }

    public double getF1PerLabel(String label) {
        double p = this.getPrecisionPerLabel(label);
        double r = this.getRecallPerLabel(label);
        return 2.0 * p * r / (p + r);
    }

    public double getF1() {
        double p = this.getPrecision();
        double r = this.getRecall();
        return 2.0 * p * r / (p + r);
    }

    public double getExactMatch() {
        return (double)this.nExactMatches / (double)this.nGraphs;
    }

    private static List<GraphPair> readGraphs(String goldStandardFile, String systemOutputFile) throws Exception {
        Graph goldStandard;
        LinkedList<GraphPair> graphPairs = new LinkedList<GraphPair>();
        GraphReader goldStandardReader = new GraphReader(goldStandardFile);
        GraphReader systemOutputReader = new GraphReader(systemOutputFile);
        while ((goldStandard = goldStandardReader.readGraph()) != null) {
            Graph systemOutput = systemOutputReader.readGraph();
            graphPairs.add(new GraphPair(goldStandard, systemOutput));
        }
        assert (systemOutputReader.readGraph() == null);
        goldStandardReader.close();
        systemOutputReader.close();
        return graphPairs;
    }

    private static void score(Scorer scorer, List<GraphPair> graphPairs) {
        for (GraphPair pair : graphPairs) {
            scorer.update(pair.goldStandard, pair.systemOutput);
        }
    }

    private static void score(boolean includeTopNodes, boolean includePunctuation, boolean treatEdgesAsUndirected, List<GraphPair> graphPairs, boolean latexOutput) {
        Scorer scorerL = new Scorer(true, includeTopNodes, includePunctuation, treatEdgesAsUndirected);
        Scorer scorerU = new Scorer(false, includeTopNodes, includePunctuation, treatEdgesAsUndirected);
        Scorer.score(scorerL, graphPairs);
        Scorer.score(scorerU, graphPairs);
        System.err.format("Number of edges in gold standard: %d%n", scorerL.getNEdgesInGoldStandard());
        System.err.format("Number of edges in system output: %d%n", scorerL.getNEdgesInSystemOutput());
        System.err.format("Number of edges in common, labeled: %d%n", scorerL.getNEdgesInCommon());
        System.err.format("Number of edges in common, unlabeled: %d%n", scorerU.getNEdgesInCommon());
        System.err.println();
        System.err.println("### Labeled scores");
        System.err.println();
        System.err.format("LP: %3.2f%n", scorerL.getPrecision() * 100.0);
        System.err.format("LR: %3.2f%n", scorerL.getRecall() * 100.0);
        System.err.format("LF: %3.2f%n", scorerL.getF1() * 100.0);
        System.err.format("LM: %3.2f%n", scorerL.getExactMatch() * 100.0);
        System.err.println();
        System.err.println("### Breakdown by label type");
        System.err.println();
        System.err.format("%20s %10s %10s %10s %10s %10s %10s%n", "Label", "#Gold", "#System", "#Correct", "Precison", "Recall", "F1");
        ArrayList<String> labels = new ArrayList<String>(scorerL.getLabels());
        Collections.sort(labels);
        for (String label : labels) {
            System.err.format("%20s %10d %10d %10d %10.2f %10.2f %10.2f%n", label, scorerL.getNEdgesInGoldStandardByLabel(label), scorerL.getNEdgesInSystemOutputByLabel(label), scorerL.getNCorrectEdgesByLabel(label), scorerL.getPrecisionPerLabel(label) * 100.0, scorerL.getRecallPerLabel(label) * 100.0, scorerL.getF1PerLabel(label) * 100.0);
        }
        System.err.println();
        System.err.println("### Breakdown by edge length");
        System.err.println();
        ArrayList<String> quantizedLengths = new ArrayList<String>();
        for (int i = 1; i < 100; ++i) {
            String quantizedLength = scorerL.getQuantizedLength(i);
            if (quantizedLengths.contains(quantizedLength)) continue;
            quantizedLengths.add(quantizedLength);
        }
        System.err.format("%10s %10s %10s %10s %10s%n", "Length", "#Gold", "#System", "Precison", "Recall");
        for (String quantizedLength : quantizedLengths) {
            System.err.format("%10s:%10d %10d %10.2f %10.2f%n", quantizedLength, scorerL.getNEdgesInGoldStandardByQuantizedLength(quantizedLength), scorerL.getNEdgesInSystemOutputByQuantizedLength(quantizedLength), scorerL.getPrecisionPerQuantizedLength(quantizedLength) * 100.0, scorerL.getRecallPerQuantizedLength(quantizedLength) * 100.0);
        }
        System.err.println();
        System.err.println("### Unlabeled scores");
        System.err.println();
        System.err.format("UP: %3.2f%n", scorerU.getPrecision() * 100.0);
        System.err.format("UR: %3.2f%n", scorerU.getRecall() * 100.0);
        System.err.format("UF: %3.2f%n", scorerU.getF1() * 100.0);
        System.err.format("UM: %3.2f%n", scorerU.getExactMatch() * 100.0);
        if (latexOutput) {
            System.err.format("%20s%n", "\\begin{tabular}{cccc}");
            System.err.format("%20s & %10s & %10s & %10s  \\\\%n", "", "LP", "LR", "LF");
            System.err.format("%20s%n", "\\hline");
            for (String label : labels) {
                System.err.format("%20s & %10.2f & %10.2f & %10.2f  \\\\%n", label, scorerL.getPrecisionPerLabel(label) * 100.0, scorerL.getRecallPerLabel(label) * 100.0, scorerL.getF1PerLabel(label) * 100.0);
            }
            System.err.format("%20s%n", "\\hline");
            System.err.format("%20s & %10.2f & %10.2f & %10.2f  \\\\%n", "TOTAL", scorerU.getPrecision() * 100.0, scorerU.getRecall() * 100.0, scorerU.getF1() * 100.0);
            System.err.format("%20s%n", "\\end{tabular}");
            System.err.println();
        }
    }

    public static void main(String[] args) throws Exception {
        boolean includePunctuation = true;
        boolean treatEdgesAsUndirected = false;
        boolean inclusive = false;
        boolean latex = false;
        for (String arg : args) {
            if (arg.equals("excludePunctuation")) {
                System.err.println("Will exclude punctuation.");
                includePunctuation = false;
            }
            if (arg.equals("treatEdgesAsUndirected")) {
                System.err.println("Will treat edges as undirected.");
                treatEdgesAsUndirected = true;
            }
            if (arg.equals("incl")) {
                System.err.println("print results including ROOT as well.");
                inclusive = true;
            }
            if (!arg.equals("latex")) continue;
            System.err.println("print results including ROOT as well.");
            latex = true;
        }
        System.err.println("# Evaluation");
        System.err.println();
        System.err.format("Gold standard file: %s%n", args[0]);
        System.err.format("System output file: %s%n", args[1]);
        System.err.println();
        List<GraphPair> graphPairs = Scorer.readGraphs(args[0], args[1]);
        if (inclusive) {
            System.err.println("## Scores including virtual dependencies to top nodes");
            System.err.println();
            Scorer.score(true, includePunctuation, treatEdgesAsUndirected, graphPairs, latex);
        } else {
            System.err.println("## Scores excluding virtual dependencies to top nodes");
            System.err.println();
            Scorer.score(false, includePunctuation, treatEdgesAsUndirected, graphPairs, latex);
        }
    }

    private ScorerEdge makeEdge(int graphId, int src, int tgt, String label) {
        if (this.treatEdgesAsUndirected) {
            return new UndirectedScorerEdge(graphId, src, tgt, label);
        }
        return new ScorerEdge(graphId, src, tgt, label);
    }

    private Set<String> getLabels() {
        HashSet<String> labels = new HashSet<String>();
        for (ScorerEdge edge : this.edgesInGoldStandard) {
            labels.add(edge.label);
        }
        for (ScorerEdge edge : this.edgesInSystemOutput) {
            labels.add(edge.label);
        }
        return labels;
    }

    private int getNEdgesByLabel(String label, Set<ScorerEdge> edges) {
        int n = 0;
        for (ScorerEdge edge : edges) {
            n += edge.label.equals(label) ? 1 : 0;
        }
        return n;
    }

    private int getNEdgesInGoldStandardByLabel(String label) {
        return this.getNEdgesByLabel(label, this.edgesInGoldStandard);
    }

    private int getNEdgesInSystemOutputByLabel(String label) {
        return this.getNEdgesByLabel(label, this.edgesInSystemOutput);
    }

    private double getPrecisionPerLabel(String label) {
        int nEdges = 0;
        int nCorrect = 0;
        for (ScorerEdge edgeS : this.edgesInSystemOutput) {
            if (!edgeS.label.equals(label)) continue;
            ++nEdges;
            if (!this.edgesInGoldStandard.contains(edgeS)) continue;
            ++nCorrect;
        }
        return (double)nCorrect / (double)nEdges;
    }

    private int getNCorrectEdgesByLabel(String label) {
        int nCorrect = 0;
        for (ScorerEdge edgeS : this.edgesInSystemOutput) {
            if (!edgeS.label.equals(label) || !this.edgesInGoldStandard.contains(edgeS)) continue;
            ++nCorrect;
        }
        return nCorrect;
    }

    private double getRecallPerLabel(String label) {
        int nEdges = 0;
        int nCorrect = 0;
        for (ScorerEdge edgeG : this.edgesInGoldStandard) {
            if (!edgeG.label.equals(label)) continue;
            ++nEdges;
            if (!this.edgesInSystemOutput.contains(edgeG)) continue;
            ++nCorrect;
        }
        return (double)nCorrect / (double)nEdges;
    }

    private String getQuantizedLength(int length) {
        if (length <= 4) {
            return Integer.toString(length);
        }
        if (length < 10) {
            return "5-9";
        }
        return "10-";
    }

    private String getQuantizedLength(ScorerEdge edge) {
        return this.getQuantizedLength(edge.getLength());
    }

    private Set<String> getQuantizedLengths() {
        HashSet<String> lengths = new HashSet<String>();
        for (ScorerEdge edge : this.edgesInGoldStandard) {
            lengths.add(this.getQuantizedLength(edge));
        }
        for (ScorerEdge edge : this.edgesInSystemOutput) {
            lengths.add(this.getQuantizedLength(edge));
        }
        return lengths;
    }

    private int getNEdgesByQuantizedLength(String quantizedLength, Set<ScorerEdge> edges) {
        int n = 0;
        for (ScorerEdge edge : edges) {
            if (!this.getQuantizedLength(edge).equals(quantizedLength)) continue;
            ++n;
        }
        return n;
    }

    private int getNEdgesInGoldStandardByQuantizedLength(String quantizedLength) {
        return this.getNEdgesByQuantizedLength(quantizedLength, this.edgesInGoldStandard);
    }

    private int getNEdgesInSystemOutputByQuantizedLength(String quantizedLength) {
        return this.getNEdgesByQuantizedLength(quantizedLength, this.edgesInSystemOutput);
    }

    private double getPrecisionPerQuantizedLength(String quantizedLength) {
        int nEdges = 0;
        int nCorrect = 0;
        for (ScorerEdge edgeS : this.edgesInSystemOutput) {
            if (!this.getQuantizedLength(edgeS).equals(quantizedLength)) continue;
            ++nEdges;
            if (!this.edgesInGoldStandard.contains(edgeS)) continue;
            ++nCorrect;
        }
        return (double)nCorrect / (double)nEdges;
    }

    private double getRecallPerQuantizedLength(String quantizedLength) {
        int nEdges = 0;
        int nCorrect = 0;
        for (ScorerEdge edgeG : this.edgesInGoldStandard) {
            if (!this.getQuantizedLength(edgeG).equals(quantizedLength)) continue;
            ++nEdges;
            if (!this.edgesInSystemOutput.contains(edgeG)) continue;
            ++nCorrect;
        }
        return (double)nCorrect / (double)nEdges;
    }

    private static class UndirectedScorerEdge
    extends ScorerEdge {
        public UndirectedScorerEdge(int graphId, int src, int tgt, String label) {
            super(graphId, src, tgt, label);
        }

        @Override
        public int hashCode() {
            int hash = 3;
            hash = 53 * hash + this.graphId;
            hash = 53 * hash + Math.min(this.src, this.tgt);
            hash = 53 * hash + Math.max(this.src, this.tgt);
            hash = 53 * hash + (this.label != null ? this.label.hashCode() : 0);
            return hash;
        }

        @Override
        public boolean equals(Object obj) {
            if (obj == null) {
                return false;
            }
            if (this.getClass() != obj.getClass()) {
                return false;
            }
            ScorerEdge other = (ScorerEdge)obj;
            if (this.graphId != other.graphId) {
                return false;
            }
            if (Math.min(this.src, this.tgt) != Math.min(other.src, other.tgt)) {
                return false;
            }
            if (Math.max(this.src, this.tgt) != Math.max(other.src, other.tgt)) {
                return false;
            }
            return !(this.label == null ? other.label != null : !this.label.equals(other.label));
        }
    }

    private static class ScorerEdge {
        final int graphId;
        final int src;
        final int tgt;
        final String label;

        public ScorerEdge(int graphId, int src, int tgt, String label) {
            this.graphId = graphId;
            this.src = src;
            this.tgt = tgt;
            this.label = label;
        }

        public int getLength() {
            return Math.max(this.src, this.tgt) - Math.min(this.src, this.tgt);
        }

        public int hashCode() {
            int hash = 3;
            hash = 53 * hash + this.graphId;
            hash = 53 * hash + this.src;
            hash = 53 * hash + this.tgt;
            hash = 53 * hash + (this.label != null ? this.label.hashCode() : 0);
            return hash;
        }

        public boolean equals(Object obj) {
            if (obj == null) {
                return false;
            }
            if (this.getClass() != obj.getClass()) {
                return false;
            }
            ScorerEdge other = (ScorerEdge)obj;
            if (this.graphId != other.graphId) {
                return false;
            }
            if (this.src != other.src) {
                return false;
            }
            if (this.tgt != other.tgt) {
                return false;
            }
            return !(this.label == null ? other.label != null : !this.label.equals(other.label));
        }
    }

    private static class GraphPair {
        public final Graph goldStandard;
        public final Graph systemOutput;

        public GraphPair(Graph goldStandard, Graph systemOutput) {
            this.goldStandard = goldStandard;
            this.systemOutput = systemOutput;
        }
    }
}

