import collections
import json
import numpy as np
import os
import re
import string
import sys
import argparse

def normalize_answer(s):
  """Lower text and remove punctuation, articles and extra whitespace."""
  def remove_articles(text):
    regex = re.compile(r'\b(a|an|the)\b', re.UNICODE)
    return re.sub(regex, ' ', text)
  def white_space_fix(text):
    return ' '.join(text.split())
  def remove_punc(text):
    exclude = set(string.punctuation)
    return ''.join(ch for ch in text if ch not in exclude)
  def lower(text):
    return text.lower()
  return white_space_fix(remove_articles(remove_punc(lower(s))))

def get_tokens(s):
  if not s: return []
  return normalize_answer(s).split()

def compute_exact(a_gold, a_pred):
  return int(normalize_answer(a_gold) == normalize_answer(a_pred))

def compute_f1(a_gold, a_pred):
  gold_toks = get_tokens(a_gold)
  pred_toks = get_tokens(a_pred)
  common = collections.Counter(gold_toks) & collections.Counter(pred_toks)
  num_same = sum(common.values())
  if len(gold_toks) == 0 or len(pred_toks) == 0:
    # If either is no-answer, then F1 is 1 if they agree, 0 otherwise
    return int(gold_toks == pred_toks)
  if num_same == 0:
    return 0
  precision = 1.0 * num_same / len(pred_toks)
  recall = 1.0 * num_same / len(gold_toks)
  f1 = (2 * precision * recall) / (precision + recall)
  return f1

def parse_args():
  parser = argparse.ArgumentParser('evaluation script')
  parser.add_argument('-data_file', metavar='data.json', help='Input data JSON file.')
  parser.add_argument('-pred_file', metavar='pred.txt', help='Model predictions.')
  parser.add_argument('-dpr', action='store_true')
  parser.add_argument('-fid', action='store_true')
  parser.add_argument('--out-file', '-o', metavar='eval.json',
                      help='Write accuracy metrics to file (default is stdout).')
  parser.add_argument('-correct_indices')
  parser.add_argument('-nao_indices')


  if len(sys.argv) == 1:
    parser.print_help()
    sys.exit(1)
  return parser.parse_args()


if __name__ == '__main__':
  args = parse_args()
  f1s = []
  new_f1s = []
  if args.fid:
    answers = [l.strip('\n').split('\t')[1].strip() for l in open(args.pred_file)]
    #answers = [l.strip('\n').strip() for l in open(args.pred_file)]
    data = json.load(open(args.data_file))

    corr = None
    if args.correct_indices:
      print(args.correct_indices)
      corr = [int(l.strip('\n')) for l in open(args.correct_indices)]        

    nao = None
    if args.nao_indices:
      print(args.nao_indices)
      nao = [int(l.strip('\n')) for l in open(args.nao_indices)]    

    fw=open('see.txt', 'w')
    ll=[]
    assert len(answers) == len(data)
    for i, ans in enumerate(answers):
      # assert(len(data[i]['answers'])==1)
      # print(ans, data[i]['answers'][0])
      if (not corr) or (i in corr):
          #if True:
        if (not nao) or (i in nao):
          f1s.append(max(compute_exact(ans, gold) for gold in data[i]['answers']))
          new_f1s.append(max(compute_exact(ans, gold) for gold in data[i]['sub_answers']))
          #f1s.append(max(compute_exact(ans, gold) for gold in data[i]['answers']))
          if max(compute_exact(ans, gold) for gold in data[i]['answers']) != 0:
              fw.write(str(i)+'\n')
            
  print('old_f1:', sum(f1s)/float(len(f1s)))
  print(len(f1s))
  print('new_f1:', sum(new_f1s)/float(len(new_f1s)))
  print(len(new_f1s))
