/*
 * Decompiled with CFR 0.152.
 */
package edu.stanford.nlp.ie.machinereading;

import edu.stanford.nlp.classify.GeneralDataset;
import edu.stanford.nlp.classify.LinearClassifier;
import edu.stanford.nlp.classify.LinearClassifierFactory;
import edu.stanford.nlp.classify.RVFDataset;
import edu.stanford.nlp.classify.SVMLightClassifierFactory;
import edu.stanford.nlp.ie.machinereading.Extractor;
import edu.stanford.nlp.ie.machinereading.LabelValidator;
import edu.stanford.nlp.ie.machinereading.RelationFeatureFactory;
import edu.stanford.nlp.ie.machinereading.structure.AnnotationUtils;
import edu.stanford.nlp.ie.machinereading.structure.ExtractionObject;
import edu.stanford.nlp.ie.machinereading.structure.MachineReadingAnnotations;
import edu.stanford.nlp.ie.machinereading.structure.RelationMention;
import edu.stanford.nlp.ie.machinereading.structure.RelationMentionFactory;
import edu.stanford.nlp.io.IOUtils;
import edu.stanford.nlp.ling.CoreAnnotations;
import edu.stanford.nlp.ling.Datum;
import edu.stanford.nlp.pipeline.Annotation;
import edu.stanford.nlp.stats.Counter;
import edu.stanford.nlp.stats.Counters;
import edu.stanford.nlp.util.CoreMap;
import edu.stanford.nlp.util.Execution;
import edu.stanford.nlp.util.Pair;
import java.io.File;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.ObjectOutputStream;
import java.io.PrintWriter;
import java.io.StringWriter;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.logging.Level;
import java.util.logging.Logger;

public class BasicRelationExtractor
implements Extractor {
    private static final long serialVersionUID = 2606577772115897869L;
    private static final Logger logger = Logger.getLogger(BasicRelationExtractor.class.getName());
    protected LinearClassifier<String, String> classifier;
    @Execution.Option(name="featureCountThreshold", gloss="feature count threshold to apply to dataset")
    public int featureCountThreshold = 2;
    @Execution.Option(name="featureFactory", gloss="Feature factory for the relation extractor")
    public RelationFeatureFactory featureFactory;
    @Execution.Option(name="sigma", gloss="strength of the prior on the linear classifier (passed to LinearClassifierFactory) or the C constant if relationExtractorClassifierType=svm")
    public double sigma = 1.0;
    public String relationExtractorClassifierType = "linear";
    protected boolean createUnrelatedRelations;
    private LabelValidator validator;
    protected RelationMentionFactory relationMentionFactory;

    public void setValidator(LabelValidator lv) {
        this.validator = lv;
    }

    public void setRelationExtractorClassifierType(String s) {
        this.relationExtractorClassifierType = s;
    }

    public void setFeatureCountThreshold(int i) {
        this.featureCountThreshold = i;
    }

    public void setSigma(double d) {
        this.sigma = d;
    }

    public BasicRelationExtractor(RelationFeatureFactory featureFac, Boolean createUnrelatedRelations, RelationMentionFactory factory) {
        this.featureFactory = featureFac;
        this.createUnrelatedRelations = createUnrelatedRelations;
        this.relationMentionFactory = factory;
        logger.setLevel(Level.INFO);
    }

    public void setCreateUnrelatedRelations(boolean b) {
        this.createUnrelatedRelations = b;
    }

    public static BasicRelationExtractor load(String modelPath) throws IOException, ClassNotFoundException {
        return (BasicRelationExtractor)IOUtils.readObjectFromURLOrClasspathOrFileSystem(modelPath);
    }

    @Override
    public void save(String modelpath) throws IOException {
        String path;
        File f;
        int lastSlash = modelpath.lastIndexOf(File.separator);
        if (lastSlash > 0 && !(f = new File(path = modelpath.substring(0, lastSlash))).exists()) {
            f.mkdirs();
        }
        FileOutputStream fos = new FileOutputStream(modelpath);
        ObjectOutputStream out2 = new ObjectOutputStream(fos);
        out2.writeObject(this);
        out2.close();
    }

    @Override
    public void train(Annotation sentences) {
        GeneralDataset<String, String> trainSet = this.createDataset(sentences);
        this.trainMulticlass(trainSet);
    }

    public void trainMulticlass(GeneralDataset<String, String> trainSet) {
        if (this.relationExtractorClassifierType.equalsIgnoreCase("linear")) {
            LinearClassifierFactory lcFactory = new LinearClassifierFactory(1.0E-4, false, this.sigma);
            lcFactory.setVerbose(false);
            this.classifier = lcFactory.trainClassifier((GeneralDataset)trainSet);
        } else if (this.relationExtractorClassifierType.equalsIgnoreCase("svm")) {
            SVMLightClassifierFactory svmFactory = new SVMLightClassifierFactory();
            svmFactory.setC(this.sigma);
            this.classifier = svmFactory.trainClassifier((GeneralDataset)trainSet);
        } else {
            throw new RuntimeException("Invalid classifier type: " + this.relationExtractorClassifierType);
        }
        if (logger.isLoggable(Level.FINE)) {
            BasicRelationExtractor.reportWeights(this.classifier, null);
        }
    }

    protected static void reportWeights(LinearClassifier<String, String> classifier, String classLabel) {
        if (classLabel != null) {
            logger.fine("CLASSIFIER WEIGHTS FOR LABEL " + classLabel);
        }
        Map<String, Counter<String>> labelsToFeatureWeights = classifier.weightsAsMapOfCounters();
        ArrayList<String> labels = new ArrayList<String>(labelsToFeatureWeights.keySet());
        Collections.sort(labels);
        for (String label : labels) {
            Counter<String> featWeights = labelsToFeatureWeights.get(label);
            List<Pair<String, Double>> sorted = Counters.toSortedListWithCounts(featWeights);
            StringBuilder bos = new StringBuilder();
            bos.append("WEIGHTS FOR LABEL ").append(label).append(':');
            for (Pair<String, Double> feat : sorted) {
                bos.append(' ').append(feat.first()).append(':').append(feat.second() + "\n");
            }
            logger.fine(bos.toString());
        }
    }

    protected String classOf(Datum<String, String> datum, ExtractionObject rel) {
        Counter<String> probs = this.classifier.probabilityOf(datum);
        List<Pair<String, Double>> sortedProbs = Counters.toDescendingMagnitudeSortedListWithCounts(probs);
        double nrProb = probs.getCount("_NR");
        for (Pair<String, Double> choice : sortedProbs) {
            if (((String)choice.first).equals("_NR")) {
                return (String)choice.first;
            }
            if (nrProb >= (Double)choice.second) {
                return "_NR";
            }
            if (!this.compatibleLabel((String)choice.first, rel)) continue;
            return (String)choice.first;
        }
        return "_NR";
    }

    private boolean compatibleLabel(String label, ExtractionObject rel) {
        if (rel == null) {
            return true;
        }
        if (this.validator != null) {
            return this.validator.validLabel(label, rel);
        }
        return true;
    }

    protected Counter<String> probabilityOf(Datum<String, String> testDatum) {
        return this.classifier.probabilityOf(testDatum);
    }

    protected void justificationOf(Datum<String, String> testDatum, PrintWriter pw, String label) {
        this.classifier.justificationOf(testDatum, pw);
    }

    protected List<RelationMention> extractAllRelations(CoreMap sentence) {
        ArrayList<RelationMention> extractions = new ArrayList<RelationMention>();
        List<RelationMention> cands = null;
        if (this.createUnrelatedRelations) {
            cands = AnnotationUtils.getAllUnrelatedRelations(this.relationMentionFactory, sentence, false);
        } else {
            cands = (List<RelationMention>)sentence.get(MachineReadingAnnotations.RelationMentionsAnnotation.class);
            if (cands == null) {
                cands = new ArrayList<RelationMention>();
            }
        }
        for (RelationMention rel : cands) {
            Datum<String, String> testDatum = this.createDatum(rel);
            String label = this.classOf(testDatum, rel);
            Counter<String> probs = this.probabilityOf(testDatum);
            double prob = probs.getCount(label);
            StringWriter sw = new StringWriter();
            PrintWriter pw = new PrintWriter(sw);
            if (logger.isLoggable(Level.INFO)) {
                this.justificationOf(testDatum, pw, label);
            }
            logger.info("Current sentence: " + AnnotationUtils.tokensAndNELabelsToString(rel.getArg(0).getSentence()) + "\n" + "Classifying relation: " + rel + "\n" + "JUSTIFICATION for label GOLD:" + rel.getType() + " SYS:" + label + " (prob:" + prob + "):\n" + sw.toString());
            logger.info("Justification done.");
            RelationMention relation = this.relationMentionFactory.constructRelationMention(rel.getObjectId(), sentence, rel.getExtent(), label, null, rel.getArgs(), probs);
            extractions.add(relation);
            if (!relation.getType().equals(rel.getType())) {
                logger.info("Classification: found different type " + relation.getType() + " for relation: " + rel);
                logger.info("The predicted relation is: " + relation);
                logger.info("Current sentence: " + AnnotationUtils.tokensAndNELabelsToString(rel.getArg(0).getSentence()));
                continue;
            }
            logger.info("Classification: found similar type " + relation.getType() + " for relation: " + rel);
            logger.info("The predicted relation is: " + relation);
            logger.info("Current sentence: " + AnnotationUtils.tokensAndNELabelsToString(rel.getArg(0).getSentence()));
        }
        return extractions;
    }

    public List<String> annotateMulticlass(List<Datum<String, String>> testDatums) {
        ArrayList<String> predictedLabels = new ArrayList<String>();
        for (Datum<String, String> testDatum : testDatums) {
            String label = this.classOf(testDatum, null);
            Counter<String> probs = this.probabilityOf(testDatum);
            double prob = probs.getCount(label);
            StringWriter sw = new StringWriter();
            PrintWriter pw = new PrintWriter(sw);
            if (logger.isLoggable(Level.FINE)) {
                this.justificationOf(testDatum, pw, label);
            }
            logger.fine("JUSTIFICATION for label GOLD:" + (String)testDatum.label() + " SYS:" + label + " (prob:" + prob + "):\n" + sw.toString() + "\nJustification done.");
            predictedLabels.add(label);
            if (!((String)testDatum.label()).equals(label)) {
                logger.info("Classification: found different type " + label + " for relation: " + testDatum);
                continue;
            }
            logger.info("Classification: found similar type " + label + " for relation: " + testDatum);
        }
        return predictedLabels;
    }

    public void annotateSentence(CoreMap sentence) {
        ArrayList<RelationMention> relations = new ArrayList<RelationMention>();
        for (RelationMention rel : this.extractAllRelations(sentence)) {
            relations.add(rel);
        }
        for (RelationMention r : relations) {
            if (r.getType().equals("_NR")) continue;
            logger.fine("Found positive relation in annotateSentence: " + r);
        }
        sentence.set(MachineReadingAnnotations.RelationMentionsAnnotation.class, relations);
    }

    @Override
    public void annotate(Annotation dataset) {
        for (CoreMap sentence : (List)dataset.get(CoreAnnotations.SentencesAnnotation.class)) {
            this.annotateSentence(sentence);
        }
    }

    protected GeneralDataset<String, String> createDataset(Annotation corpus) {
        RVFDataset<String, String> dataset = new RVFDataset<String, String>();
        for (CoreMap sentence : (List)corpus.get(CoreAnnotations.SentencesAnnotation.class)) {
            for (RelationMention rel : AnnotationUtils.getAllRelations(this.relationMentionFactory, sentence, this.createUnrelatedRelations)) {
                ((GeneralDataset)dataset).add(this.createDatum(rel));
            }
        }
        ((GeneralDataset)dataset).applyFeatureCountThreshold(this.featureCountThreshold);
        return dataset;
    }

    protected Datum<String, String> createDatum(RelationMention rel) {
        assert (this.featureFactory != null);
        return this.featureFactory.createDatum(rel);
    }

    protected Datum<String, String> createDatum(RelationMention rel, String label) {
        assert (this.featureFactory != null);
        Datum<String, String> datum = this.featureFactory.createDatum(rel, label);
        return datum;
    }

    @Override
    public void setLoggerLevel(Level level) {
        logger.setLevel(level);
    }
}

