///////////////////////////////////////////////////////////////////////////////
//                                                                           //
// This file is part of ModelBlocks. Copyright 2009, ModelBlocks developers. //
//                                                                           //
//    ModelBlocks is free software: you can redistribute it and/or modify    //
//    it under the terms of the GNU General Public License as published by   //
//    the Free Software Foundation, either version 3 of the License, or      //
//    (at your option) any later version.                                    //
//                                                                           //
//    ModelBlocks is distributed in the hope that it will be useful,         //
//    but WITHOUT ANY WARRANTY; without even the implied warranty of         //
//    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the          //
//    GNU General Public License for more details.                           //
//                                                                           //
//    You should have received a copy of the GNU General Public License      //
//    along with ModelBlocks.  If not, see <http://www.gnu.org/licenses/>.   //
//                                                                           //
///////////////////////////////////////////////////////////////////////////////

#include "nl-cpt.h"

char psX[]="";
char psSlash[]="/";
char psComma[]=",";
char psSemi[]=";";
char psSemiSemi[]=";;";
char psDashDiamondDash[]="-<>-";
char psTilde[]="~";
//char psBar[]="|";
//char psOpenBrace[]="{";
//char psCloseBrace[]="}";
char psLangle[]="<";
char psRangle[]=">";
char psLbrack[]="[";
char psRbrack[]="]";

const char* BEG_STATE = "S/end;-/-;-/-;-/-";
const char* END_STATE = "S/end;-/-;-/-;-/-";


////////////////////////////////////////////////////////////////////////////////
//
//  Random Variables
//
////////////////////////////////////////////////////////////////////////////////

//////////////////////////////////////// Simple Variables

//// D: depth (input only, to HHMM models)...
DiscreteDomain<char> domD;
//typedef DiscreteDomainRV<char,domD> D;
class D : public DiscreteDomainRV<char,domD> {
 public:
  D ( )                : DiscreteDomainRV<char,domD> ( )    { }
  D ( int i )          : DiscreteDomainRV<char,domD> ( i )  { }
  D ( const char* ps ) : DiscreteDomainRV<char,domD> ( ps ) { }
};
const D D_0("0");
const D D_1("1");
const D D_2("2");
const D D_3("3");
const D D_4("4");
const D D_5("5");


//// B: boolean
DiscreteDomain<char> domB;
//typedef DiscreteDomainRV<char,domB> B;
class B : public DiscreteDomainRV<char,domB> {
 public:
  B ( )                : DiscreteDomainRV<char,domB> ( )    { }
  B ( const char* ps ) : DiscreteDomainRV<char,domB> ( ps ) { }
};
const B B_0 ("0");
const B B_1 ("1");


//// G: constituent category...
DiscreteDomain<int> domG;
//typedef DiscreteDomainRV<int,domG> G;
class G : public DiscreteDomainRV<int,domG> {
 private:
  static SimpleHash<G,B> hToTerm;
  void calcDetModels ( string s ) {
    if (!hToTerm.contains(*this)) {
      hToTerm.set(*this) = ('A'<=s[0] && s[0]<='Z') ? B_0 : B_1;
    }
  }
 public:
  G ( )                                      : DiscreteDomainRV<int,domG> ( )    { }
  G ( const DiscreteDomainRV<int,domG>& rv ) : DiscreteDomainRV<int,domG> ( rv ) { }
  G ( const char* ps )                       : DiscreteDomainRV<int,domG> ( ps ) { calcDetModels(ps); }
  //G ( string s ) : DiscreteDomainRV<int,domG> ( s )  { calcDetModels(s); }
  B getTerm ( ) const { return hToTerm.get(*this); }
  friend pair<StringInput,G*> operator>> ( StringInput si, G& m ) { return pair<StringInput,G*>(si,&m); }
  friend StringInput operator>> ( pair<StringInput,G*> si_m, const char* psD ) {
    if ( si_m.first == NULL ) return NULL;
    StringInput si=si_m.first>>(DiscreteDomainRV<int,domG>&)*si_m.second>>psD;
    si_m.second->calcDetModels(si_m.second->getString()); return si; }
};
SimpleHash<G,B> G::hToTerm;
const G G_BOT("-");
const G G_TOP("ROOT");
const G G_RST("REST");


//// F: final-state...
DiscreteDomain<int> domF;
//typedef DiscreteDomainRV<char,domF> F;
class F : public DiscreteDomainRV<int,domF> {
 private:
  static SimpleHash<F,G> hToG;
  static SimpleHash<G,F> hFromG;
  void calcDetModels ( string s ) {
    if (!hToG.contains(*this)) {
      hToG.set(*this) = (this->toInt()>1) ? G(this->getString().c_str()) : G_BOT;
      hFromG.set(G(this->getString().c_str())) = *this;
    }
  }
 public:
  F ( )                                      : DiscreteDomainRV<int,domF> ( )    { }
  F ( const DiscreteDomainRV<int,domF>& rv ) : DiscreteDomainRV<int,domF> ( rv ) { }
  F ( const char* ps )                       : DiscreteDomainRV<int,domF> ( ps ) { calcDetModels(ps); }
  F ( const G& g )                                                               { *this = hFromG.get(g); }
  G        getG ( )     const { return hToG.get(*this); }
  static F getF ( G g )       { return hFromG.get(g); }
  friend pair<StringInput,F*> operator>> ( StringInput si, F& m ) { return pair<StringInput,F*>(si,&m); }
  friend StringInput operator>> ( pair<StringInput,F*> si_m, const char* psD ) {
    if ( si_m.first == NULL ) return NULL;
    StringInput si=si_m.first>>(DiscreteDomainRV<int,domF>&)*si_m.second>>psD;
    si_m.second->calcDetModels(si_m.second->getString()); return si; }
};
SimpleHash<F,G> F::hToG;
SimpleHash<G,F> F::hFromG;
const F F_0("0");
const F F_1("1");
const F F_BOT("-");


/* //// Q: syntactic state... */
/* typedef DelimitedJoint2DRV<psX,G,psSlash,G,psX> Q; */
/* //class Q : public DelimitedJoint2DRV<psX,G,psSlash,G,psX> { */
/* // public: */
/* //  Q ( )                                                   : DelimitedJoint2DRV<psX,G,psSlash,G,psX> ( )    { } */
/* //  Q ( char* ps )                                          : DelimitedJoint2DRV<psX,G,psSlash,G,psX> ( ps ) { } */
/* //  Q ( const G& g1, const G& g2 )                          : DelimitedJoint2DRV<psX,G,psSlash,G,psX> ( g1, g2 ) { } */
/* //}; */
/* #define Q_BOT Q(G_BOT,G_BOT) */
/* #define Q_TOP Q(G_TOP,G_RST) */


//// Q: syntactic state...
DiscreteDomain<int> domQ;
class Q : public DiscreteDomainRV<int,domQ> {
 private:
  static SimpleHash<Q,G> hToAct;
  static SimpleHash<Q,G> hToAwa;
  void calcDetModels ( string s ) {
    if (!hToAct.contains(*this)) {
      size_t i=s.find('/');
      assert(i!=string::npos);
      hToAct.set(*this) = G(s.substr(0,i).c_str());
    }
    if (!hToAwa.contains(*this)) {
      size_t i=s.find('/');
      assert(i!=string::npos); 
      hToAwa.set(*this) = G(s.substr(i+1).c_str());
    }
  }
 public:
  Q ( )                                      : DiscreteDomainRV<int,domQ> ( )    { }
  Q ( const DiscreteDomainRV<int,domQ>& rv ) : DiscreteDomainRV<int,domQ> ( rv ) { }
  Q ( const char* ps )                       : DiscreteDomainRV<int,domQ> ( ps ) { calcDetModels(ps); }
  G getAct ( ) const { return hToAct.get(*this); }
  G getAwa ( ) const { return hToAwa.get(*this); }
  friend pair<StringInput,Q*> operator>> ( StringInput si, Q& m ) { return pair<StringInput,Q*>(si,&m); }
  friend StringInput operator>> ( pair<StringInput,Q*> si_m, const char* psD ) {
    if ( si_m.first == NULL ) return NULL;
    StringInput si=si_m.first>>(DiscreteDomainRV<int,domQ>&)*si_m.second>>psD;
    si_m.second->calcDetModels(si_m.second->getString()); return si; }
};
SimpleHash<Q,G> Q::hToAct;
SimpleHash<Q,G> Q::hToAwa;
#define Q_BOT Q("-/-")
#define Q_TOP Q("ROOT/REST")


//////////////////////////////////////// Joint Variables

//// R: collection of syntactic variables at all depths in each `reduce' phase...
typedef DelimitedJointArrayRV<4,psSemi,F> R;


//// S: collection of syntactic variables at all depths in each `shift' phase...
class S : public DelimitedJoint2DRV<psX,DelimitedJointArrayRV<4,psSemi,Q>,psSemi,G,psX> {
 public:
  operator G()  const { return ( ( (second      !=G_BOT) ? second       :
                                   (first.get(3)!=Q_BOT) ? first.get(3).getAwa() :
                                   (first.get(2)!=Q_BOT) ? first.get(2).getAwa() :
                                   (first.get(1)!=Q_BOT) ? first.get(1).getAwa() : first.get(0).getAwa() ) ); }
  bool compareFinal ( const S& s ) const { return(*this==s); }
};


//// H: the set of all (marginalized) reduce and (modeled) shift variables in the HHMM...
class H : public DelimitedJoint2DRV<psX,R,psDashDiamondDash,S,psX>
{ public:
  operator R() const {return first;}
  operator S() const {return second;}
  operator D() const {return ( (first.get(3)==F_0) ? D_4 :
                               (first.get(2)==F_0) ? D_3 :
                               (first.get(1)==F_0) ? D_2 :
                               (first.get(0)==F_0) ? D_1 : D_0 );}
};


////////////////////////////////////////////////////////////////////////////////
//
//  Models
//
////////////////////////////////////////////////////////////////////////////////

//////////////////////////////////////// "Wrapper" models for individual RVs...

//// Model of F given D and F and Q (from above) and Q (from previous)
class FModel {
 private:
  HidVarCPT4DModel<F,D,Q,Q,LogProb> mFr;            // Reduction model: F given D, G (active cat from prev), G (awaited cat from above) (assume prev awaited = reduction)
  static const HidVarCPT1DModel<F,LogProb> mF_0;    // Fixed F_ZERO model.
  static const HidVarCPT1DModel<F,LogProb> mF_BOT;  // Fixed F_BOT model.
 public:
  static bool F_ROOT_OBS;
  LogProb setIterProb ( F::ArrayIterator<LogProb>& f, const D& d, const F& fD, const Q& qP, const Q& qU, int& a ) const {
    LogProb pr;
    if ( fD==F_BOT && (qP.getAwa()==G_BOT) ) {
      // >1 (bottom) case...
      //cerr<<"a\n";
      pr = mF_BOT.setIterProb(f,a);
    }
    else if ( fD==F_BOT && qP.getAwa().getTerm()==B_1 ) {
      // >1 (middle) case...
      //cerr<<"b\n";
      pr = mFr.setIterProb(f,d,qP,qU,a);
    }
    else {
      // 0 (top) case...
      //cerr<<"c\n";
      pr = mF_0.setIterProb(f,a);      
    }
    // Report error...
    if ( a<-2 && pr==LogProb() ) cerr<<"\nERROR: no condition Fr "<<d<<" "<<qP<<" "<<qU<<" "<<qP.getAwa()<<" "<<qP.getAwa().getTerm()<<" "<<fD<<"\n\n";
    // Iterate again if result doesn't match root observation...
    if ( a>=-2 && d==D_1 && F_ROOT_OBS!=(F(f)>F_1) ) pr=LogProb();
    ////cerr<<"  F "<<d<<" "<<fD<<" "<<qP<<" "<<qU<<" : "<<f<<" = "<<pr<<" ("<<a<<")\n";
    return pr;
  }
  friend pair<StringInput,FModel*> operator>> ( StringInput si, FModel& m ) { return pair<StringInput,FModel*>(si,&m); }
  friend StringInput operator>> ( pair<StringInput,FModel*> si_m, const char* psD ) {
    StringInput si;
    return ( (si=si_m.first>>"Fr " >>si_m.second->mFr >>psD)!=NULL ) ? si : StringInput(NULL);
  }
};
const HidVarCPT1DModel<F,LogProb> FModel::mF_0(F_0);
const HidVarCPT1DModel<F,LogProb> FModel::mF_BOT(F_BOT);
bool FModel::F_ROOT_OBS = false;


//// Model of Q given D and F and F and Q(from prev) and Q(from above)
class QModel {
 private:
  HidVarCPT3DModel<Q,D,Q,LogProb> mQe;       // Expansion model of Q given D, Q (from above)
  HidVarCPT5DModel<Q,D,G,Q,Q,LogProb> mQtm;  // Awaited transition model of Q given D, G (from reduction), Q (from previous), Q (from above)
  HidVarCPT4DModel<Q,D,Q,Q,LogProb> mQtp;    // Active Transition model of Q given D, Q (from previous), Q (from above)
  static const HidVarCPT1DModel<Q,LogProb>   mQ_BOT;   // Fixed Q_BOT model.
  static       HidVarCPT2DModel<Q,Q,LogProb> mQ_COPY;  // Cached F_COPY model  --  WARNING: STATIC NON-CONST is not thread safe!
 public:
  LogProb setIterProb ( Q::ArrayIterator<LogProb>& q, const D& d, const F& fD, const F& f, const Q& qP, const Q& qU, int& a ) const {
    LogProb pr,p;
    if (fD>F_1) {
      if (f>=F_1) {
        if (f>F_1) {
          if (qU.getAwa().getTerm()==B_1 || qU==Q_BOT) {
            ////cerr<<"a\n";
            // >1 >1 (expansion to null) case:
            pr  = mQ_BOT.setIterProb(q,a);
          }
          else {
            ////cerr<<"b\n";
            // >1 >1 (expansion) case:
            pr = p = mQe.setIterProb(q,d,qU,a);
            if ( a==-2 && p==LogProb() ) cerr<<"\nERROR: no condition Qe "<<d<<" "<<qU<<"\n\n";
          }
        }
        else {
          ////cerr<<"c\n";
          // >1 1 ('plus' transition following reduction) case:
          pr = p = mQtp.setIterProb(q,d,qP,qU,a);
          if ( a==-2 && p==LogProb() ) cerr<<"\nERROR: no condition Qtp "<<d<<" "<<qP<<" "<<qU<<"\n\n";
        }
      }
      else {
        ////cerr<<"d\n";
        // >1 0 ('minus' transition without reduction) case:
        pr = p = mQtm.setIterProb(q,d,fD.getG(),qP,qU,a);
        if ( a==-2 && p==LogProb() ) cerr<<"\nERROR: no condition Qtm "<<d<<" "<<fD.getG()<<" "<<qP<<" "<<qU<<"\n\n";
      }
    }
    else {
      ////cerr<<"e\n";
      // <=1 0 (copy) case:
      if ( !mQ_COPY.contains(qP) )  mQ_COPY.setProb(qP,qP)=1.0;
      pr = p = mQ_COPY.setIterProb(q,qP,a);
    }
    ////cerr<<"  Q "<<d<<" "<<fD<<" "<<f<<" "<<qP<<" "<<qU<<" : "<<q<<" = "<<pr<<" ("<<a<<")\n";
    return pr;
  }
  friend pair<StringInput,QModel*> operator>> ( StringInput si, QModel& m ) { return pair<StringInput,QModel*>(si,&m); }
  friend StringInput operator>> ( pair<StringInput,QModel*> si_m, const char* psD ) {
    StringInput si;
    return ( (si=si_m.first>>"Qtm ">>si_m.second->mQtm>>psD)!=NULL ||
             (si=si_m.first>>"Qtp ">>si_m.second->mQtp>>psD)!=NULL ||
             (si=si_m.first>>"Qe " >>si_m.second->mQe>>psD)!=NULL ) ? si : StringInput(NULL);
  }
};
const HidVarCPT1DModel<Q,LogProb> QModel::mQ_BOT(Q_BOT);
HidVarCPT2DModel<Q,Q,LogProb> QModel::mQ_COPY;


//////////////////////////////////////// Joint models...

//////////////////// Reduce phase...

//// Model of R given S
class RModel : public SingleFactoredModel<FModel> {
 public:
  LogProb setIterProb ( R::ArrayIterator<LogProb>& r, const S& sP, int& a ) const {
    const FModel& mF = getM1();
    LogProb pr;
    pr  = mF.setIterProb ( r.set(4-1), 4, F(sP.second) , sP.first.get(4-1), sP.first.get(3-1), a );
    pr *= mF.setIterProb ( r.set(3-1), 3, F(r.get(4-1)), sP.first.get(3-1), sP.first.get(2-1), a );
    pr *= mF.setIterProb ( r.set(2-1), 2, F(r.get(3-1)), sP.first.get(2-1), sP.first.get(1-1), a );
    pr *= mF.setIterProb ( r.set(1-1), 1, F(r.get(2-1)), sP.first.get(1-1), Q_TOP            , a );
    return pr;
  }
};


//////////////////// Shift phase...

//// Model of S given R and S
class SModel : public SingleFactoredModel<QModel> {
 private:
  HidVarCPT2DModel<G,Q,LogProb> mGe5;     // Expansion model of G given 5, Q (from above)
  static const HidVarCPT1DModel<G,LogProb> mG_BOT;
 public:
  LogProb setIterProb ( S::ArrayIterator<LogProb>& s, const R::ArrayIterator<LogProb>& r, const S& sP, int& a ) const {
    const QModel& mQ = getM1();
    LogProb pr;
    pr  = mQ.setIterProb ( s.first.set(1-1), 1, F(r.get(2-1)), F(r.get(1-1)), sP.first.get(1-1), Q_TOP               ,a );
    pr *= mQ.setIterProb ( s.first.set(2-1), 2, F(r.get(3-1)), F(r.get(2-1)), sP.first.get(2-1), Q(s.first.set(1-1)) ,a );
    pr *= mQ.setIterProb ( s.first.set(3-1), 3, F(r.get(4-1)), F(r.get(3-1)), sP.first.get(3-1), Q(s.first.set(2-1)) ,a );
    pr *= mQ.setIterProb ( s.first.set(4-1), 4, sP.second,     F(r.get(4-1)), sP.first.get(4-1), Q(s.first.set(3-1)) ,a );
    pr *= ( G(Q(s.first.set(4-1)).getAwa())!=Q_BOT &&
            G(Q(s.first.set(4-1)).getAwa()).getTerm()!=B_1 )
      ? mGe5.setIterProb ( s.second, Q(s.first.set(4-1)), a )
      : mG_BOT.setIterProb ( s.second, a );
    ////cerr<<"  G "<<5<<" "<<G(q4.second)<<" : "<<g<<" = "<<pr<<" ("<<a<<")\n";
    return pr;
  }
  friend pair<StringInput,SModel*> operator>> ( StringInput si, SModel& m ) { return pair<StringInput,SModel*>(si,&m); }
  friend StringInput operator>> ( pair<StringInput,SModel*> si_m, const char* psD ) {
    StringInput si;
    return ( (si=si_m.first>>"Ge5 ">>si_m.second->mGe5   >>psD)!=NULL ||
             (si=si_m.first>>        si_m.second->setM1()>>psD)!=NULL ) ? si : StringInput(NULL);
  }
};
const HidVarCPT1DModel<G,LogProb> SModel::mG_BOT(G_BOT);


//////////////////// Overall...

//// Model of H=R,S given S
class HModel : public DoubleFactoredModel<RModel,SModel> {
 public:
  typedef H::ArrayIterator<LogProb> IterVal;
  S& setTrellDat ( S& s, const H::ArrayIterator<LogProb>& h ) const {
    s.setVal(h.second);
    return s;
  }
  R setBackDat ( const H::ArrayIterator<LogProb>& h ) const {
    R r;
    for(int i=0;i<4;i++)
      r.set(i)=F(h.first.get(i));
    return r;
  }
  LogProb setIterProb ( H::ArrayIterator<LogProb>& h, const S& sP, int& a ) const {
    const RModel& mR = getM1();
    const SModel& mS = getM2();
    LogProb pr;
    pr  = mR.setIterProb ( h.first, sP, a );
    if ( LogProb()==pr ) return pr;
    pr *= mS.setIterProb ( h.second, h.first, sP, a );
    return pr;
  }
};
