import ntpath
import subprocess
from collections import OrderedDict

import tensorflow as tf
import numpy as np
import util
import os
import re
from subprocess import check_output, CalledProcessError


transformation_list =[
  'fix_labels',
  'move_arg',
  'merge_spans',
  'split_spans',
  'fix_boundary',
  # 'drop_overlapping_arg',
  'drop_arg',
  'add_arg'
]

# todo simplify to convert_bio
def convert_bilou(bio_predicted_roles):
  '''

  :param bio_predicted_roles: sequence of BIO-encoded predicted role labels
  :return: sequence of conll-formatted predicted role labels
  '''

  converted = []
  started_types = []
  for i, s in enumerate(bio_predicted_roles):
    s = s if isinstance(s, str) else s.decode('utf-8')
    label_parts = s.split('/')
    # print("debug label_parts", label_parts)
    curr_len = len(label_parts)
    combined_str = ''
    Itypes = []
    Btypes = []
    for idx, label in enumerate(label_parts):
      bilou = label[0]
      label_type = label[2:]
      props_str = ''
      if bilou == 'I':
        Itypes.append(label_type)
        props_str = ''
      # elif bilou == 'I' and label_type not in started_types:
      #   #if there's a I going in the middle without B -> treat it as B
      #   props_str = '(' + label_type
      #   started_types.append(label_type)
      #   Btypes.append(label_type)
      #   # Itypes.append(label_type)
        # props_str = ''
      elif bilou == 'O':
        curr_len = 0
        props_str = ''
      elif bilou == 'U':
        # need to check whether last one was ended
        props_str = '(' + label_type + ('*)' if idx == len(label_parts) - 1 else "")
      elif bilou == 'B':
        # need to check whether last one was ended
        props_str = '(' + label_type
        started_types.append(label_type)
        Btypes.append(label_type)
      elif bilou == 'L':
        # print("<debug>: ", started_types, curr_len)
        props_str = ')'
        started_types.pop()
        curr_len -= 1
      combined_str += props_str
    while len(started_types) > curr_len:
      # print("<debug>: ", started_types, curr_len, Btypes, Itypes)
      converted[-1] += ')'
      started_types.pop()
    while len(started_types) < len(Itypes) + len(Btypes):
      combined_str = '(' + Itypes[-1] + combined_str
      started_types.append(Itypes[-1])
      Itypes.pop(0)
    if not combined_str:
      combined_str = '*'
    elif combined_str[0] == "(" and combined_str[-1] != ")":
      combined_str += '*'
    elif combined_str[-1] == ")" and combined_str[0] != "(":
      combined_str = '*' + combined_str
    # print("debug combined_str", combined_str)
    converted.append(combined_str)
  while len(started_types) > 0:
    converted[-1] += ')'
    started_types.pop()
  return converted
def convert_bilou_true(bio_predicted_roles):
  '''

  :param bio_predicted_roles: sequence of BIO-encoded predicted role labels
  :return: sequence of conll-formatted predicted role labels
  '''

  converted = []
  started_types = []
  for i, s in enumerate(bio_predicted_roles):
    s = s if isinstance(s, str) else s.decode('utf-8')
    label_parts = s.split('/')
    curr_len = len(label_parts)
    combined_str = ''
    Itypes = []
    Btypes = []
    for idx, label in enumerate(label_parts):
      bilou = label[0]
      label_type = label[2:]
      props_str = ''
      if bilou == 'I':
        Itypes.append(label_type)
        props_str = ''
      elif bilou == 'O':
        curr_len = 0
        props_str = ''
      elif bilou == 'U':
        # need to check whether last one was ended
        props_str = '(' + label_type + ('*)' if idx == len(label_parts) - 1 else "")
      elif bilou == 'B':
        # need to check whether last one was ended
        props_str = '(' + label_type
        started_types.append(label_type)
        Btypes.append(label_type)
      elif bilou == 'L':
        props_str = ')'
        started_types.pop()
        curr_len -= 1
      combined_str += props_str
    while len(started_types) > curr_len:
      converted[-1] += ')'
      started_types.pop()
    while len(started_types) < len(Itypes) + len(Btypes):
      combined_str = '(' + Itypes[-1] + combined_str
      started_types.append(Itypes[-1])
      Itypes.pop()
    if not combined_str:
      combined_str = '*'
    elif combined_str[0] == "(" and combined_str[-1] != ")":
      combined_str += '*'
    elif combined_str[-1] == ")" and combined_str[0] != "(":
      combined_str = '*' + combined_str
    converted.append(combined_str)
  while len(started_types) > 0:
    converted[-1] += ')'
    started_types.pop()
  return converted


def convert_conll(predicted_roles):
  '''

  :param bio_predicted_roles: sequence of predicted role labels
  :return: sequence of conll-formatted predicted role labels
  '''

  def convert_single(s):
    s = s if isinstance(s, str) else s.decode('utf-8')
    return "*" if s == "_" else "(%s*)" % s

  converted = map(convert_single, predicted_roles)
  return converted


def accuracy_np(predictions, targets, mask, accumulator):

  correct = np.sum(np.multiply(predictions == targets, mask))
  total = np.sum(mask)

  accumulator['correct'] += correct
  accumulator['total'] += total

  accuracy = accumulator['correct'] / accumulator['total']
  return accuracy


def precision_np(predictions, targets, mask, accumulator):
  # print(predictions)
  tpfp = np.multiply(mask, predictions != 0)
  tp = np.sum(np.multiply(predictions == targets, tpfp))
  tpfp = np.sum(tpfp)

  accumulator['tp'] += tp
  accumulator['tpfp'] += tpfp
  # print('tp:{} tpfp:{}'.format(accumulator['tp'], accumulator['tpfp']))

  precision = accumulator['tp'] / accumulator['tpfp']
  return precision

def recall_np(predictions, targets, mask, accumulator):
  tpfn = np.multiply(mask, targets != 0)
  # tpfp = np.multiply(mask, predictions != 0)
  # tp = np.sum(np.multiply(predictions == targets, tpfp))
  # tp = accumulator['tp']#np.sum(np.multiply(predictions == targets and prediction != 0, tpfn))
  tpfn = np.sum(tpfn)

  # accumulator['tp'] += tp
  accumulator['tpfn'] += tpfn
  # print('tp:{} tpfp:{}'.format(accumulator['tp'], accumulator['tpfn']))

  recall = accumulator['tp'] / accumulator['tpfn']
  return recall

def f1_np(predictions, targets, mask, accumulator):
  precision = precision_np(predictions, targets, mask, accumulator)
  recall = recall_np(predictions, targets, mask, accumulator)

  # accumulator['correct'] += correct
  # accumulator['total'] += total
  fscore = 2*precision*recall/(precision+recall)
  # recall = accumulator['correct'] / accumulator['total']
  return fscore

# Write targets file w/ format:
# -        (A1*  (A1*
# -          *     *
# -          *)    *)
# -          *     *
# expected (V*)    *
# -        (C-A1*  *
# widen     *     (V*)
# -         *     (A4*
def write_srl_eval(filename, words, predicates, sent_lens, role_labels):
  with open(filename, 'w') as f:
    role_labels_start_idx = 0
    num_predicates_per_sent = np.sum(predicates, -1)

    words = util.batch_str_decode(words)

    # for each sentence in the batch
    for sent_words, sent_predicates, sent_len, sent_num_predicates in zip(words, predicates, sent_lens,
                                                                          num_predicates_per_sent):
      # grab predicates and convert to conll format from bio
      # this is a sent_num_predicates x batch_seq_len array
      sent_role_labels_bio = role_labels[role_labels_start_idx: role_labels_start_idx + sent_num_predicates]

      # this is a list of sent_num_predicates lists of srl role labels
      sent_role_labels = list(map(list, zip(*[convert_bilou(j[:sent_len]) for j in sent_role_labels_bio])))
      role_labels_start_idx += sent_num_predicates

      # for each token in the sentence
      for j, (word, predicate) in enumerate(zip(sent_words[:sent_len], sent_predicates[:sent_len])):
        tok_role_labels = sent_role_labels[j] if sent_role_labels else []
        predicate_str = word if predicate else '-'
        roles_str = '\t'.join(tok_role_labels)
        print("%s\t%s" % (predicate_str, roles_str), file=f)
      print(file=f)


# Write targets file w/ format:
# 0	The	_	_	DET	DET	_	_	2	2	det	det	_	_	_	_	_	_
# 1	economy	_	_	NOUN	NOUN	_	_	4	4	nmod:poss	nmod:poss	_	_	A1	_	_	_
# 2	's	_	_	PART	PART	_	_	2	2	case	case	_	_	_	_	_	_
# 3	temperature	_	_	NOUN	NOUN	_	_	7	7	nsubjpass	nsubjpass	Y	temperature.01	A2	A1	_	_
def write_srl_eval_09(filename, words, predicates, sent_lens, role_labels, parse_heads, parse_labels, pos_tags, sense):
  with open(filename, 'w') as f:
    role_labels_start_idx = 0

    predicates = util.batch_str_decode(predicates)
    words = util.batch_str_decode(words)
    parse_labels = util.batch_str_decode(parse_labels)
    pos_tags = util.batch_str_decode(pos_tags)
    role_labels = util.batch_str_decode(role_labels)

    # print("predicates:", predicates)
    # todo pretty sure this assumes that 0 == '_'
    num_predicates_per_sent = np.sum(predicates == 'True', -1)

    # for each sentence in the batch
    for sent_words, sent_predicates, sent_len, sent_num_predicates, \
        sent_parse_heads, sent_parse_labels, sent_pos_tags, sent_sense in zip(words, predicates, sent_lens, num_predicates_per_sent,
                                                                  parse_heads, parse_labels, pos_tags, sense):
      # grab predicates and convert to conll format from bio
      # this is a sent_num_predicates x batch_seq_len array
      sent_role_labels = np.transpose(role_labels[role_labels_start_idx: role_labels_start_idx + sent_num_predicates])

      # this is a list of sent_num_predicates lists of srl role labels
      role_labels_start_idx += sent_num_predicates

      # for each token in the sentence
      for j, (word, predicate, parse_head, parse_label, pos_tag, sense) in enumerate(zip(sent_words[:sent_len],
                                                                                  sent_predicates[:sent_len],
                                                                                  sent_parse_heads[:sent_len],
                                                                                  sent_parse_labels[:sent_len],
                                                                                  sent_pos_tags[:sent_len],
                                                                                  sent_sense[:sent_len])):
        tok_role_labels = sent_role_labels[j] if len(sent_role_labels) > 0 else []
        predicate_str = "Y\t%s:%s" % (word, str(sense)) if predicate == 'True' else '_\t_'
        roles_str = '\t'.join(tok_role_labels)
        outputs = (j, word, pos_tag, pos_tag, parse_head, parse_head, parse_label, parse_label, predicate_str, roles_str)
        print("%s\t%s\t_\t_\t%s\t%s\t_\t_\t%s\t%s\t%s\t%s\t%s\t%s" % outputs, file=f)
      print(file=f)
def write_srl_eval_09_a(filename, words, predicates, sent_lens, role_labels, parse_heads, parse_labels, pos_tags,sense, input_source):
  try:
    filename = filename.decode('utf-8')
  except (UnicodeDecodeError, AttributeError):
    pass

  # tf.logging.log(tf.logging.INFO, "log input source  {}".format(input_source))
  tf.logging.log(tf.logging.INFO, "log to {}".format(filename+'.'+ntpath.basename(input_source).split('.')[0]))
  with open(filename+'.'+ntpath.basename(input_source).split('.')[0], 'a') as f:
    role_labels_start_idx = 0

    predicates = util.batch_str_decode(predicates)
    words = util.batch_str_decode(words)
    parse_labels = util.batch_str_decode(parse_labels)
    pos_tags = util.batch_str_decode(pos_tags)
    role_labels = util.batch_str_decode(role_labels)

    # print("predicates:", predicates)
    # todo pretty sure this assumes that 0 == '_'
    num_predicates_per_sent = np.sum(predicates == 'True', -1)

    # for each sentence in the batch
    for sent_words, sent_predicates, sent_len, sent_num_predicates, \
        sent_parse_heads, sent_parse_labels, sent_pos_tags, sent_sense in zip(words, predicates, sent_lens, num_predicates_per_sent,
                                                                  parse_heads, parse_labels, pos_tags, sense):
      # grab predicates and convert to conll format from bio
      # this is a sent_num_predicates x batch_seq_len array
      sent_role_labels = np.transpose(role_labels[role_labels_start_idx: role_labels_start_idx + sent_num_predicates])

      # this is a list of sent_num_predicates lists of srl role labels
      role_labels_start_idx += sent_num_predicates

      # for each token in the sentence
      for j, (word, predicate, parse_head, parse_label, pos_tag, sense) in enumerate(zip(sent_words[:sent_len],
                                                                                  sent_predicates[:sent_len],
                                                                                  sent_parse_heads[:sent_len],
                                                                                  sent_parse_labels[:sent_len],
                                                                                  sent_pos_tags[:sent_len],
                                                                                  sent_sense[:sent_len])):
        tok_role_labels = sent_role_labels[j] if len(sent_role_labels) > 0 else []
        predicate_str = "Y\t%s:%s" % (word, str(sense)) if predicate == 'True' else '_\t_'
        roles_str = '\t'.join(tok_role_labels)
        outputs = (j, word, pos_tag, pos_tag, parse_head, parse_head, parse_label, parse_label, predicate_str, roles_str)
        print("%s\t%s\t_\t_\t%s\t%s\t_\t_\t%s\t%s\t%s\t%s\t%s\t%s" % outputs, file=f)
      print(file=f)

# Write to this format for eval.pl:
# 1       The             _       DT      _       _       2       det
# 2       economy         _       NN      _       _       4       poss
# 3       's              _       POS     _       _       2       possessive
# 4       temperature     _       NN      _       _       7       nsubjpass
# 5       will            _       MD      _       _       7       aux
def write_parse_eval(filename, words, parse_heads, sent_lens, parse_labels, pos_tags):

  words = util.batch_str_decode(words)
  pos_tags = util.batch_str_decode(pos_tags)
  parse_labels = util.batch_str_decode(parse_labels)

  with open(filename, 'w') as f:

    # for each sentence in the batch
    for sent_words, sent_parse_heads, sent_len, sent_parse_labels, sent_pos_tags in zip(words, parse_heads, sent_lens,
                                                                                        parse_labels, pos_tags):
      # for each token in the sentence
      for j, (word, parse_head, parse_label, pos_tag) in enumerate(zip(sent_words[:sent_len],
                                                                       sent_parse_heads[:sent_len],
                                                                       sent_parse_labels[:sent_len],
                                                                       sent_pos_tags[:sent_len])):
        parse_head = 0 if j == parse_head else parse_head + 1
        print("%d\t%s\t_\t%s\t_\t_\t%d\t%s" % (j, word, pos_tag, int(parse_head), parse_label), file=f)
      print(file=f)


def write_srl_debug(filename, words, predicates, sent_lens, role_labels, pos_predictions, pos_targets):
  with open(filename, 'w') as f:
    role_labels_start_idx = 0
    num_predicates_per_sent = np.sum(predicates, -1)
    # for each sentence in the batch
    for sent_words, sent_predicates, sent_len, sent_num_predicates, pos_preds, pos_targs in zip(words, predicates, sent_lens,
                                                                          num_predicates_per_sent, pos_predictions,
                                                                          pos_targets):
      # grab predicates and convert to conll format from bio
      # this is a sent_num_predicates x batch_seq_len array
      sent_role_labels_bio = role_labels[role_labels_start_idx: role_labels_start_idx + sent_num_predicates]

      # this is a list of sent_num_predicates lists of srl role labels
      sent_role_labels = list(map(list, zip(*[convert_bilou(j[:sent_len]) for j in sent_role_labels_bio])))
      role_labels_start_idx += sent_num_predicates

      sent_role_labels_bio = list(zip(*sent_role_labels_bio))

      pos_preds = list(map(lambda d: d.decode('utf-8'), pos_preds))
      pos_targs = list(map(lambda d: d.decode('utf-8'), pos_targs))

      # for each token in the sentence
      # printed = False
      for j, (word, predicate, pos_p, pos_t) in enumerate(zip(sent_words[:sent_len], sent_predicates[:sent_len],
                                                              pos_preds[:sent_len], pos_targs[:sent_len])):
        tok_role_labels = sent_role_labels[j] if sent_role_labels else []
        bio_tok_role_labels = sent_role_labels_bio[j][:sent_len] if sent_role_labels else []
        word_str = word.decode('utf-8')
        predicate_str = str(predicate)
        roles_str = '\t'.join(tok_role_labels)
        bio_roles_str = '\t'.join(map(lambda d: d.decode('utf-8'), bio_tok_role_labels))
        print("%s\t%s\t%s\t%s\t%s\t%s" % (word_str, predicate_str, pos_t, pos_p, roles_str, bio_roles_str), file=f)
      print(file=f)


def conll_srl_eval(srl_predictions, predicate_predictions, words, mask, srl_targets, predicate_targets,
                      pred_srl_eval_file, gold_srl_eval_file, pos_predictions=None, pos_targets=None):

  # predictions: num_predicates_in_batch x batch_seq_len tensor of ints
  # predicate predictions: batch_size x batch_seq_len [ x 1?] tensor of ints (0/1)
  # words: batch_size x batch_seq_len tensor of ints (0/1)

  # need to print for every word in every sentence

  sent_lens = np.sum(mask, -1).astype(np.int32)
  # import time
  # debug_fname = pred_srl_eval_file.decode('utf-8') + str(time.time())
  # write_srl_debug(debug_fname, words, predicate_targets, sent_lens, srl_targets, pos_predictions, pos_targets)

  # write gold labels
  write_srl_eval(gold_srl_eval_file, words, predicate_targets, sent_lens, srl_targets)

  # write predicted labels
  write_srl_eval(pred_srl_eval_file, words, predicate_predictions, sent_lens, srl_predictions)

  # run eval script
  correct, excess, missed = 0, 0, 0
  with open(os.devnull, 'w') as devnull:
    try:
      srl_eval = check_output(["perl", "bin/srl-eval.pl", gold_srl_eval_file, pred_srl_eval_file], stderr=devnull)
      srl_eval = srl_eval.decode('utf-8')
      # print(" debug <srl_eval>: ", srl_eval)
      # print(srl_eval)
      correct, excess, missed = map(int, srl_eval.split('\n')[6].split()[1:4])
      # vc, ve, vm = map(int, srl_eval.split('\n')[-3].split()[1:4])
      # assert ve == 0
      # assert  vm == 0
      # print("correct, excess")
    except CalledProcessError as e:
      tf.logging.log(tf.logging.ERROR, "Call to srl-eval.pl (conll srl eval) failed.")

  # print( "debug <SRL correct {}, excess {}, missed {}>".format(correct, excess, missed))
  return correct, excess, missed



def conll_srl_eval_with_transformation(srl_predictions, predicate_predictions, words, mask, srl_targets, predicate_targets,
                      pred_srl_eval_file, gold_srl_eval_file, pos_predictions=None, pos_targets=None):

  def run_eval_script(pred_srl_eval_file, gold_srl_eval_file):
    correct, excess, missed = 0, 0, 0
    # print("run eval on {} {}".format(pred_srl_eval_file, gold_srl_eval_file))
    with open(os.devnull, 'w') as devnull:
      try:
        srl_eval = check_output(["perl", "bin/srl-eval.pl", gold_srl_eval_file, pred_srl_eval_file], stderr=devnull)
        srl_eval = srl_eval.decode('utf-8')
        # print(" debug <srl_eval>: ", srl_eval)
        # print(srl_eval)
        correct, excess, missed = map(int, srl_eval.split('\n')[6].split()[1:4])
      except CalledProcessError as e:
        tf.logging.log(tf.logging.ERROR, "Call to srl-eval.pl (conll srl eval) failed.")
    return {'correct': correct, 'missed': missed, 'excess': excess}

  # print( "debug <SRL correct {}, excess {}, missed {}>".format(correct, excess, missed))

  # predictions: num_predicates_in_batch x batch_seq_len tensor of ints
  # predicate predictions: batch_size x batch_seq_len [ x 1?] tensor of ints (0/1)
  # words: batch_size x batch_seq_len tensor of ints (0/1)

  # need to print for every word in every sentence

  sent_lens = np.sum(mask, -1).astype(np.int32)
  # import time
  # debug_fname = pred_srl_eval_file.decode('utf-8') + str(time.time())
  # write_srl_debug(debug_fname, words, predicate_targets, sent_lens, srl_targets, pos_predictions, pos_targets)

  # write gold labels
  write_srl_eval(gold_srl_eval_file, words, predicate_targets, sent_lens, srl_targets)

  # print(predicate_targets)
  # print(predicate_predictions)
  # print(predicate_predictions == predicate_targets)

  # write predicted labels
  write_srl_eval(pred_srl_eval_file, words, predicate_predictions, sent_lens, srl_predictions)

  subprocess.run(["python", "bin/make_srl_transformation.py", pred_srl_eval_file, gold_srl_eval_file])

  transformation_count_map = {'original': run_eval_script(pred_srl_eval_file, gold_srl_eval_file)}

  for t_name in transformation_list:
    transformation_count_map[t_name] = run_eval_script("{}.t.{}".format(pred_srl_eval_file, t_name), "{}.t.{}".format(gold_srl_eval_file, t_name))
    # print("{}".format(t_name), transformation_count_map[t_name])
  return transformation_count_map


def conll09_srl_eval(srl_predictions, predicate_predictions, words, mask, srl_targets, predicate_targets,
                     parse_label_predictions, parse_head_predictions, parse_label_targets, parse_head_targets,
                     pos_targets, pos_predictions, pred_srl_eval_file, gold_srl_eval_file, pred_sense, gold_sense):

  # predictions: num_predicates_in_batch x batch_seq_len tensor of ints
  # predicate predictions: batch_size x batch_seq_len [ x 1?] tensor of ints (0/1)
  # words: batch_size x batch_seq_len tensor of ints (0/1)

  # print("pred_target", predicate_targets)
  # print("pred_pred", predicate_predictions)
  # need to print for every word in every sentence
  sent_lens = np.sum(mask, -1).astype(np.int32)

  # import time
  # debug_fname = pred_srl_eval_file.decode('utf-8') + str(time.time())
  # write_srl_debug(debug_fname, words, predicate_targets, sent_lens, srl_targets, pos_predictions, pos_targets)

  # write gold labels
  write_srl_eval_09(gold_srl_eval_file, words, predicate_targets, sent_lens, srl_targets, parse_head_targets,
                    parse_label_targets, pos_targets, gold_sense)
  # if not input_source == "INVALID":
  #  write_srl_eval_09_a(gold_srl_eval_file, words, predicate_targets, sent_lens, srl_targets, parse_head_targets,
  #                   parse_label_targets, pos_targets, gold_sense, input_source)

  # write predicted labels
  write_srl_eval_09(pred_srl_eval_file, words, predicate_predictions, sent_lens, srl_predictions,
                    parse_head_predictions, parse_label_predictions, pos_predictions, pred_sense)
  # if not input_source == "INVALID":
  #   write_srl_eval_09_a(pred_srl_eval_file, words, predicate_predictions, sent_lens, srl_predictions,
  #                     parse_head_predictions, parse_label_predictions, pos_predictions, pred_sense,input_source)


  # run eval script
  labeled_correct, labeled_excess, labeled_missed, prop_correct, prop_excess, prop_missed = 0, 0, 0, 0, 0, 0
  attempts = 0
  with open(os.devnull, 'w') as devnull:
    # Adding retry attempts
    while attempts<5:
      try:
        srl_eval = check_output(["perl", "bin/eval09.pl", "-g", gold_srl_eval_file, "-s", pred_srl_eval_file],
                                stderr=devnull)
        # print("succeed getting output!")
        srl_eval = srl_eval.decode('utf-8')
        # Looks like this:
        #   SYNTACTIC SCORES:
        #   Labeled   attachment score: 2793 / 3125 * 100 = 89.38 %
        #   Unlabeled attachment score: 2894 / 3125 * 100 = 92.61 %
        #   Label accuracy score:       2921 / 3125 * 100 = 93.47 %
        #   Exact syntactic match:      115 / 256 * 100 = 44.92 %
        #
        #   SEMANTIC SCORES:
        #   Labeled precision:          (20 + 486) / (200 + 607) * 100 = 62.70 %
        #   Labeled recall:             (20 + 486) / (220 + 621) * 100 = 60.17 %
        #   Labeled F1:                 61.41
        #   Unlabeled precision:        (51 + 544) / (200 + 607) * 100 = 73.73 %
        #   Unlabeled recall:           (51 + 544) / (220 + 621) * 100 = 70.75 %
        #   Unlabeled F1:               72.21
        #   Proposition precision:      429 / 607 * 100 = 70.68 %
        #   Proposition recall:         429 / 621 * 100 = 69.08 %
        #   Proposition F1:             69.87
        #   Exact semantic match:       110 / 256 * 100 = 42.97 %
        #
        #   OVERALL MACRO SCORES (Wsem = 0.50):
        #   Labeled macro precision:    76.04 %
        #   Labeled macro recall:       74.77 %
        #   Labeled macro F1:           75.40 %
        #   Unlabeled macro precision:  83.17 %
        #   Unlabeled macro recall:     81.68 %
        #   Unlabeled macro F1:         82.42 %
        #   Exact overall match:        56 / 256 * 100 = 21.88 %
        #
        #   OVERALL MICRO SCORES:
        #   Labeled micro precision:    (2793 + 20 + 486) / (3125 + 200 + 607) * 100 = 83.90 %
        #   Labeled micro recall:       (2793 + 20 + 486) / (3125 + 220 + 621) * 100 = 83.18 %
        #   Labeled micro F1:           83.54
        #   Unlabeled micro precision:  (2894 + 51 + 544) / (3125 + 200 + 607) * 100 = 88.73 %
        #   Unlabeled micro recall:     (2894 + 51 + 544) / (3125 + 220 + 621) * 100 = 87.97 %
        #   Unlabeled micro F1:         88.35
        eval_lines = srl_eval.split('\n')
        # print(eval_lines[7])
        labeled_precision_ints = list(map(int, re.sub('[^0-9 ]', '', eval_lines[7]).split()))
        # print(eval_lines[8])
        labeled_recall_ints = list(map(int, re.sub('[^0-9 ]', '', eval_lines[8]).split()))
        prop_precision_ints = list(map(int, re.sub('[^0-9 ]', '', eval_lines[13]).split()))
        prop_recall_ints = list(map(int, re.sub('[^0-9 ]', '', eval_lines[14]).split()))

        labeled_correct = labeled_precision_ints[0] + labeled_precision_ints[1]
        labeled_excess = labeled_precision_ints[2] + labeled_precision_ints[3] - labeled_correct
        labeled_missed = labeled_recall_ints[2] + labeled_recall_ints[3] - labeled_correct

        prop_correct = prop_precision_ints[0]
        prop_excess = prop_precision_ints[1] - prop_correct
        prop_missed = prop_recall_ints[1] - prop_correct
        attempts = 10

      except CalledProcessError as e:
        tf.logging.log(tf.logging.ERROR, "Call to eval09.pl (conll09 srl eval) failed. retry {}".format(attempts))
        attempts+=1
    if attempts<10:
      tf.logging.log(tf.logging.ERROR, "Call to eval09.pl (conll09 srl eval) failed.")

  return labeled_correct, labeled_excess, labeled_missed

def conll09_srl_eval_srl_only(srl_predictions, predicate_predictions, words, mask, srl_targets, predicate_targets,
                     parse_label_predictions, parse_head_predictions, parse_label_targets, parse_head_targets,
                     pos_targets, pos_predictions, pred_srl_eval_file, gold_srl_eval_file, pred_sense, gold_sense,  input_source="INVALID"):

  # predictions: num_predicates_in_batch x batch_seq_len tensor of ints
  # predicate predictions: batch_size x batch_seq_len [ x 1?] tensor of ints (0/1)
  # words: batch_size x batch_seq_len tensor of ints (0/1)

  # print("pred_target", predicate_targets)
  # print("pred_pred", predicate_predictions)
  # need to print for every word in every sentence
  sent_lens = np.sum(mask, -1).astype(np.int32)

  # import time
  # debug_fname = pred_srl_eval_file.decode('utf-8') + str(time.time())
  # write_srl_debug(debug_fname, words, predicate_targets, sent_lens, srl_targets, pos_predictions, pos_targets)

  # write gold labels
  write_srl_eval_09(gold_srl_eval_file, words, predicate_targets, sent_lens, srl_targets, parse_head_targets,
                    parse_label_targets, pos_targets, gold_sense)
  if not input_source == "INVALID":
   write_srl_eval_09_a(gold_srl_eval_file, words, predicate_targets, sent_lens, srl_targets, parse_head_targets,
                    parse_label_targets, pos_targets, gold_sense, input_source)

  # write predicted labels
  write_srl_eval_09(pred_srl_eval_file, words, predicate_predictions, sent_lens, srl_predictions,
                    parse_head_predictions, parse_label_predictions, pos_predictions, pred_sense)
  if not input_source == "INVALID":
    write_srl_eval_09_a(pred_srl_eval_file, words, predicate_predictions, sent_lens, srl_predictions,
                      parse_head_predictions, parse_label_predictions, pos_predictions, pred_sense, input_source)


  # run eval script
  labeled_correct, labeled_excess, labeled_missed, prop_correct, prop_excess, prop_missed = 0, 0, 0, 0, 0, 0
  attempts = 0
  with open(os.devnull, 'w') as devnull:
    # Adding retry attempts
    while attempts<5:
      try:
        srl_eval = check_output(["perl", "bin/eval09.pl", "-g", gold_srl_eval_file, "-s", pred_srl_eval_file],
                                stderr=devnull)
        # print("succeed getting output!")
        srl_eval = srl_eval.decode('utf-8')
        # Looks like this:
        #   SYNTACTIC SCORES:
        #   Labeled   attachment score: 2793 / 3125 * 100 = 89.38 %
        #   Unlabeled attachment score: 2894 / 3125 * 100 = 92.61 %
        #   Label accuracy score:       2921 / 3125 * 100 = 93.47 %
        #   Exact syntactic match:      115 / 256 * 100 = 44.92 %
        #
        #   SEMANTIC SCORES:
        #   Labeled precision:          (20 + 486) / (200 + 607) * 100 = 62.70 %
        #   Labeled recall:             (20 + 486) / (220 + 621) * 100 = 60.17 %
        #   Labeled F1:                 61.41
        #   Unlabeled precision:        (51 + 544) / (200 + 607) * 100 = 73.73 %
        #   Unlabeled recall:           (51 + 544) / (220 + 621) * 100 = 70.75 %
        #   Unlabeled F1:               72.21
        #   Proposition precision:      429 / 607 * 100 = 70.68 %
        #   Proposition recall:         429 / 621 * 100 = 69.08 %
        #   Proposition F1:             69.87
        #   Exact semantic match:       110 / 256 * 100 = 42.97 %
        #
        #   OVERALL MACRO SCORES (Wsem = 0.50):
        #   Labeled macro precision:    76.04 %
        #   Labeled macro recall:       74.77 %
        #   Labeled macro F1:           75.40 %
        #   Unlabeled macro precision:  83.17 %
        #   Unlabeled macro recall:     81.68 %
        #   Unlabeled macro F1:         82.42 %
        #   Exact overall match:        56 / 256 * 100 = 21.88 %
        #
        #   OVERALL MICRO SCORES:
        #   Labeled micro precision:    (2793 + 20 + 486) / (3125 + 200 + 607) * 100 = 83.90 %
        #   Labeled micro recall:       (2793 + 20 + 486) / (3125 + 220 + 621) * 100 = 83.18 %
        #   Labeled micro F1:           83.54
        #   Unlabeled micro precision:  (2894 + 51 + 544) / (3125 + 200 + 607) * 100 = 88.73 %
        #   Unlabeled micro recall:     (2894 + 51 + 544) / (3125 + 220 + 621) * 100 = 87.97 %
        #   Unlabeled micro F1:         88.35
        eval_lines = srl_eval.split('\n')
        # print(eval_lines[7])
        labeled_precision_ints = list(map(int, re.sub('[^0-9 ]', '', eval_lines[7]).split()))
        # print(eval_lines[8])
        labeled_recall_ints = list(map(int, re.sub('[^0-9 ]', '', eval_lines[8]).split()))
        prop_precision_ints = list(map(int, re.sub('[^0-9 ]', '', eval_lines[13]).split()))
        prop_recall_ints = list(map(int, re.sub('[^0-9 ]', '', eval_lines[14]).split()))

        labeled_correct = labeled_precision_ints[0]
        labeled_excess = labeled_precision_ints[2] - labeled_correct
        labeled_missed = labeled_recall_ints[2] - labeled_correct

        prop_correct = prop_precision_ints[0]
        prop_excess = prop_precision_ints[1] - prop_correct
        prop_missed = prop_recall_ints[1] - prop_correct
        attempts = 10

      except CalledProcessError as e:
        tf.logging.log(tf.logging.ERROR, "Call to eval09.pl (conll09 srl eval) failed. retry {}".format(attempts))
        attempts+=1
    if attempts<10:
      tf.logging.log(tf.logging.ERROR, "Call to eval09.pl (conll09 srl eval) failed.")

  return labeled_correct, labeled_excess, labeled_missed

def conll_parse_eval(parse_label_predictions, parse_head_predictions, words, mask, parse_label_targets,
                        parse_head_targets, pred_eval_file, gold_eval_file, pos_targets):

  # need to print for every word in every sentence
  sent_lens = np.sum(mask, -1).astype(np.int32)

  # write gold labels
  write_parse_eval(gold_eval_file, words, parse_head_targets, sent_lens, parse_label_targets, pos_targets)

  # write predicted labels
  write_parse_eval(pred_eval_file, words, parse_head_predictions, sent_lens, parse_label_predictions, pos_targets)

  # run eval script
  total, labeled_correct, unlabeled_correct, label_correct = 0, 0, 0, 0
  with open(os.devnull, 'w') as devnull:
    try:
      eval = check_output(["perl", "bin/eval.pl", "-g", gold_eval_file, "-s", pred_eval_file], stderr=devnull)
      eval_str = eval.decode('utf-8')

      # Labeled attachment score: 26444 / 29058 * 100 = 91.00 %
      # Unlabeled attachment score: 27251 / 29058 * 100 = 93.78 %
      # Label accuracy score: 27395 / 29058 * 100 = 94.28 %
      first_three_lines = eval_str.split('\n')[:3]
      total = int(first_three_lines[0].split()[5])
      labeled_correct, unlabeled_correct, label_correct = map(lambda l: int(l.split()[3]), first_three_lines)
    except CalledProcessError as e:
      tf.logging.log(tf.logging.ERROR, "Call to eval.pl (conll parse eval) failed.")
      print(e)

  return total, np.array([labeled_correct, unlabeled_correct, label_correct])


def conll_srl_eval_np(predictions, targets, predicate_predictions, words, mask, predicate_targets, reverse_maps,
                   gold_srl_eval_file, pred_srl_eval_file, pos_predictions, pos_targets, accumulator):

  # first, use reverse maps to convert ints to strings
  str_srl_predictions = [list(map(reverse_maps['srl'].get, s)) for s in predictions]
  str_words = [list(map(reverse_maps['word'].get, s)) for s in words]
  str_srl_targets = [list(map(reverse_maps['srl'].get, s)) for s in targets]

  correct, excess, missed = conll_srl_eval(str_srl_predictions, predicate_predictions, str_words, mask, str_srl_targets,
                                           predicate_targets, pred_srl_eval_file, gold_srl_eval_file)

  accumulator['correct'] += correct
  accumulator['excess'] += excess
  accumulator['missed'] += missed

  # print("<debug srl c {}, excess {}, missed{}>".format(accumulator['correct'], accumulator['excess'], accumulator['missed']))

  precision = accumulator['correct'] / (accumulator['correct'] + accumulator['excess'])
  recall = accumulator['correct'] / (accumulator['correct'] + accumulator['missed'])
  # print("debug <correct: {}|precision: {}|recall: {}>".format(correct, precision, recall))
  f1 = 2 * precision * recall / (precision + recall)

  return f1

def conll_srl_all_eval_np(predictions, targets, predicate_predictions, words, mask, predicate_targets, reverse_maps,
                   gold_srl_eval_file, pred_srl_eval_file, pos_predictions, pos_targets, accumulator):

  # first, use reverse maps to convert ints to strings
  str_srl_predictions = [list(map(reverse_maps['srl'].get, s)) for s in predictions]
  str_words = [list(map(reverse_maps['word'].get, s)) for s in words]
  str_srl_targets = [list(map(reverse_maps['srl'].get, s)) for s in targets]

  correct, excess, missed = conll_srl_eval(str_srl_predictions, predicate_predictions, str_words, mask, str_srl_targets,
                                           predicate_targets, pred_srl_eval_file, gold_srl_eval_file)

  accumulator['correct'] += correct
  accumulator['excess'] += excess
  accumulator['missed'] += missed

  # print("<debug srl c {}, excess {}, missed{}>".format(accumulator['correct'], accumulator['excess'], accumulator['missed']))

  precision = accumulator['correct'] / (accumulator['correct'] + accumulator['excess'])
  recall = accumulator['correct'] / (accumulator['correct'] + accumulator['missed'])
  # print("debug <correct: {}|precision: {}|recall: {}>".format(correct, precision, recall))
  f1 = 2 * precision * recall / (precision + recall)

  return [precision, recall, f1]



def conll_srl_eval_with_transformation_np(predictions, targets, predicate_predictions, words, mask, predicate_targets, reverse_maps,
                   gold_srl_eval_file, pred_srl_eval_file, pos_predictions, pos_targets, accumulator):
  # print(accumulator)
  def compute_f1(correct, excess, missed):
    # print("<correct: {}, excess: {}, missed: {}>".format(correct, excess, missed))
    precision = correct / (correct + excess)
    recall = correct / (correct + missed)
    # print("debug <correct: {}|precision: {}|recall: {}>".format(correct, precision, recall))
    f1 = 2 * precision * recall / (precision + recall)
    return f1
  # first, use reverse maps to convert ints to strings
  str_srl_predictions = [list(map(reverse_maps['srl'].get, s)) for s in predictions]
  str_words = [list(map(reverse_maps['word'].get, s)) for s in words]
  str_srl_targets = [list(map(reverse_maps['srl'].get, s)) for s in targets]

  transformation_count_map = conll_srl_eval_with_transformation(str_srl_predictions, predicate_predictions, str_words, mask, str_srl_targets,
                                           predicate_targets, pred_srl_eval_file, gold_srl_eval_file)


  for item_name, item_value in accumulator.items():
    item_to_accumulate = transformation_count_map[item_name]
    # print("<{} item to accumulate>".format(item_name), item_to_accumulate)
    for item_value_name in item_value.keys():
      item_value[item_value_name] += item_to_accumulate[item_value_name]
  # print(accumulator)
  # a = {item_name: print(**item_value) for item_name, item_value in accumulator}
  f1_map = {item_name: compute_f1(**item_value) for item_name, item_value in accumulator.items()}
  return f1_map


def conll09_srl_eval_np(predictions, targets, predicate_predictions, words, mask, predicate_targets, reverse_maps,
                        gold_srl_eval_file, pred_srl_eval_file, pos_predictions, pos_targets, parse_head_predictions,
                        parse_head_targets, parse_label_predictions, parse_label_targets, accumulator, pred_sense,gold_sense):

  # first, use reverse maps to convert ints to strings
  str_srl_predictions = [list(map(reverse_maps['srl'].get, s)) for s in predictions]
  str_words = [list(map(reverse_maps['word'].get, s)) for s in words]
  str_srl_targets = [list(map(reverse_maps['srl'].get, s)) for s in targets]
  str_pos_targets = [list(map(reverse_maps['gold_pos'].get, s)) for s in pos_targets]
  str_pos_predictions = [list(map(reverse_maps['gold_pos'].get, s)) for s in pos_predictions]
  str_parse_label_targets = [list(map(reverse_maps['parse_label'].get, s)) for s in parse_label_targets]
  str_parse_label_predictions = [list(map(reverse_maps['parse_label'].get, s)) for s in parse_label_predictions]
  str_predicate_predictions = [list(map(reverse_maps['predicate'].get, s)) for s in predicate_predictions]
  str_predicate_targets = [list(map(reverse_maps['predicate'].get, s)) for s in predicate_targets]

  correct, excess, missed = conll09_srl_eval(str_srl_predictions, str_predicate_predictions, str_words, mask,
                                             str_srl_targets, str_predicate_targets, str_parse_label_predictions,
                                             parse_head_predictions, str_parse_label_targets, parse_head_targets,
                                             str_pos_targets, str_pos_predictions, pred_srl_eval_file, gold_srl_eval_file, pred_sense, gold_sense)

  accumulator['correct'] += correct
  accumulator['excess'] += excess
  accumulator['missed'] += missed

  precision = accumulator['correct'] / (accumulator['correct'] + accumulator['excess'])
  recall = accumulator['correct'] / (accumulator['correct'] + accumulator['missed'])
  f1 = 2 * precision * recall / (precision + recall)

  return f1

def conll09_srl_eval_srl_only_np(predictions, targets, predicate_predictions, words, mask, predicate_targets, reverse_maps,
                        gold_srl_eval_file, pred_srl_eval_file, pos_predictions, pos_targets, parse_head_predictions,
                        parse_head_targets, parse_label_predictions, parse_label_targets, accumulator, pred_sense,gold_sense, input_source="INVALID"):

  # first, use reverse maps to convert ints to strings
  str_srl_predictions = [list(map(reverse_maps['srl'].get, s)) for s in predictions]
  str_words = [list(map(reverse_maps['word'].get, s)) for s in words]
  str_srl_targets = [list(map(reverse_maps['srl'].get, s)) for s in targets]
  str_pos_targets = [list(map(reverse_maps['gold_pos'].get, s)) for s in pos_targets]
  str_pos_predictions = [list(map(reverse_maps['gold_pos'].get, s)) for s in pos_predictions]
  str_parse_label_targets = [list(map(reverse_maps['parse_label'].get, s)) for s in parse_label_targets]
  str_parse_label_predictions = [list(map(reverse_maps['parse_label'].get, s)) for s in parse_label_predictions]
  str_predicate_predictions = [list(map(reverse_maps['predicate'].get, s)) for s in predicate_predictions]
  str_predicate_targets = [list(map(reverse_maps['predicate'].get, s)) for s in predicate_targets]

  correct, excess, missed = conll09_srl_eval_srl_only(str_srl_predictions, str_predicate_predictions, str_words, mask,
                                             str_srl_targets, str_predicate_targets, str_parse_label_predictions,
                                             parse_head_predictions, str_parse_label_targets, parse_head_targets,
                                             str_pos_targets, str_pos_predictions, pred_srl_eval_file, gold_srl_eval_file, pred_sense, gold_sense,input_source)

  accumulator['correct'] += correct
  accumulator['excess'] += excess
  accumulator['missed'] += missed

  precision = accumulator['correct'] / (accumulator['correct'] + accumulator['excess'])
  recall = accumulator['correct'] / (accumulator['correct'] + accumulator['missed'])
  f1 = 2 * precision * recall / (precision + recall)

  return f1


def conll09_srl_eval_all_np(predictions, targets, predicate_predictions, words, mask, predicate_targets, reverse_maps,
                        gold_srl_eval_file, pred_srl_eval_file, pos_predictions, pos_targets, parse_head_predictions,
                        parse_head_targets, parse_label_predictions, parse_label_targets, accumulator, pred_sense,gold_sense):

  # first, use reverse maps to convert ints to strings
  str_srl_predictions = [list(map(reverse_maps['srl'].get, s)) for s in predictions]
  str_words = [list(map(reverse_maps['word'].get, s)) for s in words]
  str_srl_targets = [list(map(reverse_maps['srl'].get, s)) for s in targets]
  str_pos_targets = [list(map(reverse_maps['gold_pos'].get, s)) for s in pos_targets]
  str_pos_predictions = [list(map(reverse_maps['gold_pos'].get, s)) for s in pos_predictions]
  str_parse_label_targets = [list(map(reverse_maps['parse_label'].get, s)) for s in parse_label_targets]
  str_parse_label_predictions = [list(map(reverse_maps['parse_label'].get, s)) for s in parse_label_predictions]
  str_predicate_predictions = [list(map(reverse_maps['predicate'].get, s)) for s in predicate_predictions]
  str_predicate_targets = [list(map(reverse_maps['predicate'].get, s)) for s in predicate_targets]

  correct, excess, missed = conll09_srl_eval(str_srl_predictions, str_predicate_predictions, str_words, mask,
                                             str_srl_targets, str_predicate_targets, str_parse_label_predictions,
                                             parse_head_predictions, str_parse_label_targets, parse_head_targets,
                                             str_pos_targets, str_pos_predictions, pred_srl_eval_file, gold_srl_eval_file, pred_sense, gold_sense)

  accumulator['correct'] += correct
  accumulator['excess'] += excess
  accumulator['missed'] += missed

  precision = accumulator['correct'] / (accumulator['correct'] + accumulator['excess'])
  recall = accumulator['correct'] / (accumulator['correct'] + accumulator['missed'])
  f1 = 2 * precision * recall / (precision + recall)

  return [precision, recall, f1]

def conll09_srl_eval_all_srl_only_np(predictions, targets, predicate_predictions, words, mask, predicate_targets, reverse_maps,
                        gold_srl_eval_file, pred_srl_eval_file, pos_predictions, pos_targets, parse_head_predictions,
                        parse_head_targets, parse_label_predictions, parse_label_targets, accumulator, pred_sense,gold_sense,input_source="INVALID"):

  # first, use reverse maps to convert ints to strings
  str_srl_predictions = [list(map(reverse_maps['srl'].get, s)) for s in predictions]
  str_words = [list(map(reverse_maps['word'].get, s)) for s in words]
  str_srl_targets = [list(map(reverse_maps['srl'].get, s)) for s in targets]
  str_pos_targets = [list(map(reverse_maps['gold_pos'].get, s)) for s in pos_targets]
  str_pos_predictions = [list(map(reverse_maps['gold_pos'].get, s)) for s in pos_predictions]
  str_parse_label_targets = [list(map(reverse_maps['parse_label'].get, s)) for s in parse_label_targets]
  str_parse_label_predictions = [list(map(reverse_maps['parse_label'].get, s)) for s in parse_label_predictions]
  str_predicate_predictions = [list(map(reverse_maps['predicate'].get, s)) for s in predicate_predictions]
  str_predicate_targets = [list(map(reverse_maps['predicate'].get, s)) for s in predicate_targets]

  correct, excess, missed = conll09_srl_eval_srl_only(str_srl_predictions, str_predicate_predictions, str_words, mask,
                                             str_srl_targets, str_predicate_targets, str_parse_label_predictions,
                                             parse_head_predictions, str_parse_label_targets, parse_head_targets,
                                             str_pos_targets, str_pos_predictions, pred_srl_eval_file, gold_srl_eval_file,pred_sense,gold_sense, input_source)

  accumulator['correct'] += correct
  accumulator['excess'] += excess
  accumulator['missed'] += missed

  precision = accumulator['correct'] / (accumulator['correct'] + accumulator['excess'])
  recall = accumulator['correct'] / (accumulator['correct'] + accumulator['missed'])
  f1 = 2 * precision * recall / (precision + recall)

  return [precision, recall, f1]

def conll_parse_eval_np(predictions, targets, parse_head_predictions, words, mask, parse_head_targets, reverse_maps,
                        gold_parse_eval_file, pred_parse_eval_file, pos_targets, accumulator):

  # first, use reverse maps to convert ints to strings
  str_words = [list(map(reverse_maps['word'].get, s)) for s in words]
  str_predictions = [list(map(reverse_maps['parse_label'].get, s)) for s in predictions]
  str_targets = [list(map(reverse_maps['parse_label'].get, s)) for s in targets]
  str_pos_targets = [list(map(reverse_maps['gold_pos'].get, s)) for s in pos_targets]

  total, corrects = conll_parse_eval(str_predictions, parse_head_predictions, str_words, mask, str_targets,
                                     parse_head_targets, pred_parse_eval_file, gold_parse_eval_file, str_pos_targets)

  accumulator['total'] += total
  accumulator['corrects'] += corrects

  accuracies = accumulator['corrects'] / accumulator['total']

  return accuracies


fn_dispatcher = {
  'accuracy': accuracy_np,
  'precision': precision_np,
  'recall': recall_np,
  'fscore': f1_np,
  'conll_srl_eval': conll_srl_eval_np,
  'conll_parse_eval': conll_parse_eval_np,
  'conll09_srl_eval': conll09_srl_eval_np,
  'conll09_srl_eval_srl_only': conll09_srl_eval_srl_only_np,
  'conll09_srl_eval_all': conll09_srl_eval_all_np,
  'conll09_srl_eval_all_srl_only': conll09_srl_eval_all_srl_only_np,
  'conll_srl_all_eval': conll_srl_all_eval_np,
  'conll_srl_eval_with_transformation': conll_srl_eval_with_transformation_np
}


accumulator_factory = {
  'accuracy': lambda: {'correct': 0., 'total': 0.},
  'fscore': lambda: {'tp': 0., 'tpfn': 0., 'tpfp': 0.},
  'conll_srl_eval': lambda: {'correct': 0., 'excess': 0., 'missed': 0.},
  'conll_srl_all_eval': lambda: {'correct': 0., 'excess': 0., 'missed': 0.},
  'conll_parse_eval': lambda: {'total': 0., 'corrects': np.zeros(3)},
  'conll09_srl_eval': lambda: {'correct': 0., 'excess': 0., 'missed': 0.},
  'conll09_srl_eval_srl_only': lambda: {'correct': 0., 'excess': 0., 'missed': 0.},

  'conll09_srl_eval_all': lambda: {'correct': 0., 'excess': 0., 'missed': 0.},
  'conll09_srl_eval_all_srl_only': lambda: {'correct': 0., 'excess': 0., 'missed': 0.},
  'conll_srl_eval_with_transformation': lambda: OrderedDict({t: {'correct': 0, 'excess': 0, 'missed': 0} for t in  ['original'] + transformation_list}),
  'recall': lambda: {'tp': 0., 'tpfn': 0., 'tpfp': 0.}
}


def dispatch(fn_name):
  try:
    return fn_dispatcher[fn_name]
  except KeyError:
    util.fatal_error('Undefined evaluation function `%s' % fn_name)


def get_accumulator(fn_name):
  try:
    return accumulator_factory[fn_name]()
  except KeyError:
    util.fatal_error('Undefined evaluation function `%s' % fn_name)


def get_accumulators(task_config):
  eval_accumulators = {}
  # for i in layer_task_config:
  for task, task_map in task_config.items():
    for eval_name, eval_map in task_map['eval_fns'].items():
      eval_accumulators[eval_name] = get_accumulator(eval_map['name'])
  return eval_accumulators


def get_params(task, task_map, predictions, features, labels, reverse_maps, tokens_to_keep):
  # always pass through predictions, targets and mask
  params = {'predictions': predictions['%s_predictions' % task], 'targets': labels[task], 'mask': tokens_to_keep}
  if 'params' in task_map:
    params_map = task_map['params']
    for param_name, param_values in params_map.items():
      if 'reverse_maps' in param_values:
        params[param_name] = {map_name: reverse_maps[map_name] for map_name in param_values['reverse_maps']}
      elif 'label' in param_values:
        params[param_name] = labels[param_values['label']]
      elif 'feature' in param_values:
        params[param_name] = features[param_values['feature']]
      elif 'layer' in param_values:
        params[param_name] = predictions['%s_%s' % (param_values['layer'], param_values['output'])]
      else:
        params[param_name] = param_values['value']
  return params