###############################################################################
##                                                                           ##
## This file is part of ModelBlocks. Copyright 2009, ModelBlocks developers. ##
##                                                                           ##
##    ModelBlocks is free software: you can redistribute it and/or modify    ##
##    it under the terms of the GNU General Public License as published by   ##
##    the Free Software Foundation, either version 3 of the License, or      ##
##    (at your option) any later version.                                    ##
##                                                                           ##
##    ModelBlocks is distributed in the hope that it will be useful,         ##
##    but WITHOUT ANY WARRANTY; without even the implied warranty of         ##
##    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the          ##
##    GNU General Public License for more details.                           ##
##                                                                           ##
##    You should have received a copy of the GNU General Public License      ##
##    along with ModelBlocks.  If not, see <http://www.gnu.org/licenses/>.   ##
##                                                                           ##
###############################################################################

#!/usr/bin/ruby

#####################################################################
# calcHdwdCPT.rb
# 
# Calculates 2 CPTs from trees, 
#   1. MgH = P(MOD|HDWD) 
#   2. Hgm = P(HDWD|MOD)
#
# TO RUN: cat eraseme | ruby scripts/calcHdwdCPT.rb
#
######################################################################

require "scripts/umnlp.rb"
require 'optparse'

$options = {}
$options[:verbose] = false
OptionParser.new do |opts|
  opts.banner = "Usage: ruby calcHdwdTree.rb [options]"

  opts.on("-v", "--verbose", "turns on extra stderr output") do |v|
    $options[:verbose] = v
  end
end.parse!


$hw = {} # for unks


class Tree

  attr_reader :hdwdrules
  attr_accessor :hdwd
  attr_accessor :uschead
  attr_accessor :lastrules

  def initialize(str="", parent=nil)
    @str = str
    @children = Array.new
    @parent = parent
    @num_rules = 0
    if str != ""
      ## Check if the stupid user passed in a stupid string with
      ## stupid brackets instead of parentheses
      if (str.length - str.gsub(/\)/,"").length) < (str.length - str.gsub(/\]/,"").length)
        str.gsub!(/\[/,"(")
        str.gsub!(/\]/,")")
      end
      buildStructure(@str)
    else
      head = ""
    end
    if str == ""
      str = to_s
    end
    @num_rules = getNumRules #+= @children[i].num_rules
    @hdwdrules = { # "0" indicates search R-to-L.  "1" indicates search L-to-R
      "ADJP" => "0 NNS QP NN $ ADVP JJ VBN VBG ADJP JJR NP JJS DT FW RBR RBS SBAR RB",
      "ADVP" => "1 RB RBR RBS FW ADVP TO CD JJR JJ IN NP JJS NN",
      "CONJP" => "1 CC RB IN",
      "FRAG" => "1 ",
      "INTJ" => "0 ",
      "LST" => "1 LS :",
      "NAC" => "0 NN NNS NNP NNPS NP NAC EX $ CD QP PRP VBG JJ JJS JJR ADJP FW",
      # NP is dealt with separately
      "PP" => "1 IN TO VBG VBN RP RB FW PP", #added RB after RP, and PP at the end
      "PRN" => "0 ",
      "PRT" => "1 RP",
      "QP" => "0 $ IN NNS NN JJ RB DT CD NCD QP JJR JJS",
      "RRC" => "1 VP NP ADVP ADJP PP",
      "S" => "0 TO IN VP S SBAR VB VBD VBN VBG VBP VBZ ADJP UCP NP", #added all the VB* ones after SBAR
      "SBAR" => "0 WHNP WHPP WHADVP WHADJP IN DT S SQ SINV SBAR WHSBAR FRAG", #added WHSBAR
      "SBARQ" => "0 SQ S SINV SBARQ FRAG",
      "SINV" => "0 VBZ VBD VBP VB MD VP S SINV ADJP NP",
      "SQ" => "0 VBZ VBD VBP VB MD VP SQ",
      "UCP" => "1 ",
      "VP" => "0 TO VBD VP VBN MD VBZ VB VBG VBP ADJP NN NNS NP", #moved up VP, from after VBP to after VBD
      "WHADJP" => "0 CC WRB JJ ADJP",
      "WHADVP" => "1 CC WRB",
      "WHNP" => "0 WDT WP WP$ WHADJP WHPP WHNP",
      "WHPP" => "1 IN TO FW",
      
      #add'l rules to count for previous binarization & processing
      "NN"   => "0 NN NNS NNP NNPS",
      "NNS"  => "0 NN NNS NNP NNPS",
      "NNP"  => "0 NN NNS NNP NNPS",
      "NNPS" => "0 NN NNS NNP NNPS",
      "VBZ"  => "1 VBZ VBD VBP VB VBN VBG",
      "VBD"  => "1 VBZ VBD VBP VB VBN VBG",
      "VBP"  => "1 VBZ VBD VBP VB VBN VBG",
      "VB"   => "1 VBZ VBD VBP VB VBN VBG",
      "VBN"  => "1 VBZ VBD VBP VB VBN VBG",
      "VBG"  => "1 VBZ VBD VBP VB VBN VBG",
      "JJ"   => "0 JJ JJR JJS",
      "JJR"  => "0 JJ JJR JJS",
      "JJS"  => "0 JJ JJR JJS",
      "RB"   => "0 RB RBR RBS",
      "RBR"  => "0 RB RBR RBS",
      "RBS"  => "0 RB RBR RBS",
      "CD"   => "0 CD",
      "IN"   => "1 IN",
      "LISTNP" => "1 NP",
      "LISTSBAR" => "1 SBAR",
      "LISTSINV" => "1 SINV",
      "LISTS"  => "1 S",
      "LISTVP" => "1 VP",
      "WHSBAR" => "0 WHNP WHPP WHADVP WHADJP IN DT S SQ SINV SBAR FRAG VP", #taken from SBAR, added VP to the end
      "WHSBARQ" => "0 SQ S SINV SBARQ FRAG",
    }

    @lastrules = Array["S","s","V","v","N","n","W","w","P","p","I","i","U","u","A","a","R","r","J","j","Q","q","C","c","L","l"];

  end


  def percolate ( hdwdmod, lefthdwdtot, leftmodtot ) #on binary trees only
    if @children.size==0
      $stderr.print "ERROR: EXPECT (PRETERM), POS, WORD NODES IN RCTREE!!! curr:"+@head + "\nLine=#{$line}\n"
      return
    end
    if @children.size==1 && @children[0].children.size==1 && @children[0].children[0].children.size!=0
      $stderr.print "ERROR: EXPECT (PRETERM), POS, WORD NODES IN RCTREE!!! curr:"+@head + " child:"+@children[0].head + " granch:"+@children[0].children[0].head + "\n"
      return
    end

    @uschead = @head

    ## terminal case
    if @children.size==1 && @children[0].children.size==0
#      $stderr.print "Terminal : Head = "+@head+", Parent = "+@parent.head+"\n"
      @hdwd = @children[0].head.split('#',2)[1]
      if ARGV[0] && !$hw.key?(@hdwd)
        @hdwd = "unk"
      end
#      $stderr.print "'Terminal: Head = "+@head+", Hdwd = "+@hdwd+"\n"
      return
    end

    ## unary case
    if @children.size==1
#      $stderr.print "Unary   :  Head = "+@head+", Parent = "+@parent.head+"\n"
      ## recurse to left (or unary) child...
      @children[0].percolate( hdwdmod, lefthdwdtot, leftmodtot )
      @hdwd = @children[0].hdwd
      if ARGV[0] && !$hw.key?(@hdwd)
        @hdwd = "unk"
      end
#      $stderr.print "'Unary  :  Head = "+@head+", Hdwd = "+@hdwd+"\n"
      return
    end

    ## binary case
    if @children.size==2
#      $stderr.print "Binary  :  Head = "+@head+", Parent = "+@parent.head+"\n"

      ##annotate underscore sections of trees with alternate head for underscores
      if @head.include?('_')
	@uschead = @parent.uschead
      end

      ## recurse to left (or unary) child...
      @children[0].percolate( hdwdmod, lefthdwdtot, leftmodtot )
      ## recurse to right child...
      @children[1].percolate( hdwdmod, lefthdwdtot, leftmodtot )

      ## head rules battle it out!  then count up the results
      @hdwd = hdwdbattle()
      if ARGV[0] && !$hw.key?(@hdwd)
        @hdwd = "unk"
      end
        

      if @hdwd==@children[0].hdwd

	if hdwdmod.key?(@children[1].hdwd+" "+@children[0].hdwd)
	  hdwdmod[@children[1].hdwd+" : "+@children[0].hdwd] = hdwdmod[@children[1].hdwd+" : "+@children[0].hdwd] + 1
	  print "HDMD "+@children[0].hdwd+" : "+@children[1].hdwd+"\n"  #print "HDMD "+@children[1].hdwd+" : "+@children[0].hdwd+"\n"  #
	  print "HD : "+@children[0].hdwd+"\n";
	  print "MD : "+@children[1].hdwd+"\n";
	else
	  hdwdmod[@children[1].hdwd+" : "+@children[0].hdwd] = 1
	  print "HDMD "+@children[0].hdwd+" : "+@children[1].hdwd+"\n"  #print "HDMD "+@children[1].hdwd+" : "+@children[0].hdwd+"\n"  #
	  print "HD : "+@children[0].hdwd+"\n";
	  print "MD : "+@children[1].hdwd+"\n";
	end

	## count the number of times the left child is the head, and therefore the condition to mgh
	if lefthdwdtot.key?(@children[0].hdwd)
	  lefthdwdtot[@children[0].hdwd] = lefthdwdtot[@children[0].hdwd] + 1
	else
	  lefthdwdtot[@children[0].hdwd] = 1
	end

      elsif @hdwd==@children[1].hdwd

	if hdwdmod.key?(@children[1].hdwd+" : "+@children[0].hdwd)
	  hdwdmod[@children[1].hdwd+" : "+@children[0].hdwd] = hdwdmod[@children[1].hdwd+" : "+@children[0].hdwd] + 1
	  print "HDMD "+@children[1].hdwd+" : "+@children[0].hdwd+"\n"  #print "HDMD "+@children[0].hdwd+" : "+@children[1].hdwd+"\n"  #
	  print "MD : "+@children[0].hdwd+"\n";
	  print "HD : "+@children[1].hdwd+"\n";
	else
	  hdwdmod[@children[1].hdwd+" : "+@children[0].hdwd] = 1
	  print "HDMD "+@children[1].hdwd+" : "+@children[0].hdwd+"\n"  #print "HDMD "+@children[0].hdwd+" : "+@children[1].hdwd+"\n"  #
	  print "MD : "+@children[0].hdwd+"\n";
	  print "HD : "+@children[1].hdwd+"\n";
	end

	## count the number of times the left child is the mod, and therefore the condition to hgm
	if leftmodtot.key?(@children[0].hdwd)
	  leftmodtot[@children[0].hdwd] = leftmodtot[@children[0].hdwd] + 1
	else
	  leftmodtot[@children[0].hdwd] = 1
	end

      end

#      $stderr.print "'Binary :  Head = "+@head+", Hdwd = "+@hdwd+"\n"
      return
    end
    
    if @children.size>2 || @children.size<0
      $stderr.print "ERROR: calcHdwdCPT.rb requires binary trees. Node "+@head+" has "+@children.size.to_s+" children.\n"
    end
    
  end


  def hdwdbattle
    # use @hdwdrules and special NP case to determine which binary branch gets the head
    bald = @uschead.gsub( /([A-Z]+).*/,'\1' )  # bald= a bare head

    if @hdwdrules.key?(bald) && @uschead!="PP-tmp"
      okhdlist =  @hdwdrules.fetch(bald).split
      seekdirection = okhdlist.shift

      if okhdlist.empty?

	if seekdirection=="1" 
	  #get the right corner unless it's a punctuation mark
	  return return_r_child()
	else #seekdirection=="0"
	  #get the left corner
	  return return_l_child()
	end

      else

	if seekdirection=="1"
	  #check left to right
	  okhdlist.each { |okhd| 
	    children[0].head.split('_').each { |okhdusc|
	      if okhdusc.gsub( /([A-Z]+).*/,'\1' ) == okhd
#		$stderr.print "  ReturnLtoR :  Head = "+@head+", Hdwd will become = "+@children[0].hdwd+"\n"
		return return_l_child() # @children[0].hdwd
	      end
	    }
	    children[1].head.split('_').each { |okhdusc|
	      if okhdusc.gsub( /([A-Z]+).*/,'\1' ) == okhd
#		$stderr.print "  ReturnLtoR :  Head = "+@head+", Hdwd will become = "+@children[0].hdwd+"\n"
		return return_r_child() # @children[1].hdwd
	      end
	    }
	  }
	else #seekdirection=="0"
	  #check right to left
	  okhdlist.each { |okhd| 
	    children[1].head.split('_').reverse_each { |okhdusc|
	      if okhdusc.gsub( /([A-Z]+).*/,'\1' ) == okhd
#		$stderr.print "  ReturnRtoL :  Head = "+@head+", Hdwd will become = "+@children[1].hdwd+"\n"
		return return_r_child() # @children[1].hdwd
	      end
	    }
	    children[0].head.split('_').reverse_each { |okhdusc|
	      if okhdusc.gsub( /([A-Z]+).*/,'\1' ) == okhd
#		$stderr.print "  ReturnRtoL :  Head = "+@head+", Hdwd will become = "+@children[0].hdwd+"\n"
		return return_l_child() # @children[0].hdwd
	      end
	    }
	  }
	end

      end

    elsif bald=="NP" || bald=="NX" || @uschead == "PP-tmp" #NP is treated as a special case; see Collins' notes (magerman-black.txt)

      #1. identify POS words -- SKIP this Collins step
      #2. R to L for first child which is NN, NNP, NNPS, NNS, NX, POS, JJR
      npcond1 = ["NN","NNP","NNPS","NNS","NX","POS","JJR"]
      npcond1.each { |okhd| 
	if @children[1].head.gsub( /([A-Z]+).*/,'\1' ) == okhd
#	  $stderr.print "  Returning :  Head = "+@head+", Hdwd will become = "+@children[1].hdwd+"\n"
	  return return_r_child() # @children[1].hdwd
	end
      }
      #3. L to R for first child which is NP
      npcond2 = ["NP"]
      npcond2.each { |okhd| 
	if @children[0].head.gsub( /([A-Z]+).*/,'\1' ) == okhd
#	  $stderr.print "  Returning :  Head = "+@head+", Hdwd will become = "+@children[0].hdwd+"\n"
	  return return_l_child() # @children[0].hdwd
	end
      }
      #4. R to L
      npcond3 = ["$","ADJP","PRN"]
      npcond3.each { |okhd| 
	if @children[1].head.gsub( /([A-Z]+).*/,'\1' ) == okhd
#	  $stderr.print "  Returning :  Head = "+@head+", Hdwd will become = "+@children[1].hdwd+"\n"
	  return return_r_child() # @children[1].hdwd
	end
      }
      #5. R to L
      npcond4 = ["CD"]
      npcond4.each { |okhd| 
	if @children[1].head.gsub( /([A-Z]+).*/,'\1' ) == okhd
#	  $stderr.print "  Returning :  Head = "+@head+", Hdwd will become = "+@children[1].hdwd+"\n"
	  return return_r_child() # @children[1].hdwd
	end
      }
      #6. R to L
      npcond5 = ["JJ","JJS","RB","QP"]
      npcond5.each { |okhd| 
	if @children[1].head.gsub( /([A-Z]+).*/,'\1' ) == okhd
#	  $stderr.print "  Returning :  Head = "+@head+", Hdwd will become = "+@children[1].hdwd+"\n"
	  return return_r_child() # @children[1].hdwd
	end
      }
#      $stderr.print "  Returning :  Head = "+@head+", Hdwd will become = "+@children[1].hdwd+"\n"
      return return_r_child() # @children[1].hdwd
    end
    
    lastresort = ""
    @lastrules.each{ |letter| 
      if @children[0].head.match(/^#{letter}/)
	lastresort = @children[0].hdwd
	break
      end	
      if @children[1].head.match(/^#{letter}/)
	lastresort = @children[1].hdwd
	break
      end	
    } 
    
    if $options[:verbose]
      $stderr.print "Unable to resolve hdwd for "+@head+" -> "+@children[0].head+" "+@children[1].head+", picking "+lastresort+"\n"
    end
    return lastresort # @children[0].hdwd # "!empty!" 
  end

  def return_r_child
    if @children[1].head.match(/[A-Za-z]/) && @children[1].head[1]!='!' && @children[1].head[1]!='!empty!' # @children[1].hdwd.match(/[A-Za-z]/) && 
      return @children[1].hdwd
    elsif @children[0].head.match(/[A-Za-z]/) && @children[0].head[0]!='!' && @children[0].head[0]!='!empty!' # @children[0].hdwd.match(/[A-Za-z]/) && 
      return @children[0].hdwd
    else
      return "!empty!"
    end
  end

  def return_l_child
    if @children[0].head.match(/[A-Za-z]/) && @children[0].head[0]!='!' && @children[0].head[0]!='!empty!' # @children[0].hdwd.match(/[A-Za-z]/) && 
      return @children[0].hdwd
    elsif @children[1].head.match(/[A-Za-z]/) && @children[1].head[1]!='!' && @children[1].head[1]!='!empty!' # @children[1].hdwd.match(/[A-Za-z]/) && 
      return @children[1].hdwd
    else
      return "!empty!"
    end
  end

  def hdwd_to_s
    if @children.length == 0
      return "#{hdwd}"
    else
      
      s = "( #{@head}...<#{@hdwd}> "
      @children.each{ |child|
        s += child.hdwd_to_s
#        print "#{child.hdwd}"
        s += " "
      }
      s += ")"
    end
    return s
  end


  def set_parents
    if @children.size==0
      $stderr.print "ERROR: EXPECT (PRETERM), POS, WORD NODES IN RCTREE!!! curr:"+@head + "\nLine=#{$line}\n"
      return
    end
    if @children.size==1 && @children[0].children.size==1 && @children[0].children[0].children.size!=0
      $stderr.print "ERROR: EXPECT (PRETERM), POS, WORD NODES IN RCTREE!!! curr:"+@head + " child:"+@children[0].head + " granch:"+@children[0].children[0].head + "\n"
      return
    end
    
    ## terminal case
    if @children.size==1 && @children[0].children.size==0
      return
    end
    
    ## unary case
    if @children.size==1
      ## recurse to left (or unary) child...
      @children[0].set_parents
      @children[0].parent = self
      return
    end
    
    ## binary case
    if @children.size==2
      if @head.include?('_')
	(preusc,postusc) = @head.split('_',2)
	#determine which binary branch has the head
      else
	
	## recurse to left (or unary) child...
	@children[0].set_parents
	@children[0].parent = self
	## recurse to right child...
	@children[1].set_parents
	@children[1].parent = self

	return
      end
      
      if @children.size>2 || @children.size<0
	$stderr.print "ERROR: calcHdwdCPT.rb requires binary trees. Node "+@head+" has "+@children.size.to_s+" children.\n"
      end
      #    $stderr.print "finishing up" + "\n"
    end

  end

end

def print_normalized_hash ( name, hash, totcond )
  hash.each { |key,val|
    cond = key.split(' : ',2)[1]
    prob = (val.to_f/totcond[cond].to_f)
    print name+" "+key+" = "+prob.to_s+"\n"
  }
end

##########################################

hm = {}
hd = {}
md = {}

if (ARGV[0])
#  $stderr.print "trying to open #{ARGV[0]}\n"
  arg = File.open(ARGV[0])
  while (line = arg.gets)
    parts = line.split(/[ :=]+/)
    $hw[parts[2]] = 1
  end
end


ctr=0
while (line = STDIN.gets)
  t = Tree.new(line)
  t.parent = t
  t.set_parents
  t.percolate(hm,hd,md)
#  print t.hdwd_to_s + "\n"
  print "HW : "+t.hdwd+"\n";
  ctr = ctr+1
  if ctr % 1000 == 0
    $stderr.print " ... did hd md on #{ctr} trees ...\n"
  end
end

#print_normalized_hash( "HDMD", hm, md )

