package szte.datamining.mallet;

import java.io.BufferedReader;
import java.io.FileNotFoundException;
import java.io.FileReader;
import java.io.IOException;
import java.io.PrintWriter;
import java.io.Serializable;
import java.util.HashMap;
import java.util.HashSet;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.logging.Level;

import szte.datamining.ClassificationResult;
import szte.datamining.DataHandler;
import szte.datamining.DataMiningException;
import szte.datamining.Model;
import cc.mallet.classify.ClassifierTrainer;
import cc.mallet.classify.MaxEntOptimizableByLabelLikelihood;
import cc.mallet.classify.MaxEntTrainer;
import cc.mallet.types.Alphabet;
import cc.mallet.types.AugmentableFeatureVector;
import cc.mallet.types.Instance;
import cc.mallet.types.InstanceList;
import cc.mallet.types.Label;
import cc.mallet.types.LabelAlphabet;
import cc.mallet.util.MalletLogger;
import cc.mallet.util.MalletProgressMessageLogger;

public class MalletDataHandler extends DataHandler implements Serializable{
  private static final long serialVersionUID = -2360382190864255796L;

  transient public    InstanceList data;
  transient public    Map<String, Integer> instanceIds;
  protected LabelAlphabet labelAlphabet;
  protected Alphabet featureAlphabet;
  transient protected Map<String,List<String>> nominalValues;
  transient protected ClassifierTrainer trainer = null;
  
  static{((MalletLogger)MalletProgressMessageLogger.getLogger(MaxEntOptimizableByLabelLikelihood.class.getName())).getRootLogger().setLevel(Level.WARNING);}

  protected Instance getInstance(String instanceId){
    if(!instanceIds.containsKey(instanceId))
    {
      Instance inst = new Instance(new AugmentableFeatureVector(featureAlphabet,0,false),labelAlphabet.lookupLabel(false),instanceId,instanceId);
      data.add(inst);
      inst.unLock();
      instanceIds.put(instanceId,data.size()-1);
    }
    return data.get(instanceIds.get(instanceId));
  }
  
  protected AugmentableFeatureVector getInstanceData(String instanceId) {
    return  (AugmentableFeatureVector)getInstance(instanceId).getData();
  }

  protected double getDoubleValue(String instanceId, String featureName) throws DataMiningException {
    if(!featureAlphabet.contains(featureName))
      throw new DataMiningException("getter for unexisting feature: "+featureName); 
    return  getInstanceData(instanceId).value(featureAlphabet.lookupIndex(featureName));
  }

  protected void setDoubleValue(String instanceId, String featureName, double value) {
    AugmentableFeatureVector fv = getInstanceData(instanceId); 
    int index = featureAlphabet.lookupIndex(featureName);
    if(index<0) return; //it can occure when featureAlphabet.getStopGrowth()==true and the featureset does not contain the feature  
    int location = fv.location(index);
    if(location < 0)
      fv.add(index, value);
    else 
      fv.setValueAtLocation(location, value);
  }

  public void createNewDataset(Map<String, Object> parameters) {
    if(parameters!=null && parameters.containsKey("useFeatureSet"))
    {
      MalletDataHandler dh = (MalletDataHandler)parameters.get("useFeatureSet");
      featureAlphabet = dh.featureAlphabet;
      labelAlphabet = dh.labelAlphabet;
      featureAlphabet.stopGrowth();
      labelAlphabet.stopGrowth();
    }else
    {
      featureAlphabet = new Alphabet();
      labelAlphabet = new LabelAlphabet();
      featureAlphabet.startGrowth();
      labelAlphabet.startGrowth();
    }
    data = new InstanceList(featureAlphabet, labelAlphabet);
    instanceIds = new HashMap<String, Integer>();
    nominalValues = new HashMap<String,List<String>>();
  }

  public DataHandler createSubset(Set<String> instancesSelected,
      Set<String> featuresSelected) throws DataMiningException{
    MalletDataHandler dh = new MalletDataHandler();
    Map<String,Object> param = new HashMap<String, Object>();
    param.put("useFeatureSet", this);
    dh.createNewDataset(param);
    dh.featureAlphabet = new Alphabet();
    dh.featureAlphabet.startGrowth();
    dh.data = new InstanceList(dh.featureAlphabet, dh.labelAlphabet);
    for(String inst : instancesSelected)
    {
      AugmentableFeatureVector fv = this.getInstanceData(inst);
      for(int i=0; i<fv.numLocations(); ++i)
      {
        String featurename = (String)fv.getAlphabet().lookupObject(fv.getIndices()[i]);
        if(featuresSelected.contains(featurename))
          dh.setDoubleValue(inst, featurename, fv.getValues()[i]);
      }
      dh.setLabel(inst, this.getLabel(inst));
    }
    return dh;
  }
  
  public void addDataHandler(DataHandler dh) throws DataMiningException{
    if(!(dh instanceof MalletDataHandler))
      throw new DataMiningException("MalletDataHandler can add just MalletDataHandlers");
    for(String inst : dh.getInstanceIds())
    {
      AugmentableFeatureVector fv = ((MalletDataHandler)dh).getInstanceData(inst);
      for(int i=0; i<fv.numLocations(); ++i)
        this.setDoubleValue(inst, (String)fv.getAlphabet().lookupObject(fv.getIndices()[i]), fv.getValues()[i]);
      this.setLabel(inst, dh.getLabel(inst));
    }
  }

  public Boolean getBinaryValue(String instanceId, String featureName) throws DataMiningException{
    return getDoubleValue(instanceId,featureName) > 0.0;
  }

  public int getFeatureCount() {
    return data.getAlphabet().size();
  }

  public Set<String> getFeatureNames() {
    Set<String> featurenames = new HashSet<String>();
    for(Object o : data.getDataAlphabet().toArray())
      featurenames.add((String)o);
    return featurenames;
  }

  public List<String> getFeatureValues(String featureName) {
    return nominalValues.containsKey(featureName) ? nominalValues.get(featureName) : null;
  }

  public int getInstanceCount() {
    return data.size();
  }

  public Set<String> getInstanceIds() {
    return instanceIds.keySet();
  }

  public <T extends Comparable<?>> T  getLabel(String instanceId){
    return (T)((Label)getInstance(instanceId).getTarget()).getEntry();
  }

  public String getNominalValue(String instanceId, String featureName) throws DataMiningException{
    if(!nominalValues.containsKey(featureName))
      throw new DataMiningException(featureName+" is not a nominal feature");
    return nominalValues.get(featureName).get((int)getDoubleValue(instanceId,featureName));
  }

  public Double getNumericValue(String instanceId, String featureName) throws DataMiningException{
    return getDoubleValue(instanceId,featureName);
  }

  public <T extends Comparable<?>> T getValue(String instanceId,String featureName) throws DataMiningException{
    return (T)(Double)getDoubleValue(instanceId,featureName);
  }

  public void initClassifier(Map<String, Object> parameters) throws DataMiningException{
    String classifierName = "MaxEnt";
//    String classifierName = "C45";
    if(parameters!=null && parameters.containsKey("classifier"))
      classifierName = (String)parameters.get("classifier");
    try{
      trainer = (ClassifierTrainer)Class.forName("cc.mallet.classify."+classifierName+"Trainer").newInstance();
      //((MaxEntTrainer)trainer).setGaussianPriorVariance(1.0);
      /*((C45Trainer)trainer).setMinNumInsts(3);
      ((C45Trainer)trainer).setDepthLimited(true);
      ((C45Trainer)trainer).setMaxDepth(2);
      ((C45Trainer)trainer).setDoPruning(true);*/
    }catch(Exception e){
      throw new DataMiningException("unknown classifier: "+classifierName,e);
    }
  }
  
  public Model trainClassifier() throws DataMiningException{
    if(trainer==null)
      initClassifier(null);
    return new MalletClassifier(trainer.train(data));
  }

  public ClassificationResult classifyDataset(Model model) throws DataMiningException {
    if(!(model instanceof MalletClassifier))
      throw new DataMiningException("MalletDataHandler can be used only by MALLET classifiers");
    return new MalletClassificationResult(((MalletClassifier)model).getClassifier().classify(data), this);
  }

  public void removeFeature(String featureName) throws DataMiningException{
    throw new DataMiningException("removeFeature is not implemented yet in MalletDataHandler");
  }

  public void removeInstance(String instanceId) throws DataMiningException{
    throw new DataMiningException("removeFeature is not implemented yet in MalletDataHandler");  
  }

  public void loadDataset(String source) throws DataMiningException {
    try{
      BufferedReader file = new BufferedReader(new FileReader(source));
      createNewDataset(null);
      String line;
      while ((line = file.readLine()) != null)
      {
        String[] tokens = line.split("\\s");
        if(!tokens[0].contains("@"))
          throw new DataMiningException("Corrput input format. First token should contain @");
        String id = tokens[0].split("@")[0];
        setLabel(id,Boolean.parseBoolean(tokens[0].split("@")[1]));
        for(int i=1;i<tokens.length;++i)
          setDoubleValue(id, tokens[i].split(":")[0], Double.parseDouble(tokens[i].split(":")[1]));
      }
    } catch (IOException e) {
      e.printStackTrace();
    } 
  }

  public void saveDataset(String target) {
    if(!target.contains("|") || target.split("\\|")[1].equals("mallet"))
      saveDatasetMallet(target.split("\\|")[0]);
    else if(target.split("\\|")[1].equals("svm"))
      saveDatasetSVM(target.split("\\|")[0]);
    else if(target.split("\\|")[1].equals("weka"))
      saveDatasetWeka(target.split("\\|")[0]);
    else
      System.err.println("unknow output format "+target.split("\\|")[1]);
  }
  
  public void saveDatasetMallet(String target) {
    try{
      PrintWriter out = new PrintWriter(target);
      for(String id : instanceIds.keySet())
      {
        out.print(id+"@"+getLabel(id));
        AugmentableFeatureVector fv = getInstanceData(id);
        for(int i=0; i<fv.numLocations(); ++i)
          out.print("\t"+featureAlphabet.lookupObject(fv.getIndices()[i])+":"+fv.getValues()[i]);
        out.println();
      }
      out.close();
    } catch (FileNotFoundException e) {
      e.printStackTrace();
    } 
  }

  public void saveDatasetSVM(String target) {
    try{
      PrintWriter out = new PrintWriter(target);
      for(String id : instanceIds.keySet())
      {
        out.print(getLabel(id) ? "1" : "-1");
        AugmentableFeatureVector fv = getInstanceData(id);
        for(int i=0; i<fv.numLocations(); ++i)
          out.print(" "+(fv.getIndices()[i]+1)+":"+fv.getValues()[i]);
        out.println();
      }
      out.close();
    } catch (FileNotFoundException e) {
      e.printStackTrace();
    } 
  }

  public void saveDatasetWeka(String target) {
    try{
      if(!target.endsWith(".arff"))
        target += ".arff";
      PrintWriter out = new PrintWriter(target);
      out.println("@relation MalletData");
      for(Object f : data.getAlphabet().toArray())
      {
        String name = f.toString().replaceAll("'", "");
        out.println("@attribute '"+name+"' numeric");
      }
      out.println("@attribute classlabel {0,1}");
      out.println("@data");
      for(String id : instanceIds.keySet())
      {
        out.print("{");
        AugmentableFeatureVector fv = getInstanceData(id);
        for(int i=0; i<fv.numLocations(); ++i)
        {
          if(i>0)
            out.print(",");
          out.print((fv.getIndices()[i])+" "+fv.getValues()[i]);
        }
        if(getLabel(id))
          out.print(","+ data.getAlphabet().size() + " 1");
        out.println("}");
      }
      out.close();
    } catch (FileNotFoundException e) {
      e.printStackTrace();
    } 
  }
  
  public void setBinaryValue(String instanceId, String featureName, Boolean value) {
    setDoubleValue(instanceId,featureName,value ? 1.0 : 0.0);
  }

  public void setBinaryValue(String instanceId, String featureName, Boolean value, boolean ternal) {
    System.err.println("ternal isn't implemented in MalletDataHandler");
  }

  public void setDefaultFeatureValue(String featureName, String value) throws DataMiningException{
    if(!nominalValues.containsKey(featureName))
      throw new DataMiningException("setDefaultFeatureValue is called for a feature which is not nominal");
    if(nominalValues.get(featureName).contains(value))
    {
      nominalValues.get(featureName).remove(value);
    }
    nominalValues.get(featureName).add(0, value);
  }

  public <T extends Comparable<?>> void  setLabel(String instanceId, T label) {
    getInstance(instanceId).setTarget(labelAlphabet.lookupLabel(label));
  }

  public void setNominalValue(String instanceId, String featureName,String value) {
    if(!nominalValues.containsKey(featureName))
    {
      nominalValues.put(featureName, new LinkedList<String>());
      nominalValues.get(featureName).add("MISSINGVALUE");
    }
    int pos = nominalValues.get(featureName).indexOf(value);
    if(pos < 0)
    {
      nominalValues.get(featureName).add(value);
      pos = nominalValues.get(featureName).size()-1;
    }
    setDoubleValue(instanceId,featureName,(double)pos);
  }

  public void setNumericValue(String instanceId, String featureName,double value) {
    setDoubleValue(instanceId,featureName,value);
  }
  
  protected MalletDataHandler clone(){
    MalletDataHandler dh = new MalletDataHandler();
    dh.createNewDataset(null);
    try{
      dh.addDataHandler(this);
    }catch(DataMiningException e)
    {
      e.printStackTrace();
    }
    return dh;
  }


  public <T extends Comparable<?>> void setValue(String instanceId,
      String featureName, T value) throws DataMiningException {
    if(featureName.startsWith("b_"))
      setBinaryValue(instanceId, featureName, (Boolean)value);
    else  if(featureName.startsWith("t_"))
      setBinaryValue(instanceId, featureName, (Boolean)value, true);
    else if(featureName.startsWith("m_"))
      setNominalValue(instanceId, featureName, (String)value);
    else if(featureName.startsWith("m_"))
      setNumericValue(instanceId, featureName, (Double)value);
    else
      throw new DataMiningException("unknown featuretype "+featureName);
  }

  public static void main(String[] args) throws Exception {
    MalletDataHandler d = new MalletDataHandler();
    d.loadDataset("data.in");
    d.saveDataset("data.out");
    d.setBinaryValue("A", "Aladr", true);
    d.setNominalValue("A", "Bla", "pcs");
    d.setNumericValue("B", "Cecil", 5.55);
    d.setNominalValue("C", "Bla", "tcs");
    d.setLabel("A", "P");
    d.setLabel("B", "P");
    d.setLabel("C", "P");
    System.out.println(d.getBinaryValue("A", "Aladr"));
    System.out.println(d.getNumericValue("A", "Aladr"));
    System.out.println(d.getNominalValue("A", "Bla"));
    System.out.println(d.getNominalValue("B", "Bla"));
    
    System.out.println("size:"+d.data.size());
    System.out.println(d.classifyDataset(d.trainClassifier()).getPredictionProbabilities("B"));

    MalletDataHandler d2 = (MalletDataHandler)d.clone();
    System.out.println("size:"+d.data.size());
    System.out.println("size:"+d2.data.size());
    d2.setNumericValue("D", "Cecil", 5.55);
    System.out.println("size:"+d.data.size());
    System.out.println("size:"+d2.data.size());
  }
}

