package szte.csd.indicatorsel;

import java.util.*;

import szte.csd.ContentShiftDetector;
import szte.datamining.DataHandler;

/**
 * 
 * Each indicator selection method has a paramter T.
 * TresholdOptimizer finds the best T value at the training set of a given task. 
 *
 */
public class TresholdOptimizer {
  final double B = 2; //Beta for the F-measure
  
  /**
   * 
   * Threshold values obtained by cross-validation
   * 
   */
  public static double getThreshold(String indsel, String task) {
    if(indsel.equals("condent"))
    {
      if(task.equals("cmc"))
        return 0.36;
      else if(task.equals("obes"))
        return 0.2;
      else if(task.equals("wiki"))
        return 0.33;
      else if(task.equals("reuters"))
        return 0.14;
    }else if(indsel.equals("mutualinfo")){
      if(task.equals("cmc"))
        return 0.04;
      else if(task.equals("obes"))
        return 0.04;
      else if(task.equals("reuters"))
        return 0.025;
      else if(task.equals("wiki"))
        return 0.06;
    }else if(indsel.equals("maxent")){
      if(task.equals("cmc"))
        return 0.94;
      else if(task.equals("obes"))
        return 0.16;
      else if(task.equals("wiki"))
        return 0.73;
      else if(task.equals("reuters"))
        return 0.61;
    }else if(indsel.equals("prob")){
      if(task.equals("cmc"))
        return 0.33;
      else if(task.equals("obes"))
        return 0.42;
      else if(task.equals("wiki"))
        return 0.33;
      else if(task.equals("reuters"))
        return 0.39;
      else if(task.equals("webps"))
        return 0.5;
    }
    System.err.println("unkonwn configuration: " + task + " "+ indsel);
    return 0.0; 
  }

  protected ContentShiftDetector ltc;
  
  public TresholdOptimizer(ContentShiftDetector l){
    ltc = l;
  }
  
  protected double getF(List<DataHandler> traindata, List<DataHandler> evaldata,double t) throws Exception{
    ltc.getIndicatorselector().setThreshold(t);
    double f = 0.0;
    for(int fold=0; fold<evaldata.size(); ++fold)
    {
      int tp=0, fp=0, fn=0;
      for(String label : ltc.getTerms().keySet())
      {
        ltc.relabel(traindata.get(fold), ltc.getTrain(), label);
        DataHandler fstraindata = ltc.featuresel(traindata.get(fold));
        ltc.getTerms().put(label, ltc.getIndicatorselector().getIndicators(fstraindata));
        DataHandler fsevaldata = evaldata.get(fold).createSubset(evaldata.get(fold).getInstanceIds(), ltc.getTerms().get(label));
        ltc.relabel(fsevaldata,ltc.getEval(),label);
        Map<String, Set<String>> result = ltc.predictAllVSM(fsevaldata, label);
        for(String id : result.keySet())
        {
          boolean pred = result.get(id).contains(label);
          boolean etal = (Boolean)fsevaldata.getLabel(id);
          if(pred && etal)
            tp++;
          else if(pred && !etal)
          {
            fp++;
          }
          else if(!pred && etal)
          {
            fn++;
          }
        }
      }
      double prec = (double)tp / (tp+fp);
      double rec  = (double)tp / (tp+fn);
      if(prec+rec!=0.0)
        f += ((1+B)*prec*rec)/(B*prec+rec);
    }
    return f/evaldata.size();
  }

  protected double crossval(DataHandler vsm, int cv, int precision, double min, double max) throws Exception{
    List<DataHandler> evaldata = new ArrayList<DataHandler>();
    List<DataHandler> traindata = new ArrayList<DataHandler>();
    ltc.initClassifiers(ltc.getTrain());
    ltc.setEval(ltc.getTrain());
    for(int f=0;f<cv;++f)
    {
      int i=-1;
      Set<String> eval = new HashSet<String>();
      for(String id : vsm.getInstanceIds())
        if((++i)>vsm.getInstanceCount()/cv*f && i<vsm.getInstanceCount()/cv*(f+1))
          eval.add(id);
      evaldata.add(vsm.createSubset(eval, vsm.getFeatureNames()));
      Set<String> train = new HashSet<String>(vsm.getInstanceIds());
      train.removeAll(eval);
      traindata.add(vsm.createSubset(train, vsm.getFeatureNames()));
    }

    for(int d=0;d<precision;++d)
    {
      double mid = (min+max)/2.0;
      double low = (min+mid)/2.0;
      double high= (max+mid)/2.0;
      long time = System.currentTimeMillis();
      double l = getF(traindata, evaldata, low);
      double h = getF(traindata, evaldata, high);
      System.out.println((System.currentTimeMillis()-time)/1000 + "sec " + low+"="+l + " "+high+"="+h);
      if(h>l)
        min=mid;
      else
        max=mid;
    }
    return (min+max)/2.0;
  }

  /**
   * 
   * Determines the best T threshold on all task's training set for each indicator selection method (it takes several hours).
   * Results are printed to out. 
   * 
   */
  public static void crossvalidateAll() throws Exception{
    ContentShiftDetector ltc = new ContentShiftDetector();
    for(String t : new String("wiki reuters cmc obes").split(" "))
    {
      ltc.setTask(t);
      ltc.readCorpora();
      for(String is : new String("prob maxent condent mi mutualinfor").split(" "))
      {
        ltc.setIndicatorSelector(is);
        double bestt = new TresholdOptimizer(ltc).crossval(ltc.buildVSM(ltc.getTrain(), null),5,7, 0.0, is.equals("mutualinfo") ? 0.3 : 1.0);
        System.out.println(t+"\t"+is+"\t"+bestt);
      }
    }
  }

  /**
   * 
   * Determines the best T threshold on the given task's training set for the indicator selection method is.
   * Results are printed to out as well. 
   * 
   */
  public static double crossvalidate(String task, String is) throws Exception{
    ContentShiftDetector ltc = new ContentShiftDetector();
    ltc.setTask(task);
    ltc.readCorpora();
    ltc.setIndicatorSelector(is);
    
    double bestt = new TresholdOptimizer(ltc).crossval(ltc.buildVSM(ltc.getTrain(), null),5,5, 0.0, is.equals("mutualinfo") ? 0.1 : 1.0);
    System.out.println(task+"\t"+is+"\t"+bestt);
    return bestt;
  }
}
