///////////////////////////////////////////////////////////////////////////////
//                                                                           //
// This file is part of ModelBlocks. Copyright 2009, ModelBlocks developers. //
//                                                                           //
//    ModelBlocks is free software: you can redistribute it and/or modify    //
//    it under the terms of the GNU General Public License as published by   //
//    the Free Software Foundation, either version 3 of the License, or      //
//    (at your option) any later version.                                    //
//                                                                           //
//    ModelBlocks is distributed in the hope that it will be useful,         //
//    but WITHOUT ANY WARRANTY; without even the implied warranty of         //
//    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the          //
//    GNU General Public License for more details.                           //
//                                                                           //
//    You should have received a copy of the GNU General Public License      //
//    along with ModelBlocks.  If not, see <http://www.gnu.org/licenses/>.   //
//                                                                           //
///////////////////////////////////////////////////////////////////////////////

#include <sstream>
#include "nl-cpt.h"
#include "nl-dtree.h"

////////////////////////////////////////////////////////////////////////////////
//
//  Random Variables
//
////////////////////////////////////////////////////////////////////////////////


//// P: part of speech category...
DiscreteDomain<short> domainP;
typedef DiscreteDomainRV<short,domainP> P;

// W: word (full word)
DiscreteDomain<int> domainW;
typedef DiscreteDomainRV<int,domainW> W;

////////////////////////////////////////////////////////////////////////////////
//
//  Models
//
////////////////////////////////////////////////////////////////////////////////

//// Preterminal (POS) given constituent category models...
typedef HidVarCPT2DModel<P,G,LogProb> PgivGModel;

DiscreteDomain<int> domainUnk;
typedef DiscreteDomainRV<int,domainUnk> UNK;

class WModel; // forward declarataion

class Unk2DModel : public RandAccCPT2DModel<UNK,P,LogProb>{
  private:
  public:
  static  SimpleHash<W,UNK> wordsToUnks;
  static const char* lastUnk;
  static W lastW;
  static int lastLoc;
  
  //LogProb getProb(const W& w, const HidVarCPT1DModel<P,LogProb>::IterVal& p, int loc, const RandAccCPT1DModel<W,LogProb>& modW) const {
  LogProb getProb(const W& w, const P::ArrayIterator<LogProb>& p, int loc, const RandAccCPT1DModel<W,LogProb>& modW) const {
    const char *sig = buildSignature(w,loc,modW);
    LogProb lp = RandAccCPT2DModel<UNK,P,LogProb>::getProb(UNK(sig), p);
    if (lp == LogProb()){
//      cerr << "ERROR: No match for Unk signature: " << sig << " with pos " << p << endl;
    }
    return lp;
  }
  
  const char * buildSignature(const W& w, int loc, const RandAccCPT1DModel<W,LogProb>& modW) const ;
};

SimpleHash<W,UNK> Unk2DModel::wordsToUnks;
const char * Unk2DModel::lastUnk;
W Unk2DModel::lastW;
int Unk2DModel::lastLoc;

//// Generative model of word given tag...
class WModel {
 private:
  Unk2DModel modWunkgivP;
  
  RandAccCPT2DModel<P,W,LogProb> modPgivWs;
  RandAccCPT1DModel<P,LogProb> modP;
  RandAccCPT1DModel<W,LogProb> modW;
  
 public:
  //LogProb getProb ( const W& w, const HidVarCPT1DModel<P,LogProb>::IterVal& p, int loc) const {
  LogProb getProb ( const W& w, const P::ArrayIterator<LogProb>& p, int loc) const {
    assert(modP.getProb(p)!=LogProb());
//    Wend small_w(w.getString().c_str());
    LogProb pr = ( modW.contains(w) ? 
                                    modPgivWs.getProb(p,w) * LogProb(-1000) / modP.getProb(p) 
                                    : 
                                    modWunkgivP.getProb(w,p,loc,modW) );
    return pr;
  }
  void writeFields ( FILE* pf, string sPref ) { /*modWunkgivP.writeFields(pf,sPref);*/ }
  friend pair<StringInput,WModel*> operator>> ( StringInput si, WModel& m ) { return pair<StringInput,WModel*>(si,&m); }
  friend StringInput operator>> ( pair<StringInput,WModel*> delimbuff, const char* psD ) {
    StringInput si;
    return ( (si=delimbuff.first>>"W "   >>delimbuff.second->modW      >>psD)!=NULL ||
             (si=delimbuff.first>>"Pw "  >>delimbuff.second->modPgivWs >>psD)!=NULL ||
             (si=delimbuff.first>>"UNK ">>delimbuff.second->modWunkgivP>>psD)!=NULL ||
             (si=delimbuff.first>>"P "   >>delimbuff.second->modP      >>psD)!=NULL ) ? si : StringInput(NULL);
  }
};


const char* Unk2DModel::buildSignature(const W& w, int loc, const RandAccCPT1DModel<W,LogProb>& modW) const
{
    // check cache...
    if(w.getString()==lastW.getString() && loc==lastLoc){
      return lastUnk;
    }
    // Create the signature
    stringstream buf(stringstream::in | stringstream::out);
    buf << "UNK";
      
    const char *str = w.getString().c_str();
    int wlen = strlen(str);
    int numCaps = 0;
    bool hasDigit = false;
    bool hasDash = false;
    bool hasLower = false;
    stringstream lowered(stringstream::in|stringstream::out);
    for(int i = 0; i < wlen; i++){
      if(str[i] >= 48 && str[i] <= 57){
        hasDigit = true;
        lowered << str[i];
      }else if(str[i] == '-'){
        hasDash = true;
        lowered << str[i];
      }else if(str[i] >= 65 && str[i] <= 90){
        numCaps++;
        lowered << (str[i]+26);
      }else if(str[i] >= 91 && str[i] <= 122){
        hasLower = true;
        lowered << str[i];
      }
    }
    char c = str[0];
   
    if(c >= 65 && c <= 90){
      if(loc == 0 && numCaps == 1){
        buf << "-INITC";
        if(modW.contains(lowered.str().c_str())){
          buf << "-KNOWNLC";
        }
      }else{
        buf << "-CAPS";
      }
    }else if(numCaps > 0){
      buf << "-CAPS";
    }else if(hasLower){
      buf << "-LC";
    }
    if(hasDigit){
      buf << "-NUM";
    }
    if(hasDash){
      buf << "-DASH";
    }
    const char *l = lowered.str().c_str();
    if(l[wlen-1] == 's'){
      char ch2 = l[wlen-2];
      if(ch2 != 's' && ch2 != 'i' && ch2 != 'u'){
        buf << "-s";
      }
    }else if(wlen >= 5 && !hasDash && !(hasDigit && numCaps > 0)){
      if(l[wlen-1]=='d' && l[wlen-2]){ 
        buf << "-ed";
      }else if(l[wlen-1]=='g' && l[wlen-2]=='n' && l[wlen-3]=='i'){
        buf << "-ing";
      }else if(l[wlen-1]=='n' && l[wlen-2]=='o' && l[wlen-3]=='i'){
        buf << "-ion";
      }else if(l[wlen-1]=='r' && l[wlen-2]=='e'){
        buf << "-er";
      }else if(l[wlen-1]=='t' && l[wlen-2]=='s' && l[wlen-3]=='e'){
        buf << "-est";
      }else if(l[wlen-1]=='y' && l[wlen-2]=='l'){
        buf << "-ly";
      }else if(l[wlen-1]=='y' && l[wlen-2]=='t' && l[wlen-3]=='i'){
        buf << "-ity";
      }else if(l[wlen-1]=='y'){
        buf << "-y";
      }else if(l[wlen-1]=='l' && l[wlen-2]=='a'){
        buf << "-al";
      }
    }
    // get UNK signature for this word and add to my hash
    //wordsToUnks.set(w) = UNK(buf.str().c_str());
//    UNK ret = UNK(buf.str().c_str());
    lastUnk = buf.str().c_str();
    lastLoc = loc;
    lastW = w;
    cerr << "Found signature for word: " << w << " : " << lastUnk << endl;
    return lastUnk;
}

//// Wrapper class for model
class OModel {

 private:

  PgivGModel modPgivG;
  WModel     modWgivP;
  static int word_num;
  
 public:
  
  class DistribModeledWordgivG {
   private:
    SimpleHash<G,Prob> hcpCache;
   public:
    DistribModeledWordgivG& set ( const W& w, const OModel& m ) { 
      word_num++;
      m.calcProb(*this,w); 
      return *this; 
    }
    /*
      DistribModeledWgivG& operator= ( const W& w ) {
      hcpCache.clear();
      for ( PgivGModel::const_iterator iter = modPgivG.begin(); iter!=modPgivG.end(); iter++ ) {
      G c = iter->first.getX1();
      PgivGModel::IterVal p;
      for ( bool bp=modPgivG.setFirst(p,c); bp; bp=modPgivG.setNext(p,c) ) {
      hcpCache.set(c) += modPgivG.getProb(p,c).toProb() * modWgivP.getProb(w,p).toProb();
      }
      }
      return *this;
      }
    */
    void    clear   ( )                     { hcpCache.clear(); }
    Prob&   setProb ( const G& g )       { return hcpCache.set(g); }
    LogProb getProb ( const G& g ) const { return LogProb(hcpCache.get(g)); }
  };

  typedef DistribModeledWordgivG RandVarType;

//  void getWordSig(const W& w) const {
//    if(!modWgivP.hasWord(w)){
//      modWgivP.calcUnkSignature(w, word_num);
//    }  
//  }
  
  void calcProb ( OModel::RandVarType& o, const W& w ) const {
    o.clear();
    for ( PgivGModel::const_iterator iter = modPgivG.begin(); iter!=modPgivG.end(); iter++ ) {
      G g = iter->first.getX1();
      //PgivGModel::IterVal p;
      P::ArrayIterator<LogProb> p;
      int aCtr=-1;
      //for ( bool bp=modPgivG.setIterProb(p,g,aCtr); bp; bp=modPgivG.setIterProb(p,g,aCtr=0) ) {
      for (LogProb pr=modPgivG.setIterProb(p,g,aCtr); pr!=LogProb(); pr = modPgivG.setIterProb(p,g,aCtr=0) ){
        o.setProb(g) += modPgivG.getProb(p,g).toProb() * modWgivP.getProb(w,p,word_num).toProb();
      }
     
    }
  }

  LogProb getProb ( const OModel::RandVarType& o, const G& g ) const { return o.getProb(g); }

  friend pair<StringInput,OModel*> operator>> ( StringInput si, OModel& m ) { return pair<StringInput,OModel*>(si,&m); }
  friend StringInput operator>> ( pair<StringInput,OModel*> delimbuff, const char* psD ) {
    StringInput si;
    return ( (si=delimbuff.first>>"Pg ">>delimbuff.second->modPgivG>>psD)!=NULL ||
             (si=delimbuff.first>>       delimbuff.second->modWgivP>>psD)!=NULL ) ? si : StringInput(NULL);
  }

  void writeFields ( FILE* pf, string sPref ) { modWgivP.writeFields(pf,sPref); }
};

int OModel::word_num = -1;
