///////////////////////////////////////////////////////////////////////////////
//                                                                           //
// This file is part of ModelBlocks. Copyright 2009, ModelBlocks developers. //
//                                                                           //
//    ModelBlocks is free software: you can redistribute it and/or modify    //
//    it under the terms of the GNU General Public License as published by   //
//    the Free Software Foundation, either version 3 of the License, or      //
//    (at your option) any later version.                                    //
//                                                                           //
//    ModelBlocks is distributed in the hope that it will be useful,         //
//    but WITHOUT ANY WARRANTY; without even the implied warranty of         //
//    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the          //
//    GNU General Public License for more details.                           //
//                                                                           //
//    You should have received a copy of the GNU General Public License      //
//    along with ModelBlocks.  If not, see <http://www.gnu.org/licenses/>.   //
//                                                                           //
///////////////////////////////////////////////////////////////////////////////

#ifndef HHMMLM_UTT
#define HHMMLM_UTT


#include "nl-cpt.h"

char psX[]="";
char psComma[]=",";
char psSemi[]=";";
char psSemiSemi[]=";;";
char psDashDiamondDash[]="-<>-";
char psTilde[]="~";
//char psBar[]="|";
//char psOpenBrace[]="{";
//char psCloseBrace[]="}";
char psLangle[]="<";
char psRangle[]=">";
char psLbrack[]="[";
char psRbrack[]="]";

const char* BEG_STATE = "UTT/end|end;-/-|-;-/-|-;-/-|-";
const char* END_STATE = "UTT/end|end;-/-|-;-/-|-;-/-|-";


////////////////////////////////////////////////////////////////////////////////
//
//  Random Variables
//
////////////////////////////////////////////////////////////////////////////////

//////////////////////////////////////// Simple Variables

//// D: depth (input only, to HHMM models)...
DiscreteDomain<char> domD;
typedef DiscreteDomainRV<char,domD> D;
const D D_0("0");
const D D_1("1");
const D D_2("2");
const D D_3("3");
const D D_4("4");
const D D_5("5");


//// F: final-state...
DiscreteDomain<char> domF;
typedef DiscreteDomainRV<char,domF> F;
const F F_0("0");
const F F_1("1");

//// G: constituent category...
DiscreteDomain<int> domG;
//typedef DiscreteDomainRV<int,domG> G;
class G : public DiscreteDomainRV<int,domG> {
 private:
  static SimpleHash<G,F> hToTerm;
  void calcDetModels ( string s ) {
    if (!hToTerm.contains(*this)) {
      hToTerm.set(*this) = ('A'<=s[0] && s[0]<='Z') ? F_0 : F_1;
    }
  }
 public:
  G ( )                : DiscreteDomainRV<int,domG> ( )    { }
  G ( const char* ps ) : DiscreteDomainRV<int,domG> ( ps ) { calcDetModels(ps); }
  //G ( string s ) : DiscreteDomainRV<int,domG> ( s )  { calcDetModels(s); }
  F getTerm ( ) const { return hToTerm.get(*this); }
  friend pair<StringInput,G*> operator>> ( StringInput si, G& m ) { return pair<StringInput,G*>(si,&m); }
  friend StringInput operator>> ( pair<StringInput,G*> si_m, const char* psD ) {
    if ( si_m.first == NULL ) return NULL;
    StringInput si=si_m.first>>(DiscreteDomainRV<int,domG>&)*si_m.second>>psD;
    si_m.second->calcDetModels(si_m.second->getString()); return si; }
};
SimpleHash<G,F> G::hToTerm;
const G G_NIL("-");
//const G G_0("<\?\?>");


//// C: composition state resulting from right-corner reduction...
DiscreteDomain<int> domC;
//typedef DiscreteDomainRV<char,domC> C;
class C : public DiscreteDomainRV<int,domC> {
 private:
  static SimpleHash<C,F> hToSimp;
  void calcDetModels ( string s ) {
    if (!hToSimp.contains(*this)) {
      hToSimp.set(*this) = (s.find('/')==string::npos) ? F_1 : F_0;
    }
  }
 public:
  C ( )                : DiscreteDomainRV<int,domC> ( )    { }
  C ( const char* ps ) : DiscreteDomainRV<int,domC> ( ps ) { calcDetModels(ps); }
  //C ( string s ) : DiscreteDomainRV<int,domC> ( s )  { calcDetModels(s); }
  F getSimp ( ) const { return hToSimp.get(*this); }
  friend pair<StringInput,C*> operator>> ( StringInput si, C& m ) { return pair<StringInput,C*>(si,&m); }
  friend StringInput operator>> ( pair<StringInput,C*> si_m, const char* psD ) {
    if ( si_m.first == NULL ) return NULL;
    StringInput si=si_m.first>>(DiscreteDomainRV<int,domC>&)*si_m.second>>psD;
    si_m.second->calcDetModels(si_m.second->getString()); return si; }
};
SimpleHash<C,F> C::hToSimp;
const C C_NIL("-");


//// I: intermediate incomplete category state (without lookahead)...
DiscreteDomain<int> domI;
typedef DiscreteDomainRV<int,domI> I;
const I I_NIL ("-/-");


//// Q: syntactic state...
DiscreteDomain<int> domQ;
class Q : public DiscreteDomainRV<int,domQ> {
 private:
  static SimpleHash<Q,G> hToAct;
  static SimpleHash<Q,G> hToAwa;
  static SimpleHash<Q,G> hToExp;
  static SimpleHash<Q,Q> hToUp;
  void calcDetModels ( string s ) {
    if (!hToAct.contains(*this)) {
      size_t i=s.find('='); i=(string::npos==i)?0:i+1;
      size_t j=s.find('/'); j=(string::npos==i)?0:j;
      hToAct.set(*this) = G(s.substr(i,j-i).c_str());
    }
    if (!hToAwa.contains(*this)) {
      size_t i=s.find('/');
      size_t j=s.find('|');
      assert(i!=string::npos && j!=string::npos); 
      hToAwa.set(*this) = G(s.substr(i+1,j-1-i).c_str());
    }
    if (!hToExp.contains(*this)) {
      size_t i=s.find('|');
      assert(i!=string::npos); 
      hToExp.set(*this) = G(s.substr(i+1).c_str());
    }
    if (!hToUp.contains(*this)) {
      size_t i=s.find('=');
      i=(string::npos==i)?0:i+1; 
      hToUp.set(*this).DiscreteDomainRV<int,domQ>::operator=(s.substr(i).c_str());
    }
  }
 public:
  Q ( )                : DiscreteDomainRV<int,domQ> ( )    { }
  Q ( const char* ps ) : DiscreteDomainRV<int,domQ> ( ps ) { calcDetModels(ps); }
  //Q ( string s ) : DiscreteDomainRV<int,domQ> ( s )  { calcDetModels(s); }
  G getAct ( ) const { return hToAct.get(*this); }
  G getAwa ( ) const { return hToAwa.get(*this); }
  G getExp ( ) const { return hToExp.get(*this); }
  Q getUp  ( ) const { return hToUp. get(*this); }
  friend pair<StringInput,Q*> operator>> ( StringInput si, Q& m ) { return pair<StringInput,Q*>(si,&m); }
  friend StringInput operator>> ( pair<StringInput,Q*> si_m, const char* psD ) {
    if ( si_m.first == NULL ) return NULL;
    StringInput si=si_m.first>>(DiscreteDomainRV<int,domQ>&)*si_m.second>>psD;
    si_m.second->calcDetModels(si_m.second->getString()); return si; }
};
SimpleHash<Q,G> Q::hToAct;
SimpleHash<Q,G> Q::hToAwa;
SimpleHash<Q,G> Q::hToExp;
SimpleHash<Q,Q> Q::hToUp;
#define Q_NIL Q("-/-|-")
#define Q_TOP Q("ROOT/UTT|UTT")
//const Q Q_NIL("-/-|-");
//const Q Q_TOP("ROOT/S|S");


//////////////////////////////////////// Joint Variables

//// R: collection of syntactic variables at all depths in each `reduce' phase...
typedef DelimitedStaticSafeArray<4,psSemi,F> R;

//// S: collection of syntactic variables at all depths in each `shift' phase...
class S : public DelimitedStaticSafeArray<4,psSemi,Q> {
 public:
  operator G()  const { return ( ( (get(3)!=Q_NIL) ? get(3) :
                                   (get(2)!=Q_NIL) ? get(2) :
                                   (get(1)!=Q_NIL) ? get(1) : get(0) ).getExp() ); }
  bool compareFinal ( const S& s ) const { return(*this==s); }
};

//// H: the set of all (marginalized) reduce and (modeled) shift variables in the HHMM...
class H : public DelimitedJoint2DRV<psX,R,psDashDiamondDash,S,psX>
{ public:
  operator R() const {return getSub1();}
  operator S() const {return getSub2();}
  operator D() const {return ( (getSub1().get(3)==F_0) ? D_4 :
                               (getSub1().get(2)==F_0) ? D_3 :
                               (getSub1().get(1)==F_0) ? D_2 :
                               (getSub1().get(0)==F_0) ? D_1 : D_0 );}
};


////////////////////////////////////////////////////////////////////////////////
//
//  Models
//
////////////////////////////////////////////////////////////////////////////////

//// Model of F given D and Q (from previous); assume active cat from previous = awaited cat from above
typedef HidVarCPT3DModel<F,D,Q,LogProb> FcptModel;

//// Transition model of Q given D and G (from above) and Q (from previous)
typedef HidVarCPT4DModel<Q,D,G,Q,LogProb> QtcptModel;

//// Expansion model of Q given D and Q (from above)
typedef HidVarCPT3DModel<Q,D,Q,LogProb> QecptModel;


//////////////////////////////////////// "Wrapper" models for individual RVs...

//// Model of F given D and F and Q (from above) and Q (from previous)
class FModel {
 private:
  FcptModel mFcpt;
  static const HidVarCPT1DModel<F,LogProb> mF_0;
  static const HidVarCPT1DModel<F,LogProb> mF_1;
 public:
  static bool F_ROOT_OBS;
  typedef FcptModel::IterVal IterVal;
  LogProb setNext ( FModel::IterVal& f, const D& d, const F& fD, const Q& qP, const Q& qU, int& a ) const {
    //if (d==D_1 && F_ROOT_OBS==false) return mF_0.setNext(f,a);
    //if (d==D_1 && F_ROOT_OBS==true ) return mF_1.setNext(f,a);
    LogProb pr;
    if (fD==F_1 && d==D_1 && F_ROOT_OBS && (qP.getAct().getString() == "UTT" || qP.getAct().getString() == "utt"))
      pr = mF_1.setNext(f,a);
    else if (fD==F_1 && (qU.getExp().getTerm()==F_1 || qU.getExp()==G_NIL))        pr = mF_1.setNext(f,a);
    else if (fD==F_1 && qU.getExp()==qP.getAct() && qP.getAwa()==qP.getExp()) pr = mFcpt.setNext(f,d,qP.getUp(),a);
    else                                                                      pr = mF_0.setNext(f,a);
    // Report error...
    if ( a<-1 && pr==LogProb() ) cerr<<"\nERROR: no condition Fr "<<d<<" "<<qP.getUp()<<"\n\n";
//    cerr<<"  d="<<d << " " << qU<<" "<<qU.getExp()<<" "<<qP.getAct()<<"   "<<qP<<" "<<qP.getAwa()<<" "<<qP.getExp()<<"   f="<<f<<" fobs=" << F_ROOT_OBS << " condition=" << (F_ROOT_OBS!=(F_1==f)) << " a=" << a << "\n";
    // Return fail if result doesn't match root observation...
    if ( a>=-1 && d==D_1 && F_ROOT_OBS!=(F_1==f) ) pr=LogProb();
//    cerr << "Returning: " << pr << endl;
    return pr;
  }
  friend pair<StringInput,FModel*> operator>> ( StringInput si, FModel& m ) { return pair<StringInput,FModel*>(si,&m); }
  friend StringInput operator>> ( pair<StringInput,FModel*> si_m, const char* psD ) { return si_m.first>>"Fr ">>si_m.second->mFcpt>>psD; }
};
const HidVarCPT1DModel<F,LogProb> FModel::mF_0(F_0);
const HidVarCPT1DModel<F,LogProb> FModel::mF_1(F_1);
bool FModel::F_ROOT_OBS = false;


//// Model of Q given D and F and F and Q(from prev) and Q(from above)
class QModel {
 private:
  QtcptModel mQtcpt;
  QecptModel mQecpt;
  static const HidVarCPT1DModel<Q,LogProb> mQ_NIL;
  static       HidVarCPT2DModel<Q,Q,LogProb> mQ_COPY;
 public:
  //typedef QecptModel::IterVal IterVal;
  typedef QtcptModel::IterVal IterVal;
//  typedef MinHeap<ComplexTripleIteratedModeledRV<psX,CrcptModel::IterVal,psTilde,ItcptModel::IterVal,psTilde,QtModel::IterVal,psX>> IterVal;
  LogProb setNext ( QModel::IterVal& q, const D& d, const F& fD, const F& f, const Q& qP, const Q& qU, int& a ) const {
    LogProb pr;
    if (fD==F_1) {
      if (f==F_1) {
        if (qU.getExp().getTerm()==F_1 || qU.getExp()==G_NIL) {
          // 1 1 expansion to null:
          pr = mQ_NIL.setNext(q,a);
          return pr;
        }
        // 1 1 (expansion) case:
        pr = mQecpt.setNext(q,d,qU.getUp(),a);
        // Report error...
        if ( a<-1 && pr==LogProb() ) cerr<<"\nERROR: no condition Qe "<<d<<" "<<qU.getUp()<<"\n\n";
        return pr;
      }
      // 0 1 (transition) case:
      pr = mQtcpt.setNext(q,d,qU.getExp(),qP.getUp(),a);
      // Report error...
//      if ( a<-1 && pr==LogProb() ) cerr<<"\nERROR: no condition Qt "<<d<<" "<<qU.getExp()<<" "<<qP.getUp()<<"\n\n";
      return pr;
    }
    // 0 0 (copy) case:
    if ( !mQ_COPY.contains(qP) ) mQ_COPY.setProb(qP,qP)=1.0;
    pr = mQ_COPY.setNext(q,qP,a);
    //if ( LogProb()==pr ) { mQ_COPY.setProb(qP,qP)=1.0; pr=mQ_COPY.setNext(q,qP,a-c.NUM_ITERS-i.NUM_ITERS); }
    return pr;
  }
  friend pair<StringInput,QModel*> operator>> ( StringInput si, QModel& m ) { return pair<StringInput,QModel*>(si,&m); }
  friend StringInput operator>> ( pair<StringInput,QModel*> si_m, const char* psD ) {
    StringInput si;
    return ( (si=si_m.first>>"Qt ">>si_m.second->mQtcpt>>psD)!=NULL ||
             (si=si_m.first>>"Qe ">>si_m.second->mQecpt>>psD)!=NULL ) ? si : StringInput(NULL);
  }
};
const HidVarCPT1DModel<Q,LogProb> QModel::mQ_NIL(Q_NIL);
HidVarCPT2DModel<Q,Q,LogProb> QModel::mQ_COPY;


//////////////////////////////////////// Joint models...

//////////////////// Reduce phase...

//// Model of R given S
class RModel : public SingleFactoredModel<FModel> {
 private:
  static const HidVarCPT1DModel<F,LogProb> mF_1;
 public:
  typedef ComplexArrayIteratedModeledRV<FModel::IterVal,psSemi,4> IterVal;
  LogProb setNext ( RModel::IterVal& r, const S& sP, int& a ) const {
    const FModel& mF = getM1();
    const Q  qP0 = Q_TOP;
    const Q& qP1 = sP.get(0);
    const Q& qP2 = sP.get(1);
    const Q& qP3 = sP.get(2);
    const Q& qP4 = sP.get(3);
    int aCtr;
    FModel::IterVal  f5; mF_1.setNext(f5,aCtr=-1); // = F_1;
    FModel::IterVal& f4 = r.iter_array.set(3);
    FModel::IterVal& f3 = r.iter_array.set(2);
    FModel::IterVal& f2 = r.iter_array.set(1);
    FModel::IterVal& f1 = r.iter_array.set(0);
    LogProb pr;
    pr  = mF.setNext(f4,4,f5,qP4,qP3,a);
    pr *= mF.setNext(f3,3,f4,qP3,qP2,a);
    pr *= mF.setNext(f2,2,f3,qP2,qP1,a);
    pr *= mF.setNext(f1,1,f2,qP1,qP0,a);
    return pr;
  }
};
const HidVarCPT1DModel<F,LogProb> RModel::mF_1(F_1);


//////////////////// Shift phase...

//// Model of S given R and S
class SModel : public SingleFactoredModel<QModel> {
 private:
  static const HidVarCPT1DModel<F,LogProb> mF_1;
  static const HidVarCPT1DModel<Q,LogProb> mQ_TOP;
 public:
  typedef ComplexArrayIteratedModeledRV<QModel::IterVal,psSemi,4> IterVal;
  LogProb setNext ( SModel::IterVal& s, const RModel::IterVal& r, const S& sP, int& a ) const {
    const QModel& mQ = getM1();
    const Q&  qP1 = sP.get(0);
    const Q&  qP2 = sP.get(1);
    const Q&  qP3 = sP.get(2);
    const Q&  qP4 = sP.get(3);
    int aCtr;
    FModel::IterVal   f5; mF_1.setNext(f5,aCtr=-1);  //= F_1;
    const FModel::IterVal&  f4  = r.iter_array.get(3);
    const FModel::IterVal&  f3  = r.iter_array.get(2);
    const FModel::IterVal&  f2  = r.iter_array.get(1);
    const FModel::IterVal&  f1  = r.iter_array.get(0);
    QModel::IterVal  q0; mQ_TOP.setNext(q0,aCtr=-1); //= Q_TOP;
    QModel::IterVal& q1 = s.iter_array.set(0);
    QModel::IterVal& q2 = s.iter_array.set(1);
    QModel::IterVal& q3 = s.iter_array.set(2);
    QModel::IterVal& q4 = s.iter_array.set(3);
    LogProb pr;
    pr  = mQ.setNext(q1,1,f2,f1,qP1,Q(q0),a);
    pr *= mQ.setNext(q2,2,f3,f2,qP2,Q(q1),a);
    pr *= mQ.setNext(q3,3,f4,f3,qP3,Q(q2),a);
    pr *= mQ.setNext(q4,4,f5,f4,qP4,Q(q3),a);
    return pr;
  }
};
const HidVarCPT1DModel<F,LogProb> SModel::mF_1(F_1);
const HidVarCPT1DModel<Q,LogProb> SModel::mQ_TOP(Q_TOP);


//////////////////// Overall...

//// Model of H=R,S given S
class HModel : public DoubleFactoredModel<RModel,SModel> {
 public:
  typedef ComplexDoubleIteratedModeledRV<psLbrack,RModel::IterVal,psRbrack,SModel::IterVal,psX> IterVal;
  S& setTrellDat ( S& s, const HModel::IterVal& h ) const {
    for(int i=0;i<4;i++)
      s.set(i)=h.iter_second.iter_array.get(i);
    return s;
  }
  R setBackDat ( const HModel::IterVal& h ) const {
    R r;
    for(int i=0;i<4;i++)
      r.set(i)=h.iter_first.iter_array.get(i);
    return r;
  }
  LogProb setNext ( HModel::IterVal& h, const S& sP, int& a ) const {
    const RModel& mR = getM1();
    const SModel& mS = getM2();
    RModel::IterVal& r = h.iter_first;
    SModel::IterVal& s = h.iter_second;
    LogProb pr;
    pr  = mR.setNext(r,sP,a);
    if ( LogProb()==pr ) return pr;
    pr *= mS.setNext(s,r,sP,a);
    return pr;
  }
};
#endif
