///////////////////////////////////////////////////////////////////////////////
//                                                                           //
// This file is part of ModelBlocks. Copyright 2009, ModelBlocks developers. //
//                                                                           //
//    ModelBlocks is free software: you can redistribute it and/or modify    //
//    it under the terms of the GNU General Public License as published by   //
//    the Free Software Foundation, either version 3 of the License, or      //
//    (at your option) any later version.                                    //
//                                                                           //
//    ModelBlocks is distributed in the hope that it will be useful,         //
//    but WITHOUT ANY WARRANTY; without even the implied warranty of         //
//    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the          //
//    GNU General Public License for more details.                           //
//                                                                           //
//    You should have received a copy of the GNU General Public License      //
//    along with ModelBlocks.  If not, see <http://www.gnu.org/licenses/>.   //
//                                                                           //
///////////////////////////////////////////////////////////////////////////////

#include "nl-cpt.h"

char psX[]="";
char psSlash[]="/";
char psComma[]=",";
char psSemi[]=";";
char psSemiSemi[]=";;";
char psDashDiamondDash[]="-<>-";
char psTilde[]="~";
char psBar[]="|";
//char psOpenBrace[]="{";
//char psCloseBrace[]="}";
char psLangle[]="<";
char psRangle[]=">";
char psLbrack[]="[";
char psRbrack[]="]";

const char* BEG_STATE = "S|S/end;-|-/-;-|-/-;-|-/-;-";
const char* END_STATE = "S|S/end;-|-/-;-|-/-;-|-/-;-";


////////////////////////////////////////////////////////////////////////////////
//
//  Random Variables
//
////////////////////////////////////////////////////////////////////////////////

//////////////////////////////////////// Simple Variables

//// D: depth (input only, to HHMM models)...
DiscreteDomain<char> domD;
class D : public DiscreteDomainRV<char,domD> {
 public:
  D ( )                : DiscreteDomainRV<char,domD> ( )    { }
  D ( int i )          : DiscreteDomainRV<char,domD> ( i )  { }
  D ( const char* ps ) : DiscreteDomainRV<char,domD> ( ps ) { }
};
const D D_0("0");
const D D_1("1");
const D D_2("2");
const D D_3("3");
const D D_4("4");
const D D_5("5");


//// B: boolean
DiscreteDomain<char> domB;
class B : public DiscreteDomainRV<char,domB> {
 public:
  B ( )                : DiscreteDomainRV<char,domB> ( )    { }
  B ( const char* ps ) : DiscreteDomainRV<char,domB> ( ps ) { }
};
const B B_0 ("0");
const B B_1 ("1");


//// G: constituent category...
DiscreteDomain<int> domG;
class G : public DiscreteDomainRV<int,domG> {
 private:
  static SimpleHash<G,B> hToTerm;
  void calcDetModels ( string s ) {
    if (!hToTerm.contains(*this)) {
      hToTerm.set(*this) = ('A'<=s[0] && s[0]<='Z') ? B_0 : B_1;
    }
  }
 public:
  G ( )                : DiscreteDomainRV<int,domG> ( )    { }
  G ( const DiscreteDomainRV<int,domG>& rv ) : DiscreteDomainRV<int,domG>(rv) { }
  G ( const char* ps ) : DiscreteDomainRV<int,domG> ( ps ) { calcDetModels(ps); }
  B getTerm ( ) const { return hToTerm.get(*this); }
  friend pair<StringInput,G*> operator>> ( StringInput si, G& m ) { return pair<StringInput,G*>(si,&m); }
  friend StringInput operator>> ( pair<StringInput,G*> si_m, const char* psD ) {
    if ( si_m.first == NULL ) return NULL;
    StringInput si=si_m.first>>(DiscreteDomainRV<int,domG>&)*si_m.second>>psD;
    si_m.second->calcDetModels(si_m.second->getString()); return si; }
};
SimpleHash<G,B> G::hToTerm;
const G G_BOT("-");
const G G_TOP("ROOT");
const G G_RST("REST");


//// F: final-state...
DiscreteDomain<int> domF;
//typedef DiscreteDomainRV<char,domF> F;
class F : public DiscreteDomainRV<int,domF> {
 private:
  static SimpleHash<F,G> hToG;
  static SimpleHash<G,F> hFromG;
  void calcDetModels ( string s ) {
    if (!hToG.contains(*this)) {
      hToG.set(*this) = (this->toInt()>1) ? G(this->getString().c_str()) : G_BOT;
      hFromG.set(G(this->getString().c_str())) = *this;
    }
  }
 public:
  F ( )                : DiscreteDomainRV<int,domF> ( )    { }
  F ( const DiscreteDomainRV<int,domF>& rv ) : DiscreteDomainRV<int,domF>(rv) { }
  F ( const char* ps ) : DiscreteDomainRV<int,domF> ( ps ) { calcDetModels(ps); }
  F ( const G& g )                                         { *this = hFromG.get(g); }
  G        getG ( )     const { return hToG.get(*this); }
  static F getF ( G g )       { return hFromG.get(g); }
  friend pair<StringInput,F*> operator>> ( StringInput si, F& m ) { return pair<StringInput,F*>(si,&m); }
  friend StringInput operator>> ( pair<StringInput,F*> si_m, const char* psD ) {
    if ( si_m.first == NULL ) return NULL;
    StringInput si=si_m.first>>(DiscreteDomainRV<int,domF>&)*si_m.second>>psD;
    si_m.second->calcDetModels(si_m.second->getString()); return si; }
};
SimpleHash<F,G> F::hToG;
SimpleHash<G,F> F::hFromG;
const F F_0("0");
const F F_1("1");
const F F_BOT("-");


//// Q: syntactic state...
typedef DelimitedJoint3DRV<psX,G,psBar,G,psSlash,G,psX> Q;
#define Q_BOT Q(G_BOT,G_BOT,G_BOT)
#define Q_TOP Q(G_TOP,G_TOP,G_RST)


//////////////////////////////////////// Joint Variables

//// R: collection of syntactic variables at all depths in each `reduce' phase...
typedef DelimitedJointArrayRV<4,psSemi,F> R;


//// S: collection of syntactic variables at all depths in each `shift' phase...
class S : public DelimitedJoint2DRV<psX,DelimitedJointArrayRV<4,psSemi,Q>,psSemi,G,psX> {
 public:
  operator G()  const { return ( ( (second      !=G_BOT) ? second             :
                                   (first.get(3)!=Q_BOT) ? first.get(3).third :
                                   (first.get(2)!=Q_BOT) ? first.get(2).third :
                                   (first.get(1)!=Q_BOT) ? first.get(1).third : first.get(0).third ) ); }
  bool compareFinal ( const S& s ) const { return(*this==s); }
};


//// H: the set of all (marginalized) reduce and (modeled) shift variables in the HHMM...
class H : public DelimitedJoint2DRV<psX,R,psDashDiamondDash,S,psX>
{ public:
  operator R() const {return first;}
  operator S() const {return second;}
  operator D() const {return ( (first.get(3)==F_0) ? D_4 :
                               (first.get(2)==F_0) ? D_3 :
                               (first.get(1)==F_0) ? D_2 :
                               (first.get(0)==F_0) ? D_1 : D_0 );}
};


////////////////////////////////////////////////////////////////////////////////
//
//  Models
//
////////////////////////////////////////////////////////////////////////////////

//////////////////////////////////////// "Wrapper" models for individual RVs...

//// Model of F given D and F and Q (from above) and Q (from previous)
class FModel {
 private:
  HidVarCPT4DModel<F,D,G,G,LogProb> mFr;            // Reduction model: F given D, G (active cat from prev), G (awaited cat from above) (assume prev awaited = reduction)
  static const HidVarCPT1DModel<F,LogProb> mF_0;    // Fixed F_ZERO model.
  static const HidVarCPT1DModel<F,LogProb> mF_BOT;  // Fixed F_BOT model.
 public:
  static bool F_ROOT_OBS;
  LogProb setIterProb ( F::ArrayIterator<LogProb>& f, const D& d, const F& fD, const Q& qP, const Q& qU, int& a ) const {
    LogProb pr;
    if ( fD==F_BOT && (qP.third==G_BOT) ) {
      // >1 (bottom) case...
      pr = mF_BOT.setIterProb(f,a);
    }
    //else if ( fD>F_1 && (qP.third==F(fD).getG() || (qP.third.getTerm()==B_1 && fD==F_BOT)) ) {
    else if ( fD==F_BOT && qP.third.getTerm()==B_1 ) {
      // >1 (middle) case...
      pr = mFr.setIterProb(f,d,qP.first,qP.second,a);
    }
    //if ( fD==F_0 || (qP.third!=F(fD).getG() && !(qP.third.getTerm()==B_1 && F(fD).getG()==G_NONE)) ) ) {
    else {
      // 0 (top) case...
      pr = mF_0.setIterProb(f,a);      
    }
    // Report error...
    if ( a<-2 && pr==LogProb() ) cerr<<"\nERROR: no condition Fr "<<d<<" "<<qP.second<<" "<<qU.third<<"\n\n";
    // Iterate again if result doesn't match root observation...
    if ( a>=-2 && d==D_1 && F_ROOT_OBS!=(F(f)>F_1) ) pr=LogProb();
    ////cerr<<"  F "<<d<<" "<<fD<<" "<<qP<<" "<<qU<<" : "<<f<<" = "<<pr<<" ("<<a<<")\n";
    return pr;
  }
  friend pair<StringInput,FModel*> operator>> ( StringInput si, FModel& m ) { return pair<StringInput,FModel*>(si,&m); }
  friend StringInput operator>> ( pair<StringInput,FModel*> si_m, const char* psD ) {
    StringInput si;
    return ( (si=si_m.first>>"Fr " >>si_m.second->mFr >>psD)!=NULL ) ? si : StringInput(NULL);
  }
};
const HidVarCPT1DModel<F,LogProb> FModel::mF_0(F_0);
const HidVarCPT1DModel<F,LogProb> FModel::mF_BOT(F_BOT);
bool FModel::F_ROOT_OBS = false;


//// Model of Q given D and F and F and Q(from prev) and Q(from above)
class QModel {
 public:
  HidVarCPT3DModel<G,D,G,LogProb> mGe;     // Expansion model of G given D, G (from above)
 private:
  HidVarCPT3DModel<G,D,G,LogProb> mGa;     // Expansion completion of G given D, G (goal cat from current)
  HidVarCPT4DModel<G,D,G,G,LogProb> mGtm;  // Awaited transition model of G (active) given D, G (awaited cat from previous), G (from reduction)
  HidVarCPT4DModel<G,D,G,G,LogProb> mGtp;  // Active Transition model of G (active) given D, G (active cat from previous), G (awaited cat from above)
  HidVarCPT4DModel<G,D,G,G,LogProb> mGtq;  // Active Transition completion of G (awaited) given D, G (active cat from current), G (active cat from previous)
  static const HidVarCPT1DModel<G,LogProb>   mG_BOT;   // Fixed G_BOT model.
  static       HidVarCPT2DModel<G,G,LogProb> mG_COPY;  // Cached F_COPY model  --  WARNING: STATIC NON-CONST is not thread safe!
 public:
  LogProb setIterProb ( Q::ArrayIterator<LogProb>& q, const D& d, const F& fD, const F& f, const Q& qP, const Q& qU, int& a ) const {
    LogProb pr,p;
    if (fD>F_1) {
      if (f>=F_1) {
        if (f>F_1) {
          if (qU.third.getTerm()==B_1 || qU==Q_BOT) {
            ////cerr<<"a\n";
            // >1 >1 (expansion to null) case:
            pr  = mG_BOT.setIterProb(q.first, a);
            pr *= mG_BOT.setIterProb(q.second,a);
            pr *= mG_BOT.setIterProb(q.third, a);
          }
          else {
            ////cerr<<"b\n";
            // >1 >1 (expansion) case:
            pr  = p = mGe.setIterProb(q.first,d,qU.third,a);
            if ( a==-2 && p==LogProb() ) cerr<<"\nERROR: no condition Ge "<<d<<" "<<qU.third<<"\n\n";
            pr *= p = mGa.setIterProb(q.second,d,G(q.first),a);
            if ( a==-2 && p==LogProb() ) cerr<<"\nERROR: no condition Ga "<<d<<" "<<G(q.first)<<"\n\n";
            if ( !mG_COPY.contains(G(q.second)) ) mG_COPY.setProb(G(q.second),G(q.second))=1.0;
            pr *= p = mG_COPY.setIterProb(q.third,G(q.second),a);
          }
        }
        else {
          ////cerr<<"c\n";
          // >1 1 ('plus' transition following reduction) case:
          if ( !mG_COPY.contains(qP.first) ) mG_COPY.setProb(qP.first,qP.first)=1.0;
          pr  = p = mG_COPY.setIterProb(q.first,G(qP.first),a);
          pr *= p = mGtp.setIterProb(q.second,d,G(q.first),qP.second,a);
          if ( a==-2 && p==LogProb() ) cerr<<"\nERROR: no condition Gtp "<<d<<" "<<G(q.first)<<" "<<qP.second<<"\n\n";
          pr *= p = mGtq.setIterProb(q.third,d,G(q.second),qP.second,a);
          ////if ( a==-2 ) cerr<<"applying "<<d<<" "<<G(q.second)<<" "<<qP.second<<" : "<<G(q.third)<<"\n";
          if ( a==-2 && p==LogProb() ) cerr<<"\nERROR: no condition Gtq "<<d<<" "<<G(q.second)<<" "<<qP.second<<"\n\n";
        }
      }
      else {
        ////cerr<<"d\n";
        // >1 0 ('minus' transition without reduction) case:
        if ( !mG_COPY.contains(qP.first ) ) mG_COPY.setProb(qP.first ,qP.first )=1.0;
        pr  = p = mG_COPY.setIterProb(q.first,qP.first,a);
        if ( !mG_COPY.contains(qP.second) ) mG_COPY.setProb(qP.second,qP.second)=1.0;
        pr *= p = mG_COPY.setIterProb(q.second,qP.second,a);
        pr *= p = mGtm.setIterProb(q.third,d,qP.third,fD.getG(),a);
        if ( a==-2 && p==LogProb() ) cerr<<"\nERROR: no condition Gtm "<<d<<" "<<qP.third<<" "<<fD.getG()<<"\n\n";
      }
    }
    else {
      ////cerr<<"e\n";
      // <=1 0 (copy) case:
      if ( !mG_COPY.contains(qP.first ) ) mG_COPY.setProb(qP.first, qP.first )=1.0;
      pr  = p = mG_COPY.setIterProb(q.first, qP.first, a);
      if ( !mG_COPY.contains(qP.second) ) mG_COPY.setProb(qP.second,qP.second)=1.0;
      pr *= p = mG_COPY.setIterProb(q.second,qP.second,a);
      if ( !mG_COPY.contains(qP.third ) ) mG_COPY.setProb(qP.third, qP.third )=1.0;
      pr *= p = mG_COPY.setIterProb(q.third, qP.third, a);
    }
    ////cerr<<"  Q "<<d<<" "<<fD<<" "<<f<<" "<<qP<<" "<<qU<<" : "<<q<<" = "<<pr<<" ("<<a<<")\n";
    return pr;
  }
  friend pair<StringInput,QModel*> operator>> ( StringInput si, QModel& m ) { return pair<StringInput,QModel*>(si,&m); }
  friend StringInput operator>> ( pair<StringInput,QModel*> si_m, const char* psD ) {
    StringInput si;
    return ( (si=si_m.first>>"Ge ">>si_m.second->mGe>>psD)!=NULL ||
             (si=si_m.first>>"Ga ">>si_m.second->mGa>>psD)!=NULL ||
             (si=si_m.first>>"Gtm ">>si_m.second->mGtm>>psD)!=NULL ||
             (si=si_m.first>>"Gtp ">>si_m.second->mGtp>>psD)!=NULL ||
             (si=si_m.first>>"Gtq ">>si_m.second->mGtq>>psD)!=NULL ) ? si : StringInput(NULL);
  }
};
const HidVarCPT1DModel<G,LogProb> QModel::mG_BOT(G_BOT);
HidVarCPT2DModel<G,G,LogProb> QModel::mG_COPY;


//////////////////////////////////////// Joint models...

//////////////////// Reduce phase...

//// Model of R given S
class RModel : public SingleFactoredModel<FModel> {
 public:
  LogProb setIterProb ( R::ArrayIterator<LogProb>& r, const S& sP, int& a ) const {
    const FModel& mF = getM1();
    LogProb pr;
    pr  = mF.setIterProb ( r.set(4-1), 4, F(sP.second) , sP.first.get(4-1), sP.first.get(3-1), a );
    pr *= mF.setIterProb ( r.set(3-1), 3, F(r.get(4-1)), sP.first.get(3-1), sP.first.get(2-1), a );
    pr *= mF.setIterProb ( r.set(2-1), 2, F(r.get(3-1)), sP.first.get(2-1), sP.first.get(1-1), a );
    pr *= mF.setIterProb ( r.set(1-1), 1, F(r.get(2-1)), sP.first.get(1-1), Q_TOP            , a );
    return pr;
  }
};


//////////////////// Shift phase...

//// Model of S given R and S
class SModel : public SingleFactoredModel<QModel> {
 private:
  static const HidVarCPT1DModel<G,LogProb>   mG_BOT;
 public:
  LogProb setIterProb ( S::ArrayIterator<LogProb>& s, const R::ArrayIterator<LogProb>& r, const S& sP, int& a ) const {
    const QModel& mQ = getM1();
    LogProb pr;
    pr  = mQ.setIterProb ( s.first.set(1-1), 1, F(r.get(2-1)), F(r.get(1-1)), sP.first.get(1-1), Q_TOP                        ,a );
    pr *= mQ.setIterProb ( s.first.set(2-1), 2, F(r.get(3-1)), F(r.get(2-1)), sP.first.get(2-1), Q().setVal(s.first.set(1-1)) ,a );
    pr *= mQ.setIterProb ( s.first.set(3-1), 3, F(r.get(4-1)), F(r.get(3-1)), sP.first.get(3-1), Q().setVal(s.first.set(2-1)) ,a );
    pr *= mQ.setIterProb ( s.first.set(4-1), 4, sP.second,     F(r.get(4-1)), sP.first.get(4-1), Q().setVal(s.first.set(3-1)) ,a );
    pr *= ( G(s.first.set(4-1).third)!=G_BOT &&
            G(s.first.set(4-1).third).getTerm()!=B_1 )
      ? mQ.mGe.setIterProb ( s.second, 5, G(s.first.set(4-1).third), a )
      : mG_BOT.setIterProb ( s.second, a );
    ////cerr<<"  G "<<5<<" "<<G(q4.third)<<" : "<<g<<" = "<<pr<<" ("<<a<<")\n";
    return pr;
  }
};
const HidVarCPT1DModel<G,LogProb> SModel::mG_BOT(G_BOT);


//////////////////// Overall...

//// Model of H=R,S given S
class HModel : public DoubleFactoredModel<RModel,SModel> {
 public:
  typedef H::ArrayIterator<LogProb> IterVal;
  S& setTrellDat ( S& s, const H::ArrayIterator<LogProb>& h ) const {
    s.setVal(h.second);
    return s;
  }
  R setBackDat ( const H::ArrayIterator<LogProb>& h ) const {
    R r;
    r.setVal(h.first);
    return r;
  }
  LogProb setIterProb ( H::ArrayIterator<LogProb>& h, const S& sP, int& a ) const {
    const RModel& mR = getM1();
    const SModel& mS = getM2();
    LogProb pr;
    pr  = mR.setIterProb ( h.first, sP, a );
    if ( LogProb()==pr ) return pr;
    pr *= mS.setIterProb ( h.second, h.first, sP, a );
    return pr;
  }
};
