#!/usr/bin/perl -w

# runBitextCotrain.pl

# Shane Bergsma
# December 6, 2010

# Run co-training in the bitext directory.
my ($nflag, $numIterations, $fflag, $focus, $uflag, $unlabPool, $pflag, $positiveProportion, $aflag, $num2Add, $cflag, $CValues) = @ARGV;

# numIterations: how many we run
# focus: what we optimize c params on each iteration
# unlabPool: where we draw the unlabeled ones to add
# positiveProportion: how many positives and negatives to add each time
# num2Add: how many to add each time

# Focus defines what classifier we choose at each iteration -- one
# that does well on wsj or one that does well on bitext.1?  Better:
# just stick with same C-value throughout.

die "Missing or too-many params." unless (@ARGV == 12);
die "Bad command-line arguments"
  unless ($cflag eq "-c" && $nflag eq "-n" && $uflag eq "-u" && $fflag eq "-f" && $pflag eq "-p" && $aflag eq "-a");
die "Taking more than unlabeled pool size: $unlabPool < $num2Add" unless ($unlabPool > $num2Add);

for (my $i=0; $i<$numIterations; $i++) {
  print STDERR "Running co-training iteration $i\n";
  # Train and produce the classifications on all the data of interest:
  printAndExecute("Do/runEvalCotrain1.sh $CValues");
  printAndExecute("Do/runEvalCotrain2.sh $CValues");
  # Evaluate the data and get the highest scoring c-value
  my $cval1 = evalIteration($i, $CValues, $focus, 1);
  my $cval2 = evalIteration($i, $CValues, "bitext", 2);
  # Also, for fun, how is the combined classifier doing?
  printCombinedAccuracies($cval1, $cval2);
  # Extract the top guys from the unlabeled set, in accordance with the
  # given proportion and numbers:
  # Keep track of which ones to remove (which ones get added to L):
  my %movedExamples;
  addExamplesFromTo(1, 2, $cval1, $num2Add, $unlabPool, $positiveProportion, \%movedExamples);
  # Add much more from 2 to 1, since that needs more data...
  addExamplesFromTo(2, 1, $cval1, $num2Add*($i+1), $unlabPool*($i+1), $positiveProportion, \%movedExamples);
  #  addExamplesFromTo(2, 1, $cval1, $num2Add, $unlabPool, $positiveProportion, \%movedExamples);
  # Now overwrite the labeled sets:
  ####################################
  # 4) Overwrite the two unlabeled data sets with the remainder
  ####################################
  # A: read them:
  open(FVS_1, "Data/Cotrain/bitext.1.unlab.fvs") or die "No FVs";
  open(FVS_2, "Data/Cotrain/bitext.2.unlab.fvs") or die "No FVs";
  my @fvs1 = <FVS_1>;
  my @fvs2 = <FVS_2>;
  die "What the?" if (scalar(@fvs1) != scalar(@fvs2));
  close(FVS_1);
  close(FVS_2);
  # B: write them back, skipping moved ones:
  open(FVS_1, ">Data/Cotrain/bitext.1.unlab.fvs") or die "No FVs";
  open(FVS_2, ">Data/Cotrain/bitext.2.unlab.fvs") or die "No FVs";
  my $numPrinted = 0;
  for (my $j=0; $j<@fvs1; $j++) {
	if (!defined($movedExamples{$j})) {
	  print FVS_1 $fvs1[$j];
	  print FVS_2 $fvs2[$j];
	  $numPrinted++;
	}
  }
  close(FVS_1);
  close(FVS_2);
  print STDERR "Wrote $numPrinted unlabeled back to bitext.1/2.unlab\n";
}

# Find out which examples had high prediction values
sub addExamplesFromTo {
  ####################################
  # 1) Read the parameters
  ####################################
  my ($fromV, $toV, $cval, $toAdd, $uPool, $posProp, $movedExamplesHashName) = @_;
  ####################################
  # 2) Open up the fv and prob files
  ####################################
  # For convenienve, get the score-to-line-number mapping:
  my %score2LineNumber;
  # First, read and hash the corresponding unlab predictions AND also the to-fvs:
  my %scoreHashTo;
  # Also want to keep track of the from-FVs so we can remove them from the unlabeled pool:
  my %scoreHashFrom;
  # open the files:
  # The unlabeled FVs : both views (as we take from both of these)
  open(FVS_TO, "Data/Cotrain/bitext.$toV.unlab.fvs") or die "No FVs";  my @fvsTo = <FVS_TO>; close(FVS_TO);
  open(FVS_FR, "Data/Cotrain/bitext.$fromV.unlab.fvs") or die "No FVs"; my @fvsFrom = <FVS_FR>; close(FVS_FR);
  # The predictions for $fromV
  open(PREDS, "Data/ML_CT/bitext.$fromV.unlab.fvs.bitext.$fromV.train.fvs.c$cval.model.preds")
	or die "No preds: Data/ML_CT/bitext.$fromV.unlab.fvs.bitext.$fromV.train.fvs.c$cval.model.preds";
  ####################################
  # 3) Determine the prob/label mapping
  ####################################
  # We want low scores to be label 0 and high scores to be label 1
  my $labels = <PREDS>; # Just skip over this line
  chomp $labels;
  my ($dumm, $lab1, $lab2) = split(/ /, $labels);
  ####################################
  # 4) Check to see if enough data
  ####################################
  my @preds = <PREDS>; close(PREDS);
  my $lineNumber = 0;
  die "Mismatching unlabeled files: bitext.$toV.unlab.fvs and bitext.$fromV.unlab.fvs.bitext.$fromV.train.fvs.c$cval.model.preds " unless (scalar(@preds) == scalar(@fvsTo));
  die "Mismatching unlabeled files: bitext.$fromV.unlab.fvs and bitext.$toV.unlab.fvs" unless (scalar(@fvsTo) == scalar(@fvsFrom));
  my $numUnlabeled = scalar(@preds);
  print STDERR "Drawing $toAdd from pool of $uPool (of $numUnlabeled) unlabeled\n";
  if ($uPool > $numUnlabeled) {
	print STDERR "Warning:  Pool size bigger than remaining: $uPool > $numUnlabeled\n";
	$uPool = $numUnlabeled;
  }
  ####################################
  # 5) Go through line-by-line
  ####################################
  for (my $i=0; $i<$uPool; $i++) {
	$fvTo = $fvsTo[$i];	$fvFrom = $fvsFrom[$i];	$pred = $preds[$i];
    chomp $fvTo; chomp $fvFrom; chomp $pred;
    my ($truthTo, @restTo) = split(/ /, $fvTo);
    my ($truthFrom, @restFrom) = split(/ /, $fvFrom);
    my ($prd, $prob1, $prob0) = split(/ /, $pred);
	# Which prob you use depends on the labels you read.
	if ($lab1 == 1) {
	  $prob = $prob1;
	} else {
	  $prob = $prob0;
	}
	# Add randomness to prevent any two fvs from having the same score
	my $goodProb = sprintf("%.15f", $prob+0.000000001 * rand() + 0.0000000000001 * rand());
	$scoreHashTo{$goodProb} = "@restTo";
	$scoreHashFrom{$goodProb} = "@restFrom";
	# Hash a signature composed of the goodProb AND the restTo field,
	# both of which should be unique.
	$score2LineNumber{"$goodProb...@restTo"} = $lineNumber++;
	# Could also be written: $score2LineNumber{"$goodProb...$scoreHashTo{$goodProb}"} = $lineNumber++;
  }
  ####################################
  # 7) Write the low/top-scoring FVs to the train set
  ####################################
  # Figure out how many to add: include some random noise so that we
  # don't systematically add a wrong proportion after rounding:
  my $numPositiveToAdd = int($posProp * $toAdd + rand());
  my $numNegativeToAdd = int($toAdd - $numPositiveToAdd);
  my $totalAdd = $numPositiveToAdd + $numNegativeToAdd;
  die "Error: to-Add size bigger than remaining: $totalAdd > $numUnlabeled" if ($totalAdd > $numUnlabeled);
  my @lowScores = (sort {$a<=>$b} keys %scoreHashTo)[1..$numNegativeToAdd];
  my @topScores = (reverse sort {$a<=>$b} keys %scoreHashTo)[1..$numPositiveToAdd];
  # And append them to the training set:
  open(TRAIN, ">>Data/Cotrain/bitext.$toV.train.fvs") or die "No FVs";
  # Keeping track of what you moved in the movedExampleHash:
  foreach $lowScore (@lowScores) {
	print TRAIN "0 $scoreHashTo{$lowScore}\n";
	my $movedLineNumber = $score2LineNumber{"$lowScore...$scoreHashTo{$lowScore}"};
	${$movedExamplesHashName}{$movedLineNumber} = 1;
  }
  foreach $topScore (@topScores) {
	print TRAIN "1 $scoreHashTo{$topScore}\n";
	my $movedLineNumber = $score2LineNumber{"$topScore...$scoreHashTo{$topScore}"};
	${$movedExamplesHashName}{$movedLineNumber} = 1;
  }
  close(TRAIN);
}

# Return c-value that gives the top accuracy for this iteration
sub evalIteration {
  my ($iteration, $CVals, $foc, $view) = @_;
  # $foc chooses a sub-index in the accuracy array:
  my $f;
  if ($foc eq "bitext") {
    $f = 0; # For bitext accuracy
  } elsif ($foc eq "wsj") {
	die "bad coding: no view2 for wsj" if ($view == 2);
    $f = 2; # For wsj accuracy
  } else {
    die "Unknown focus.";
  }
  my $devAcc = -1;
  my $testAcc;
  my $bestC;
  my @bestAllAccs;
  foreach $cval (split(/_/, $CVals)) {
    #    print "Checking accuracy at $cval\n";
    my @accuracies = getAccuracies($cval, $view, $iteration);
    if ($accuracies[$f] > $devAcc) {
      $devAcc = $accuracies[$f];
      $testAcc = $accuracies[$f+1];
	  @bestAllAccs = @accuracies;
      $bestC = $cval;
    }
  }
  if ($view == 1) {
	#	printf STDERR "$view - Iter=$iteration - Accuracy, $foc is at $bestC -- Dev:%.1f, Test:%.1f\n", $devAcc, $testAcc;
	printf STDERR "$view - Iter=$iteration - Accuracy, all - bitext - dev:%.1f, test:%.1f, wsj - dev:%.1f, test:%.1f, train:%.1f\n", $bestAllAccs[0], $bestAllAccs[1], $bestAllAccs[2], $bestAllAccs[3], $bestAllAccs[4];
  } else {
	printf STDERR "$view - Iter=$iteration - Accuracy, bitext - $bestC - Dev:%.1f, Test:%.1f\n", $devAcc, $testAcc;
  }
  return $bestC;
}

sub getAccuracies {
  my ($cval, $view, $iteration) = @_;
  my $bitextD = getAccuracy("bitext.$view.dev", $cval, $view, $iteration);
  my $bitextT = getAccuracy("bitext.$view.test", $cval, $view, $iteration);
  my $wsjD = 0;
  my $wsjTr = 0;
  if ($view == 1) {
	$wsjD = getAccuracy("wsj.dev", $cval, $view, $iteration);
	$wsjT = getAccuracy("wsj.test", $cval, $view, $iteration);
	$wsjTr = getAccuracy("wsj.train", $cval, $view, $iteration);
  }
  return ($bitextD, $bitextT, $wsjD, $wsjT, $wsjTr);
}

sub getAccuracy {
  my ($file, $cval, $view, $iteration) = @_;
  open(FVS, "Data/Cotrain/$file") or die "No FVs";
  open(PREDS, "Data/ML_CT/$file.fvs.bitext.$view.train.fvs.c$cval.model.preds")
	or die "No preds: Data/ML_CT/$file.fvs.bitext.$view.train.fvs.c$cval.model.preds";
  my $numCorrect = 0;
  my $total = 0;
  my $labels = <PREDS>; # Just skip over this line
  while (defined($fv = <FVS>) && defined($pred = <PREDS>)) {
    chomp $fv; chomp $pred;
    my ($truth, @rest) = split(/ /, $fv);
    my ($prd, $prob1, $prob0) = split(/ /, $pred); # Doesn't matter the order of prob1/prob2 since prd is independent of that.
    $numCorrect++ if ($truth == $prd);
    $total++;
  }
  # And while we're here, store for later processing:
  Execute("cp Data/ML_CT/$file.fvs.bitext.$view.train.fvs.c$cval.model.preds Data/ML_CT/$file.fvs.bitext.$view.train.fvs.c$cval.model.preds.it$iteration");
  #	if ($file =~ /^wsj/); -- now, want ALL the test sets... just get all
  return 100 * $numCorrect / $total;
  close(PREDS);
  close(FVS);

}

sub printCombinedAccuracies {
  my ($cval1, $cval2) = @_;
  my $bitextD = getCombinedAccuracy("dev", $cval1, $cval2);
  my $bitextT = getCombinedAccuracy("test", $cval1, $cval2);
  printf STDERR "C - Accuracy, bitext at $cval1/$cval2 -- Dev:%.1f, Test:%.1f\n", $bitextD, $bitextT;
}

sub getCombinedAccuracy {
  my ($set, $cval1, $cval2) = @_;
  open(FVS, "Data/Cotrain/bitext.2.$set") or die "No FVs";
  open(PREDS1, "Data/ML_CT/bitext.1.$set.fvs.bitext.1.train.fvs.c$cval1.model.preds");
  open(PREDS2, "Data/ML_CT/bitext.2.$set.fvs.bitext.2.train.fvs.c$cval2.model.preds");
  my $numCorrect = 0;
  my $total = 0;
  my $labels1 = <PREDS1>; # Just skip over this line
  chomp $labels1;
  my ($dumm1, $lab11, $lab12) = split(/ /, $labels1);
  my $labels2 = <PREDS2>; # Just skip over this line
  chomp $labels2;
  my ($dumm2, $lab21, $lab22) = split(/ /, $labels2);
  while (defined($fv = <FVS>) && defined($pred1 = <PREDS1>) && defined($pred2 = <PREDS2>)) {
    chomp $fv; chomp $pred1; chomp $pred2;
    my ($truth, @rest) = split(/ /, $fv);
    my ($prd1, $prob11, $prob10) = split(/ /, $pred1);
    my ($prd2, $prob21, $prob20) = split(/ /, $pred2);
	# It could be that the '0' prob is actually in the prob*1 field of
	# one or both the above.  Swap if the labels tell us so:
	my $tmp;
	if ($lab11 == 0) { # swap(prob11, prob10)
	  $tmp = $prob11;
	  $prob11 = $prob10;
	  $prob10 = $tmp;
	}
	if ($lab21 == 0) { # swap(prob21, prob20)
	  $tmp = $prob21;
	  $prob21 = $prob20;
	  $prob20 = $tmp;
	}
	my $prediction = 0;
	$prediction = 1 if ($prob11 * $prob21 > $prob10 * $prob20);
    $numCorrect++ if ($truth == $prediction);
    $total++;
  }
  return 100 * $numCorrect / $total;
  close(PREDS);
  close(FVS);
}

# Print and execute the given string:
sub printAndExecute {
  print "> $_[0]\n";
  system "$_[0]";
}
# Similar to the above: spot the difference!
sub Execute {
  system "$_[0]";
}
