#!/usr/bin/ruby
###############################################################################
##                                                                           ##
## This file is part of ModelBlocks. Copyright 2009, ModelBlocks developers. ##
##                                                                           ##
##    ModelBlocks is free software: you can redistribute it and/or modify    ##
##    it under the terms of the GNU General Public License as published by   ##
##    the Free Software Foundation, either version 3 of the License, or      ##
##    (at your option) any later version.                                    ##
##                                                                           ##
##    ModelBlocks is distributed in the hope that it will be useful,         ##
##    but WITHOUT ANY WARRANTY; without even the implied warranty of         ##
##    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the          ##
##    GNU General Public License for more details.                           ##
##                                                                           ##
##    You should have received a copy of the GNU General Public License      ##
##    along with ModelBlocks.  If not, see <http://www.gnu.org/licenses/>.   ##
##                                                                           ##
###############################################################################

require 'scripts/umnlp.rb'
require 'scripts/emutils.rb'

sentences = Array.new
wordnums = ArrayHash.new
wordnums.store("__fake__", 0)

$stderr.puts "Reading in data..."
$maxw = 0
## Save all the sentences because we'll use them repeatedly
while(line = gets)
  line.chomp!
  sentence = Array.new
  sentence << 0
  line.split.each{ |word|
    sentence << wordnums[word]
    if wordnums[word] > $maxw
      $maxw = wordnums[word]
    end
  }
  sentences << sentence
end

#$stderr.puts "#{wordnums}"
#$stderr.puts sentences[0]
#exit
$maxw += 1 # Include fake word that init state emits
$stderr.puts "Max word is #{$maxw}"

## initialize the params
$maxf = 5
$maxact = 5
$maxawa = 5
#maxw = 20000

MIN_INT = -2147483647

$stderr.puts "Allocating probability tables..."
## Allocate transition matrices
act_cpt = Array.new($maxact){Array.new($maxawa){Array.new($maxf){Array.new($maxact)}}}
awa_cpt = Array.new($maxawa){Array.new($maxf){Array.new($maxact){Array.new($maxawa)}}}
f_cpt = Array.new($maxawa){Array.new($maxf)}
w_cpt = Array.new($maxawa){Array.new($maxw,MIN_INT)}

## Function to index cpts
def getTrellisIndex(f_i,act_i,awa_i)
  return f_i * $maxact * $maxawa + act_i * $maxawa + awa_i
end

def breakdownIndex(ind)
  awa_i = ind % $maxawa
  act_i = (ind - awa_i) / $maxawa % $maxact
  f_i = (ind) / ($maxf*$maxact)
  return f_i,act_i,awa_i
end


###################################
## Initialize cpts
###################################

$stderr.puts "Initializing tables..."
## f_cpt - rand dist.
0.upto($maxawa-1){|i|
 0.upto($maxf-1){ |j|
  f_cpt[i][j] = rand
 }
 normalize!(f_cpt[i])
# $stderr.puts f_cpt[i]
 f_cpt[i].collect!{|v|Math.log(v)}
}

## act_cpt
0.upto($maxact-1) {|prev_act|
  0.upto($maxawa-1) {|prev_awa|
    0.upto($maxf-1) {|cur_f|
      0.upto($maxact-1) {|cur_act|
        act_cpt[prev_act][prev_awa][cur_f][cur_act] = rand
      }
      normalize!(act_cpt[prev_act][prev_awa][cur_f])
      act_cpt[prev_act][prev_awa][cur_f].collect!{|v|Math.log(v)}
    }
  }
}

## awa_cpt
0.upto($maxawa-1){ |prev_awa|
  0.upto($maxf-1){ |cur_f|
    0.upto($maxact-1){ |cur_act|
      0.upto($maxawa-1){ |cur_awa|
        awa_cpt[prev_awa][cur_f][cur_act][cur_awa] = rand
      }
      normalize!(awa_cpt[prev_awa][cur_f][cur_act])
      awa_cpt[prev_awa][cur_f][cur_act].collect!{|v|Math.log(v)}
    }
  }
}

## w_cpt
## Initialize 0 awa to emit word 0, else random
w_cpt[0][0] = Math.log(1.0)
#w_cpt[0].collect!{|v|Math.log(v)}
1.upto($maxawa-1){ |i|
  w_cpt[i][0] = 0.0
  1.upto($maxw){ |j|
    w_cpt[i][j] = rand
  }
  normalize!(w_cpt[i])
  ## Special case for 0 prob
  w_cpt[i][0] = 1.0
  w_cpt[i].collect!{|v|Math.log(v)}
  w_cpt[i][0] = MIN_INT
}

$stderr.puts "Allocating trellises..."
## Allocate trellis
trell_size = $maxf*$maxact*$maxawa

#trellis = Array.new(sentences.length)
#sentences.each_index{ |i|
#  $stderr.puts "Initializing trellis for sentence #{i}"
#  s = sentences[i]
#  trellis[i] = Array.new(s.length+1)
#  0.upto(s.length){ |j|
#    trellis[i][j] = Array.new(trell_size){TrellNode.new}
#  }
#  trellis[i][0][0].prob = Math.log(1.0)
#  trellis[i][0][0].prev = -1
#}

#############################
## Start iterations
#############################
num_iters = 2
## Temp variables
max_state = 0

$stderr.print "Performing #{num_iters} EM iterations...\n"
1.upto(num_iters){ |iter|
  $stderr.print "Running iteration #{iter}..."
  act_counts = Array.new($maxact){Array.new($maxawa){Array.new($maxf){Array.new($maxact,1)}}}
  act_cond_counts = Array.new($maxact){Array.new($maxawa){Array.new($maxf,$maxact)}}
  awa_counts = Array.new($maxawa){Array.new($maxf){Array.new($maxact){Array.new($maxawa,1)}}}
  awa_cond_counts = Array.new($maxawa){Array.new($maxf){Array.new($maxact,$maxawa)}}
  f_counts = Array.new($maxawa){Array.new($maxf,1)}
  f_cond_counts = Array.new($maxawa,$maxf)
  w_counts = Array.new($maxawa){Array.new($maxw,1)}
  w_cond_counts= Array.new($maxawa,$maxw)

  ########################
  # Expectation - compute probabilities of the data
  ########################
  sentences.each_index{ |i|
    $stderr.print "."
    s = sentences[i]
    trellis = Array.new(s.length+1)
    trellis[0] = Array.new(trell_size){TrellNode.new}
    trellis[0][0].prob = 0.0 # Log(1.0)
    trellis[0][0].prev = -1
    ## Iterate from word 1 (first real word) to word (length-1)
    1.upto(s.length-1){ |j|
      max_state = 0
      word = s[j]
      trellis[j] = Array.new(trell_size){TrellNode.new}
      ## For each state at the previous time step (j-1):
      0.upto(trell_size-1){ |ind|
        node = trellis[j-1][ind]
        if node.prob <= MIN_INT
#          $stderr.puts "i=#{i}, j=#{j}, ind=#{ind} and node.prob=#{node.prob}, node.prev=#{node.prev}"
          next
        end
        #$stderr.puts "Looking at the #{ind} element at t-1"
        old_f, old_act, old_awa = breakdownIndex(ind)
        ## Iterate over future states
        0.upto($maxf-1){ |cur_f|
          p_old_f = f_cpt[old_awa][cur_f] + node.prob
#          $stderr.puts "p_old_f = #{p_old_f}"
          0.upto($maxact-1){ |cur_act|
            p_old_f_act = p_old_f + act_cpt[old_act][old_awa][cur_f][cur_act]
            0.upto($maxawa-1){ |cur_awa|
              new_index = getTrellisIndex(cur_f,cur_act,cur_awa)
              p_tot = p_old_f_act + 
                      awa_cpt[old_awa][cur_f][cur_act][cur_awa] +
                      w_cpt[cur_awa][word]
              if p_tot > trellis[j][new_index].prob
                trellis[j][new_index].prob = p_tot
                trellis[j][new_index].prev = ind
              end
              if j == s.length-1 and p_tot > trellis[j][max_state].prob
                max_state = new_index
              end
            }
          }
        }
      }
#      $stderr.puts "####################################################"
#      $stderr.puts "##### Trellis after t = #{j} ###################"
#      trellis[i][j].each_index{ |v| $stderr.puts "#{v} #{trellis[i][j][v].prob}"}
#      $stderr.puts "max_state = #{max_state}, word = #{word}, prob = #{trellis[i][j][max_state].prob}"
    }
    state = max_state #trellis[i][s.length-1][max_state].prev
    (s.length-1).downto(0){ |j|
      prev_state = trellis[j][state].prev
      prev_f, prev_act, prev_awa = breakdownIndex(prev_state)
      cur_f, cur_act, cur_awa = breakdownIndex(state)
#      $stderr.puts "#{prev_f},#{prev_act},#{prev_awa} => #{cur_f},#{cur_act},#{cur_awa}"
      word = s[j]
      f_cond_counts[prev_awa] += 1
      f_counts[prev_awa][cur_f] += 1
      act_cond_counts[prev_act][prev_awa][cur_f] += 1
      act_counts[prev_act][prev_awa][cur_f][cur_act] += 1
      awa_cond_counts[prev_awa][cur_f][cur_act] += 1
      awa_counts[prev_awa][cur_f][cur_act][cur_awa] += 1
      w_cond_counts[cur_awa] += 1
      w_counts[cur_awa][word] += 1
      state = prev_state
    }
  }
  
  $stderr.puts "Reweighting cpts..."
  $stderr.puts 
  ########################
  # Maximization - recompute the distributions of all the hidden variables
  ########################
  ## Reweight f cpt
  puts "################# F_CPT ###############"
  0.upto($maxawa-1){ |awa|
    print "#{awa} : "
    0.upto($maxf-1){ |f|
      f_cpt[awa][f] = Math.log(f_counts[awa][f].to_f / f_cond_counts[awa].to_f)
      print "#{f_cpt[awa][f]} "
    }
    print "\n"
  }
  
  puts "################# ACT_CPT #################"
  0.upto($maxact-1){ |old_act|
    0.upto($maxawa-1){ |old_awa|
      0.upto($maxf-1){ |f|
        print "#{old_act} #{old_awa} #{f} : "
        0.upto($maxact-1){ |new_act|
          act_cpt[old_act][old_awa][f][new_act] = Math.log(act_counts[old_act][old_awa][f][new_act].to_f / act_cond_counts[old_act][old_awa][f].to_f)
          print "#{act_cpt[old_act][old_awa][f][new_act]} "
        }
        print "\n"
      }
    }
  }
  
  ## Reweight awa cpt
  puts "############### AWA_CPT ###################"
  0.upto($maxawa-1){ |old_awa|
    0.upto($maxf-1){ |f|
      0.upto($maxact-1){ |act|
        print "#{old_awa} #{f} #{act} : "
        0.upto($maxawa-1){ |awa|
          awa_cpt[old_awa][f][act][awa] = Math.log(awa_counts[old_awa][f][act][awa].to_f / awa_cond_counts[old_awa][f][act].to_f)
          print "#{awa_cpt[old_awa][f][act][awa]} "
        }
        print "\n"
      }
    }
  }
  
  puts "################## W_CPT #####################"
  0.upto($maxawa-1){ |old_awa|
    print "#{old_awa} : "
    0.upto($maxw-1){ |word|
      w_cpt[old_awa][word] = Math.log(w_counts[old_awa][word].to_f / w_cond_counts[old_awa].to_f)
      print "#{w_cpt[old_awa][word]} "
    }
    print "\n"
  }
}

## Print cpts
#puts "################# F_CPT ###############"
#0.upto($maxawa-1){ |awa|
#  print "#{awa} : "
#  0.upto($maxf-1){ |f|
#    print "#{(f_cpt[awa][f])} "
#  }
#  print "\n"
#}

#puts "################# ACT_CPT #################"
#0.upto($maxact-1){ |prev_act|
#  0.upto($maxawa-1){ |prev_awa|
#    0.upto($maxf-1){ |f|
#      print "#{prev_act} #{prev_awa} #{f} : "
#      0.upto($maxact-1){ |act|
#        print "#{act_cpt[prev_act][prev_awa][f][act]} "
#      }
#      print "\n"
#    }
#  }
#}
#
#puts "############### AWA_CPT ###################"
#0.upto($maxawa-1){ |prev_awa|
#  0.upto($maxf-1){ |f|
#    0.upto($maxact-1){ |act|
#      print "#{prev_awa} #{f} #{act} : "
#      0.upto($maxawa-1){ |awa|
#        print "#{(awa_cpt[prev_awa][f][act][awa])} "
#      }
#      print "\n"
#    }
#  }
#}
#
#puts "################## W_CPT #####################"
#0.upto($maxawa-1){ |awa|
#  print "#{awa} : "
#  0.upto($maxw-1){ |w|
#    print "#{(w_cpt[awa][w])} "
#  }
#  print "\n"
#}

        