///////////////////////////////////////////////////////////////////////////////
//                                                                           //
// This file is part of ModelBlocks. Copyright 2009, ModelBlocks developers. //
//                                                                           //
//    ModelBlocks is free software: you can redistribute it and/or modify    //
//    it under the terms of the GNU General Public License as published by   //
//    the Free Software Foundation, either version 3 of the License, or      //
//    (at your option) any later version.                                    //
//                                                                           //
//    ModelBlocks is distributed in the hope that it will be useful,         //
//    but WITHOUT ANY WARRANTY; without even the implied warranty of         //
//    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the          //
//    GNU General Public License for more details.                           //
//                                                                           //
//    You should have received a copy of the GNU General Public License      //
//    along with ModelBlocks.  If not, see <http://www.gnu.org/licenses/>.   //
//                                                                           //
///////////////////////////////////////////////////////////////////////////////

#include <iostream>
#include <fstream>
#include <sstream>
#include <getopt.h>

#include "nl-randvar.h"
#include "nl-dtree.h"
#include "nl-iomacros.h"
#include "nl-refrv.h"
//#include "nl-modelfile.h"

//char psX[]="";
//char psUscUscUsc[]="___";

//#define NOISY 1

#define THRESHOLD 0.0

////////////////////////////////////////////////////////////////////////////////
//
//  Random Variables
//
////////////////////////////////////////////////////////////////////////////////

#include "TextObsVars.h"
#include "PCFGLangModel-ml.h"
#include "TextObsModel-svs.h"

////////////////////////////////////////////////////////////////////////////////
//
//  helper classes, functions
//
////////////////////////////////////////////////////////////////////////////////

bool NOISY = false;

class ChartCell {
 private:
  static const ChartCell xDummy;
  bool             bRooted;
  G                g;
  int              k;
  const ChartCell* px1;
  const ChartCell* px2;
  LogProb          pr;
 public:
  ChartCell ( )                                                                      : bRooted(false), g(),   k(0),  px1(NULL), px2(NULL), pr()    { }
  ChartCell ( G& gA, LogProb prA )                                                   : bRooted(false), g(gA), k(0),  px1(NULL), px2(NULL), pr(prA) { }
  ChartCell ( G& gA, const ChartCell& x1, LogProb prA )                              : bRooted(false), g(gA), k(0),  px1(&x1),  px2(NULL), pr(prA) { }
  ChartCell ( G& gA, int kA, const ChartCell& x1, const ChartCell& x2, LogProb prA ) : bRooted(false), g(gA), k(kA), px1(&x1),  px2(&x2),  pr(prA) { }
  void setRooted ( )       { bRooted=true; }
  bool getRooted ( ) const { return bRooted; }
  bool operator!= ( const ChartCell& x ) const { return (g!=x.g || k!=x.k || px1!=x.px1 || px2!=x.px2 || pr!=x.pr); }
  const G&         getG    ( ) const { return g;  }
  int              getK    ( ) const { return k;  }
  const ChartCell& getX1   ( ) const { return (px1!=NULL) ? *px1 : xDummy; }
  const ChartCell& getX2   ( ) const { return (px2!=NULL) ? *px2 : xDummy; }
  LogProb          getProb ( ) const { return pr; }
  void prune      ( ) { /*???*/ }
  void writeBest  ( FILE* pf, List<W>& lw ) const { 
    fprintf(pf," (%s{%s}",g.first.getString().c_str(),g.second.getString().c_str());
    if(px1)px1->writeBest(pf,lw);
    if(px2)px2->writeBest(pf,lw);
    if(!px1 && !px2){fprintf(pf," ");fprintf(pf,"%s",lw.getFirst()->getString().c_str());lw.pop();}
    fprintf(pf,")"); }
};
const ChartCell ChartCell::xDummy;

////////////////////////////////////////////////////////////////////////////////

typedef SimpleHash<G,ChartCell> RefCell;
//std::map<G2Cell,G2Cell> domCell;
//typedef RefRV<G2Cell,domCell> RefCell;
//typedef SafePtr<G2Cell> RefCell;
//typedef pair<G,ChartCell> GCell;
//std::map<GCell,GCell> domCell;
//typedef RefRV<GCell,domCell> RefCell;


////////////////////////////////////////////////////////////////////////////////

void tryUnary ( SafeArray2D<Id<int>,Id<int>,RefCell>& axChart, int i, int j, const ChartCell& xG, const GuModel& modGgivG ) {
  G gP;
//  for ( bool bgp=gP.setFirst(); bgp; bgp=gP.setNext() ) {
  for ( GuModel::const_iterator iter = modGgivG.begin(); iter!=modGgivG.end(); iter++ ) {
    G          gP = iter->first.getX1();
    G          gg = xG.getG();
    ChartCell& xP = axChart.set(i,j).set(gP);
    LogProb    pr = xG.getProb()*modGgivG.getProb(gg,gP);
    if ( pr>xP.getProb() ) {
      xP = ChartCell(gP,xG,pr);
      if ( NOISY )
        cout << "unary match! " << i << " " << j << " " << gP << " -> " << gg << " -> " << pr.toInt() << endl;
      tryUnary ( axChart, i, j, xP, modGgivG );
    }
  }
}


////////////////////////////////////////////////////////////////////////////////
//
//  Main Function (pipe data input)
//
////////////////////////////////////////////////////////////////////////////////

static struct option long_options[] = {
  {"forest", optional_argument, 0, 'f'},
  {"help", no_argument, 0, 'h'},
  {"verbose", no_argument, 0, 'v'},
  {0, 0, 0, 0} };

void printUsage(char* progname) {
  fprintf(stderr,"Usage: cat <RAW_FILE> | %s [OPTIONS]... [MODEL_FILE1] [MODEL_FILE2] ...\n",progname);
  fprintf(stderr,"Runs the speech recognition system on RAW_FILE using the models in MODEL_FILE1, 2, etc.\n");
  fprintf(stderr," System variable Settings:\n");
  fprintf(stderr,"  -f, --forest=TRUE\n");
  fprintf(stderr,"  -h, --help\t\tPrint this message\n");
  fprintf(stderr,"  -v, --verbose=TRUE\n");
}

int main (int nArgs, char* argv[]) {

  //modGgivG.NOISY_READ = true;

  bool FOREST = false;

  GGModel modGGgivG;
  MModel  modM;
  LModel  modL;
  GuModel modGgivG;
  GModel  modGr;
  CModel  modC;
  HWModel modHW;
  OModel  mO;

  // Parse command-line options
  char* progname = strdup(argv[0]);
  int option_index = 0;
  char opt;
  while( (opt = getopt_long(nArgs, argv, "fhv", long_options, &option_index)) != -1 ) {
    switch(opt) {
    case 'f': FOREST=true; break;
    case 'h': printUsage(progname); exit(0);
    case 'v': NOISY=true; break;
    default: break;
    }
  }

  for ( int a=optind; a<nArgs; a++ ) {
    FILE* pf = fopen(argv[a],"r"); assert(pf);                                // READ MODEL FILE
    cerr << "Loading model \'" << argv[a] << "\'...\n";
    int c=' '; int i=0; int line=1; Array<char*> aps(100); String sBuff(1000);   // Lookahead/ctrs/buffers
    CONSUME_ALL ( pf, c, WHITESPACE(c), line);                                   // Get to first record
    while ( c!=-1 && c!='\0' && c!='\5' ) {                                      // For each record
      CONSUME_STR ( pf, c, (c!='\n' && c!='\0' && c!='\5'), sBuff, i, line );    //   Consume line
      String sBuff2(sBuff); StringInput si(sBuff2.c_array());
      //sBuff.split(aps," :=");                                                  //   Split into fields
      //cerr<<si.c_str()<<endl;
      if ( !( sBuff[0]=='#' ||                                                   //   Accept comments/fields
              //si>>mH>>"\0"!=NULL ||
              si>>mO>>"\0"!=NULL ||
	      si>>"C ">>modC>>"\0"!=NULL || 
	      si>>"Gr ">>modGr>>"\0"!=NULL || 
	      si>>"Gu ">>modGgivG>>"\0"!=NULL || 
	      //si>>"GG ">>modGGgivG>>"\0"!=NULL ) )
	      si>>"HW ">>modHW>>"\0"!=NULL || 
	      si>>"M ">>modM>>"\0"!=NULL || 
	      si>>"L ">>modL>>"\0"!=NULL ) )
	      //si>>"Pc ">>modPgivCvar>>"\0"!=NULL ||
	      //si>>"Pw ">>modPgivWs>>"\0"!=NULL ||
	      //si>>"PwDT ">>modPgivWdt>>"\0"!=NULL ||
	      //si>>"P ">>modP>>"\0"!=NULL ) )
	      ////mH.readFields(aps) || mO.readFields(aps)*/ ) )
	//	cerr<<"\nERROR: "<<aps.size()<<"-arg "<<((aps.size()>0)?aps[0]:"??")<<" in line "<<line<<"\n\n";
	cerr<<"\nERROR: can't parse \'"<<sBuff<<"\' in line "<<line<<"\n\n";
      CONSUME_ALL ( pf, c, WHITESPACE(c), line);                                 //   Consume whitespace
      if ( line%100000==0 ) cerr<<"  "<<line<<" lines read...\n";                //   Progress for big models
    }
    cerr << "Model \'" << argv[a] << "\' loaded.\n";
  }
//   for ( int i=optind; i<nArgs; i++ ) {
//     cerr << "Loading model \'" << argv[i] << "\'\n";
//     FILE* pf = fopen(argv[i],"r"); 
//     assert(pf);
//     processModelFilePtr ( pf, readFields );
//     cerr << "Model \'" << argv[i] << "\' loaded.\n";
//   }

  ifstream fin;
  List<W>  lw;
  string   sLine;
  int ctr=0;

//  if (nArgs>2) fin.open(argv[2]);

  // Read in each line...
//  while ( (nArgs>2) ? getline(fin,sLine) : getline(cin,sLine) ) {
  while ( getline(cin,sLine) ) {
    ctr++;
    cerr << ctr ;

    // Read in each word...
    stringstream ss(stringstream::in|stringstream::out);
    lw=List<W>();
    ss<<sLine;
    for (string s; ss>>s; ) {
      cerr << " " << s;
      lw.add() = W(s.c_str()); //.read(s);
    }

    cerr << endl;

    // Allocate chart...
    SafeArray2D<Id<int>,Id<int>,RefCell> axChart;
    int n = lw.getCard();
    axChart.init ( n+1, n+2 );

    // Initialize chart with terminals (POSs)...
    int i=0;
    Listed(W)* pw;
    foreach(pw,lw) {

      //OModel::RandVarType w; w = *static_cast<W*>(pw);
      OModel::RandVarType w; w.set(*pw,mO);
      if ( NOISY )
        cerr << "word " << pw->getString() << endl;
      //for ( CPT2DModel<P,C,LogProb>::const_iterator iter = modPgivCvar.begin(); iter!=modPgivCvar.end(); iter++ ) {
      //  C c = iter->first.getX1();

      //G g;
      //for ( bool bG=g.setFirst(); bG; bG=g.setNext() ) {
      CModel::IterVal c;
      for ( bool bc=modC.setFirst(c); bc; bc=modC.setNext(c) ) {

	G g; // composed of separate C and E models to account for unk and E-matching
	E e = E(pw->getString().c_str());
	// Set unknown words to "unk".  
	if ( modHW.getProb(HW(pw->getString().c_str()))==LogProb() )
	  g = G(c,E_UNK);
	else
	  g = G(c,e);

        if ( w.getProb(g)!=LogProb() ) {
          ChartCell* px = &axChart.set(i,i+1).set(g);
          *px = ChartCell ( g, w.getProb(g) );
          //tryUnary ( axChart, i, i+1, *px, modGgivG );
          if ( NOISY )
            cerr << "pos match! " << i << " " << i+1 << " " << g << " -> " << pw->getString()
                 << " " << w.getProb(g).toInt() << "=" << px->getProb().toInt() << endl;
          if ( FOREST )
            fprintf ( stdout, "%d %s{%s} %d : %d %s %d = %d\n",
                      i, g.first.getString().c_str(), g.second.getString().c_str(), i+1,
                      i, pw->getString().c_str(), i+1, w.getProb(g).toInt() );
        }
      }
      
      /*
      ModeledWgivP w = *pw;
      #ifdef NOISY /////////////////////////
      cout << "word " << w.getString() << endl;
      #endif ///////////////////////////////
      for ( CPT2DModel<P,C,LogProb>::const_iterator iter = modPgivC.getConds().begin(); iter!=modPgivC.getConds().end(); iter++ ) {
        C c = iter->first.getX1();
        ModeledPgivC p;
        for ( bool bp=p.setFirst(c); bp; bp=p.setNext(c) ) {
          if ( w.getProb(p)!=LogProb() ) {
            ChartCell* px = &axChart.set(i,i+1,c);
            *px = ChartCell ( c, w.getProb(p)*p.getProb(c) );
            tryUnary ( axChart, i, i+1, *px );
            #ifdef NOISY /////////////////////////
            cout << "pos match! " << i << " " << i+1 << " " << c.getString() << " -> " << p.getString() << " -> " << w.getString()
                 << " " //<< w.getProb(p).toInt() 
              //<< modPgivW.getProb(p,w).toInt() << "/" << modP.getProb(p).toInt()
                 << "*" << p.getProb(c).toInt() << "=" << px->getProb().toInt() << endl;
              //(w.getProb(p)*p.getProb(c)).toInt() << endl;
            #endif ///////////////////////////////
          }
        }
      }
      */

      i++;
    }

    // Fill chart with nonterminals...
    for ( int di=1; di<=n; di++ ) {
      cerr << "  " << di << endl;
      for ( int i=0; i<=n-di; i++ ) {
        int j=i+di;
        for ( int k=i+1; k<=j-1; k++ ) {
          for ( MModel::const_iterator iter = modM.begin(); iter!=modM.end(); iter++ ) {
            G                gP = iter->first.getX1();

	    // Calculate the factored probabilities
	    MModel::IterVal ilclc;
	    LModel::IterVal ieLeft;
	    LModel::IterVal ieRight;
	    if ( !modGGgivG.contains(gP) ) {
	      //cerr<<" entering loop w/ gParent = "<<gP<<"\n";
	      for ( bool bg=modM.setFirst(ilclc,gP); bg; bg=modM.setNext(ilclc,gP) ) {
		L lLeft  = ilclc.first.first;
		L lRight = ilclc.second.first;
		//cerr<<"  iter over lc lc = "<<lLeft<<":"<<ilclc.first.second<<" "<<lRight<<":"<<ilclc.second.second<<endl;
		E eP   = gP.second;
		for ( bool ble=modL.setFirst(ieLeft,lLeft,eP); ble; ble=modL.setNext(ieLeft,lLeft,eP) ) {
		  //cerr<<"   iter over eLeft = "<<ieLeft<<endl;
		  for ( bool ble=modL.setFirst(ieRight,lRight,eP); ble; ble=modL.setNext(ieRight,lRight,eP) ) {
		    //cerr<<"   iter over eRight = "<<ieRight<<endl;
		    GG gg = GG( G(ilclc.first.second,ieLeft), G(ilclc.second.second,ieRight) );
		    Prob pOld = modGGgivG.getProb(gg,gP);
		    modGGgivG.setProb(gg,gP) = 
		      pOld //modGGgivG.getProb(gg,gP).toDouble()
		      + modM.getProb(ilclc,gP)*modL.getProb(ieLeft,lLeft,eP)*modL.getProb(ieRight,lRight,eP);
		    //cerr<<" trying to calc "<<gP<<" -> "<<ilclc<<" ("<<modM.getProb(ilclc,gP)<<")"<<",    "
		    //<<lLeft<<" "<<eP<<" -> "<<ieLeft<<" ("<<modL.getProb(ieLeft,lLeft,eP)<<")"<<",    "
		    //<<lRight<<" "<<eP<<" -> "<<ieRight<<" ("<<modL.getProb(ieRight,lRight,eP)<<")     "
		    //<<pOld.toDouble()<<" = "<<modGGgivG.getProb(gg,gP).toDouble()<<" = "
		    //<<LogProb(modGGgivG.getProb(gg,gP)).toInt()<<endl;
		  }
		}
	      }
	    }

	    // Check/fill the chart
            GGModel::IterVal gg;
            for ( bool bgg=modGGgivG.setFirst(gg,gP); bgg; bgg=modGGgivG.setNext(gg,gP) ) {
              if ( LogProb(modGGgivG.getProb(gg,gP))>LogProb(THRESHOLD) ) {
		//cerr << " found GG rule "<< gP << " -> " << gg<< endl;
                const G& g1 = gg.first;  const ChartCell& x1 = axChart.get(i,k).get(g1);
                const G& g2 = gg.second; const ChartCell& x2 = axChart.get(k,j).get(g2);
                if ( x1!=ChartCell() && x2!=ChartCell() ) {
                  ChartCell& x = axChart.set(i,j).set(gP);
                  LogProb pr = x1.getProb()*x2.getProb()*LogProb(modGGgivG.getProb(gg,gP));
                  if ( NOISY )
                    cerr << "      binary match! " << i << " " << k << " " << j << " " << gP
                         << " -> " << g1 << " " << g2 << " "
                         << x1.getProb().toInt() << "*" << x2.getProb().toInt() << "*" << LogProb(modGGgivG.getProb(gg,gP)).toInt()
                         << "=" << pr.toInt() << endl;
                  //if ( FOREST )
                  //  fprintf ( stdout, "%d %s %d : %d %s %d %s %d = %d\n",
                  //            i, gP.getString().c_str(), j,
                  //            i, g1.getString().c_str(), k, g2.getString().c_str(), j, pr.toInt() );
                  if ( pr>x.getProb() ) {
                    x = ChartCell(gP,k,x1,x2,pr);
		    //tryUnary ( axChart, i, j, x, modGgivG );
                  }
                }
              }
            }
          }
        }
      }
    }

    if ( FOREST ) {

      //ModeledC c;
      GModel::IterVal g;
      for ( bool bg=modGr.setFirst(g); bg; bg=modGr.setNext(g) )
        if ( axChart.get(0,n).get(g).getProb()*modGr.getProb(g) != LogProb() ) {
          axChart.set(0,n).set(g).setRooted();
        }

      for ( int di=n; di>=1; di-- ) {
        cerr << "  " << di << endl;
        for ( int i=0; i<=n-di; i++ ) {
          int j=i+di;
          for ( GGModel::const_iterator iter = modGGgivG.begin(); iter!=modGGgivG.end(); iter++ ) {
            G                gP = iter->first.getX1();
            const ChartCell& x  = axChart.get(i,j).get(gP);
            if ( x.getRooted() ) {
              for ( int k=i+1; k<=j-1; k++ ) {
                GGModel::IterVal gg;
                for ( bool bgg=modGGgivG.setFirst(gg,gP); bgg; bgg=modGGgivG.setNext(gg,gP) ) {
                  if ( LogProb(modGGgivG.getProb(gg,gP))>LogProb(THRESHOLD) ) {
                    const G& g1 = gg.first;  ChartCell& x1 = axChart.set(i,k).set(g1);
                    const G& g2 = gg.second; ChartCell& x2 = axChart.set(k,j).set(g2);
                    if ( x1!=ChartCell() && x2!=ChartCell() ) {
                      x1.setRooted();
                      x2.setRooted();
                      LogProb pr = x1.getProb()*x2.getProb()*LogProb(modGGgivG.getProb(gg,gP));
                      fprintf ( stdout, "%d %s{%s} %d : %d %s{%s} %d %s{%s} %d = %d\n",
                                i, gP.first.getString().c_str(), gP.second.getString().c_str(), j,
                                i, g1.first.getString().c_str(), g1.second.getString().c_str(), 
				k, g2.first.getString().c_str(), g2.second.getString().c_str(), j, pr.toInt() );
                    }
                  }
                }
              }
            }
          }
        }
      }
      cout << "-" << endl;

    } else {
      // Print best...
      LogProb         prBest=LogProb();
      G               gBest;
      GModel::IterVal g;
      for ( bool bg=modGr.setFirst(g); bg; bg=modGr.setNext(g) )
        if ( LogProb()==prBest || axChart.get(0,n).get(g).getProb()*modGr.getProb(g)>prBest ) {
		  gBest  = g;
		  prBest = axChart.get(0,n).get(g).getProb()*modGr.getProb(g);
        }
      if ( prBest>LogProb() ) {
        axChart.get(0,n).get(gBest).writeBest(stdout,lw);
	//if (NOISY)
	//cout << "      Probability="<<prBest<<endl;
      }
      else
        cout << "[ERROR failed]";
      cout << endl;
    }

  }

  //if (nArgs>2) fin.close();
  //if (NOISY) {
  //  cerr << " trying to dump modGGgivG\n";
  //  modGGgivG.dump(cerr,"GG");
  //}


}

