###############################################################################
##                                                                           ##
## This file is part of ModelBlocks. Copyright 2009, ModelBlocks developers. ##
##                                                                           ##
##    ModelBlocks is free software: you can redistribute it and/or modify    ##
##    it under the terms of the GNU General Public License as published by   ##
##    the Free Software Foundation, either version 3 of the License, or      ##
##    (at your option) any later version.                                    ##
##                                                                           ##
##    ModelBlocks is distributed in the hope that it will be useful,         ##
##    but WITHOUT ANY WARRANTY; without even the implied warranty of         ##
##    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the          ##
##    GNU General Public License for more details.                           ##
##                                                                           ##
##    You should have received a copy of the GNU General Public License      ##
##    along with ModelBlocks.  If not, see <http://www.gnu.org/licenses/>.   ##
##                                                                           ##
###############################################################################

## backoff-EM.pl
##   back-off for the M model, or perhaps this will become general at some point
use Getopt::Std;
use strict;

my $MAX_ITER   = 1000; #for E-M finishing
my $PRINT_ITER = 100;
my $VERBOSE    = 0;

getopts('o:i:p:v'); # will print out M instead of GG
our( $opt_o, $opt_i, $opt_p, $opt_v);
my $HELDOUT = $ARGV[0];

########### OPTIONS
if ($opt_o ne "") {
    my $filename = $opt_o.".log";
    open( LOG, '>', $filename) or die $!;
}
if ($opt_i ne "") {
    $MAX_ITER = $opt_i; #print STDERR " Will stop at $MAX_ITER iteration(s).\n";
}
if ($opt_p ne "") {
    $PRINT_ITER = $opt_p;
}
if ($opt_v ne "") {
    $VERBOSE = 1;
}
###########

# temporary
my %FreqMargE;  # indexed by model
my %FreqMargL;  # indexed by model
my %FreqMargLE; # indexed by model
my %FreqALL;  # indexed by model
my %MapMargE;  # indexed by model
my %MapMargL;  # indexed by model
my %MapMargLE; # indexed by model
my %MapALL;  # indexed by model


# important
my %MargE;  # indexed by model
my %MargL;  # indexed by model
my %MargLE; # indexed by model
my %P_ALL;
my %P_COND;

###############################################################################
##### (1a) Read in held out data from ARGV[0], initialize
###############################################################################

## for each rule...
open HELDOUT,'<',$HELDOUT or die "ERROR: can't read file $HELDOUT\n";
while ( <HELDOUT> ) {
    chomp;

    if ( m/^\#/ ) { 
	next; 

	# read in a conditional rule
    } elsif ( m/^([^ ]+) (.*) : ([^=]*)( = (.*))?$/ ) {
	my $model = $1; my $cond = $2; my $targ = $3; my $prob = $5;
	$P_ALL{$1}{"$2 : $3"}  = defined($5) ? $5 : $P_ALL{$1}{"$2 : $3"}+1;
	$P_COND{$1}{$2}       += defined($5) ? $5 : 1;
	#print "read model($1) cond($2) : targ($3) = prob($5),    cond=$P_COND{$1}{$2}\n";

	if ($1 eq "MdL" || $1 eq "MdR") {
	    (my $d,my $l,my $c,my $e) = split(/[ :\{\}]/,$2);
	    $FreqMargE{$model}{"$d $l:$c : $targ"}    += defined($prob) ? $prob : 1;
	    $FreqMargL{$model}{"$d $c\{$e\} : $targ"} += defined($prob) ? $prob : 1;
	    $FreqMargLE{$model}{"$d $c : $targ"}      += defined($prob) ? $prob : 1;
	    $FreqALL{$model}                          += defined($prob) ? $prob : 1;

	    $MapMargE{$model}{"$d $l:$c\{$e\} : $targ"}  = "$d $l:$c : $targ";
	    $MapMargL{$model}{"$d $l:$c\{$e\} : $targ"}  = "$d $c\{$e\} : $targ";
	    $MapMargLE{$model}{"$d $l:$c\{$e\} : $targ"} = "$d $c : $targ";
	    $MapALL{$model}{"$d $l:$c\{$e\} : $targ"}    ="$d $l:$c\{$e\} : $targ";

	}

	# read in a prior rule
    } elsif ( m/^([^ ]+) +: ([^=]*)( = (.*))?$/ ) {
	$P_ALL{$1}{": $2"} = defined($4) ? $4 : $P_ALL{$1}{": $2"}+1;
	$P_COND{$1}{""}   += defined($4) ? $4 : 1;
    }

}

# normalize into probabilities
foreach my $model ( "MdL", "MdR" ) {
    foreach my $all ( sort keys %{$P_ALL{$model}} ) {
	#print "all $model $all, dividing by $FreqALL{$model}\n";
	my ($d,$l,$c,$e,$targ) = ($all =~ /([^ ]*) ([^ ]*):([^ ]*)\{([^ ]*)\} : (.*)/);
	$MargE{$model}{"$d $l:$c : $targ"}  = $FreqMargE{$model}{"$d $l:$c : $targ"} / $FreqALL{$model};
	$MargL{$model}{"$d $c\{$e\} : $targ"}  = $FreqMargL{$model}{"$d $c\{$e\} : $targ"} / $FreqALL{$model};
	$MargLE{$model}{"$d $c : $targ"} = $FreqMargLE{$model}{"$d $c : $targ"} / $FreqALL{$model};
	$P_ALL{$model}{$all}  = $P_ALL{$model}{$all} / $FreqALL{$model};
    }
}

## distributions that matter
my %P_Z;
$P_Z{"MdL"}{"1111"}=.5; $P_Z{"MdL"}{"1110"}=.5; #$P_Z{"MdL"}{"1011"}=.25; $P_Z{"MdL"}{"1010"}=.25;
$P_Z{"MdR"}{"1111"}=.5; $P_Z{"MdR"}{"1110"}=.5; #$P_Z{"MdR"}{"1011"}=.25; $P_Z{"MdR"}{"1010"}=.25;
my %CONST_P_Z = %P_Z; # just used for the keys
my %P_ZgivALL;


###############################################################################
##### (1b) Read in training data from STDIN, initialize
###############################################################################
my %P_TARGgivZCOND;
my %P_TRAININGKEYS; # same keys as targgivzcond, but no z
my %P_ZCOND; # temporary denominator

my %P_TRAINALL;
my %P_TRAINCOND;
### initialize P(lc lc|z d lce)
#foreach my $m ( "MdL", "MdR" ) {
#    foreach my $z ( sort keys %{$CONST_P_Z{$m}} ) {
#	my %P_MODELALL;
#	if    ($z=~/^.0.0$/) { %P_MODELALL = %MargLE; }
#	elsif ($z=~/^.1.0$/) { %P_MODELALL = %MargE; } 
#	elsif ($z=~/^.0.1$/) { %P_MODELALL = %MargL; } 
#	elsif ($z=~/^.1.1$/) { %P_MODELALL = %P_ALL; }
#	foreach my $all ( sort keys %{$P_MODELALL{$m}} ) {
#	    $P_TARGgivZCOND{$m}{"$z $all"} = 1/(keys %{$P_Z{$m}});
#	}
#    }
#}

while (<STDIN>) {
    chomp;
    # look specifically for M models
    if ( m/^(M[^ ]*) (.*) : ([^=]*)( = (.*))?$/ ) {
	my $model = $1; my $cond = $2; my $targ = $3; my $prob = $5;

	(my $d,my $l,my $c,my $e) = split(/[ :\{\}]/,$cond);
	$MapMargE{$model}{"$cond : $targ"}  = "$d $l:$c : $targ";
	$MapMargL{$model}{"$cond : $targ"}  = "$d $c\{$e\} : $targ";
	$MapMargLE{$model}{"$cond : $targ"} = "$d $c : $targ";
	$MapALL{$model}{"$cond : $targ"}    = "$d $l:$c\{$e\} : $targ";
	
	$P_TRAININGKEYS{$model}{"$cond : $targ"} += defined($5) ? $prob : 1;
	foreach my $z ( sort keys %{$CONST_P_Z{$model}} ) {
	    my %Map_MODELALL;
	    if    ($z=~/^.0.0$/) { %Map_MODELALL = %MapMargLE; }
	    elsif ($z=~/^.1.0$/) { %Map_MODELALL = %MapMargE; } 
	    elsif ($z=~/^.0.1$/) { %Map_MODELALL = %MapMargL; } 
	    elsif ($z=~/^.1.1$/) { %Map_MODELALL = %MapALL; }
	    $P_TARGgivZCOND{$model}{$z." ".$Map_MODELALL{$model}{"$cond : $targ"}} += $prob;
	    $P_ZCOND{$model}{$targ}                                           += $prob;  # leave out $z, unnecessary if backoff models are unique
	}
    } elsif ( m/^([^ ]+) (.*) : ([^=]*)( = (.*))?$/ ) {
	my $model = $1; my $cond = $2; my $targ = $3; my $prob = $5;
	$P_TRAINALL{$1}{"$2 : $3"}  = defined($5) ? $5 : $P_TRAINALL{$1}{"$2 : $3"}+1;
	$P_TRAINCOND{$1}{$2}       += defined($5) ? $5 : 1;
	$P_TRAININGKEYS{$model}{"$cond : $targ"} += $prob;
	#print "read model($1) cond($2) : targ($3) = prob($5),    cond=$P_COND{$1}{$2}\n";

     # read in a prior rule
    } elsif ( m/^([^ ]+) +: ([^=]*)( = (.*))?$/ ) {
	$P_TRAINALL{$1}{": $2"} = defined($4) ? $4 : $P_TRAINALL{$1}{": $2"}+1;
	$P_TRAINCOND{$1}{""}   += defined($4) ? $4 : 1;
	$P_TRAININGKEYS{$1}{": $2"} += defined($4) ? $4 : 1;
    }

}

# normalize into probabilities
foreach my $model (keys %P_TARGgivZCOND) {
    foreach my $all (keys %{$P_TARGgivZCOND{$model}}) {
	my ($cond, $targ) = split(' : ',$all);
	$P_TARGgivZCOND{$model}{$all} = $P_TARGgivZCOND{$model}{$all} / $P_ZCOND{$model}{$targ};
    }
}
# will do other-rule normalization at printout

## main loop
for my $iter (1..$MAX_ITER) {

    if ($iter % $PRINT_ITER == 0) {
	print STDERR "E-M iteration $iter...\n";
    }


    ###########################################################################
    ##### (2) Expectation Step 
    ###########################################################################
    
    my %MargZ_ALL;  # temporary denominator, M-step: Sum_z(P(zdcelclc)) for P(z|dce lc lc)

    ### prepare for new expectation counts
    foreach my $m (keys %P_Z) {
	delete $P_ZgivALL{$m};
    }

    ## using current backoff parameters, estimate desired probability model (numerator)
    foreach my $model ( "MdL", "MdR" ) {
	foreach my $z ( sort keys %{$CONST_P_Z{$model}} ) {
	    if ($VERBOSE) { print " iter=$iter, model=$model, z=$z\n"; }

	    my %P_MODELALL;
	    my %Map_MODELALL;
	    if    ($z=~/^.0.0$/) { %P_MODELALL = %MargLE; %Map_MODELALL = %MapMargLE; }
	    elsif ($z=~/^.1.0$/) { %P_MODELALL = %MargE;  %Map_MODELALL = %MapMargE; } 
	    elsif ($z=~/^.0.1$/) { %P_MODELALL = %MargL;  %Map_MODELALL = %MapMargL; } 
	    elsif ($z=~/^.1.1$/) { %P_MODELALL = %P_ALL;  %Map_MODELALL = %MapALL; }

	    foreach my $all (keys %{$P_ALL{$model}}) {
		my ($cond, $targ) = split(' : ', $all);
	        my $some = $Map_MODELALL{$model}{$all};
		my $prob = $P_Z{$model}{$z} * $P_TARGgivZCOND{$model}{"$z $some"};  # * $P_COND{$model}{$cond};  --- unnecessary, not $z-dependent
		$P_ZgivALL{$model}{$z}{$all}   = $prob;
		$P_ZgivALL{$model}{$z}{$some} += $prob;  ## added 2/6, 10:40pm to fix 1010 favoring after TRAINING data used
		$MargZ_ALL{$model}{$all}      += $prob;
		$MargZ_ALL{$model}{$some}     += $prob;

		if ($VERBOSE) { print "E-step P($z|all)=$prob ...   P($z)=".$P_Z{$model}{$z}." * P($some;$z)=".$P_TARGgivZCOND{$model}{"$z $some"}."---\n"; }
	    }
	}
    }
    
    ## normalize
    foreach my $model ( "MdL", "MdR" ) {
	foreach my $z ( sort keys %{$CONST_P_Z{$model}} ) {
	    if ($VERBOSE) { print " iter=$iter, model=$model, z=$z\n"; }

	    my %Map_MODELALL;
	    if    ($z=~/^.0.0$/) { %Map_MODELALL = %MapMargLE; }
	    elsif ($z=~/^.1.0$/) { %Map_MODELALL = %MapMargE; } 
	    elsif ($z=~/^.0.1$/) { %Map_MODELALL = %MapMargL; } 
	    elsif ($z=~/^.1.1$/) { %Map_MODELALL = %MapALL; }

	    foreach my $some (keys %{$P_ZgivALL{$model}{$z}}) {
	        if ($VERBOSE) { print "E-step from F($z|all)=".$P_ZgivALL{$model}{$z}{$some}." / F($some)=".$MargZ_ALL{$model}{$some}."  to  "; }

		$P_ZgivALL{$model}{$z}{$some} = ($P_ZgivALL{$model}{$z}{$some}>0) ? $P_ZgivALL{$model}{$z}{$some} / $MargZ_ALL{$model}{$some} : 0;

		if ($VERBOSE) { print "P($z|$some)=$P_ZgivALL{$model}{$z}{$some}---\n"; }
	    }
	}
    }

    ###########################################################################
    ##### (3) Maximization Step 
    ###########################################################################

    my %P_ZALL;    # temporary numerator;   E-step for P(lclc|zdce)
    my %P_ZCOND;   # temporary denominator; E-step for P(lclc|zdce)
    my %MargALL_Z; # temporary numerator;   E-step for P(z)
    my %MargZALL;  # temporary denominator; E-step for P(z)

    # prepare for new expectation counts
    foreach my $m (keys %P_Z) {
	delete $P_Z{$m};
#	delete $P_TARGgivZCOND{$m};
    }

    # calculate P(z)... the backoff parameter weights
    foreach my $model ( "MdL", "MdR" ) {
	foreach my $z ( sort keys %{$CONST_P_Z{$model}} ) {
	    if ($VERBOSE) { print " iter=$iter, model=$model, z=$z\n"; }

	    my %P_MODELALL;
	    my %Map_MODELALL;
	    if    ($z=~/^.0.0$/) { %P_MODELALL = %MargLE; %Map_MODELALL = %MapMargLE; }
	    elsif ($z=~/^.1.0$/) { %P_MODELALL = %MargE;  %Map_MODELALL = %MapMargE; } 
	    elsif ($z=~/^.0.1$/) { %P_MODELALL = %MargL;  %Map_MODELALL = %MapMargL; } 
	    elsif ($z=~/^.1.1$/) { %P_MODELALL = %P_ALL;  %Map_MODELALL = %MapALL; }

	    foreach my $some (sort keys %{$P_MODELALL{$model}} ) {
#	    foreach my $all (sort keys %{ %{$P_ALL{$model}},%{$P_TRAININGKEYS{$model}} } ) {
		#my $some = $Map_MODELALL{$model}{$all};
		#my ($cond, $targ) = split(' : ', $some);
		my $prob = $P_ZgivALL{$model}{$z}{$some} * $P_MODELALL{$model}{$some};
		$P_Z{$model}{$z}     += $prob;  # keep this one for next time step
		$MargZALL{$model}     += $prob;
#		$P_ZALL{$model}{"$z $all"}    = $prob;
#		$P_ZCOND{$model}{"$z $cond"} += $prob; 

		if ($VERBOSE) { print "M-step P($z)=$P_Z{$model}{$z}... P($z,all)=$prob... P($z|$some)=".$P_ZgivALL{$model}{$z}{$some}." * P~(some)=".$P_MODELALL{$model}{$some}."---\n"; } #.",   P($z $cond)=".$P_ZCOND{$model}{"$z $cond"}."---\n"; 
	    }
	}

    }
    
    # normalize
    foreach my $model ( "MdL", "MdR" ) {
	foreach my $z ( sort keys %{$CONST_P_Z{$model}} ) {
	    if ($VERBOSE) { print " iter=$iter, model=$model, z=$z\n"; }

	    my %Map_MODELALL;
	    if    ($z=~/^.0.0$/) { %Map_MODELALL = %MapMargLE; }
	    elsif ($z=~/^.1.0$/) { %Map_MODELALL = %MapMargE; } 
	    elsif ($z=~/^.0.1$/) { %Map_MODELALL = %MapMargL; } 
	    elsif ($z=~/^.1.1$/) { %Map_MODELALL = %MapALL; }

	    if ($VERBOSE) { print "M-step from F($z)=$P_Z{$model}{$z} / P()=$MargZALL{$model} --to-- "; }

	    $P_Z{$model}{$z} = $P_Z{$model}{$z} / $MargZALL{$model};

	    if ($VERBOSE) { print "P($z)=$P_Z{$model}{$z}\n"; }

#	    foreach my $all (sort keys %{$P_ALL{$model}}) {
#		my $some = $Map_MODELALL{$model}{$all};
#		my ($cond, $targ) = split(' : ', $some);
#		#print "M-step from P($z)=$P_Z{$model}{$z}... P($z,all)=".$P_ZALL{$model}{"$z $all"}." / P($z,$cond)=".$P_ZCOND{$model}{"$z $cond"}." --to-- ";
#		$P_TARGgivZCOND{$model}{"$z $some"} = $P_ZALL{$model}{"$z $all"} / $P_ZCOND{$model}{"$z $cond"};
#		#print "P($all;$z)=".$P_TARGgivZCOND{$model}{"$z $some"}."\n";
#	    }

	}
    }
    

    ###########################################################################
    ##### (4) Convergence Criterion / Model dumps
    ###########################################################################

    if ($iter % $PRINT_ITER == 0 && $iter != $MAX_ITER && $opt_o ne "") {

	my %Pout_TARGgivCOND;
        foreach my $m ( "MdL", "MdR" ) {

	    foreach my $z (keys %{$CONST_P_Z{$m}}) {

		my %Map_MODELALL;
		if    ($z=~/^.0.0$/) { %Map_MODELALL = %MapMargLE; }
		elsif ($z=~/^.1.0$/) { %Map_MODELALL = %MapMargE; } 
		elsif ($z=~/^.0.1$/) { %Map_MODELALL = %MapMargL; } 
		elsif ($z=~/^.1.1$/) { %Map_MODELALL = %MapALL; }

		foreach my $all (keys %{$P_TRAININGKEYS{$m}}) {
		    $Pout_TARGgivCOND{$m}{$all} += $P_Z{$m}{$z} * $P_TARGgivZCOND{$m}{$z." ".$Map_MODELALL{$m}{$all}};
		    #print STDERR "P($all)=".$Pout_TARGgivCOND{$m}{$all}." ...  $m $z = ".$P_Z{$m}{$z}." *...$z...$Map_MODELALL{$m}{$all} = ".$P_TARGgivZCOND{$m}{$z." ".$Map_MODELALL{$m}{$all}}."---\n";
		}
	    }

	}
	print( LOG "------------------ E-M iteration $iter ---------------------\n" );
	foreach my $m (sort keys %CONST_P_Z) {
	    foreach my $z (sort keys %{$CONST_P_Z{$m}}) {
		print LOG "Z $m $z = $P_Z{$m}{$z}\n"
	    }
	}
	foreach my $m (sort keys %P_TRAININGKEYS) {
	    if ($m eq "MdL" || $m eq "MdR") {
		foreach my $all (sort keys %{$Pout_TARGgivCOND{$m}}) {
		    #foreach my $all (sort keys %{$P_TRAININGKEYS{$m}}) {
		    print( LOG "$m $all = $Pout_TARGgivCOND{$m}{$all}\n");
		}
	    }
	}
    }

}


### FINAL PRINTOUT
my %Pout_TARGgivCOND;
foreach my $m (keys %P_TRAININGKEYS) {

    if ($m eq "MdL" || $m eq "MdR") {

	foreach my $z (keys %{$CONST_P_Z{$m}}) {

	    my %Map_MODELALL;
	    if    ($z=~/^.0.0$/) { %Map_MODELALL = %MapMargLE; }
	    elsif ($z=~/^.1.0$/) { %Map_MODELALL = %MapMargE; } 
	    elsif ($z=~/^.0.1$/) { %Map_MODELALL = %MapMargL; } 
	    elsif ($z=~/^.1.1$/) { %Map_MODELALL = %MapALL; }

	    foreach my $all (keys %{$P_TRAININGKEYS{$m}}) {
		$Pout_TARGgivCOND{$m}{$all} += $P_Z{$m}{$z} * $P_TARGgivZCOND{$m}{$z." ".$Map_MODELALL{$m}{$all}};
		#print "prob=".$Pout_TARGgivCOND{$m}{$all}." ...  $m $z = ".$P_Z{$m}{$z}." *...$z...$Map_MODELALL{$m}{$all} = ".$P_TARGgivZCOND{$m}{$z." ".$Map_MODELALL{$m}{$all}}."---\n";
	    }
	}
	
    }
}
foreach my $m (sort keys %P_TRAININGKEYS) {
    if ($m eq "MdL" || $m eq "MdR") {
	foreach my $all (sort keys %{$Pout_TARGgivCOND{$m}}) {
	    print( STDOUT "$m $all = $Pout_TARGgivCOND{$m}{$all}\n");
	}
    } else {
	foreach my $all (sort keys %{$P_TRAININGKEYS{$m}}) {
	    my ($cond, $targ) = split(': ',$all);
	    $cond =~ s/\ *$//g;
	    #print STDERR "reg model $m $all w/ cond $cond, $P_ALL{$m}{$all} / $P_COND{$m}{$cond} will get set\n";
	    my $prob = $P_TRAINALL{$m}{$all} / $P_TRAINCOND{$m}{$cond};
	    print( STDOUT "$m $all = $prob\n" );
	}
    }
}
