#!/usr/bin/python

# This script takes German sentences parsed by the Berkeley parser
# with a grammar that includes function labels on each node
# and reorders each sentence according to the rules in
#
# Michael Collins, Philipp Koehn and Ivona Kucerova. Clause restructuring
# for statistical machine translation. In Proceedings of the 43rd Annual
# Meeting of the Association for Computational Linguistics, pages 531-540, 2005.
#
# The script takes 4 command-line arguments:
# sentences-file: file containing the original German sentences that were passed
#     to the parser (tokenized)
# parses-file: the parses of the sentences in sentences-file. Each should
#     have been parsed with "-tree_likelihood -confidence" so for each sentence
#     parses-file contains "tree-likelihood \t confidence \t parse"
# output-file: file created by the script containing the reordered version
#     of each sentence in sentences-file
# features-file [optional]: file created by the script containing feature
#     information for each sentence. If this argument is included, features
#     will be calculated; if not, feature collection code will be skipped.
#     Features will be stored in features-file as a string representation of
#     a Python dictionary.
#
# Features calculated are:
# used_rules: array of 6 integers counting the number of times each
#     of the reordering rules is applied
# countN, countV, countP, countConj: counts of the number of nouns,
#     verbs, prepositions and conjunctions (respectively) in the sentence
# jump_bins: array of counts of how many words a node has been moved by, binned
# ntokens: number of tokens in the original German sentence
# likelihood, confidence: P(w) and P(T|w) (respectively) as calculated by
#     the Berkeley parser


import sys
import math


# Global flag to control whether the feature collection code should be executed
CALCULATE_FEATURES = True


# Count the number of words this node of the parse tree spans
def calculate_span(nodes, i):
  if type(nodes[i][2]) == list:   # internal node
    s = 0
    for n in nodes[i][2]:
      s += calculate_span(nodes, n)
    nodes[i].append(s)
    return s
  else:   # leaf node
    nodes[i].append(1)
    return 1


# Take parse as a string and turn it into a list of nodes
# Format of each internal node is:
# [ parent, node label, [ child1, child2, ... ], span ]
# and of each leaf is:
# [ parent, node label, token, span ]
# where 'parent' and 'child*' are indices in the list of nodes
# and 'span' is added by the call to the function 'calculate_span'
def readin_parse(line):
  nodes = []
  current = None
  parent = None

  STATES = ['NODE', 'LABEL', 'LB_TOK', 'TOKEN', 'NODESEQ', 'FINISH']
  state = STATES.index('NODE')

  for c in line:
    if state == STATES.index('NODE'):
      if c != '(':
        sys.exit("Expected left parenthesis, encountered '" + c + "'")
      parent = current
      current = len(nodes)
      nodes.append([parent])
      if parent != None:
        nodes[parent][2].append(current)
      label = ''
      state = STATES.index('LABEL')

    elif state == STATES.index('LABEL'):
      if c == ' ':
        nodes[current].append(label)
        nodes[current].append([])
        state = STATES.index('LB_TOK')
      elif c == '(':
        nodes[current].append(label)
        nodes[current].append([])
        parent = current
        current = len(nodes)
        nodes.append([parent])
        nodes[parent][2].append(current)
        label = ''
        state = STATES.index('LABEL')
      elif c == ')':
        nodes[current].append(label)
        nodes[current].append('')
        current = nodes[current][0]
        state = STATES.index('NODESEQ')
      else:
        label += c

    elif state == STATES.index('LB_TOK'):
      if c == '(':
        parent = current
        current = len(nodes)
        nodes.append([parent])
        nodes[parent][2].append(current)
        label = ''
        state = STATES.index('LABEL')
      else:
        token = c
        state = STATES.index('TOKEN')
      
    elif state == STATES.index('TOKEN'):
      if c == ')':
        nodes[current][2] = token
        current = nodes[current][0]
        state = STATES.index('NODESEQ')
      else:
        token += c

    elif state == STATES.index('NODESEQ'):
      if c == ' ':
        pass
      elif c == '(':
        parent = current
        current = len(nodes)
        nodes.append([parent])
        nodes[parent][2].append(current)
        label = ''
        state = STATES.index('LABEL')
      elif c == ')':
        current = nodes[current][0]
      else:
        state = STATES.index('FINISH')
      
    elif state == STATES.index('FINISH'):
      sys.exit('Unexpected characters at end of string')

    else:
      sys.exit('Unknown state')

  calculate_span(nodes, 0)
  return nodes


# Each node label consists of two parts, separated by *
# label (e.g. NP) and function (e.g. SB (=subject)).
# These two methods extract each part of the node label.
# (If there is no *, assume the node label contains just a label, no function.)
def label(node):
  fulllabel = node[1]
  if '*' in fulllabel:
    return node[1].split('*')[0].upper()
  else:
    return fulllabel.upper()

def function(node):
  fulllabel = node[1]
  if '*' in fulllabel:
    return node[1].split('*')[1].upper()
  else:
    return ''


# Starting at a particular node in the tree, recurse down and for each
# VP node found, remove it from its parent's list of children, and 
# move all of the VP node's children to be children of the VP node's parent.
# The VP node is still in the list of nodes, but not referred to by other nodes.
def remove_VPs(nodes, nodeindex):
  node = nodes[nodeindex]

  if type(node[2]) == str:
    return nodes

  for cindex in node[2]:
    nodes = remove_VPs(nodes, cindex)

  pindex = node[0]
  if label(node) == 'VP' and pindex != None:
    parent = nodes[pindex]
    i = parent[2].index(nodeindex)  # which child of its parent is this VP node
    parent[2][i:i+1] = node[2]  # replace this VP node with its children in parent's child list
    for n in node[2]:
      nodes[n][0] = pindex  # update parent pointer for each child of the VP

  return nodes


# Return a list of the words spanned by nodes[i]
def read_off_leaves(nodes, i, words):
  if type(nodes[i][2]) == str:
    words.append(nodes[i][2])
  else:
    for j in nodes[i][2]:
      words = read_off_leaves(nodes, j, words)
  return words


# FEATURE CODE
# 'jumpsize' contains the number of words by which a node has been moved.
# These jump sizes are binned; this method increments the bin count for a given jump size.
def bin_jump(bins, jumpsize):
  if 1 <= jumpsize <= 2:
    bins[0] += 1
  elif 3 <= jumpsize <= 5:
    bins[1] += 1
  elif 6 <= jumpsize <= 8:
    bins[2] += 1
  elif 9 <= jumpsize <= 12:
    bins[3] += 1
  elif 13 <= jumpsize <= 16:
    bins[4] += 1
  elif 17 <= jumpsize <= 20:
    bins[5] += 1
  elif 21 <= jumpsize <= 25:
    bins[6] += 1
  elif 26 <= jumpsize <= 30:
    bins[7] += 1
  elif 31 <= jumpsize <= 35:
    bins[8] += 1
  elif 36 <= jumpsize:
    bins[9] += 1
  return bins


# Given the parse of a single sentence (as a string), reorder the sentence
# according to the rules of Collins et al. (05).
# Simultaneously acquire information for various features, if the flag
# CALCULATE_FEATURES is set to True. Features calculated by this method are
# used_rules, countN, countV, countP, countConj, jump_bins
def reorder(line):
  nodes = readin_parse(line)

  features = {}

  if CALCULATE_FEATURES:
    used_rules = 6*[0]
    countN = 0
    countV = 0
    countP = 0
    countConj = 0
    jump_bins = 10*[0]

  if len(nodes) <= 2:  # no parse - parse failed. Reordered sentence is empty string
    if CALCULATE_FEATURES:
      # use e^0 for used_rules and jump_bins, e^(-1) for POS counts
      features = { 'used_rules': map(math.exp, used_rules),
                   'countN': math.exp(countN-1),
                   'countV': math.exp(countV-1),
                   'countP': math.exp(countP-1),
                   'countConj': math.exp(countConj-1),
                   'jump_bins': map(math.exp, jump_bins)
                 }

    return ('', features)

  if CALCULATE_FEATURES:
    # accumulate POS counts
    for node in nodes:
      l = label(node)
      if l in ['NN','NE','PDS','PIS','PPER','PPOSS','PRF']:
        countN += 1
      elif l in ['VVFIN','VVIMP','VVINF','VVIZU','VVPP','VAFIN','VAIMP','VAINF','VAPP','VMFIN','VMINF','VMPP']:
        countV += 1
      elif l in ['APPR','APPRART','APPO','APZR']:
        countP += 1
      elif l in ['KOUI','KOUS','KON','KOKOM','PRELS','PRELAT','PWS','PWAV']:
        countConj += 1


  # BEGIN REORDERING

  # [1] Verb initial
  for node in nodes:
    if label(node) == 'VP':
      if type(node[2]) != list:  # for some reason we have a VP leaf - skip it
        continue
      i = 0
      while i < len(node[2]):
        cindex = node[2][i]  # child of the VP node, expressed as an index in 'nodes'
        if function(nodes[cindex]) == 'HD':

          if i != 0:  # it's not already in the right place

            if CALCULATE_FEATURES:
              used_rules[0] += 1

              jumpsize = 0
              for sibling in node[2][0:i]:
                jumpsize += nodes[sibling][3]
              if nodes[cindex][3] > jumpsize:
                jumpsize = nodes[cindex][3]
              jump_bins = bin_jump(jump_bins, jumpsize)

            del node[2][i]
            node[2].insert(0,cindex)

          break  # we only want the first head (not that there should be any more)
        i += 1


  # [2] Verb 2nd
  for node in nodes:
    if label(node) == 'S':
      if type(node[2]) != list:  # for some reason we have an S leaf - skip it
        continue

      # traverse list of children once only, finding first complementizer and first head
      comp = None
      compi = -1
      head = None
      headi = -1
      for (i,cindex) in enumerate(node[2]):
        child = nodes[cindex]
        if comp == None and label(child) in ['KOUS', 'PRELS', 'PRELAT', 'PWS', 'PWAV']:
          comp = cindex  # index in 'nodes'
          compi = i  # which child of 'node' is it
        if head == None and function(child) == 'HD':
          head = cindex  # index in 'nodes'
          headi = i  # which child of 'node' is it

      if comp == None or head == None:  # this S doesn't have both a complementizer and a head - skip it
        continue

      if headi != compi+1:  # it's not already in the right place

        if CALCULATE_FEATURES:
          used_rules[1] += 1

          jumpsize = 0
          if compi < headi:
            for sibling in node[2][compi+1:headi]:
              jumpsize += nodes[sibling][3]
          else:
            for sibling in node[2][headi+1:compi]:
              jumpsize += nodes[sibling][3]
          if nodes[head][3] > jumpsize:
            jumpsize = nodes[head][3]
          jump_bins = bin_jump(jump_bins, jumpsize)

        del node[2][headi]
        if compi > headi: # compensate for deletion
          compi -= 1
        node[2].insert(compi+1,head)


  # [3] Move Subject
  for node in nodes:
    if label(node) == 'S':
      if type(node[2]) != list:  # for some reason we have an S leaf - skip it
        continue

      # traverse list of children once only, finding first subject and first head
      subj = None
      subji = -1
      head = None
      headi = -1
      for (i,cindex) in enumerate(node[2]):
        child = nodes[cindex]
        if subj == None and ( function(child) == 'SB' or
                              ( label(child) == 'PPER' and function(child) == 'EP' ) ):
          subj = cindex  # index in 'nodes'
          subji = i  # which child of 'node' is it
        if head == None and function(child) == 'HD':
          head = cindex  # index in 'nodes'
          headi = i  # which child of 'node' is it

      if subj == None or head == None:  # this S doesn't have both a subject and a head - skip it
        continue

      if subji != headi - 1:  # it's not already in the right place

        if CALCULATE_FEATURES:
          used_rules[2] += 1

          jumpsize = 0
          if headi < subji:
            for sibling in node[2][headi:subji]:
              jumpsize += nodes[sibling][3]
          else:
            for sibling in node[2][subji+1:headi]:
              jumpsize += nodes[sibling][3]
          if nodes[subj][3] > jumpsize:
            jumpsize = nodes[subj][3]
          jump_bins = bin_jump(jump_bins, jumpsize)

        del node[2][subji]
        if headi > subji: # compensate for deletion
          headi -= 1
        node[2].insert(headi,subj)


  # [4] Particles
  for node in nodes:
    if label(node) == 'S':
      if type(node[2]) != list:  # for some reason we have an S leaf - skip it
        continue
      
      # traverse list of children once only, finding first finite verb and first particle
      vfin = None
      vfini = -1
      ptk = None
      ptki = -1
      for (i,cindex) in enumerate(node[2]):
        child = nodes[cindex]
        if vfin == None and label(child) == 'VVFIN':
          vfin = cindex  # index in 'nodes'
          vfini = i  # which child of 'node' is it
        if ptk == None and label(child) == 'PTKVZ':
          ptk = cindex  # index in 'nodes'
          ptki = i  # which child of 'node' is it

      if vfin == None or ptk == None:  # this S doesn't have both a finite verb and a particle - skip it
        continue

      if ptki != vfini - 1:  # it's not already in the right place

        if CALCULATE_FEATURES:
          used_rules[3] += 1

          jumpsize = 0
          if vfini < ptki:
            for sibling in node[2][vfini:ptki]:
              jumpsize += nodes[sibling][3]
          else:
            for sibling in node[2][ptki+1:vfini]:
              jumpsize += nodes[sibling][3]
          if nodes[ptk][3] > jumpsize:
            jumpsize = nodes[ptk][3]
          jump_bins = bin_jump(jump_bins, jumpsize)

        del node[2][ptki]
        if vfini > ptki: # compensate for deletion
          vfini -= 1
        node[2].insert(vfini,ptk)


  # [5] Infinitives
  nodes = remove_VPs(nodes, 0)

  # From this point, there may be some nodes in 'nodes' that are no longer connected
  # to their parents, but these are all VPs

  for node in nodes:
    if label(node) == 'S':
      if type(node[2]) != list:  # for some reason we have an S leaf - skip it
        continue

      # traverse list of children once, finding first finite verb
      vfin = None
      vfini = -1
      for (i,cindex) in enumerate(node[2]):
        child = nodes[cindex]
        if vfin == None and label(child) in ['VVFIN', 'VAFIN', 'VMFIN']:
          vfin = cindex  # index in 'nodes'
          vfini = i  # which child of 'node' is it

      if vfin == None:  # this S doesn't have a finite verb - skip it
        continue

      movetarget = vfini + 1  # where infinitive will move to

      seenarg = False  # whether we've seen an argument to jump over

      # traverse list of children again, starting from after the first finite verb,
      # to find infinitives to move
      j = vfini + 1
      while j < len(node[2]):
        cindex = node[2][j]
        child = nodes[cindex]
        if label(child) in ['VVINF', 'VVIZU', 'VAINF', 'VMINF', 'VZ']:
          if seenarg:
            # move infinitive to position indicated by movetarget
            if CALCULATE_FEATURES:
              used_rules[4] += 1

              jumpsize = 0
              for sibling in node[2][movetarget:j]:
                jumpsize += nodes[sibling][3]
              jump_bins = bin_jump(jump_bins, jumpsize)

            del node[2][j]
            node[2].insert(movetarget, cindex)

            movetarget += 1  # any more moved infinitives should be placed after this one

          else:
            # this infinitive didn't have to move - any later infinitives should only
            # be moved back to here at most
            movetarget = j + 1
            seenarg = False

        elif function(child) in ['DA', 'OA', 'OA2', 'OG', 'PD', 'SB', 'SBP', 'SP']:
          seenarg = True

        j += 1


  # [6] Negation
  for node in nodes:
    if label(node) == 'S':
      if type(node[2]) != list:  # for some reason we have an S leaf - skip it
        continue

      # traverse list of children once only, finding first finite verb, an infinitival and first negative
      vfin = None
      vfini = -1
      vinf = None
      neg = None
      negi = -1
      for (i,cindex) in enumerate(node[2]):
        child = nodes[cindex]
        if vfin == None and label(child) in ['VVFIN', 'VAFIN', 'VMFIN']:
          vfin = cindex  # index in 'nodes'
          vfini = i  # which child of 'node' is it
        if vinf == None and label(child) in ['VVINF', 'VVIZU', 'VAINF', 'VMINF']:
          vinf = cindex  # index in 'nodes'
        if neg == None and label(child) == 'PTKNEG':
          neg = cindex  # index in 'nodes'
          negi = i  # which child of 'node' is it

      if vfin == None or vinf == None or neg == None:  # this S doesn't have all three - skip it
        continue

      if negi != vfini + 1:  # it's not already in the right place

        if CALCULATE_FEATURES:
          used_rules[5] += 1

          jumpsize = 0
          if vfini < negi:
            for sibling in node[2][vfini+1:negi]:
              jumpsize += nodes[sibling][3]
          else:
            for sibling in node[2][negi+1:vfini+1]:
              jumpsize += nodes[sibling][3]
          if nodes[neg][3] > jumpsize:
            jumpsize = nodes[neg][3]
          jump_bins = bin_jump(jump_bins, jumpsize)

        del node[2][negi]
        if vfini > negi: # compensate for deletion
          vfini -= 1
        node[2].insert(vfini+1,neg)


  # REORDERING COMPLETE

  words = read_off_leaves(nodes, 0, [])
  reordered = ' '.join(words)

  if CALCULATE_FEATURES:
    # use e^x instead of x for every feature value (since Moses applies a log-transform)
    features = { 'used_rules': map(math.exp, used_rules),
                 'countN': math.exp(countN),
                 'countV': math.exp(countV),
                 'countP': math.exp(countP),
                 'countConj': math.exp(countConj),
                 'jump_bins': map(math.exp, jump_bins)
               }

  return (reordered, features)


# Entry point to script
if __name__ == "__main__":
  if len(sys.argv) not in [4,5]:
    sys.exit("Usage is " + sys.argv[0] + " sentences-file parses-file output-file (features-file)")

  sentfile = open(sys.argv[1])
  parsefile = open(sys.argv[2])
  outfile = open(sys.argv[3], 'w')
  if len(sys.argv) == 5:
    featsfile = open(sys.argv[4], 'w')
    CALCULATE_FEATURES = True
  else:
    CALCULATE_FEATURES = False

  for (sent, parseline) in zip(sentfile, parsefile):
    parseparts = parseline.split('\t')
    if len(parseparts) == 3:
      (likelihood, confidence, parse) = parseparts
    else:
      likelihood = "-Infinity"
      confidence = "-Infinity"
      parse = parseparts[-1]  # len(parseparts) probably only 1, but to be sure we'll throw out all but the last

    likelihood = likelihood.strip()
    confidence = confidence.strip()
    parse = parse.strip()

    (reordered, features) = reorder(parse)

    if reordered == '':
      outfile.write(sent)
    else:
      outfile.write(reordered + '\n')

    if CALCULATE_FEATURES:
      features['ntokens'] = math.exp(len(sent.split()))
      features['likelihood'] = math.exp(float(likelihood)/10)
      features['confidence'] = math.exp(float(confidence)/10)

      featsfile.write(str(features) + '\n')


  sentfile.close()
  parsefile.close()
  outfile.close()

  if CALCULATE_FEATURES:
    featsfile.close()

