package szte.csd;

import java.io.*;
import java.util.*;
import java.util.Map.Entry;

import szte.nlputils.MapUtils;
import szte.csd.CSDModel;
import szte.csd.Sentence;
import szte.csd.Tokenizer;
import szte.csd.baseline.*;
import szte.csd.indicatorsel.*;
import szte.datamining.*;
import szte.datamining.mallet.MalletDataHandler;
import szte.io.*;

import cc.mallet.classify.FeatureConstraintUtil;
import cc.mallet.types.AugmentableFeatureVector;
import cc.mallet.types.Instance;
import cc.mallet.types.Label;

/**
 * 
 * The main entry point of the project. ContentShiftDetector performs training and evaluation of the original ContentShiftDetection procedure aling with several baselines (article's results can be replicated). 
 *
 */
public class ContentShiftDetector {

  protected CSDModel csd = null; //it will contain a trained model
  protected int iteration = 0;
  protected PrintWriter log  = null; // the log file gets a name consisting of the main parameters of the run
  protected DataHandler datahandler = null;
  protected IndicatorSelector indicatorselector = null;
  protected DataHandler traindata = null;
  protected Tokenizer tokenizer = null;
  protected DocumentSet train = null;
  protected DocumentSet eval = null;
  protected String task = "cmc";
  protected boolean do_not_use_csd = false;
  protected int iterationNum = 3;
  protected String outputModelName;
  protected boolean logprediction = false;
  
  protected boolean baseline_featuresel = false; //whether the Vector Space Model baseline should use Information Gain-based feature selection
  protected RuleBasedCSD rulebasedCSD = null; //if not null, it does not learn CSD but use hand-crafted rules
  protected Map<String, Set<String>> negative_terms = null; //for baseline indicator selection 
  protected Map<String, Set<String>> overlapping_terms = null; //for baseline indicator selection

  public ContentShiftDetector(Properties prop){
    task = prop.getProperty("task");
    readCorpora();

    setIndicatorSelector(prop.getProperty("indicatorselector"));

    String logfile = "log/ltc";
    try {
      if(!new File("log").exists())
        new File("log").mkdir();
      logfile+="_"+task+"_"+prop.getProperty("indicatorselector")+"_"+prop.getProperty("lemmatize");
      log = new PrintWriter(logfile);
    } catch (FileNotFoundException e) {
      System.err.println("Cannot create "+logfile);
    }
    datahandler = new MalletDataHandler();

    if(prop.getProperty("baselineFeaturesel").equals("1"))
      baseline_featuresel = true;

    csd = new CSDModel(task, Integer.parseInt(prop.getProperty("syntaxFE")));
    tokenizer = csd.getTokenizer();

    if(prop.getProperty("lemmatize").equals("1"))
      tokenizer.lemmatiseOn();
    CSDModel.PHRASELENGTH = Integer.parseInt(prop.getProperty("phraselength"));
    iterationNum = Integer.parseInt(prop.getProperty("iterationNum"));
    
    outputModelName = prop.getProperty("model");
    logprediction = prop.getProperty("logPrediction").equals("1");
      
    rulebasedCSD = null;
    if(prop.getProperty("RBcsd").equals("sent"))
      rulebasedCSD = new SentenceRemover(task);
    else if(prop.getProperty("RBcsd").equals("in-sent") && task.equals("obes"))
      rulebasedCSD = new ObesRuleBased();
    else if(prop.getProperty("RBcsd").equals("in-sent") && task.equals("cmc"))
      rulebasedCSD = new CMCRuleBased();
    else if(prop.getProperty("RBcsd").equals("bioscope"))
    {
      if(!task.equals("cmc"))
      {
        System.err.println("Bioscope baseline can be applied just on CMC");
        System.exit(2);
      }  
      rulebasedCSD = new Bioscope();
    }
  }

  public ContentShiftDetector() {
    indicatorselector = new ProbIndSel();
    try {
      log = new PrintWriter("log_noname");
    } catch (FileNotFoundException e) {
      e.printStackTrace();
    }
    datahandler = new MalletDataHandler();
    tokenizer = new Tokenizer();
  }
  
  public void setIndicatorSelector(String indsel) {
    if(indsel.equals("condent"))
      indicatorselector = new CondEntropyIndicatorSelector();
    else if(indsel.equals("mi"))  
      indicatorselector = new MaxEntIterativeIndicatorSel();
    else if(indsel.equals("mutualinfo"))
      indicatorselector = new MutualInfoIndicatorSelector();
    else if(indsel.equals("maxent"))
      indicatorselector = new MaxEntSelector();
    else if(indsel.equals("prob"))
      indicatorselector = new ProbIndSel();
    else{
      System.err.println("unknown indicatorselector "+ indsel);
      System.exit(3);
    }
    indicatorselector.setThreshold(TresholdOptimizer.getThreshold(indsel, task));
  }

  public void readCorpora(){
    if(task.equals("cmc")){
      train = new CMCDataHolder();
      train.readDocumentSet("corpus/2007ChallengeTrainData.xml");
      eval = new CMCDataHolder();
      eval.readDocumentSet("corpus/2007ChallengeTestDataCodes.xml");
    }else if(task.equals("obes")){
      train = new ObesDocumentSet();
      train.readDocumentSet("corpus/obes/|corpus/obesity_standoff_annotations_training_all.xml");
      eval = new ObesDocumentSet();
      eval.readDocumentSet("corpus/obes_test/|corpus/obesity_standoff_annotations_test.xml");
    }else if(task.equals("wiki")){  
      train = new WikiDocSet();
      train.readDocumentSet("corpus/wiki.txt|corpus/wiki_train_labels");
      eval = new WikiDocSet();
      eval.readDocumentSet("corpus/wiki.txt|corpus/wiki_eval_labels");
    }else if(task.equals("reuters")){  
      train = new ReutersDocSet();
      train.readDocumentSet("corpus/reuters.xml|TRAIN");
      eval = new ReutersDocSet();
      eval.readDocumentSet("corpus/reuters.xml|TEST");
    }else{
      System.err.println("Unknown task "+task);
      System.exit(2);
    }
  }

  /**
   * 
   * You can define here a fancy termination critera for co-training.
   * 
   */
  protected boolean termination(){
    return iteration >= iterationNum;
  }

  /**
   * 
   * Evaluates the actual CSD model. Results are printed in the log file.
   * 
   */
  protected void evaluate(DocumentSet eval) throws Exception{
    log.println("Iteration "+iteration+".");
    csd.log(log);
    log.println("DOCCLASS WITHOUT CSD:"); 
    do_not_use_csd = true;
    log.println(fmeasure(eval));
    do_not_use_csd = false;;
    log.println("DOCCLASS WITH CSD:"); 
    log.println(fmeasure(eval));
    log.println("CSD ITSEF:\n" + csd_fmeasure(eval));
    log.println("--------------------------\n");
    log.flush();
  }

  /**
   * 
   * Mesears the performance of the local content shift detector.
   * precision/recall/F-measure values measure how many false positive matches of the indicator phrases
   *  can be recognized, i.e. here, the true positives are local contexts of an indicator phrase which do not indicate
   * a document label in the etalon set and the local content shift detector predicted it to be altered. 
   * 
   */
  protected String csd_fmeasure(DocumentSet eval) throws Exception{
    int tp=0, fp=0, fn=0, n=0;
    for(Document doc : eval)
    {
      Set<String> docClassPred = csd.predict(doc.getText(),true);

      if(docClassPred.size()==0)
        continue;
      Set<String> modPred = csd.predict(doc.getText(),false);
      for(String label : docClassPred)
      {          
        ++n;
        boolean etalon = doc.getLabels().contains(label);
        boolean pred   = modPred.contains(label);
        if(!etalon && !pred)
          tp++;
        else if(etalon && !pred)
        {
          fp++;
        }
        else if(!etalon && pred)
        {
          fn++;
        }
      }
    }
    double prec = (double)tp / (tp+fp);
    double rec  = (double)tp / (tp+fn);
    double f    = (2*prec*rec)/(prec+rec);
    return "P:" + prec + " R:" + rec + " F:" + f + " TP/N:" + tp+ "/" + n;
  }

/**
 * 
 * Measures the performance of indicator selector and CSD on the document multi-labeling task. 
 * It is the micro-averged F-measure of the positive (binary labeling) class. 
 * 
 */
  protected String fmeasure(DocumentSet eval) throws Exception{
    int tp=0, fp=0, fn=0;
    PrintWriter predfile = null;
    if(logprediction)
        predfile = new PrintWriter("pred_"+task+"_itr"+iteration);
    for(Document doc : eval)
    {
      Set<String> predicted_labels = csd.predict(doc.getText(), do_not_use_csd);
      for(String label : csd.getTerms().keySet())
      {
        boolean etal = doc.getLabels().contains(label);
        boolean pred = predicted_labels.contains(label);
        if(pred && etal)
        {
          tp++;
        }
        else if(pred && !etal)
        {
          fp++;
        }
        else if(!pred && etal)
        {
          fn++;
        }
        if(logprediction)
          predfile.println(doc.getDocID()+" "+label+" "+pred);
      }
    }
    double prec = (double)tp / (tp+fp);
    double rec  = (double)tp / (tp+fn);
    double f    = (2*prec*rec)/(prec+rec);
    predfile.close();
    return "P:" + prec + " R:" + rec + " F:" + f;
  } 

  public boolean isAltered(Sentence sentence, int pos){
    if(do_not_use_csd)
      return false;
    if(rulebasedCSD != null)
      return rulebasedCSD.isAltered(sentence, pos);
    if(csd.getModifier() == null)
      return false;
    return csd.isAltered(sentence, pos);
  }

/**
 * Constructs the Vector Space Model according to basic parameters (phrase length, lemmatization etc.).
 *  It employs the CSD, i.e. it ignores the text spans which are predicted to be altered.
 *  
 */
  public DataHandler buildVSM(DocumentSet docset, DataHandler train){
    System.out.print("Building VSM... "); long  start = System.currentTimeMillis(); 
    DataHandler dh = datahandler.createEmptyDataHandler();
    if(train != null)
    {
      Map<String,Object> param = new HashMap<String, Object>();
      param.put("useFeatureSet", train);
      dh.createNewDataset(param);
    }else
      dh.createNewDataset(null);
    for(Document doc : docset){
      List<Sentence> sentences = tokenizer.tokenise(doc.getText());
      Map<String, Integer> freq = new HashMap<String, Integer>(); 
      for(Sentence sent : sentences)
      {
        List<String> tokens = sent.tokens;
        for(int i=0; i<tokens.size(); ++i){
          if(tokens.get(i).length()<=1) 
            continue;
          String phrase = tokens.get(i);
          if(isAltered(sent, i))
            continue;
          MapUtils.addToMap(freq, phrase, 1);
          for(int j=1; j<csd.PHRASELENGTH; ++j){
            if(i-j<0) break;
            phrase = tokens.get(i-j) + " " + phrase;
            MapUtils.addToMap(freq, phrase, 1);
          }
        }
      }
      for(Entry<String, Integer> e : freq.entrySet())
        dh.setNumericValue(doc.getDocID(), e.getKey(), (double)e.getValue());
      dh.setLabel(doc.getDocID(), false);
    }
    System.out.println((System.currentTimeMillis()-start)/1000+" sec");
    return dh;
  }

  /**
   * 
   * We use the same Vector Space Model throughout an iteration, just its binary labeling varies according to the target labels of the original document multi-lableing task.
   *  
   */
  public void relabel(DataHandler dh, DocumentSet docset, String label) throws Exception
  {
    for(Document doc : docset)
    {
      String docid = doc.getDocID();
      if(!dh.getInstanceIds().contains(docid)) continue;
      dh.setLabel(docid, doc.getLabels().contains(label));
    }
  }

  public void initClassifiers(DocumentSet docset){
    Set<String> labels = new HashSet<String>();
    negative_terms = new HashMap<String, Set<String>>();
    overlapping_terms = new HashMap<String, Set<String>>();
    for(Document doc : docset)
    {
      labels.addAll(doc.getLabels());
    }
    
    for(String label : labels)
      csd.getTerms().put(label, new HashSet<String>());
  }

  public Map<String, Set<String>> predictAllVSM(DataHandler vsm, String label) throws Exception{
    Map<String, Set<String>> result = new HashMap<String, Set<String>>();
    Set<String> fn = vsm.getFeatureNames();
    for(String id : vsm.getInstanceIds())
    {
      if(!result.containsKey(id))
        result.put(id, new HashSet<String>());
      Boolean pred = false;
      for(String feature : csd.getTerms().get(label))
        if(fn.contains(feature) && vsm.getBinaryValue(id, feature))
        {
          pred = true;
          break;
        }
      if(pred){
        result.get(id).add(label);
      }
    }
    return result;
  }

  protected String isTerm(List<String> tokens, int pos){
    String label = csd.isTerm(tokens, pos);
    if(label == null) return null;
    if(overlapping_terms!=null){
      for(int i=0; i<=2; ++i)
      {
        if(pos+i>=tokens.size()) 
          break;
        if(label.equals(CSDModel.matchTerm(tokens, pos+i, overlapping_terms)))
          return null;
      }
    }
    return label;
  }

  /**
   * 
   * Constructs the training dataset for the CSD.
   * The instances are the local contexts (our local context of each occurrence of indicator phrases
   * in the training document set. The instances of this content shifter training dataset are then labeled as
   * non-altered when the indicated label is among the gold-standard labels of the document in question
   * or is labeled as altered otherwise.
   */
  protected DataHandler buildOccurances(DocumentSet docset) throws Exception {
    DataHandler dh = datahandler.createEmptyDataHandler();
    dh.createNewDataset(null);
    for(Document doc : docset){
      List<Sentence> sentences = tokenizer.tokenise(doc.getText());
      Map<String,Integer> freq = new HashMap<String,Integer>();
      int n=0;
      for(Sentence sent : sentences){
        List<String> tokens = sent.tokens;
        for(int i=0; i<tokens.size(); ++i)
        {
          String label; 
          if((label = isTerm(tokens,i))!=null && (!freq.containsKey(label) || freq.get(label)==1))
          {
            String id = doc.getDocID()+(++n);
            csd.featureExtract(dh,id, sent, i);
            dh.setLabel(id, !doc.getLabels().contains(label));
          }
        }
      }
    }
    return featuresel(dh);
  }

  public void fullprocess() throws Exception{
    iteration = 0;
    initClassifiers(train);
    while(!termination())
    {
      System.out.println((++iteration)+". iteration");
      traindata = buildVSM(train,null);
      System.out.println("#allFeature: "+traindata.getFeatureCount());
      for(String label : csd.getTerms().keySet())
      {//indicator selection separately for each label 
        System.out.println("label: "+label);
        relabel(traindata, train, label);
        DataHandler fstraindata = featuresel(traindata);
        System.out.println("#FSfeature: "+fstraindata.getFeatureCount());
        csd.getTerms().put(label, indicatorselector.getIndicators(fstraindata));
        System.out.println(csd.getTerms().get(label));
      }
      //CSD training by gathering occurences of each label's occurences
      csd.trainClassifier(buildOccurances(train));
      evaluate(eval);
    }
    csd.serialize(outputModelName);
  }

  public static void main(String[] args){
    if(args.length != 2)
    {
      System.err.println("usage: ContentShiftDetector train|doccleaner|baseline_bow|baseline_csd paramfile");
      System.exit(1);
    }
    Properties prop = new Properties();
    try {
      prop.load(new FileReader(args[1]));
    } catch (FileNotFoundException e1) {
      e1.printStackTrace();
    } catch (IOException e1) {
      e1.printStackTrace();
    }
    ContentShiftDetector ltc = new ContentShiftDetector(prop);
    try {
      if(args[0].equals("train"))
        ltc.fullprocess();
      else if(args[0].equals("baseline_bow"))
        ltc.BoWbaseline();
      else if(args[0].equals("baseline_csd"))
        ltc.justCSDLearning();
      else if(args[0].equals("doccleaner"))
      {
        ltc.fullprocess();
        ltc.BoWbaseline();
      }
    } catch (Exception e) {
      e.printStackTrace();
    }
  }

  public Map<String, Set<String>> getTerms() {
    return csd.getTerms();
  }

  public IndicatorSelector getIndicatorselector() {
    return indicatorselector;
  }

  public void setIndicatorselector(IndicatorSelector indicatorselector) {
    this.indicatorselector = indicatorselector;
  }

  public DocumentSet getTrain() {
    return train;
  }

  public DocumentSet getEval() {
    return eval;
  }

  public void setEval(DocumentSet eval) {
    this.eval = eval;
  }

  public String getTask() {
    return task;
  }

  public void setTask(String task) {
    this.task = task;
  }

///////////////////////////////
/// Baseline methods for experiments with hand-crafted indicator phrases (Section 4.2.1 of the paper)
/// They are used only for replicating the baseline methods.
///////////////////////////////  
  protected void justCSDLearning() throws Exception {
    if(task.equals("obes"))
      readTermsFromDir("corpus/obes_terms");
    else if(task.equals("genia"))
      readTermsFromDir("corpus/genia_terms");
    else if(task.equals("cmc"))
      readTermsFromFile("corpus/cmc_terms");
    else if(task.equals("wiki"))
      readTermsFromFile("corpus/wiki_terms");
    
    csd.trainClassifier( buildOccurances(train));
    evaluate(eval);
  }

  private void readTermsFromFile(String fn) {
    try{
      negative_terms = new HashMap<String, Set<String>>();
      overlapping_terms = new HashMap<String, Set<String>>();
      BufferedReader file = new BufferedReader(new FileReader(fn));
      String line, label ="";
      while ((line = file.readLine()) != null)
      {
        if(line.length()<2) continue;
        if(line.startsWith("-CLASS-"))
        {
          label = line.split("\t")[1];
          csd.getTerms().put(label, new HashSet<String>());
          negative_terms.put(label, new HashSet<String>());
          overlapping_terms.put(label, new HashSet<String>());
          continue;
        }
        if(!line.contains("#"))
          csd.getTerms().get(label).add(line.trim());
        else{
          overlapping_terms.get(label).add(line.trim().split("#")[1]);
          if(line.startsWith("N"))
            negative_terms.get(label).add(line.trim().split("#")[1]);
        }
      }
      file.close();
    } catch (IOException e) {
      System.err.println("Problem with file: " + fn);
    }
  }

  private void readTermsFromDir(String dir) {
    negative_terms = new HashMap<String, Set<String>>();
    overlapping_terms = new HashMap<String, Set<String>>();
    for(File f : new File(dir).listFiles()){
      try {
        BufferedReader file = new BufferedReader(new FileReader(f.getAbsoluteFile()));
        Set<String> pset = new HashSet<String>();
        Set<String> nset = new HashSet<String>();
        Set<String> oset = new HashSet<String>();
        String line;
        while ((line = file.readLine()) != null)
        {
          if(line.length()<2) continue;
          if(!line.contains("#"))
            pset.add(line.trim());
          else{
            oset.add(line.trim().split("#")[1]);
            if(line.startsWith("N"))
              nset.add(line.trim().split("#")[1]);
          }
        }
        file.close();
        csd.getTerms().put(f.getName(), pset);
        negative_terms.put(f.getName(), nset);
        overlapping_terms.put(f.getName(), oset);
        //negative_terms.put(f.getName(), new HashSet<String>());
      } catch (IOException e) {
        System.err.println("Problem with file: " + f.getAbsoluteFile());
      }
    }
  }

///////////////////////////////
/// Baseline methods for experiments with bag-of-word document classification (Raw 2 and 9 of Table 4 in the paper)
/// They are used only for replicating the baseline methods.
///////////////////////////////  
  protected void BoWbaseline() throws Exception{
    DataHandler tdata = buildVSM(train,null);
    Set<String> labels = new TreeSet<String>();
    for(Document doc : train)
      labels.addAll(doc.getLabels());
    int tp=0, fp=0, fn=0;
    PrintWriter predfile = new PrintWriter("baseline.pred");
    for(String label : labels)
    {
      System.out.println("baseline for "+label);
      relabel(tdata, train,label);
      DataHandler fs = null;
      //HACK for speed up feature selection (using directly the MALLET interface instead of DataHandler)
      if(baseline_featuresel)
      {
        List<Integer> ff = FeatureConstraintUtil.selectFeaturesByInfoGain(((MalletDataHandler)tdata).data,500);
        Set<String> set = new HashSet<String>();
        for(Integer e : ff)
          set.add(((MalletDataHandler)tdata).data.getAlphabet().lookupObject(e).toString());
        fs = tdata.createSubset(tdata.getInstanceIds(), set);
        System.out.println(fs.getFeatureCount()+"|"+set.size());
      }
      else
        fs = featuresel(tdata);
      DataHandler edata = buildVSM(eval,fs);
      relabel(edata, eval, label);
      int t=0;
      for(String id : fs.getInstanceIds())
        if((Boolean)fs.getLabel(id))
          t++;
      log.println(label+" "+t);log.flush();
      System.out.println(label+" "+fs.getInstanceCount()+" "+t+" "+fs.getFeatureCount());
      if(t <= 3 || fs.getFeatureCount() < 1)
      {
        for(String id : edata.getInstanceIds())
        {
          if((Boolean)edata.getLabel(id))
            fn++;
          predfile.println(id+" "+label+" false");
        }
        continue;
      }

  ////////   You can extract here the datasets for WEKA and SVMLigth     
  //    fs.saveDataset("svm/"+task+"/train_"+label+"|svm");
  //    edata.saveDataset("svm/"+task+"/eval_"+label+"|svm");
  //    fs.saveDataset("weka/"+task+"/train_"+label+"|weka");
  //    edata.saveDataset("weka/"+task+"/eval_"+label+"|weka");
  
      fs.initClassifier(null);
      Model m = fs.trainClassifier();
  
      ClassificationResult res = edata.classifyDataset(m);
      for(String id : res.getInstanceIds())
      {
        boolean pred = (Boolean)res.getPredictedLabel(id);
        boolean etal = res.getLabel(id);
        if(pred && etal) tp++;
        else if(pred && !etal) fp++;
        else if(!pred && etal)  fn++;
        predfile.println(id+" "+label+" "+pred);
      }
      log.println(tp+" "+fp+" "+fn);
      double prec = (double)tp / (tp+fp);
      double rec  = (double)tp / (tp+fn);
      double f    = (2*prec*rec)/(prec+rec);
      m.printModel(log);
      log.println("Baseline\tP:" + prec + " R:" + rec + " F:" + f);log.flush();
    }
    double prec = (double)tp / (tp+fp);
    double rec  = (double)tp / (tp+fn);
    double f    = (2*prec*rec)/(prec+rec);
    log.println("BaselineTOTAL\tP:" + prec + " R:" + rec + " F:" + f);log.flush();
    predfile.close();
  }

  /*  protected DataHandler featuresel(DataHandler vsm) throws DataMiningException{
    Set<String> set = new HashSet<String>();
    for(String feature : vsm.getFeatureNames())
    {
      for(String id : vsm.getInstanceIds())
        if(vsm.getBinaryValue(id, feature) && (Boolean)vsm.getLabel(id))
        {
          set.add(feature);
          continue;
        }
    }
    return vsm.createSubset(vsm.getInstanceIds(), set);
  }
   */
  public DataHandler featuresel(DataHandler vsm) throws Exception{
    MalletDataHandler d = (MalletDataHandler)vsm;
    Map<String,Integer> pfreq = new HashMap<String,Integer>(vsm.getFeatureCount());
    for(int i=0; i<d.data.size(); ++i)
    {
      Instance inst = d.data.get(i);
      if((Boolean)((Label)inst.getTarget()).getEntry())
      {
        AugmentableFeatureVector fv = (AugmentableFeatureVector)inst.getData();
        for(int j=0; j<fv.numLocations(); ++j)  
          MapUtils.addToMap(pfreq,(String)fv.getAlphabet().lookupObject(fv.indexAtLocation(j)),1);
      }
    }
    Set<String> set = new HashSet<String>();
    for(Entry<String,Integer> e : pfreq.entrySet())
      if(e.getValue() > 2)
        set.add(e.getKey());
    return vsm.createSubset(vsm.getInstanceIds(), set);
  }
}
