 package szte.csd.indicatorsel;

import java.util.HashMap;
import java.util.HashSet;
import java.util.Map;
import java.util.Set;

import szte.datamining.DataHandler;
import szte.datamining.DataMiningException;
import szte.datamining.Model;
import szte.datamining.mallet.MalletClassifier;
import szte.datamining.mallet.MalletDataHandler;
import szte.nlputils.MapUtils;
import cc.mallet.classify.MaxEnt;
import cc.mallet.types.AugmentableFeatureVector;
import cc.mallet.types.Instance;
import cc.mallet.types.Label;

/**
 * 
 * MaxEntIterativeIndicatorSel selects the highest weigthed feature (by MaxEnt trainng) for the positive class in each iteration and tries to iteratievly cover the positive class.
 * For details see (Farkas and Szarvas, 2008)
 *
 */
public class MaxEntIterativeIndicatorSel implements IndicatorSelector {

    private double threshold;

    public Set<String> getIndicators(DataHandler vsm) throws DataMiningException {
    MalletDataHandler d = (MalletDataHandler)vsm;
    Set<Integer> dnf = new HashSet<Integer>();

     while(true){
      MalletDataHandler dh = new MalletDataHandler();
      Map<String,Object> param = new HashMap<String, Object>();
      param.put("useFeatureSet", vsm);
      dh.createNewDataset(param);
      int tru = 0;
      instances: for(int i=0; i<d.data.size(); ++i)
      {
          Instance inst = d.data.get(i);
          AugmentableFeatureVector fv = (AugmentableFeatureVector)inst.getData();
          for(Integer term : dnf)
            for(int j=0;j<fv.getIndices().length;++j)
              if(fv.indexAtLocation(j)==term)
                continue instances;
              else if(fv.indexAtLocation(j)>term)
                break;
          dh.data.add(inst);
          if((Boolean)((Label)inst.getTarget()).getEntry())
            tru++;
       }
      if(vsm.getFeatureCount()==0)
        break;
      param = new HashMap<String,Object>();
      param.put("classifier", "MaxEnt");
      vsm.initClassifier(param);
      if(dh.getInstanceCount() < 7 || tru < 3)
        break;
      Model model = dh.trainClassifier();
      MaxEnt maxent = (MaxEnt)((MalletClassifier)model).getClassifier();
      int numFeatures = maxent.getAlphabet().size()+1;
      int li = maxent.getLabelAlphabet().lookupIndex(true);
      Map<Integer,Double> features = new HashMap<Integer,Double>();
      for (int i = 0; i < maxent.getDefaultFeatureIndex(); i++) {
        Integer name = i;
        double weight = maxent.getParameters() [li*numFeatures + i];
        features.put(name,weight);
      }
      if(features.size()==0)
        break;
      Integer bestfeature = MapUtils.sortMapByValue(features,true).firstKey();
      if(features.get(bestfeature) < threshold)
        break;
      dnf.add(bestfeature);
    }
    Set<String> features = new HashSet<String>();
    for(Integer i : dnf)
      features.add((String)d.data.getAlphabet().lookupObject(i));
    return features;
  }

  public void setThreshold(double t) {
    threshold = t;
  }
}
