/*
 * Decompiled with CFR 0.152.
 */
package main.lexinduct;

import babel.content.eqclasses.EquivalenceClass;
import babel.content.eqclasses.properties.context.Context;
import babel.ranking.EquivClassCandRanking;
import babel.ranking.MRRAggregator;
import babel.ranking.Ranker;
import babel.ranking.Reranker;
import babel.ranking.scorers.Scorer;
import babel.ranking.scorers.context.DictScorer;
import babel.ranking.scorers.context.FungS1Scorer;
import babel.ranking.scorers.edit.EditDistanceScorer;
import babel.ranking.scorers.timedistribution.TimeDistributionCosineScorer;
import babel.util.config.Configurator;
import babel.util.dict.Dictionary;
import java.io.BufferedWriter;
import java.io.FileWriter;
import java.io.IOException;
import java.text.DecimalFormat;
import java.util.Collection;
import java.util.HashSet;
import java.util.Set;
import main.lexinduct.FreqBinInductPreparer;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;

public class FreqBinInductor {
    public static final Log LOG = LogFactory.getLog(FreqBinInductor.class);
    protected static int[] K = new int[]{1, 5, 10, 20, 30, 40, 50, 60, 80, 100, 200, 300, 400, 500};

    public static void main(String[] args) throws Exception {
        LOG.info((Object)("\n" + Configurator.getConfigDescriptor()));
        FreqBinInductor collector = new FreqBinInductor();
        collector.gogo();
    }

    protected void gogo() throws Exception {
        boolean slidingWindow = Configurator.CONFIG.getBoolean("experiments.time.SlidingWindow");
        int windowSize = Configurator.CONFIG.getInt("experiments.time.WindowSize");
        String outDir = Configurator.CONFIG.getString("output.Path");
        int numThreads = Configurator.CONFIG.getInt("experiments.NumRankingThreads");
        boolean doContext = Configurator.CONFIG.getBoolean("experiments.DoContext");
        boolean doTime = Configurator.CONFIG.getBoolean("experiments.DoTime");
        boolean doEditDist = Configurator.CONFIG.getBoolean("experiments.DoEditDistance");
        boolean doAggregate = Configurator.CONFIG.getBoolean("experiments.DoAggregate");
        int maxNumTrgPerSrc = K[K.length - 1];
        FreqBinInductPreparer preparer = new FreqBinInductPreparer();
        preparer.prepare();
        preparer.writeSelectedCandidates(String.valueOf(outDir) + "src.selected");
        Set<EquivalenceClass> srcSubset = preparer.getSrcEqsToInduct();
        Set<EquivalenceClass> trgSet = preparer.getTrgEqs();
        FungS1Scorer contextScorer = new FungS1Scorer(preparer.getSeedDict(), preparer.getMaxSrcTokCount(), preparer.getMaxTrgTokCount());
        TimeDistributionCosineScorer timeScorer = new TimeDistributionCosineScorer(windowSize, slidingWindow);
        EditDistanceScorer editScorer = new EditDistanceScorer();
        preparer.collectContextAndTimeProps(srcSubset, trgSet);
        preparer.prepareContextAndTimeProps(true, srcSubset, contextScorer, timeScorer, false);
        preparer.prepareContextAndTimeProps(false, trgSet, contextScorer, timeScorer, false);
        HashSet<Collection<EquivClassCandRanking>> allCands = new HashSet<Collection<EquivClassCandRanking>>();
        int binNum = 0;
        for (Set<EquivalenceClass> srcBin : preparer.getBinnedSrcEqs()) {
            Collection<EquivClassCandRanking> cands;
            LOG.info((Object)(" --- Ranking candidates from bin " + binNum + " ---"));
            allCands.clear();
            if (doTime) {
                LOG.info((Object)" - Ranking candidates using time...");
                cands = this.rank(timeScorer, srcBin, trgSet, maxNumTrgPerSrc, numThreads);
                this.evaluate(cands, preparer.getSeedDict(), String.valueOf(outDir) + "time." + binNum + ".eval");
                EquivClassCandRanking.dumpToFile(preparer.getSeedDict(), cands, String.valueOf(outDir) + "time." + binNum + ".scored");
                allCands.add(cands);
            }
            if (doContext) {
                LOG.info((Object)"Ranking candidates using context...");
                cands = this.rank(contextScorer, srcBin, trgSet, maxNumTrgPerSrc, 0.0, numThreads);
                this.evaluate(cands, preparer.getSeedDict(), String.valueOf(outDir) + "context." + binNum + ".eval");
                EquivClassCandRanking.dumpToFile(preparer.getSeedDict(), cands, String.valueOf(outDir) + "context." + binNum + ".scored");
                allCands.add(cands);
            }
            if (doEditDist) {
                LOG.info((Object)"Ranking candidates using edit distance...");
                cands = this.rank(editScorer, srcBin, trgSet, maxNumTrgPerSrc, numThreads);
                this.evaluate(cands, preparer.getSeedDict(), String.valueOf(outDir) + "edit." + binNum + ".eval");
                EquivClassCandRanking.dumpToFile(preparer.getSeedDict(), cands, String.valueOf(outDir) + "edit." + binNum + ".scored");
                allCands.add(cands);
            }
            if (doAggregate) {
                LOG.info((Object)"Aggregating (MRR) all rankings...");
                MRRAggregator aggregator = new MRRAggregator();
                cands = aggregator.aggregate(allCands);
                this.evaluate(cands, preparer.getSeedDict(), String.valueOf(outDir) + "aggmrr." + binNum + ".eval");
                EquivClassCandRanking.dumpToFile(preparer.getSeedDict(), cands, String.valueOf(outDir) + "aggmrr." + binNum + ".scored");
            }
            ++binNum;
        }
        LOG.info((Object)"--- Done ---");
    }

    protected Collection<EquivClassCandRanking> rank(Scorer scorer, Set<EquivalenceClass> srcSubset, Set<EquivalenceClass> trgSet, int maxNumberPerSrc, double threshold, int numThreads) throws Exception {
        Ranker ranker = new Ranker(scorer, maxNumberPerSrc, threshold, numThreads);
        return ranker.getBestCandLists(srcSubset, trgSet);
    }

    protected Collection<EquivClassCandRanking> rank(Scorer scorer, Set<EquivalenceClass> srcSubset, Set<EquivalenceClass> trgSet, int maxNumberPerSrc, int numThreads) throws Exception {
        Ranker ranker = new Ranker(scorer, maxNumberPerSrc, numThreads);
        return ranker.getBestCandLists(srcSubset, trgSet);
    }

    protected Collection<EquivClassCandRanking> reRank(Scorer scorer, Collection<EquivClassCandRanking> cands) {
        Reranker reranker = new Reranker(scorer);
        return reranker.reRank(cands);
    }

    protected Collection<EquivClassCandRanking> reRank(Scorer scorer, Collection<EquivClassCandRanking> cands, double threshold) {
        Reranker reranker = new Reranker(scorer, threshold);
        return reranker.reRank(cands);
    }

    protected void pruneContextsAccordingToScore(Set<EquivalenceClass> srcEqs, Set<EquivalenceClass> trgEqs, DictScorer scorer) {
        Context.ScoreComparator comparator = new Context.ScoreComparator(scorer);
        int pruneContextEqs = Configurator.CONFIG.getInt("experiments.context.PruneContextToSize");
        for (EquivalenceClass ec : srcEqs) {
            ((Context)ec.getProperty(Context.class.getName())).pruneContext(pruneContextEqs, comparator);
        }
        for (EquivalenceClass ec : trgEqs) {
            ((Context)ec.getProperty(Context.class.getName())).pruneContext(pruneContextEqs, comparator);
        }
    }

    protected void evaluate(Collection<EquivClassCandRanking> cands, Dictionary testDict, String outFileName) throws IOException {
        BufferedWriter writer = new BufferedWriter(new FileWriter(outFileName));
        DecimalFormat df = new DecimalFormat("0.00");
        writer.write("K\tAccuracy@TopK\tNumInDict");
        writer.newLine();
        int i = 0;
        while (i < K.length) {
            double oneInTopK = 0.0;
            double total = 0.0;
            for (EquivClassCandRanking ranking : cands) {
                Set<EquivalenceClass> goldTrans = testDict.translate(ranking.getEqClass());
                if (goldTrans == null) continue;
                oneInTopK += (double)(ranking.numInTopK(goldTrans, K[i]) > 0 ? 1 : 0);
                total += 1.0;
            }
            double accInTopK = 100.0 * oneInTopK / total;
            writer.write(String.valueOf(K[i]) + "\t" + df.format(accInTopK) + "\t" + total);
            writer.newLine();
            ++i;
        }
        writer.close();
    }
}

