#!/usr/bin/python

import sys

# Given a dictionary of counts and an item
# increase the count of that item in the dictionary
def count(counts, item):
  if item in counts:
    counts[item] += 1
  else:
    counts[item] = 1


# Extract n-grams from sentence
def ngrams(sentence):
  tokens = sentence.split()

  unigrams = {}
  bigrams = {}
  trigrams = {}
  fourgrams = {}

  if len(tokens) >= 3:
    prev = [tokens[0], tokens[1], tokens[2]]
    count(unigrams, prev[0])
    count(unigrams, prev[1])
    count(unigrams, prev[2])
    count(bigrams, " ".join(prev[0:2]))
    count(bigrams, " ".join(prev[1:3]))
    count(trigrams, " ".join(prev))

    for token in tokens[3:]:
      prev.append(token)
      count(unigrams, prev[3])
      count(bigrams, " ".join(prev[2:4]))
      count(trigrams, " ".join(prev[1:4]))
      count(fourgrams, " ".join(prev))
      prev.pop(0)

  elif len(tokens) == 2:
    count(unigrams, tokens[0])
    count(unigrams, tokens[1])
    count(bigrams, " ".join(tokens))

  elif len(tokens) == 1:
    count(unigrams, tokens[0])

  return (unigrams, bigrams, trigrams, fourgrams)


# Given a dictionary of n-grams from the reference and
# a dictionary of n-grams from the candidate translation (for the same n),
# count the number of n-grams from the candidate that appear in the reference,
# each capped at the number of times it appears in the reference
def overlaps(reference, candidate):
  overlap = 0
  for ngram in candidate:
    if ngram in reference:
      overlap += min(candidate[ngram],reference[ngram])

  return overlap


# Entry point to script
if __name__ == "__main__":

  if len(sys.argv) != 5:
    sys.exit("Usage is: " + sys.argv[0] + " reference original-translations reordered-translations oracle-output\n" +
             "e.g. " + sys.argv[0] + " baseline/evaluation/newstest2009.reference.txt.1.tokenized " +
             "baseline/evaluation/newstest2009.recased.1 reordered/evaluation/newstest2009.recased.1 " +
             "oracle-evaluation/evaluation/newstest2009.oracle-output")

  ref_file = open(sys.argv[1])    # reference translations
  orig_file = open(sys.argv[2])   # translations produced by baseline system
  reord_file = open(sys.argv[3])  # translations produced by reordering-as-preprocessing system
  oracle_file = open(sys.argv[4], 'w')

  stats = {}

  for (ref, orig, reord) in zip(ref_file, orig_file, reord_file):

    (ref_uni, ref_bi, ref_tri, ref_four) = ngrams(ref.lower())
    (orig_uni, orig_bi, orig_tri, orig_four) = ngrams(orig.lower())
    (reord_uni, reord_bi, reord_tri, reord_four) = ngrams(reord.lower())

    # choose the translation with higher 4-gram overlap
    orig_four_overlap = overlaps(ref_four, orig_four)
    reord_four_overlap = overlaps(ref_four, reord_four)

    if orig_four_overlap > reord_four_overlap:
      oracle_file.write(orig)
      count(stats, 'orig4')
      continue
    elif orig_four_overlap < reord_four_overlap:
      oracle_file.write(reord)
      count(stats, 'reord4')
      continue

    # if equal, choose the translation with higher trigram overlap
    orig_tri_overlap = overlaps(ref_tri, orig_tri)
    reord_tri_overlap = overlaps(ref_tri, reord_tri)

    if orig_tri_overlap > reord_tri_overlap:
      oracle_file.write(orig)
      count(stats, 'orig3')
      continue
    elif orig_tri_overlap < reord_tri_overlap:
      oracle_file.write(reord)
      count(stats, 'reord3')
      continue

    # if equal, choose the translation with higher bigram overlap
    orig_bi_overlap = overlaps(ref_bi, orig_bi)
    reord_bi_overlap = overlaps(ref_bi, reord_bi)

    if orig_bi_overlap > reord_bi_overlap:
      oracle_file.write(orig)
      count(stats, 'orig2')
      continue
    elif orig_bi_overlap < reord_bi_overlap:
      oracle_file.write(reord)
      count(stats, 'reord2')
      continue

    # if equal, choose the translation with higher unigram overlap
    orig_uni_overlap = overlaps(ref_uni, orig_uni)
    reord_uni_overlap = overlaps(ref_uni, reord_uni)

    if orig_uni_overlap > reord_uni_overlap:
      oracle_file.write(orig)
      count(stats, 'orig1')
      continue
    elif orig_uni_overlap < reord_uni_overlap:
      oracle_file.write(reord)
      count(stats, 'reord1')
      continue

    # both have the same n-gram overlaps - use the output of the original
    oracle_file.write(orig)
    if orig.lower().strip() == reord.lower().strip():
      count(stats, 'identical')
    else:
      count(stats, 'default')

  print "Sentences identical:", stats['identical']
  print "Not identical but same number of n-gram overlaps (original used):", stats['default']
  print "Original preferred based on overlap of"
  print "- 4-grams:", stats['orig4']
  print "- trigrams:", stats['orig3']
  print "- bigrams:", stats['orig2']
  print "- unigrams:", stats['orig1']
  print "( total:", stats['orig4']+stats['orig3']+stats['orig2']+stats['orig1'], ")"
  print "Reordered preferred based on overlap of"
  print "- 4-grams:", stats['reord4']
  print "- trigrams:", stats['reord3']
  print "- bigrams:", stats['reord2']
  print "- unigrams:", stats['reord1']
  print "( total:", stats['reord4']+stats['reord3']+stats['reord2']+stats['reord1'], ")"

  ref_file.close()
  orig_file.close()
  reord_file.close()
  oracle_file.close()

