/*******************************************************************/
/*      File: lexicon.C                                            */
/*    Author: Helmut Schmid                                        */
/*   Purpose:                                                      */
/*   Created: Tue Nov  5 12:03:06 2002                             */
/*  Modified: Thu Feb 12 08:59:06 2009 (schmid)                    */
/* Copyright: Institut fuer maschinelle Sprachverarbeitung         */
/*            Universitaet Stuttgart                               */
/*******************************************************************/

#include <iostream>
using std::cerr;

#include "lexicon.h"

#define BUFFER_SIZE    100000
#define PROB_THRESHOLD 0.01   // POS tags are deleted if less probable than 
                              // threshold * highest tag probability

double SmoothingWeight = 1.0; // weight of the equivalence class probability
bool WithLemmas;

static const char *DefaultString = "<DEFAULT>";

static vector<char> is_oc_tag;


/*******************************************************************/
/*                                                                 */
/*  Tags::init                                                     */
/*                                                                 */
/*******************************************************************/

void Tags::init( vector<SymNum> &symbols )

{
  l = (unsigned int)symbols.size();
  delete[] tag;
  tag = new Tag[symbols.size()];
  for( size_t i=0; i<symbols.size(); i++ ) {
    tag[i].symbol = (SymNum)symbols[i];
    tag[i].freq = 0.0;
    tag[i].lemma = NULL;
  }
}


/*******************************************************************/
/*                                                                 */
/*  Tags::init                                                     */
/*                                                                 */
/*******************************************************************/

void Tags::init( vector<SymNum> &symbols, vector<float> &freqs )

{
  l = (unsigned int)symbols.size();
  delete[] tag;
  tag = new Tag[symbols.size()];
  for( size_t i=0; i<symbols.size(); i++ ) {
    tag[i].symbol = (SymNum)symbols[i];
    if (WithProbs)
      tag[i].freq = freqs[i];
    if (WithLemmas)
      tag[i].lemma = (char*)NULL;
  }
}


/*******************************************************************/
/*                                                                 */
/*  Tags::init                                                     */
/*                                                                 */
/*******************************************************************/

void Tags::init( vector<SymNum> &symbols, vector<float> &freqs, 
		 vector<const char*> &lemmas )
{
  l = (unsigned int)symbols.size();
  delete[] tag;
  tag = new Tag[symbols.size()];
  for( size_t i=0; i<symbols.size(); i++ ) {
    tag[i].symbol = (SymNum)symbols[i];
    if (WithProbs)
      tag[i].freq = freqs[i];
    if (WithLemmas)
      tag[i].lemma = lemmas[i];
  }
}


/*******************************************************************/
/*                                                                 */
/*  Lexicon::add_entry                                             */
/*                                                                 */
/*******************************************************************/

void Lexicon::add_entry(char *word, vector<SymNum> &syms, vector<float> &freqs, 
			vector<const char*> &lemmas)

{
  iterator it=lex.find(word);

  if (it != lex.end()) {
    cerr << "Warning: two lexicon entries for " << word;
    cerr << " (second ignored)!\n";
    return;
  }
  lex[RefString(word)].init( syms, freqs, lemmas );
}


/*******************************************************************/
/*                                                                 */
/*  Lexicon::read_oc                                               */
/*                                                                 */
/*******************************************************************/

void Lexicon::read_oc( char *filename, Tags &tags )

{
  vector<SymNum> tag;
  vector<float> freq;

  FILE *file=fopen(filename,"rt");
  if (file == NULL) {
    char *message=(char*)malloc(1000);
    sprintf(message,"unable to open file \"%s\"", filename);
    throw message;
  }

  char buffer[1000];
  int r;
  while ((r = fscanf(file, "%s", buffer)) == 1) {
    int sym=grammar.symbol_number(buffer);
    if (sym == -1)
      cerr << "\nWarning: unknown symbol " << buffer<<" in open class file!\n";
    else
      tag.push_back(sym);
    if (WithProbs) {
      float f;
      if ((r = fscanf(file, "%f", &f)) != 1) {
	char *message=(char*)malloc(1000);
	sprintf(message,"missing frequency in open class file \"%s\"",
		filename);
	throw message;
      }
      if (f < 0.0) {
	char *message=(char*)malloc(1000);
	sprintf(message,"negative frequency in open class file \"%s\" at: %s %f\n", filename, buffer, f);
	throw message;
      }
      if (sym != -1)
	freq.push_back(f);
    }
  }

  if (r != -1) {
    char *message=(char*)malloc(1000);
    sprintf(message,"in open class file \"%s\"", filename);
    throw message;
  }
  fclose(file);

  tags.init(tag, freq);
}



/*******************************************************************/
/*                                                                 */
/*  Lexicon::read_wc                                               */
/*                                                                 */
/*******************************************************************/

void Lexicon::read_wc( char *filename )

{
  FILE *file=fopen(filename,"rt");
  if (file == NULL) {
    char *message=(char*)malloc(1000);
    sprintf(message,"unable to open file \"%s\"", filename);
    throw message;
  }

  // read the automaton from the file
  WCAutomaton = new Automaton(file);
  if (number_of_wordclasses() == 0)
    throw "error in wordclass automaton!";
  fclose(file);

  // build "is_oc_tag" array
  if (OCTags.size() == 0)
    // all tags allowed
    is_oc_tag.resize(grammar.number_of_symbols(), 1);
  else {
    // only tags from the open class file allowed
    is_oc_tag.resize(grammar.number_of_symbols(), 0);
    for( size_t i=0; i<OCTags.size(); i++ )
      is_oc_tag[OCTags[i]] = 1;
  }

  // initialize the frequency table
  double **freq = new double*[number_of_wordclasses()];
  for( int wc=0; wc<number_of_wordclasses(); wc++ ) {
    freq[wc] = new double[grammar.number_of_symbols()];
    for( size_t k=0; k<grammar.number_of_symbols(); k++ )
      freq[wc][k] = 0.0;
  }

  vector<double> tagfreq(grammar.number_of_symbols(), 0.0);
  vector<int> tagfreq1(grammar.number_of_symbols(), 0);

  // compute the tag frequencies for each word class
  for( iterator it=begin(); it!=end(); it++ ) {  // for all words
    int wc=wordclass(it->first);  // word class of the current word
    Tags &tags=it->second;        // tags of the current word
    if (WithProbs) {
      double sum=0.0;
      for( size_t i=0; i<tags.size(); i++ )
	sum += tags.freq(i);
      // add the tag probabilities to the POS counts of the word class
      for( size_t i=0; i<tags.size(); i++ )
	if (sum == 0.0)
	  freq[wc][tags[i]] += 1.0 / (double)tags.size();
	else {
	  freq[wc][tags[i]] += tags.freq(i) / sum;
	  if (sum > 0.5 && sum < 1.5 &&  // allow fractional counts
	      is_oc_tag[i] && tags.freq(i) > sum * 0.5) 
	    {
	      tagfreq1[tags[i]]++;
	    }
	}
    }
    else
      for( size_t i=0; i<tags.size(); i++ )
	freq[wc][tags[i]] += 1.0;
  }    
    
  // initialize the tagset array
  WCTags = new Tags[number_of_wordclasses()];

  // compute the POS probabilities for each word class
  for( int wc=0; wc<number_of_wordclasses(); wc++ ) {
    vector<SymNum> symbols;
    double max=0.0;

    // compute the max of the probabilities of the current wordclass
    for( size_t i=0; i<grammar.number_of_symbols(); i++ )
      if (freq[wc][i] > 0.0 && is_oc_tag[i])
	if (max < freq[wc][i])
	  max = freq[wc][i];

    // compute the set of symbols with a count above the threshold
    double threshold = max * PROB_THRESHOLD;
    double sum = 0.0;
    for( size_t i=0; i<grammar.number_of_symbols(); i++ )
      if (freq[wc][i] > threshold && is_oc_tag[i]) {
	symbols.push_back((SymNum)i);
	sum += freq[wc][i];
      }
    if (symbols.size() == 0) {
      if (wc > 0)
	cerr << "Lexicon contains no words from word class " << wc <<"!\n";
      else
	cerr << "Lexicon contains no words from default word class 0!\n";
    }
    WCTags[wc].init(symbols);
    for( size_t i=0; i<symbols.size(); i++ ) {
      WCTags[wc].getprob(i) = (float)(freq[wc][symbols[i]] / sum);
      WCTags[wc].getfreq(i) = freq[wc][symbols[i]];
      tagfreq[symbols[i]] += freq[wc][symbols[i]];
    }
  }

  for( int i=0; i<number_of_wordclasses(); i++ )
    delete[] freq[i];
  delete[] freq;
  
  // correct the POS frequencies 
#if 0
  for( int wc=0; wc<number_of_wordclasses(); wc++ ) {
    Tags &tags=WCTags[wc];
    for( size_t i=0; i<tags.size(); i++ )
      tags.getfreq(i) *= (tagfreq1[tags[i]] + 0.5) / tagfreq[tags[i]];
  }
#endif
}


/*******************************************************************/
/*                                                                 */
/*  Lexicon::smooth_with_wordclass                                 */
/*                                                                 */
/*******************************************************************/

void Lexicon::smooth_with_wordclass( vector<double> &prior_prob )

{
  // for all words
  for( iterator it=begin(); it!=end(); it++ ) {
    const char *word = it->first;
    int wc = wordclass(word);
    Tags &tags = it->second;
    vector<float> tagfreq;
    vector<SymNum> tagnum;
    vector<const char*> lemma;

    // Witten-Bell smoothing with the prior distribution
    // eliminates zero frequencies
    unsigned observed = 0;
    for( size_t i=0; i<tags.size(); i++ )
      if (tags.freq(i) >= 1.0)
	observed++;

    // copy the tag frequencies and compute the word smoothed frequency
    double f = (observed == 0) ? 1 : observed;

    bool is_oc=false;
    for( size_t i=0; i<tags.size(); i++ ) {
      tagnum.push_back(tags[i]);
      tagfreq.push_back((float)(tags.freq(i) + f * prior_prob[tags[i]]));
      lemma.push_back(tags.lemma(i));
      if (is_oc_tag[tags[i]])
	is_oc = true;
    }

    if (is_oc && observed == tags.size()) {
      double maxfreq=0.0;
      double wordfreq=0.0;
      for( size_t i=0; i<tagnum.size(); i++ ) {
	wordfreq += tagfreq[i];
	if (maxfreq < tagfreq[i])
	  maxfreq = tagfreq[i];
      }

      // Add the wordclass probabilities weighted with SmoothingWeight
      for( size_t i=0; i<WCTags[wc].size(); i++ ) {
	SymNum cat=WCTags[wc][i];

	// find this tag in the list of tags
	size_t k;
	for( k=0; k<tagnum.size(); k++ )
	  if (tagnum[k] == cat)
	    break; // tag was found
	if (k == tagnum.size()) {
	  // new tag
	  tagnum.push_back(cat);
	  tagfreq.push_back(0.0);
	  lemma.push_back(word);
	}
	    
	tagfreq[k] = (float)(tagfreq[k] + WCTags[wc].prob(i) * SmoothingWeight);
      }

      double newwordfreq = 0.0;
      float minfreq = (float)(maxfreq * PROB_THRESHOLD);
      for( size_t i=0; i<tagnum.size(); i++ ) {
	if (tagfreq[i] > minfreq)
	  newwordfreq += tagfreq[i];
	else {
	  // delete this tag
	  for( size_t k=i; k < tagnum.size()-1; k++ ) {
	    tagnum[k] = tagnum[k+1];
	    tagfreq[k] = tagfreq[k+1];
	    lemma[k] = lemma[k+1];
	  }
	  tagnum.pop_back();
	  tagfreq.pop_back();
	  lemma.pop_back();
	  i--;
	}
      }
     
      // renormalize the tag-word frequencies such that
      // the word frequency is unchanged
      double f = wordfreq / newwordfreq;
      for( size_t i=0; i<tagnum.size(); i++ )
	tagfreq[i] = (float)(tagfreq[i] * f);

      tags.init(tagnum, tagfreq, lemma);
    }
  }
}


/*******************************************************************/
/*                                                                 */
/*  Lexicon::compute_priors                                        */
/*                                                                 */
/*******************************************************************/

void Lexicon::compute_priors( vector<double> &prior_prob )

{
  vector<double> freqsum;
  for( iterator it=begin(); it!=end(); it++ ) {
    Tags &tags = it->second;
    for( size_t i=0; i<tags.size(); i++ ) {
      if (tags[i] >= (SymNum)freqsum.size())
	freqsum.resize(tags[i]+1, 0.1);
      freqsum[tags[i]] += tags.freq(i);
    }
  }

  double sum=0.0;
  for( size_t i=0; i<freqsum.size(); i++ )
    sum += freqsum[i];

  prior_prob.resize(freqsum.size());
  for( size_t i=0; i<freqsum.size(); i++ ) {
    prior_prob[i] = freqsum[i] / sum;
    if (prior_prob[i] == 0.0)
      throw "Error in function compute_priors: zero prior probability!\n";
  }
}



/*******************************************************************/
/*                                                                 */
/*  Lexicon::smooth_lexical_frequencies                            */
/*                                                                 */
/*******************************************************************/

void Lexicon::smooth_lexical_frequencies( vector<double> &prior_prob )

{
  // Witten-Bell smoothing with a uniform distribution
  // eliminates zero frequencies
  for( iterator it=begin(); it!=end(); it++ ) {
    Tags &tags = it->second;
    unsigned observed = 0;
    for( size_t i=0; i<tags.size(); i++ )
      if (tags.freq(i) >= 1.0)
	observed++;
    if (observed < tags.size()) {
      if (observed == 0)
	observed = 1;
      for( size_t i=0; i<tags.size(); i++ ) {
	tags.getfreq(i) += observed * prior_prob[tags[i]];
	if (tags.freq(i) == 0.0)
	  throw "Error in function estimate_probs!\n";
      }
    }
  }
}


/*******************************************************************/
/*                                                                 */
/*  Lexicon::estimate_probs                                        */
/*                                                                 */
/*******************************************************************/

void Lexicon::estimate_probs()

{
  vector<double> prior_prob;
  compute_priors( prior_prob );

  // replace zero frequencies in the lexicon
  smooth_lexical_frequencies( prior_prob );
  // smooth frequencies with the class-based probability distribution
  // adding new POS tags
  if (WCAutomaton)
    smooth_with_wordclass( prior_prob );
  // OVERALL TAG FREQUENCIES

  vector<double> freqsum;
  freqsum.resize(grammar.number_of_symbols(), 0.0);

  // grammar
  for( size_t i=0; i<grammar.rules.size(); i++ )
    freqsum[grammar.rules[i].symbol(0)] += grammar.rulefreq[i];

  // lexicon
  for( iterator it=begin(); it!=end(); it++ ) {
    Tags &tags = it->second;
    for( size_t i=0; i<tags.size(); i++ )
      freqsum[tags[i]] += tags.freq(i);
  }

  // word classes
  for( int wc=0; wc<number_of_wordclasses(); wc++ )
    for( size_t i=0; i<WCTags[wc].size(); i++ )
      freqsum[WCTags[wc][i]] += WCTags[wc].freq(i);

  // unknown words
  for( size_t i=0; i<OCTags.size(); i++ )
    freqsum[OCTags[i]] += OCTags.freq(i);


  // PARAMETER ESTIMATION

  // rule probabilities

  grammar.ruleprob.resize(grammar.rulefreq.size());
  for( size_t i=0; i<grammar.rules.size(); i++ ) {
    grammar.ruleprob[i] = grammar.rulefreq[i] /
      freqsum[grammar.rules[i].symbol(0)];
    grammar.rulefreq[i] = 0.0;
  }
  
  // unknown words
  for( size_t i=0; i<OCTags.size(); i++ ) {
    OCTags.getprob(i) = (float)(OCTags.freq(i) / freqsum[OCTags[i]]);
    OCTags.tag[i].freq = 0.0;
  }
  // remove tags with probability 0
  unsigned k=0;
  for( size_t i=0; i<OCTags.size(); i++ ) {
    if (OCTags.getprob(i) > 0.0)
      OCTags.tag[k++] = OCTags.tag[i];
  }
  OCTags.l = k;

  // lexicon
  for( iterator it=begin(); it!=end(); it++ ) {
    Tags &tags = it->second;

    for( size_t i=0; i<tags.size(); i++ ) {
      tags.getprob(i) = (float)(tags.freq(i) / freqsum[tags[i]]);
      tags.tag[i].freq = 0.0;
      if (tags.getprob(i) == 0.0)
	throw "Error in function estimate_probs: entry with probability 0!\n";
    }
  }
  
  // word class-based probabilities
  // p(word | tag) = p(tag | word) * p(word) / p(tag)
  //               ~ p(tag | [word]) / f(tag)
#if 0
  for( int wc=0; wc<number_of_wordclasses(); wc++ ) {
    Tags &tags = WCTags[wc];
    for( size_t i=0; i<tags.size(); i++ ) {
      tags.getprob(i) = tags.getfreq(i) / freqsum[tags[i]];
      tags.getfreq(i) = 0.0;
    }
  }
#endif
}


/*******************************************************************/
/*                                                                 */
/*  Lexicon::Lexicon                                               */
/*                                                                 */
/*******************************************************************/

Lexicon::Lexicon( Grammar &gram, FILE *file, char *ocf, char *wcf )
  : grammar(gram)

{
  char buffer[BUFFER_SIZE];

  if (!Quiet)
    cerr << "reading the lexicon...";

  // Read the next entry
  for( unsigned line=0; fgets(buffer, BUFFER_SIZE, file) != NULL; line++ ) {
    char *word = strtok(buffer, "\t\n");
    if (word == NULL)
      continue;

    // scanning of the POS tags
    vector<SymNum> symbols;
    vector<float> probs;
    vector<const char*> lemmas;
    char *s;
    while ((s = strtok(NULL, " \t\n")) != NULL) {
      int n = grammar.symbol_number(s);
      if (n == -1)
	cerr << "\nWarning: unknown category " << s << " of word \"" << word << "\" in line " << line << " of lexicon\n";
      else
	symbols.push_back(n);

      if (WithProbs) {
	char *s2, *e;
	float d;
	if ((s2 = strtok(NULL, " \t\n")) == NULL ||
	    ((d = (float)strtod(s2, &e)) == 0.0 && s2 == e)) {
	  char *message=(char*)malloc(1000);
	  sprintf(message, "in line %u of lexicon: missing frequency at tag %s of word \'%s\'\n", line, s, word);
	  throw message;
	}
	if (n != -1)
	  probs.push_back(d);
      }

      if (WithLemmas) {
	char *s2;
	if ((s2 = strtok(NULL, " \t\n")) == NULL) {
	  char *message=(char*)malloc(1000);
	  sprintf(message, "in line %u of lexicon: missing lemma at tag %s of word \'%s\'\n", line, s, word);
	  throw message;
	}
	if (n != -1)
	  lemmas.push_back(RefString(s2));
      }
    }

    if (symbols.size() > 0)
      add_entry( word, symbols, probs, lemmas );
    else
      cerr << "\nWarning: lexicon contains no parts of speech for \"" << word << "\"\n";

  }

  if (!Quiet)
    cerr << "finished\n";

  if (ocf)
    read_oc(ocf, OCTags);

  if (wcf) {
    if (!Quiet)
      cerr << "reading the word class guesser...";
    read_wc(wcf);
    if (!Quiet)
      cerr << "finished\n";
  }
  else
    WCAutomaton = NULL;

  if (WithProbs) {
    if (!Quiet)
      cerr << "parameter estimation...";
    estimate_probs();
    if (!Quiet)
      cerr << "finished\n";
  }
}



/*******************************************************************/
/*                                                                 */
/*  is_upper                                                       */
/*                                                                 */
/*******************************************************************/

static bool is_upper( char c )

{
  return (c >= 'A' && c <= 'Z') || (c >= '' && c <= '');
}



/*******************************************************************/
/*                                                                 */
/*  decap                                                          */
/*                                                                 */
/*******************************************************************/

const char *decap( const char *word )

{
  static char buffer[1000];
  if (!is_upper(word[0]) || is_upper(word[1]))
    return NULL;
  buffer[0] = (char)(word[0] + 32);
  int i;
  for( i=1; word[i] && i<999; i++ )
    buffer[i] = word[i];
  buffer[i] = 0;
  return (const char*)buffer;
}



/*******************************************************************/
/*                                                                 */
/*  Lexicon::lookup                                                */
/*                                                                 */
/*******************************************************************/

Tags *Lexicon::lookup( const char *word, bool sstart )

{
  iterator it = lex.find(word);

  if (it != lex.end())
    return &it->second;

  if (sstart) {
    const char *s = decap(word);
    if (s) {
      it = lex.find(s);
      if (it != lex.end())
	return &it->second;
    }
  }

  if (WCAutomaton)
    return &WCTags[wordclass(word)];

  if (OCTags.size() > 0)
    return &OCTags;

  return NULL;
}


/*******************************************************************/
/*                                                                 */
/*  Lexicon::store                                                 */
/*                                                                 */
/*******************************************************************/

void Lexicon::store( FILE *file )

{
  for( iterator it=begin(); it != end(); it++ ) {
    fputs( it->first, file);
    Tags &tags=it->second;
    for( size_t i=0; i<tags.size(); i++ )
      fprintf( file, "\t%s %.2f", grammar.symbol_name(tags[i]), tags.freq(i));
    fputc( '\n', file );
  }
}


/*******************************************************************/
/*                                                                 */
/*  Lexicon::store_oc                                              */
/*                                                                 */
/*******************************************************************/

void Lexicon::store_oc( FILE *file )

{
  Tags &tags=OCTags;
  for( size_t i=0; i<tags.size(); i++ )
    fprintf( file, "%s %.2f\n", grammar.symbol_name(tags[i]), tags.freq(i));
}


/*******************************************************************/
/*                                                                 */
/*  append                                                         */
/*                                                                 */
/*******************************************************************/

static char *append( char *p, const char *s )

{
  while (*s)
    *(p++) = *(s++);
  return p;
}


/*******************************************************************/
/*                                                                 */
/*  add_traces                                                     */
/*                                                                 */
/*******************************************************************/

static char *add_traces( char *p, vector<Trace> &trace, size_t &k, size_t i,
			 Grammar &grammar)
{
  while (k < trace.size() && trace[k].pos == (int)i) {
    p = append( p, grammar.traces.symbol_name( trace[k].sn ) );
    *(p++) = ' ';
    k++;
  }
  return p;
}


/*******************************************************************/
/*                                                                 */
/*  Lexicon::read_scores                                           */
/*                                                                 */
/*******************************************************************/

void Lexicon::read_scores( char *filename )

{
  FILE *file=fopen(filename,"rt");
  if (file == NULL) {
    char *message=(char*)malloc(1000);
    sprintf(message,"unable to open file \"%s\"", filename);
    throw message;
  }

  // store the grammar rules in a hash table
  SymbolTable ruletab;
  for( RuleNumber rn=0; rn<(RuleNumber)grammar.rules.size(); rn++ ) {
    char buffer[10000];
    char *p = buffer;
    size_t k=0;
    vector<Trace> &trace=grammar.traces.get( rn );
    Rule &rule = grammar.rules[rn];

    p = append( p, grammar.symbol_name( rule.symbol(0) ) );
      *(p++) = '\t';
    p = add_traces( p, trace, k, 0, grammar );
    for( size_t i=1; i<rule.length(); i++ ) {
      p = append( p, grammar.symbol_name( rule.symbol(i) ) );
      *(p++) = ' ';
      p = add_traces( p, trace, k, i, grammar );
    }
    *(--p) = 0;
    if ((int)ruletab.number( buffer ) != rn) {
      fprintf(stderr,"%d %d\n", rn, (int)ruletab.number( buffer ));
      throw "error in function read_scores!";
    }
  }

  if (!Quiet)
    cerr << "reading the scores...";

  char buffer[BUFFER_SIZE];
  for( unsigned line=0; fgets(buffer, BUFFER_SIZE, file) != NULL; line++ ) {
    float score = (float)atof(strtok(buffer, "\t\n"));
    char *word  = strtok(NULL, "\t");
    char *lhs  = strtok(NULL, "\t");
    char *rhs  = strtok(NULL, "\n");
    if (strcmp(rhs, DefaultString) == 0) {
      SymNum sym = grammar.symbol_number(lhs);
      Score.add( word, -(sym+1), score );
    }
    else {
      *(rhs-1) = '\t'; // undo strtok operation
      size_t n=ruletab.size();
      int rn=ruletab.number( lhs );
      if (rn == (int)n)
	fprintf(stderr,"%f\t%s\t%s -> %d\n", score, word, lhs, rn);
      else
	Score.add( word, rn, score );
    }
  }

  fclose(file);
  if (!Quiet)
    cerr << "finished\n";
}
