import os
import sys
import string
import re

def normalize_answer(s):
  """Lower text and remove punctuation, articles and extra whitespace."""
  def remove_articles(text):
    regex = re.compile(r'\b(a|an|the)\b', re.UNICODE)
    return re.sub(regex, ' ', text)
  def white_space_fix(text):
    return ' '.join(text.split())
  def remove_punc(text):
    exclude = set(string.punctuation)
    return ''.join(ch for ch in text if ch not in exclude)
  def lower(text):
    return text.lower()
  return white_space_fix(remove_articles(remove_punc(lower(s))))

def get_tokens(s):
  if not s: return []
  return normalize_answer(s).split()

def compute_exact(a_gold, a_pred):
  return int(normalize_answer(a_gold) == normalize_answer(a_pred))


############# main ##############

pred = [l.strip('\n').split('\t')[1] for l in open(sys.argv[1])]
#pred = [l.strip('\n') for l in open(sys.argv[1])]
if sys.argv[5] == '-fid':
    tgt = [l.strip('\n').split('\t')[1] for l in open(sys.argv[2])]
elif sys.argv[5] == '-rag':
    tgt = [l.strip('\n') for l in open(sys.argv[2])]
assert len(pred) == len(tgt)

use_id=False
if sys.argv[3] == '-id':
    ids = [int(l.strip('\n')) for l in open(sys.argv[4])]
    use_id = True
#same = [int(pred[i] == tgt[i]) for i in range(len(pred))]

same = []
for i in range(len(pred)):
    if (not use_id) or i in ids:
        same.append(compute_exact(pred[i], tgt[i]))

print(sum(same)/float(len(same)))
print(len(same))

fw = open('out.txt', 'w')
if use_id:
    for id_ in ids:
        if not compute_exact(pred[id_], tgt[id_]):
            fw.write(str(id_)+'\n')
else:
    for i in range(len(pred)):
        if not compute_exact(pred[i], tgt[i]):
            fw.write(str(i) + '\n')

fw.close()
