package szte.csd.indicatorsel;

import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.Map;
import java.util.Set;
import java.util.SortedMap;
import java.util.Map.Entry;

import szte.csd.ContentShiftDetector;
import szte.datamining.DataHandler;
import szte.datamining.DataMiningException;
import szte.datamining.Model;
import szte.datamining.mallet.MalletClassifier;
import szte.datamining.mallet.MalletDataHandler;
import szte.io.CMCDataHolder;
import szte.io.DocumentSet;
import szte.nlputils.MapUtils;
import cc.mallet.classify.MaxEnt;

/**
 * 
 * MaxEntSelector trains a MaxEnt on the dataset and returns with the features recieving the highest weights for the positive class during the training.
 * 
 */
public class MaxEntSelector implements IndicatorSelector {
  protected double threshold=0.5;
  
  public Set<String> getIndicators(DataHandler vsm)
      throws DataMiningException {
    Set<String> res = new HashSet<String>();
    if(vsm.getFeatureCount()==0)
      return res;
    Map<String,Object> param = new HashMap<String,Object>();
    param.put("classifier", "MaxEnt");
    vsm.initClassifier(param);
    Model model = vsm.trainClassifier();
    MaxEnt maxent = (MaxEnt)((MalletClassifier)model).getClassifier();
    int numFeatures = maxent.getAlphabet().size()+1;
    int li = maxent.getLabelAlphabet().lookupIndex(true);
    Map<String,Double> features = new HashMap<String,Double>();
    for (int i = 0; i < maxent.getDefaultFeatureIndex(); i++) {
      Object name = maxent.getAlphabet().lookupObject (i);
      double weight = maxent.getParameters() [li*numFeatures + i];
      features.put(name.toString(),weight);
    }
    SortedMap<String, Double> sorted = MapUtils.sortMapByValue(features);
    Iterator<Entry<String, Double>> itr = sorted.entrySet().iterator();
    while(true){
      if(!itr.hasNext()) break;
      Entry<String, Double> e = itr.next();
      DataHandler inst = new MalletDataHandler();
      Map<String,Object> par = new HashMap<String, Object>();
      par.put("useFeatureSet", vsm);
      inst.createNewDataset(par);
      inst.setBinaryValue("X", e.getKey(), true);
      double prob = inst.classifyDataset(model).getPredictionProbabilities("X").get(true);
      if(prob>threshold)
        res.add(e.getKey());
      else
        break;
    }
    return res;
  }

  public void setThreshold(double t) {
    threshold = t;
  }

  public static void main(String[] a) throws Exception{
    DocumentSet train = new CMCDataHolder();
    train.readDocumentSet("2007ChallengeTrainData.xml");
    ContentShiftDetector ltc = new ContentShiftDetector();
    DataHandler traindata = ltc.buildVSM(train,null);
    String label = "786.2";
    ltc.relabel(traindata, train, label);
    System.out.println(new MaxEntSelector().getIndicators(traindata));
  }
}
