///////////////////////////////////////////////////////////////////////////////
//                                                                           //
// This file is part of ModelBlocks. Copyright 2009, ModelBlocks developers. //
//                                                                           //
//    ModelBlocks is free software: you can redistribute it and/or modify    //
//    it under the terms of the GNU General Public License as published by   //
//    the Free Software Foundation, either version 3 of the License, or      //
//    (at your option) any later version.                                    //
//                                                                           //
//    ModelBlocks is distributed in the hope that it will be useful,         //
//    but WITHOUT ANY WARRANTY; without even the implied warranty of         //
//    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the          //
//    GNU General Public License for more details.                           //
//                                                                           //
//    You should have received a copy of the GNU General Public License      //
//    along with ModelBlocks.  If not, see <http://www.gnu.org/licenses/>.   //
//                                                                           //
///////////////////////////////////////////////////////////////////////////////

#include <getopt.h>
#include <sstream>

#include "nl-iomacros.h"
#include "nl-cpt.h"
#include "nl-hmm.h"

////////////////////////////////////////////////////////////////////////////////

static struct option long_options[] = {
  {"beam-width", required_argument, 0, 'b'},
  {"perplexity", required_argument, 0, 'p'},
  {"help", no_argument, 0, 'h'},
  {"sum-probs", no_argument, 0, 's'},
  {"entropy", no_argument, 0, 'e'},
  {"avg-depth", no_argument, 0, 'd'},
  {0, 0, 0, 0}};

void printUsage(char* progname) {
  fprintf(stderr,"Usage: cat <TXT_FILE> | %s [OPTIONS]... [MODEL_FILE1] [MODEL_FILE2] ...\n",progname);
  fprintf(stderr," System variable Settings:\n");
  fprintf(stderr,"  -b, --beam-width=WIDTH\n");
  fprintf(stderr,"  -h, --help\t\tPrint this message\n");
  fprintf(stderr,"  -s, --sum-probs\t\tSum and display probabilities on beam\n");
  fprintf(stderr,"  -p, --perplexity=LOG_BASE\t\tCalculate and display perplexity using specified log base\n");
  fprintf(stderr,"  -e, --entropy\t\tCalculate the entropy of beam's distribution\n");
  fprintf(stderr,"  -d, --avg-depth\t\tShow probability-weighted average of stack depth on beam\n");
  fprintf(stderr,"  -l,       \t\tMax sentence length allowed\n");
}

////////////////////////////////////////////////////////////////////////////////

template <class MY, class MX, class YS=typename MY::RandVarType, class B=NullBackDat<typename MY::RandVarType> >
class HHMMParser {
 public:
  HHMMParser ( int, char*[] ) ;
  HHMMParser ( int, char*[], int);
  HHMMParser ( int, char*[], int, int);
};

////////////////////
template <class MY, class MX, class YS, class B>
HHMMParser<MY,MX,YS,B>::HHMMParser ( int nArgs, char* argv[] ) {
  HHMMParser(nArgs, argv, 2000, 0);
}
template <class MY, class MX, class YS, class B>
HHMMParser<MY,MX,YS,B>::HHMMParser ( int nArgs, char* argv[], int beam ) {
  HHMMParser(nArgs, argv, beam, 0);
}
template <class MY, class MX, class YS, class B>
HHMMParser<MY,MX,YS,B>::HHMMParser ( int nArgs, char* argv[], int beam, int numFakeWords ) {

  int MAX_WORDS  = 1000;
  int BEAM_WIDTH = beam;
  bool SUM_PROBS = false;
  bool ENTROPY = false;
  bool STACK_DEPTH = false;
  bool PERPLEXITY = false;
  double PERPLEXITY_LOG_BASE = exp(1.0);

  MY mY;
  MX mX;
  HMM<MY,MX,YS,B> hmm(mY,mX);
  //hmm.OUTPUT_NOISY=1; //hmm.OUTPUT_VERYNOISY=1;

  ////////// PART 0: COMMAND-LINE OPTIONS

  // Parse command-line options
  char* progname = strdup(argv[0]);
  int option_index = 0;
  char opt;
  while( (opt = getopt_long(nArgs, argv, "b:p:hl:sdevV", 
                            long_options, &option_index)) != -1 ) {
    switch(opt) {
    case 'h': printUsage(progname);
              exit(0);
    case 'b': BEAM_WIDTH = atoi(optarg); if(BEAM_WIDTH<=0)cerr<<"\nERROR: beam width set to zero!\n\n"; break;
    case 'l': MAX_WORDS = atoi(optarg); break;
    case 's': SUM_PROBS   = true; break;
    case 'e': ENTROPY     = true; break;
    case 'd': STACK_DEPTH = true; break;
    case 'p': { 
      PERPLEXITY=true;
      if (optarg[0]=='e') {
	PERPLEXITY_LOG_BASE = exp(1.0);
      } else {
	PERPLEXITY_LOG_BASE = atof(optarg);
      } 
      cerr << "Perplexity will be calculated using log base\t" << PERPLEXITY_LOG_BASE << endl;
      break;
    }
    case 'v': hmm.OUTPUT_NOISY=1; break;
    case 'V': hmm.OUTPUT_NOISY=1; hmm.OUTPUT_VERYNOISY=1; break;
    default: break;
    }
  }

  // Complain if too few args...
  if (optind >= nArgs ) {
    printUsage(progname);
    exit(0);
  }

  //modCgivC.NOISY_READ = true;

//  clock_t start, end;
#ifdef TIMER_ON
  struct timeval start;
  struct timeval end;
  double elapsed_s;
#endif
  int numWords=0;
  int numWordsInSentence=0;
  int numOOV=0;
  int numOOVInSentence=0;
  double sentenceOOVLogProb=0.0;

  double corpusLogProb = log(1.0);

  //// I. LOAD MODELS...

  //'\0' is the null character used to terminate strings in C/C++.
  //"\0" is an empty string.
  // For each model file...
  for ( int a=optind; a<nArgs; a++ ) {                                           // read models
    FILE* pf = fopen(argv[a],"r"); assert(pf);                                   // Read model file
    if(!pf){
      cerr << "Error loading model file " << argv[a] << endl;
      return;
    }
    cerr << "Loading model \'" << argv[a] << "\'...\n";
    int c=' '; int i=0; int line=1; String sBuff(1000);                          // Lookahead/ctrs/buffers
    CONSUME_ALL ( pf, c, WHITESPACE(c), line);                                   // Get to first record
    while ( c!=-1 && c!='\0' && c!='\5' ) {                                      // For each record
      CONSUME_STR ( pf, c, (c!='\n' && c!='\0' && c!='\5'), sBuff, i, line );    //   Consume line
      StringInput si(sBuff.c_array());
      if ( !( sBuff[0]=='#' ||                                                   //   Accept comments/fields
              si>>mY>>"\0"!=NULL ||
              si>>mX>>"\0"!=NULL ) )
	cerr<<"\nERROR: can't parse \'"<<sBuff<<"\' in line "<<line<<"\n\n";
      CONSUME_ALL ( pf, c, WHITESPACE(c), line);                                 //   Consume whitespace
      if ( line%100000==0 ) cerr<<"  "<<line<<" lines read...\n";                //   Progress for big models
    }
    cerr << "Model \'" << argv[a] << "\' loaded.\n";
  }


  //// II. INIT HMM...
  YS ysBEG; 
  YS ysEND; 
  if (numFakeWords < 1) {
    StringInput(String(BEG_STATE).c_array())>>ysBEG>>"\0";
    StringInput(String(END_STATE).c_array())>>ysEND>>"\0";
    cerr<<ysBEG<<"\n";
    cerr<<ysEND<<"\n";
  }
	
  //typename MX::RandVarType ov;                                                   // define a single, running observed variable


  List<typename MX::RandVarType> lx;
  string                         sLine;
  int                            ctr=0;


  //// III. RUN PARSER...

#ifdef TIMER_ON
  gettimeofday(&start, NULL);
#endif
  while ( getline(cin,sLine) ) {
    ctr++;
    cerr <<"line number: "<< ctr <<endl;

    #ifdef USE_HDWDVECTSEMLANGMODEL
    mY.init();  // comment in for headwords
    #endif

		
    stringstream ss(stringstream::in|stringstream::out);
    lx=List<typename MX::RandVarType>();
    ss<<sLine;

    if (numFakeWords >= 1) {
      string rel;
      ss >> rel;
      const string ARG_STATE_BEG = ";-";
      //const string ARG_STATE_END = ";NULL";
      const string ARG_STATE_END = ";END";
      StringInput(String((BEG_STATE + rel + ARG_STATE_BEG).c_str()).c_array())>>ysBEG>>"\0";
      StringInput(String((END_STATE + rel + ARG_STATE_END).c_str()).c_array())>>ysEND>>"\0";
      cerr<<ysBEG<<"\n";
      cerr<<ysEND<<"\n";
    }
    hmm.init(MAX_WORDS,BEAM_WIDTH,ysBEG);                               // hmm init
    //    hmm.debugPrint();	
    // Read in each word...
    //lx=List<X>();
    
#ifndef USE_COREFSMODEL
    bool b1 = true;
#endif
    numWordsInSentence = 0;
    numOOVInSentence = 0;
    sentenceOOVLogProb = 0.0;
    for (string s; ss>>s; ) {
      numWords++;
      numWordsInSentence++;
      cerr << " " << s;
      typename MX::RandVarType x(s.c_str());                            // set observed variable value
      lx.add()=x;
      //ov.set(x,mX);

      //MY::WORD = x;
      //ov.getW();
      if (hmm.unknown(x)) {
	numOOV++;
	numOOVInSentence++;
	//	sentenceOOVLogProb += mx.getProb(
      }
      hmm.updateRanked(x,b1);                                           // update hmm
      //      hmm.debugPrint();
      //(ov);
      #ifdef WRITE_CACHES
      mY.writeCaches();
      #endif

#ifndef USE_COREFSMODEL
      b1 = false;
#endif

      cerr << " (" << hmm.getBeamUsed() << ")";
      if ( hmm.OUTPUT_NOISY ) hmm.writeCurr(cout);
      if ( SUM_PROBS ) hmm.writeCurrSum(stdout);
      if ( ENTROPY ) hmm.writeCurrEntropy(stdout);
      //if ( STACK_DEPTH ) hmm.writeCurrDepths(stdout);
    }
    cerr << "\n";
    
    if( PERPLEXITY ) {
      double prob = hmm.getCurrSum();
      if (prob > 0) {
	double logProb = log(prob) / log(PERPLEXITY_LOG_BASE);
	corpusLogProb += logProb;
	cout << "Prior probability of sentence:\t" << prob << endl;
	cout << "Number of words in sentence:\t" << numWordsInSentence << endl;
	cout << "Number of OOV words in sentence:\t" << numOOVInSentence << endl;
	double ppl = pow(PERPLEXITY_LOG_BASE,( -logProb / (numWordsInSentence-numOOVInSentence)));
	cout << "Perplexity:\t" << ppl << endl;
	cout << "Corpus logProb so far:\t" << corpusLogProb << endl;
	cout << "Corpus word count so far:\t" << numWords << endl;
	cout << "Corpus OOV count so far:\t" << numOOV << endl;
	double corpusPPL = pow(PERPLEXITY_LOG_BASE,( -corpusLogProb / (numWords-numOOV) ));
	cout << "Corpus perplexity so far:\t" << corpusPPL << endl;
      } else {
	cout << "Prior probability of sentence:\t" << 0.0 << endl;
	cout << "Sentence has zero probability; not including it in perplexity calculations." << endl;
	numWords -= numWordsInSentence;
	numOOV -= numOOVInSentence;
      }
    } 

    typename MX::RandVarType x("eos");
    lx.add()=x;
    //ov.set(x,mX);
#ifndef USE_COREFSMODEL
    b1 = true;
#endif
    //MY::WORD = x;
    //ov.getW();
    hmm.updateRanked(x,b1);
    //(ov);
    #ifdef WRITE_CACHES
    mY.writeCaches();
    #endif
    if ( hmm.OUTPUT_NOISY ) hmm.writeCurr(cout);
    if ( SUM_PROBS ) hmm.writeCurrSum(stdout);
    if ( ENTROPY ) hmm.writeCurrEntropy(stdout);
    //if ( STACK_DEPTH ) hmm.writeCurrDepths(stdout);
    

    int wnum=1;
    list<TrellNode<YS,B> > lys = hmm.getMLSnodes(ysEND);                          // get mls list
    for ( typename list<TrellNode<YS,B> >::iterator i=lys.begin(); i!=lys.end(); i++ ) {   // for each frame
      cout << "HYPOTH " << wnum++
           << " " << i->getBackData()
           << " " << *lx.getFirst()
           << " " << i->getId();

      if ( hmm.OUTPUT_NOISY || hmm.OUTPUT_VERYNOISY || SUM_PROBS ) 
	cout << " (" << i->getLogProb() << ")"; // SWU: re-included for perplexity calculation

      cout << endl;                                                            // print RV val

      lx.pop();
    }

    //#ifdef USE_HDWDVECTSEMLANGMODEL
    //mY.clear();  // comment in for headwords
    //#endif

    cout << "--------------------\n";
  }
#ifdef TIMER_ON
  gettimeofday(&end,NULL);
  double beg_time_s = (double) start.tv_sec + (double) ((double)start.tv_usec / 1000000.0);
  double end_time_s = (double) end.tv_sec + (double) ((double)end.tv_usec / 1000000.0);
  elapsed_s = (end_time_s - beg_time_s);
  fprintf(stderr, "Realtime elapsed %f s elapsed...\n", elapsed_s);
  fprintf(stderr, "%f seconds per word\n", elapsed_s / (double)numWords);
  fprintf(stderr, "%f seconds per sentence\n", elapsed_s / (double)ctr);
#endif
}
