///////////////////////////////////////////////////////////////////////////////
//                                                                           //
// This file is part of ModelBlocks. Copyright 2009, ModelBlocks developers. //
//                                                                           //
//    ModelBlocks is free software: you can redistribute it and/or modify    //
//    it under the terms of the GNU General Public License as published by   //
//    the Free Software Foundation, either version 3 of the License, or      //
//    (at your option) any later version.                                    //
//                                                                           //
//    ModelBlocks is distributed in the hope that it will be useful,         //
//    but WITHOUT ANY WARRANTY; without even the implied warranty of         //
//    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the          //
//    GNU General Public License for more details.                           //
//                                                                           //
//    You should have received a copy of the GNU General Public License      //
//    along with ModelBlocks.  If not, see <http://www.gnu.org/licenses/>.   //
//                                                                           //
///////////////////////////////////////////////////////////////////////////////

#include "nl-cpt.h"
#include "TextObsVars.h"

char psX[]="";
char psSlash[]="/";
char psComma[]=",";
char psSemi[]=";";
char psSemiSemi[]=";;";
char psDashDiamondDash[]="-<>-";
char psTilde[]="~";
//char psBar[]="|";
//char psOpenBrace[]="{";
//char psCloseBrace[]="}";
char psLangle[]="<";
char psRangle[]=">";
char psLbrack[]="[";
char psRbrack[]="]";

const char* BEG_STATE = "S/end;-/-;-/-;-/-";
const char* END_STATE = "S/end;-/-;-/-;-/-";


////////////////////////////////////////////////////////////////////////////////
//
//  Random Variables
//
////////////////////////////////////////////////////////////////////////////////

//////////////////////////////////////// Simple Variables

//// D: depth (input only, to HHMM models)...
DiscreteDomain<char> domD;
//typedef DiscreteDomainRV<char,domD> D;
class D : public DiscreteDomainRV<char,domD> {
 public:
  D ( )                : DiscreteDomainRV<char,domD> ( )    { }
  D ( int i )          : DiscreteDomainRV<char,domD> ( i )  { }
  D ( const char* ps ) : DiscreteDomainRV<char,domD> ( ps ) { }
};
const D D_0("0");
const D D_1("1");
const D D_2("2");
const D D_3("3");
const D D_4("4");
const D D_5("5");


//// B: boolean
DiscreteDomain<char> domB;
//typedef DiscreteDomainRV<char,domB> B;
class B : public DiscreteDomainRV<char,domB> {
 public:
  B ( )                : DiscreteDomainRV<char,domB> ( )    { }
  B ( const char* ps ) : DiscreteDomainRV<char,domB> ( ps ) { }
};
const B B_0 ("0");
const B B_1 ("1");


//// G: constituent category...
DiscreteDomain<int> domG;
//typedef DiscreteDomainRV<int,domG> G;
class G : public DiscreteDomainRV<int,domG> {
 private:
  static SimpleHash<G,B> hToTerm;
  void calcDetModels ( string s ) {
    if (!hToTerm.contains(*this)) {
      hToTerm.set(*this) = (('A'<=s[0] && s[0]<='Z') || s.find('_')!=string::npos) ? B_0 : B_1;
    }
  }
 public:
  G ( )                : DiscreteDomainRV<int,domG> ( )    { }
  G ( const DiscreteDomainRV<int,domG>& rv ) : DiscreteDomainRV<int,domG>(rv) { }
  G ( const char* ps ) : DiscreteDomainRV<int,domG> ( ps ) { calcDetModels(ps); }
  //G ( string s ) : DiscreteDomainRV<int,domG> ( s )  { calcDetModels(s); }
  B getTerm ( ) const { return hToTerm.get(*this); }
  friend pair<StringInput,G*> operator>> ( StringInput si, G& m ) { return pair<StringInput,G*>(si,&m); }
  friend StringInput operator>> ( pair<StringInput,G*> si_m, const char* psD ) {
    if ( si_m.first == NULL ) return NULL;
    StringInput si=si_m.first>>(DiscreteDomainRV<int,domG>&)*si_m.second>>psD;
    si_m.second->calcDetModels(si_m.second->getString()); return si; }
};
SimpleHash<G,B> G::hToTerm;
const G G_BOT("-");
const G G_TOP("ROOT");
const G G_RST("REST");


//// C: same as G, since no role labels...
typedef G C;


//// F: final-state...
DiscreteDomain<int> domF;
//typedef DiscreteDomainRV<char,domF> F;
class F : public DiscreteDomainRV<int,domF> {
 private:
  static SimpleHash<F,G> hToG;
  static SimpleHash<G,F> hFromG;
  void calcDetModels ( string s ) {
    if (!hToG.contains(*this)) {
      hToG.set(*this) = (this->toInt()>1) ? G(this->getString().c_str()) : G_BOT;
      hFromG.set(G(this->getString().c_str())) = *this;
    }
  }
 public:
  F ( )                : DiscreteDomainRV<int,domF> ( )    { }
  F ( const DiscreteDomainRV<int,domF>& rv ) : DiscreteDomainRV<int,domF>(rv) { }
  F ( const char* ps ) : DiscreteDomainRV<int,domF> ( ps ) { calcDetModels(ps); }
  F ( const G& g )                                         { *this = hFromG.get(g); }
  G        getG ( )     const { return hToG.get(*this); }
  static F getF ( G g )       { return hFromG.get(g); }
  friend pair<StringInput,F*> operator>> ( StringInput si, F& m ) { return pair<StringInput,F*>(si,&m); }
  friend StringInput operator>> ( pair<StringInput,F*> si_m, const char* psD ) {
    if ( si_m.first == NULL ) return NULL;
    StringInput si=si_m.first>>(DiscreteDomainRV<int,domF>&)*si_m.second>>psD;
    si_m.second->calcDetModels(si_m.second->getString()); return si; }
};
SimpleHash<F,G> F::hToG;
SimpleHash<G,F> F::hFromG;
const F F_0("0");
const F F_1("1");
const F F_BOT("-");

//// Q: syntactic state...
typedef DelimitedJoint2DRV<psX,G,psSlash,G,psX> Q;
//class Q : public DelimitedJoint2DRV<psX,G,psSlash,G,psX> {
// public:
//  Q ( )                          : DelimitedJoint2DRV<psX,G,psSlash,G,psX> ( )        { }
//  Q ( char* ps )                 : DelimitedJoint2DRV<psX,G,psSlash,G,psX> ( ps )     { }
//  Q ( const G& g1, const G& g2 ) : DelimitedJoint2DRV<psX,G,psSlash,G,psX> ( g1, g2 ) { }
//};
#define Q_BOT Q(G_BOT,G_BOT)
#define Q_TOP Q(G_TOP,G_RST)

/* //// Q: syntactic state... */
/* DiscreteDomain<int> domQ; */
/* class Q : public DiscreteDomainRV<int,domQ> { */
/*  private: */
/*   static SimpleHash<Q,G> hToAct; */
/*   static SimpleHash<Q,G> hToAwa; */
/*   void calcDetModels ( string s ) { */
/*     if (!hToAct.contains(*this)) { */
/*       size_t i=s.find('='); i=(string::npos==i)?0:i+1; */
/*       size_t j=s.find('/'); j=(string::npos==i)?0:j; */
/*       hToAct.set(*this) = G(s.substr(i,j-i).c_str()); */
/*     } */
/*     if (!hToAwa.contains(*this)) { */
/*       size_t i=s.find('/'); */
/*       size_t j=s.find('|'); */
/*       assert(i!=string::npos && j!=string::npos);  */
/*       hToAwa.set(*this) = G(s.substr(i+1,j-1-i).c_str()); */
/*     } */
/*   } */
/*  public: */
/*   Q ( )                : DiscreteDomainRV<int,domQ> ( )    { } */
/*   Q ( const char* ps ) : DiscreteDomainRV<int,domQ> ( ps ) { calcDetModels(ps); } */
/*   G getAct ( ) const { return hToAct.get(*this); } */
/*   G getAwa ( ) const { return hToAwa.get(*this); } */
/*   friend pair<StringInput,Q*> operator>> ( StringInput si, Q& m ) { return pair<StringInput,Q*>(si,&m); } */
/*   friend StringInput operator>> ( pair<StringInput,Q*> si_m, const char* psD ) { */
/*     if ( si_m.first == NULL ) return NULL; */
/*     StringInput si=si_m.first>>(DiscreteDomainRV<int,domQ>&)*si_m.second>>psD; */
/*     si_m.second->calcDetModels(si_m.second->getString()); return si; } */
/* }; */
/* SimpleHash<Q,G> Q::hToAct; */
/* SimpleHash<Q,G> Q::hToAwa; */
/* #define Q_BOT Q("-/-") */
/* #define Q_TOP Q("ROOT/REST") */


//////////////////////////////////////// Joint Variables

//// R: collection of syntactic variables at all depths in each `reduce' phase...
typedef DelimitedJointArrayRV<4,psSemi,F> R;

//// S: collection of syntactic variables at all depths in each `shift' phase...
class S : public DelimitedJoint2DRV<psX,DelimitedJointArrayRV<4,psSemi,Q>,psSemi,G,psX> {
 public:
  operator G()  const { return ( ( (second      !=G_BOT) ? second       :
                                   (first.get(3)!=Q_BOT) ? first.get(3).second :
                                   (first.get(2)!=Q_BOT) ? first.get(2).second :
                                   (first.get(1)!=Q_BOT) ? first.get(1).second : first.get(0).second ) ); }
  bool compareFinal ( const S& s ) const { return(*this==s); }
};

//// Y: the set of all (marginalized) reduce and (modeled) shift variables in the HHMM...
class Y : public DelimitedJoint2DRV<psX,R,psDashDiamondDash,S,psX>
{ public:
  operator R() const {return first;}
  operator S() const {return second;}
//  operator D() const {return ( (first.get(3)==F_0) ? D_4 :
//                               (first.get(2)==F_0) ? D_3 :
//                               (first.get(1)==F_0) ? D_2 :
//                               (first.get(0)==F_0) ? D_1 : D_0 );}
};


////////////////////////////////////////////////////////////////////////////////
//
//  Models
//
////////////////////////////////////////////////////////////////////////////////

//////////////////////////////////////// "Wrapper" models for individual RVs...

//// Model of F given D and F and Q (from above) and Q (from previous)
class FModel {
 private:
  HidVarCPT4DModel<F,D,G,G,LogProb> mFr;            // Reduction model: F giv D, G (active cat from prev), G (awaited cat from above) (assume prev awa = reduc)
  static const HidVarCPT1DModel<F,LogProb> mF_0;    // Fixed F_ZERO model.
  static const HidVarCPT1DModel<F,LogProb> mF_BOT;  // Fixed F_BOT model.
 public:
  //static bool F_ROOT_OBS;
  LogProb setIterProb ( F::ArrayIterator<LogProb>& f, const D& d, const F& fD, const Q& qP, const Q& qU, bool b1, int& a ) const {
    LogProb pr;
    if ( fD==F_BOT && (qP.second==G_BOT) ) {
      // >1 (bottom) case...
      pr = mF_BOT.setIterProb(f,a);
    }
    //else if ( fD>F_1 && (qP.second==F(fD).getG() || (qP.second.getTerm()==B_1 && fD==F_BOT)) ) {
    else if ( fD==F_BOT && qP.second.getTerm()==B_1 ) {
      // >1 (middle) case...
      pr = mFr.setIterProb(f,d,qP.first,qU.second,a);
      if ( a<-1 && pr==LogProb() ) cerr<<"\nERROR: no condition Fr "<<d<<" "<<qP.first<<" "<<qU.second<<"\n\n";
    }
    //if ( fD==F_0 || (qP.second!=F(fD).getG() && !(qP.second.getTerm()==B_1 && F(fD).getG()==G_NONE)) ) ) {
    else {
      // 0 (top) case...
      pr = mF_0.setIterProb(f,a);
    }
    // Iterate again if result doesn't match root observation...
    if ( a>=-1 && d==D_1 && b1!=(F(f)>F_1) ) pr=LogProb();
    ////cerr<<"    F "<<d<<" "<<fD<<" "<<qP<<" "<<qU<<" : "<<f<<" = "<<pr<<" ("<<a<<")\n";
    return pr;
  }
  friend pair<StringInput,FModel*> operator>> ( StringInput si, FModel& m ) { return pair<StringInput,FModel*>(si,&m); }
  friend StringInput operator>> ( pair<StringInput,FModel*> si_m, const char* psD ) {
    StringInput si;
    return ( (si=si_m.first>>"Fr " >>si_m.second->mFr >>psD)!=NULL ) ? si : StringInput(NULL);
  }
};
const HidVarCPT1DModel<F,LogProb> FModel::mF_0(F_0);
const HidVarCPT1DModel<F,LogProb> FModel::mF_BOT(F_BOT);
//bool FModel::F_ROOT_OBS = false;


//// Model of Q given D and F and F and Q(from prev) and Q(from above)
class QModel {
 public:
  HidVarCPT3DModel<G,D,G,LogProb> mGe;     // Expansion model of G given D, G (from above)
 private:
  HidVarCPT4DModel<G,D,G,G,LogProb> mGtm;  // Awaited transition model of G (active) given D, G (awaited cat from prev), G (from reduction)
  HidVarCPT4DModel<G,D,G,G,LogProb> mGtp;  // Active Transition model of G (active) given D, G (active cat from prev), G (awaited cat from above)
  HidVarCPT4DModel<G,D,G,G,LogProb> mGtq;  // Active Transition completion of G (awaited) given D, G (active cat from curr), G (active cat from prev)
  static const HidVarCPT1DModel<G,LogProb>   mG_BOT;   // Fixed G_BOT model.
  static       HidVarCPT2DModel<G,G,LogProb> mG_COPY;  // Cached F_COPY model  --  WARNING: STATIC NON-CONST is not thread safe!
 public:
  LogProb setIterProb ( Q::ArrayIterator<LogProb>& q, const D& d, const F& fD, const F& f, const Q& qP, const Q& qU, int& a ) const {
    LogProb pr,p;
    if (fD>F_1) {
      if (f>=F_1) {
        if (f>F_1) {
          if (qU.second.getTerm()==B_1 || qU==Q_BOT) {
            ////cerr<<"a\n";
            // >1 >1 (expansion to null) case:
            pr  = mG_BOT.setIterProb(q.first, a);
            pr *= mG_BOT.setIterProb(q.second,a);
          }
          else {
            ////cerr<<"b\n";
            // >1 >1 (expansion) case:
            pr  = p = mGe.setIterProb(q.first,d,qU.second,a);
            if ( a<-1 && p==LogProb() ) cerr<<"\nERROR: no condition Ge "<<d<<" "<<qU.second<<"\n\n";
            if ( !mG_COPY.contains(G(q.first)) ) mG_COPY.setProb(G(q.first),G(q.first))=1.0;
            pr *= p = mG_COPY.setIterProb(q.second,G(q.first),a);
          }
        }
        else {
          ////cerr<<"c\n";
          // >1 1 ('plus' transition following reduction) case:
          pr  = p = mGtp.setIterProb(q.first,d,qP.first,qU.second,a);
          if ( a<-1 && p==LogProb() ) cerr<<"\nERROR: no condition Gtp "<<d<<" "<<qP.first<<" "<<qU.second<<"\n\n";
          pr *= p = mGtq.setIterProb(q.second,d,G(q.first),qP.first,a);
          if ( a<-1 && p==LogProb() ) cerr<<"\nERROR: no condition Gtq "<<d<<" "<<G(q.first)<<" "<<qP.first<<"\n\n";
        }
      }
      else {
        ////cerr<<"d\n";
        // >1 0 ('minus' transition without reduction) case:
        if ( !mG_COPY.contains(qP.first) ) mG_COPY.setProb(qP.first,qP.first)=1.0;
        pr  = p = mG_COPY.setIterProb(q.first,qP.first,a);
        pr *= p = mGtm.setIterProb(q.second,d,fD.getG(),qP.second,a);
        if ( a<-1 && p==LogProb() ) cerr<<"\nERROR: no condition Gtm "<<d<<" "<<fD.getG()<<" "<<qP.second<<"\n\n";
      }
    }
    else {
      ////cerr<<"e\n";
      // <=1 0 (copy) case:
      if ( !mG_COPY.contains(qP.first) )  mG_COPY.setProb(qP.first, qP.first )=1.0;
      pr  = p = mG_COPY.setIterProb(q.first, qP.first, a);
      if ( !mG_COPY.contains(qP.second) ) mG_COPY.setProb(qP.second,qP.second)=1.0;
      pr *= p = mG_COPY.setIterProb(q.second,qP.second,a);
    }
    ////cerr<<"    Q "<<d<<" "<<fD<<" "<<f<<" "<<qP<<" "<<qU<<" : "<<q<<" = "<<pr<<" ("<<a<<")\n";
    return pr;
  }
  friend pair<StringInput,QModel*> operator>> ( StringInput si, QModel& m ) { return pair<StringInput,QModel*>(si,&m); }
  friend StringInput operator>> ( pair<StringInput,QModel*> si_m, const char* psD ) {
    StringInput si;
    return ( (si=si_m.first>>"Gtm ">>si_m.second->mGtm>>psD)!=NULL ||
             (si=si_m.first>>"Gtp ">>si_m.second->mGtp>>psD)!=NULL ||
             (si=si_m.first>>"Gtq ">>si_m.second->mGtq>>psD)!=NULL ||
             (si=si_m.first>>"Ge ">>si_m.second->mGe>>psD)!=NULL ) ? si : StringInput(NULL);
  }
};
const HidVarCPT1DModel<G,LogProb> QModel::mG_BOT(G_BOT);
HidVarCPT2DModel<G,G,LogProb> QModel::mG_COPY;


//////////////////////////////////////// Joint models...

//////////////////// Reduce phase...

//// Model of R given S
class RModel : public SingleFactoredModel<FModel> {
 public:
  LogProb setIterProb ( R::ArrayIterator<LogProb>& r, const S& sP, bool b1, int& a ) const {
    const FModel& mF = getM1();
    LogProb pr;
    pr  = mF.setIterProb ( r.set(4-1), 4, F(sP.second) , sP.first.get(4-1), sP.first.get(3-1), b1, a );
    pr *= mF.setIterProb ( r.set(3-1), 3, F(r.get(4-1)), sP.first.get(3-1), sP.first.get(2-1), b1, a );
    pr *= mF.setIterProb ( r.set(2-1), 2, F(r.get(3-1)), sP.first.get(2-1), sP.first.get(1-1), b1, a );
    pr *= mF.setIterProb ( r.set(1-1), 1, F(r.get(2-1)), sP.first.get(1-1), Q_TOP            , b1, a );
    return pr;
  }
};


//////////////////// Shift phase...

//// Model of S given R and S
class SModel : public SingleFactoredModel<QModel> {
 private:
  static const HidVarCPT1DModel<G,LogProb>   mG_BOT;
 public:
  LogProb setIterProb ( S::ArrayIterator<LogProb>& s, const R::ArrayIterator<LogProb>& r, const S& sP, int& a ) const {
    const QModel& mQ = getM1();
    LogProb pr;
    pr  = mQ.setIterProb ( s.first.set(1-1), 1, F(r.get(2-1)), F(r.get(1-1)), sP.first.get(1-1), Q_TOP                        ,a );
    pr *= mQ.setIterProb ( s.first.set(2-1), 2, F(r.get(3-1)), F(r.get(2-1)), sP.first.get(2-1), Q().setVal(s.first.set(1-1)) ,a );
    pr *= mQ.setIterProb ( s.first.set(3-1), 3, F(r.get(4-1)), F(r.get(3-1)), sP.first.get(3-1), Q().setVal(s.first.set(2-1)) ,a );
    pr *= mQ.setIterProb ( s.first.set(4-1), 4, sP.second,     F(r.get(4-1)), sP.first.get(4-1), Q().setVal(s.first.set(3-1)) ,a );
    pr *= ( G(s.first.set(4-1).second)!=G_BOT &&
            G(s.first.set(4-1).second).getTerm()!=B_1 )
      ? mQ.mGe.setIterProb ( s.second, 5, G(s.first.set(4-1).second), a )
      : mG_BOT.setIterProb ( s.second, a );
    ////cerr<<"  G "<<5<<" "<<G(q4.second)<<" : "<<g<<" = "<<pr<<" ("<<a<<")\n";
    return pr;
  }
};
const HidVarCPT1DModel<G,LogProb> SModel::mG_BOT(G_BOT);


//////////////////// Overall...

//// Model of Y=R,S given S
class YModel : public DoubleFactoredModel<RModel,SModel> {
 public:
  //static X WORD;
  //static bool& F_ROOT_OBS;
  typedef Y::ArrayIterator<LogProb> IterVal;
  S& setTrellDat ( S& s, const Y::ArrayIterator<LogProb>& y ) const {
    s.setVal(y.second);
    return s;
  }
  R setBackDat ( const Y::ArrayIterator<LogProb>& y ) const {
    R r;
    for(int i=0;i<4;i++)
      r.set(i)=F(y.first.get(i));
    return r;
  }
  LogProb setIterProb ( Y::ArrayIterator<LogProb>& y, const S& sP, const X& x, bool b1, int& a ) const {
    const RModel& mR = getM1();
    const SModel& mS = getM2();
    LogProb pr;
    pr  = mR.setIterProb ( y.first, sP, b1, a );
    if ( LogProb()==pr ) return pr;
    pr *= mS.setIterProb ( y.second, y.first, sP, a );
    return pr;
  }
  void update ( ) const { }
};
//X     YModel::WORD;
//bool& YModel::F_ROOT_OBS = FModel::F_ROOT_OBS;
