package szte.csd;

import java.io.FileInputStream;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.io.PrintWriter;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.TreeMap;
import java.util.zip.GZIPInputStream;
import java.util.zip.GZIPOutputStream;

import szte.csd.Sentence;
import szte.csd.SyntaxParser;
import szte.csd.Tokenizer;
import szte.datamining.DataHandler;
import szte.datamining.DataMiningException;
import szte.datamining.Model;
import szte.nlputils.Util;
import edu.stanford.nlp.trees.GrammaticalStructure;
import edu.stanford.nlp.trees.Tree;
import edu.stanford.nlp.trees.TreeGraphNode;
import edu.stanford.nlp.trees.TypedDependency;

/**
 * 
 * CSDModel stores everything for an indicator selection and local context detection model.
 *
 */
public class CSDModel {
  protected Map<String, Set<String>> terms = null;
  protected Model modifier = null;
  protected Tokenizer tokenizer = null;
  protected SyntaxParser parser = null;
  protected DataHandler datahandler = null;
  protected int fe_depparse = 0;
  protected String task = "";
  public static int PHRASELENGTH = 1;
  
  
  public CSDModel(){
    terms = new TreeMap<String, Set<String>>();
    modifier = null;
    tokenizer = new Tokenizer();
  }

  public CSDModel(String t, int dep){
    terms = new TreeMap<String, Set<String>>();
    modifier = null;
    this.task = t;
    fe_depparse = dep;
    if(t.equals("obes"))
      tokenizer = new Tokenizer(null);
    else if(dep>0){
      parser = new SyntaxParser(t);
      tokenizer = new Tokenizer(parser.getDocProcessor());
    } else
      tokenizer = new Tokenizer();
  }

  /**
   * 
   * Trains the CSD model and stores the feature set for that.
   * 
   */
  public void trainClassifier(DataHandler train){
    datahandler = train;
    train.saveDataset("mody");
    try {
      modifier = train.trainClassifier();
    } catch (DataMiningException e) {
      e.printStackTrace();
    }
  }

  /**
   * 
   * The prediction mechanism of the CSD model.
   * 
   */
  public boolean isAltered(Sentence sentence, int pos){
    DataHandler inst = datahandler.createEmptyDataHandler();
    Map<String,Object> param = new HashMap<String, Object>();
    param.put("useFeatureSet", datahandler);
    inst.createNewDataset(param);
    featureExtract(inst,"X",sentence,pos);
    try {
      return inst.classifyDataset(modifier).getPredictedLabel("X");
    } catch (Exception e1) {
      e1.printStackTrace();
      return false;
    }
  }
  
  static public String matchTerm(List<String> tokens, int pos, Map<String, Set<String>> termsset){
    String phrase = tokens.get(pos);
    int i=1;
    do{
      for(String label : termsset.keySet())
        if(termsset.get(label).contains(phrase))
          return label;
      if(pos-i<0) 
        return null;
      phrase = tokens.get(pos-i) + " " + phrase;
      ++i;
    }while(i<=PHRASELENGTH);
    return null;
  }

  public Model getModifier() {
    return modifier;
  }

  public Map<String, Set<String>> getTerms() {
    return terms;
  }

  public String isTerm(List<String> tokens, int pos){
    return matchTerm(tokens, pos, terms);
  }
  
  /**
   * 
   * Document-level prediction.
   * @return the predicted labels of the document
   * 
   */
  public Set<String> predict(String doc){
    return predict(doc, false);
  }
  
  public Set<String> predict(String doc, boolean no_csd){
    List<Sentence> sentences = tokenizer.tokenise(doc);
    Set<String> pred = new HashSet<String>();
    for(Sentence sent : sentences)
    {
      List<String> tokens = sent.tokens;
      for(int i=0; i<tokens.size(); ++i){
        if(tokens.get(i).length()<=1) 
          continue;
        if(!no_csd && isAltered(sent, i))
          continue;
        String label = isTerm(tokens,i);
        if(label != null)
        {
          pred.add(label);
        }
      }
    }
    return pred;
  }
  
  /**
   * 
   * Feature extraction for the local context classification.
   * 
   */
  public void featureExtract(DataHandler dh,String id, Sentence sentence, int pos){
    if(pos==0)
      dh.setLabel(id, false);
    for(int i=0; i<pos; ++i) //ignores the token itself
      dh.setBinaryValue(id, sentence.tokens.get(i), true);
 
    //syntax-based features
    if(fe_depparse>0 && isTerm(sentence.tokens,pos)!=null){
      GrammaticalStructure gstruct = parser.parseSentence(sentence.original_sentence);
      if(gstruct==null) return;
      try{
        TreeGraphNode targetNode = gstruct.getNodeByIndex(pos+1);
        for(Tree b : targetNode.parent().children())
          dh.setBinaryValue(id, "BROT#"+b.value(), true);

        TreeGraphNode p = null;
        do{
          p=null;
          if(fe_depparse%2==0){//whether to use the SUBJ and SUBJD features
            for(TypedDependency d : gstruct.typedDependenciesCollapsed())
              if(d.reln().toString().equals("nsubj") && d.gov().equals(targetNode)){
                TreeGraphNode q = targetNode;
                boolean hasdep = false;
                do{
                  hasdep = false;
                  for(TypedDependency dd : gstruct.typedDependenciesCollapsed())
                    if(dd.gov().equals(q)){
                      q=dd.dep();
                      String token = fe_depparse%3==0 ? sentence.tokens.get(q.index()-1) : q.value();
                      dh.setBinaryValue(id, "SUBJgov#"+token, true);
                      dh.setBinaryValue(id, "SUBJrln#"+dd.reln().toString(), true);
                      dh.setBinaryValue(id, "SUBJgovrln#"+dd.reln().toString()+"#"+token, true);
                      hasdep = true;
                      break;
                    }
                }while(hasdep);
                break;
              }
          }
          for(TypedDependency d : gstruct.typedDependenciesCollapsed())
          {
            if(d.dep().equals(targetNode)){
              p=d.gov();
              String token = fe_depparse%3==0 ? sentence.tokens.get(p.index()-1) : p.value();
              dh.setBinaryValue(id, "DEPgov#"+token, true);
              dh.setBinaryValue(id, "DEPrln#"+d.reln().toString(), true);
              dh.setBinaryValue(id, "DEPgovrln#"+d.reln().toString()+"#"+token, true);
              break;
            }
          }
          targetNode=p;
        }while(p!=null);
      }catch(Exception e){
        e.printStackTrace();
        System.err.println(sentence.original_sentence);
        System.err.println(sentence.tokens);
        return;
      }      
    }
  }

  public void log(PrintWriter log) {
    log.println(" Terms:");
    for(String label : terms.keySet())
      log.println("  "+label+": "+terms.get(label));
    if(modifier != null)
    {
      log.println(" Modifier:\n"+modifier);
      modifier.printModel(log);
    }
    log.println();
  }

  public Tokenizer getTokenizer() {
    return tokenizer;
  }
  
  public void serialize(String filename){
    if(parser != null)
      parser.serailizeParsedSentences();
    try {
      FileOutputStream fos = new FileOutputStream(filename);
      GZIPOutputStream gz = new GZIPOutputStream(fos);
      ObjectOutputStream oos = new ObjectOutputStream(gz);

      oos.writeObject(terms);
      oos.writeObject(modifier);
      oos.writeObject(fe_depparse);
      oos.writeObject(PHRASELENGTH);
      oos.writeObject(task);
      oos.writeObject(datahandler);
      
      oos.flush();
      oos.close();
      fos.close();
    } catch (IOException e) {
      e.printStackTrace();
    }
  }
  
  public static CSDModel restore(String filename){
    try {
        FileInputStream fis = new FileInputStream(filename);
        GZIPInputStream gs = new GZIPInputStream(fis);
        ObjectInputStream ois = new ObjectInputStream(gs);
        
        CSDModel csd = new CSDModel();
        csd.terms = (Map<String, Set<String>>)ois.readObject();
        csd.modifier = (Model)ois.readObject();
        csd.fe_depparse = (Integer)ois.readObject();
        CSDModel.PHRASELENGTH = (Integer)ois.readObject();
        csd.task = (String)ois.readObject();
        if(csd.task.equals("obes"))
          csd.tokenizer = new Tokenizer(null);
        else if(csd.fe_depparse>0){
          csd.parser = new SyntaxParser(csd.task);
          csd.tokenizer = new Tokenizer(csd.parser.getDocProcessor());
        }
        csd.datahandler = (DataHandler)ois.readObject();
    
        return csd;
    } catch (Exception e) {
      e.printStackTrace();
    }
    return null;
  }

  
  public static void main(String[] args){
    if(args.length != 2)
    {
      System.err.println("usage: CSDModel modelfile inputfile");
      System.exit(1);
    }
    CSDModel csd = CSDModel.restore(args[0]);
    System.out.println(csd.predict(Util.readFileToString(args[1])));
  }
}
