///////////////////////////////////////////////////////////////////////////////
//                                                                           //
// This file is part of ModelBlocks. Copyright 2009, ModelBlocks developers. //
//                                                                           //
//    ModelBlocks is free software: you can redistribute it and/or modify    //
//    it under the terms of the GNU General Public License as published by   //
//    the Free Software Foundation, either version 3 of the License, or      //
//    (at your option) any later version.                                    //
//                                                                           //
//    ModelBlocks is distributed in the hope that it will be useful,         //
//    but WITHOUT ANY WARRANTY; without even the implied warranty of         //
//    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the          //
//    GNU General Public License for more details.                           //
//                                                                           //
//    You should have received a copy of the GNU General Public License      //
//    along with ModelBlocks.  If not, see <http://www.gnu.org/licenses/>.   //
//                                                                           //
///////////////////////////////////////////////////////////////////////////////

#include <iostream>
#include <fstream>
#include <sstream>
#include <getopt.h>

#include "nl-randvar.h"
#include "nl-dtree.h"
#include "nl-iomacros.h"
//#include "nl-modelfile.h"

//char psX[]="";
//char psUscUscUsc[]="___";

//#define NOISY 1

#define THRESHOLD 0.0

////////////////////////////////////////////////////////////////////////////////
//
//  Random Variables
//
////////////////////////////////////////////////////////////////////////////////

#include "TextObsVars.h"
#include "PCFGLangModel.h"
typedef G C;
#include "TextObsModel.h"

////////////////////////////////////////////////////////////////////////////////
//
//  helper classes, functions
//
////////////////////////////////////////////////////////////////////////////////

bool NOISY = false;

class ChartCell {
 private:
  static const ChartCell ccDummy;
  bool             bRooted;
  G                g;
  int              k;
  const ChartCell* pcc1;
  const ChartCell* pcc2;
  LogProb          pr;
 public:
  ChartCell ( )                                                                      : bRooted(false), g(),   k(0),  pcc1(NULL), pcc2(NULL), pr()    { }
  ChartCell ( G& gA, LogProb prA )                                                   : bRooted(false), g(gA), k(0),  pcc1(NULL), pcc2(NULL), pr(prA) { }
  ChartCell ( G& gA, const ChartCell& x1, LogProb prA )                              : bRooted(false), g(gA), k(0),  pcc1(&x1),  pcc2(NULL), pr(prA) { }
  ChartCell ( G& gA, int kA, const ChartCell& x1, const ChartCell& x2, LogProb prA ) : bRooted(false), g(gA), k(kA), pcc1(&x1),  pcc2(&x2),  pr(prA) { }
  void setRooted ( )       { bRooted=true; }
  bool getRooted ( ) const { return bRooted; }
  bool operator!= ( const ChartCell& x ) const { return (g!=x.g || k!=x.k || pcc1!=x.pcc1 || pcc2!=x.pcc2 || pr!=x.pr); }
  const G&         getG    ( ) const { return g;  }
  int              getK    ( ) const { return k;  }
  const ChartCell& getX1   ( ) const { return (pcc1!=NULL) ? *pcc1 : ccDummy; }
  const ChartCell& getX2   ( ) const { return (pcc2!=NULL) ? *pcc2 : ccDummy; }
  LogProb          getProb ( ) const { return pr; }
  void prune      ( ) { /*???*/ }
  void writeBest  ( FILE* pf, List<X>& lx ) const { 
    fprintf(pf," (%s",g.getString().c_str());
    if(pcc1)pcc1->writeBest(pf,lx);
    if(pcc2)pcc2->writeBest(pf,lx);
    if(!pcc1 && !pcc2){fprintf(pf," ");fprintf(pf,"%s",lx.getFirst()->getString().c_str());lx.pop();}
    fprintf(pf,")"); }
};
const ChartCell ChartCell::ccDummy;


////////////////////////////////////////////////////////////////////////////////

void tryUnary ( SafeArray3D<Id<int>,Id<int>,G,ChartCell>& accChart, int i, int j, const ChartCell& ccG, const GuModel& modGgivG ) {
  G gP;
//  for ( bool bgp=gP.setFirst(); bgp; bgp=gP.setNext() ) {
  for ( GuModel::const_iterator iter = modGgivG.begin(); iter!=modGgivG.end(); iter++ ) {
    G          gP = iter->first.getX1();
    G          gg = ccG.getG();
    ChartCell& ccP = accChart.set(i,j,gP);
    LogProb    pr = ccG.getProb()*modGgivG.getProb(gg,gP);
    if ( pr>ccP.getProb() ) {
      ccP = ChartCell(gP,ccG,pr);
      if ( NOISY )
        cout << "unary match! " << i << " " << j << " " << gP.getString() << " -> " << gg.getString() << " -> " << pr.toInt() << endl;
      tryUnary ( accChart, i, j, ccP, modGgivG );
    }
  }
}


////////////////////////////////////////////////////////////////////////////////
//
//  Main Function (pipe data input)
//
////////////////////////////////////////////////////////////////////////////////

static struct option long_options[] = {
  {"forest", optional_argument, 0, 'f'},
  {"help", no_argument, 0, 'h'},
  {"verbose", no_argument, 0, 'v'},
  {0, 0, 0, 0} };

void printUsage(char* progname) {
  fprintf(stderr,"Usage: cat <RAW_FILE> | %s [OPTIONS]... [MODEL_FILE1] [MODEL_FILE2] ...\n",progname);
  fprintf(stderr,"Runs the speech recognition system on RAW_FILE using the models in MODEL_FILE1, 2, etc.\n");
  fprintf(stderr," System variable Settings:\n");
  fprintf(stderr,"  -f, --forest=TRUE\n");
  fprintf(stderr,"  -h, --help\t\tPrint this message\n");
  fprintf(stderr,"  -v, --verbose=TRUE\n");
}

int main (int nArgs, char* argv[]) {

  //modGgivG.NOISY_READ = true;

  bool FOREST = false;

  GGModel modGGgivG;
  GuModel modGgivG;
  GModel  modGr;
  GModel  modG;
  XModel  modX;

  // Parse command-line options
  char* progname = strdup(argv[0]);
  int option_index = 0;
  char opt;
  while( (opt = getopt_long(nArgs, argv, "fhv", long_options, &option_index)) != -1 ) {
    switch(opt) {
    case 'f': FOREST=true; break;
    case 'h': printUsage(progname); exit(0);
    case 'v': NOISY=true; break;
    default: break;
    }
  }

  for ( int a=optind; a<nArgs; a++ ) {
    FILE* pf = fopen(argv[a],"r"); assert(pf);                                // READ MODEL FILE
    cerr << "Loading model \'" << argv[a] << "\'...\n";
    int c=' '; int i=0; int line=1; Array<char*> aps(100); String sBuff(1000);   // Lookahead/ctrs/buffers
    CONSUME_ALL ( pf, c, WHITESPACE(c), line);                                   // Get to first record
    while ( c!=-1 && c!='\0' && c!='\5' ) {                                      // For each record
      CONSUME_STR ( pf, c, (c!='\n' && c!='\0' && c!='\5'), sBuff, i, line );    //   Consume line
      String sBuff2(sBuff); StringInput si(sBuff2.c_array());
      //sBuff.split(aps," :=");                                                  //   Split into fields
      //cerr<<si.c_str()<<endl;
      if ( !( sBuff[0]=='#' ||                                                   //   Accept comments/fields
              //si>>mH>>"\0"!=NULL ||
              si>>modX>>"\0"!=NULL ||
	      si>>"G ">>modG>>"\0"!=NULL || 
	      si>>"Gr ">>modGr>>"\0"!=NULL || 
	      si>>"Gu ">>modGgivG>>"\0"!=NULL || 
	      si>>"GG ">>modGGgivG>>"\0"!=NULL ) )
	      //si>>"Pc ">>modPgivCvar>>"\0"!=NULL ||
	      //si>>"Pw ">>modPgivWs>>"\0"!=NULL ||
	      //si>>"PwDT ">>modPgivWdt>>"\0"!=NULL ||
	      //si>>"P ">>modP>>"\0"!=NULL ) )
	      ////mH.readFields(aps) || mX.readFields(aps)*/ ) )
	//	cerr<<"\nERROR: "<<aps.size()<<"-arg "<<((aps.size()>0)?aps[0]:"??")<<" in line "<<line<<"\n\n";
	cerr<<"\nERROR: can't parse \'"<<sBuff<<"\' in line "<<line<<"\n\n";
      CONSUME_ALL ( pf, c, WHITESPACE(c), line);                                 //   Consume whitespace
      if ( line%100000==0 ) cerr<<"  "<<line<<" lines read...\n";                //   Progress for big models
    }
    cerr << "Model \'" << argv[a] << "\' loaded.\n";
  }
//   for ( int i=optind; i<nArgs; i++ ) {
//     cerr << "Loading model \'" << argv[i] << "\'\n";
//     FILE* pf = fopen(argv[i],"r"); 
//     assert(pf);
//     processModelFilePtr ( pf, readFields );
//     cerr << "Model \'" << argv[i] << "\' loaded.\n";
//   }

  ifstream fin;
  List<X>  lx;
  string   sLine;
  int ctr=0;

  //  if (nArgs>2) fin.open(argv[2]);

  // Read in each line...
  //  while ( (nArgs>2) ? getline(fin,sLine) : getline(cin,sLine) ) {
  while ( getline(cin,sLine) ) {
    ctr++;
    cerr << ctr ;

    // Read in each word...
    stringstream ss(stringstream::in|stringstream::out);
    lx=List<X>();
    ss<<sLine;
    for (string s; ss>>s; ) {
      cerr << " " << s;
      lx.add() = X(s.c_str()); //.read(s);
    }

    cerr << endl;

    // Allocate chart...
    SafeArray3D<Id<int>,Id<int>,G,ChartCell> accChart;
    int n = lx.getCard();
    accChart.init ( n+1, n+2, G::getDomain().getSize() );

    // Initialize chart with terminals (POSs)...
    int i=0;
    Listed(X)* px;
    foreach(px,lx) {

      //XModel::RandVarType x; x = *static_cast<X*>(px);
      //XModel::RandVarType x; x.set(*px,mX);
      X& x = *px;
      if ( NOISY )
        cerr << "word " << px->getString() << endl;
      //for ( CPT2DModel<P,C,LogProb>::const_iterator iter = modPgivCvar.begin(); iter!=modPgivCvar.end(); iter++ ) {
      //  C c = iter->first.getX1();
      G g;

      for ( bool bG=g.setFirst(); bG; bG=g.setNext() ) {
        if ( modX.getProb(x,g)!=LogProb() ) {
          ChartCell* pcc = &accChart.set(i,i+1,g);
          *pcc = ChartCell ( g, modX.getProb(x,g) );
          tryUnary ( accChart, i, i+1, *pcc, modGgivG );
          if ( NOISY )
            cerr << "pos match! " << i << " " << i+1 << " " << g.getString() << " -> " << px->getString()
                 << " " << modX.getProb(x,g).toInt() << "=" << pcc->getProb().toInt() << endl;
          if ( FOREST )
            fprintf ( stdout, "%d %s %d : %d %s %d = %d\n",
                      i, g.getString().c_str(), i+1,
                      i, px->getString().c_str(), i+1, modX.getProb(x,g).toInt() );
        }
      }
      
      /*
      ModeledWgivP w = *pw;
      #ifdef NOISY /////////////////////////
      cout << "word " << w.getString() << endl;
      #endif ///////////////////////////////
      for ( CPT2DModel<P,C,LogProb>::const_iterator iter = modPgivC.getConds().begin(); iter!=modPgivC.getConds().end(); iter++ ) {
        C c = iter->first.getX1();
        ModeledPgivC p;
        for ( bool bp=p.setFirst(c); bp; bp=p.setNext(c) ) {
          if ( w.getProb(p)!=LogProb() ) {
            ChartCell* pcc = &accChart.set(i,i+1,c);
            *pcc = ChartCell ( c, w.getProb(p)*p.getProb(c) );
            tryUnary ( accChart, i, i+1, *pcc );
            #ifdef NOISY /////////////////////////
            cout << "pos match! " << i << " " << i+1 << " " << c.getString() << " -> " << p.getString() << " -> " << w.getString()
                 << " " //<< w.getProb(p).toInt() 
              //<< modPgivW.getProb(p,w).toInt() << "/" << modP.getProb(p).toInt()
                 << "*" << p.getProb(c).toInt() << "=" << pcc->getProb().toInt() << endl;
              //(w.getProb(p)*p.getProb(c)).toInt() << endl;
            #endif ///////////////////////////////
          }
        }
      }
      */

      i++;
    }

    // Fill chart with nonterminals...
    for ( int di=1; di<=n; di++ ) {
      cerr << "  " << di << endl;
      for ( int i=0; i<=n-di; i++ ) {
        int j=i+di;
        for ( int k=i+1; k<=j-1; k++ ) {
          for ( GGModel::const_iterator iter = modGGgivG.begin(); iter!=modGGgivG.end(); iter++ ) {
            G                gP = iter->first.getX1();
            GGModel::IterVal gg;
            for ( bool bgg=modGGgivG.setFirst(gg,gP); bgg; bgg=modGGgivG.setNext(gg,gP) ) {
              if ( modGGgivG.getProb(gg,gP)>LogProb(THRESHOLD) ) {
                const G& g1 = gg.first;  const ChartCell& cc1 = accChart.get(i,k,g1);
                const G& g2 = gg.second; const ChartCell& cc2 = accChart.get(k,j,g2);
                if ( cc1!=ChartCell() && cc2!=ChartCell() ) {
                  ChartCell& cc = accChart.set(i,j,gP);
                  LogProb pr = cc1.getProb()*cc2.getProb()*modGGgivG.getProb(gg,gP);
                  if ( NOISY )
                    cerr << "      binary match! " << i << " " << k << " " << j << " " << gP.getString()
                         << " -> " << g1.getString() << " " << g2.getString() << " "
                         << cc1.getProb().toInt() << "*" << cc2.getProb().toInt() << "*" << modGGgivG.getProb(gg,gP).toInt()
                         << "=" << pr.toInt() << endl;
                  //if ( FOREST )
                  //  fprintf ( stdout, "%d %s %d : %d %s %d %s %d = %d\n",
                  //            i, gP.getString().c_str(), j,
                  //            i, g1.getString().c_str(), k, g2.getString().c_str(), j, pr.toInt() );
                  if ( pr>cc.getProb() ) {
                    cc = ChartCell(gP,k,cc1,cc2,pr);
                  tryUnary ( accChart, i, j, cc, modGgivG );
                  }
                }
              }
            }
          }
        }
      }
    }


    if ( FOREST ) {

      //ModeledC c;
      GModel::IterVal g;
      for ( bool bg=modGr.setFirst(g); bg; bg=modGr.setNext(g) )
        if ( accChart.get(0,n,g).getProb()*modGr.getProb(g) != LogProb() ) {
          accChart.set(0,n,g).setRooted();
        }

      for ( int di=n; di>=1; di-- ) {
        cerr << "  " << di << endl;
        for ( int i=0; i<=n-di; i++ ) {
          int j=i+di;
          for ( GGModel::const_iterator iter = modGGgivG.begin(); iter!=modGGgivG.end(); iter++ ) {
            G                gP = iter->first.getX1();
            const ChartCell& cc  = accChart.get(i,j,gP);
            if ( cc.getRooted() ) {
              for ( int k=i+1; k<=j-1; k++ ) {
                GGModel::IterVal gg;
                for ( bool bgg=modGGgivG.setFirst(gg,gP); bgg; bgg=modGGgivG.setNext(gg,gP) ) {
                  if ( modGGgivG.getProb(gg,gP)>LogProb(THRESHOLD) ) {
                    const G& g1 = gg.first;  ChartCell& cc1 = accChart.set(i,k,g1);
                    const G& g2 = gg.second; ChartCell& cc2 = accChart.set(k,j,g2);
                    if ( cc1!=ChartCell() && cc2!=ChartCell() ) {
                      cc1.setRooted();
                      cc2.setRooted();
                      LogProb pr = cc1.getProb()*cc2.getProb()*modGGgivG.getProb(gg,gP);
                      fprintf ( stdout, "%d %s %d : %d %s %d %s %d = %d\n",
                                i, gP.getString().c_str(), j,
                                i, g1.getString().c_str(), k, g2.getString().c_str(), j, pr.toInt() );
                    }
                  }
                }
              }
            }
          }
        }
      }
      cout << "-" << endl;

    } else {
      // Print best...
      LogProb         prBest=LogProb();
      G               gBest;
      GModel::IterVal g;
      for ( bool bg=modGr.setFirst(g); bg; bg=modGr.setNext(g) )
        if ( LogProb()==prBest || accChart.get(0,n,g).getProb()*modGr.getProb(g)>prBest ) {
		  gBest  = g;
		  prBest = accChart.get(0,n,g).getProb()*modGr.getProb(g);
        }
      if ( prBest>LogProb() )
        accChart.get(0,n,gBest).writeBest(stdout,lx);
      else
        cout << "[ERROR failed]";
      cout << endl;
    }
  }

  //if (nArgs>2) fin.close();
}

