
/*MA****************************************************************/
/*                                                                 */
/*     File: /home/users2/schmid/src/BitPar/retagging.C            */
/*   Author: Helmut Schmid                                         */
/*  Purpose:                                                       */
/*  Created: Mon Mar  3 16:53:31 2003                              */
/* Modified: Wed Sep 24 09:12:03 2003 (schmid)                     */
/*                                                                 */
/*ME****************************************************************/

#include <iostream>
using std::cerr;

#include "retagging.h"

static Grammar *grammar;
static Lexicon *lexicon;

class RuleProb {
  
  struct eqf {
    bool operator()(const Rule *r1, const Rule *r2) const {
      if (r1->length() != r2->length())
	return false;
      for( int i=0; i<r1->length(); i++ )
	if (r1->symbol(i) != r2->symbol(i))
	  return false;
      return true;
    }
  };
  struct hashf {
    size_t operator()(const Rule *r) const { 
      size_t key = (size_t)r->length();
      for( int i=0; i<r->length(); i++ )
	key = key ^ (r->symbol(i) << (i % 16));
      return key;
    }
  };
  
  typedef hash_map<const Rule*, Prob, hashf, eqf> RC;

 private:
  RC rc;

 public:
  typedef RC::iterator iterator;
  Prob &operator[]( Rule &r ) {
    iterator it=rc.find(&r);
    if (it == rc.end())
      return rc.insert(RC::value_type(&r, 0.0)).first->second;
    return it->second;
  };
  Prob operator[]( const Rule &r ) {
    iterator it=rc.find(&r);
    if (it == rc.end())
      return (Prob)0.0;
    return it->second;
  };
};

RuleProb rule_prob;


/*FA****************************************************************/
/*                                                                 */
/*  init_tagging                                                   */
/*                                                                 */
/*FE****************************************************************/

void init_tagging( Grammar &g, Lexicon &l )

{
  grammar = &g;
  lexicon = &l;
  for( size_t i=0; i<grammar->rules.size(); i++ )
    rule_prob[grammar->rules[i]] = grammar->ruleprob[i];
}


/*FA****************************************************************/
/*                                                                 */
/*  process_node                                                   */
/*                                                                 */
/*FE****************************************************************/

static void process_node( Tree *tree )

{
  if (tree->daughters.size() == 0 || tree->preterminal())
    return;

  for( size_t i=0; i<tree->daughters.size(); i++ )
    process_node( tree->daughters[i] );

  vector<int> cat(tree->daughters.size() + 1);
  if ((cat[0] = grammar->symbol_number(tree->label)) == -1)
    return;
  for( size_t i=0; i<tree->daughters.size(); i++ )
    if ((cat[i+1] = grammar->symbol_number(tree->daughters[i]->label)) == -1)
      return;

  for( size_t i=0; i<tree->daughters.size(); i++ ) {
    Tree *d = tree->daughters[i];
    if (d->preterminal()) {
      const char *word = d->daughters[0]->label;
      Tags *tags=lexicon->lookup(word, false);
      if (tags) {
	SymNum old_tag=cat[i+1];
	SymNum best_tag=cat[i+1];
	Prob best_prob = 0;
	for( size_t k=0; k<tags->size(); k++ ) {
	  cat[i+1] = (*tags)[k];
	  const Rule r(cat);
	  Prob p = (Prob)tags->prob(k) * rule_prob[r];
	  if (p > best_prob) {
	    best_prob = p;
	    best_tag = cat[i+1];
	  }
	}
	cat[i+1] = best_tag;
	if (best_tag != old_tag) {
	  cerr << word << " " << tree->daughters[i]->label;
	  cerr << " -> " << grammar->symbol_name(best_tag) << "\n";
	  free(tree->daughters[i]->label);
	  tree->daughters[i]->label = strdup(grammar->symbol_name(best_tag));
	}
      }
    }
  }
}

/*FA****************************************************************/
/*                                                                 */
/*  print_node                                                     */
/*                                                                 */
/*FE****************************************************************/

static void print_node( Tree *tree, int depth )

{
  if (tree->daughters.size() == 0)
    printf(" %s", tree->label);
  else if (tree->preterminal())
    printf("(%s %s)", tree->label, tree->daughters[0]->label);
  else {
    printf("(%s", tree->label);
    for( size_t i=0; i<tree->daughters.size(); i++ ) {
      printf("\n");
      for( int k=0; k<depth; k++ )
	printf("  ");
      print_node( tree->daughters[i], depth+1 );
    }
    printf(")");
  }
}


/*FA****************************************************************/
/*                                                                 */
/*  process                                                        */
/*                                                                 */
/*FE****************************************************************/

void process( Tree *tree )

{
  process_node(tree);
  print_node(tree, 1);
  printf("\n");
}
