import os
import json

import collections
import json
import numpy as np
import os
import re
import string
import sys
import argparse

def normalize_answer(s):
  """Lower text and remove punctuation, articles and extra whitespace."""
  def remove_articles(text):
    regex = re.compile(r'\b(a|an|the)\b', re.UNICODE)
    return re.sub(regex, ' ', text)
  def white_space_fix(text):
    return ' '.join(text.split())
  def remove_punc(text):
    exclude = set(string.punctuation)
    return ''.join(ch for ch in text if ch not in exclude)
  def lower(text):
    return text.lower()
  return white_space_fix(remove_articles(remove_punc(lower(s))))

def get_tokens(s):
  if not s: return []
  return normalize_answer(s).split()

def compute_exact(a_gold, a_pred):
  return int(normalize_answer(a_gold) == normalize_answer(a_pred))

def compute_f1(a_gold, a_pred):
  gold_toks = get_tokens(a_gold)
  pred_toks = get_tokens(a_pred)
  common = collections.Counter(gold_toks) & collections.Counter(pred_toks)
  num_same = sum(common.values())
  if len(gold_toks) == 0 or len(pred_toks) == 0:
    # If either is no-answer, then F1 is 1 if they agree, 0 otherwise
    return int(gold_toks == pred_toks)
  if num_same == 0:
    return 0
  precision = 1.0 * num_same / len(pred_toks)
  recall = 1.0 * num_same / len(gold_toks)
  f1 = (2 * precision * recall) / (precision + recall)
  return f1


output_file = sys.argv[1]
conf_score_file = sys.argv[2]
data_file = sys.argv[3]
org_output_file = sys.argv[4]
org_conf_score_file = sys.argv[5]
org_data_file = sys.argv[6]

data = json.load(open(data_file))
conf_score = [float(l.strip('\n')) for l in open(conf_score_file)]
outputs = [l.strip('\n').split('\t')[1] for l in open(output_file)]

org_outputs = [l.strip('\n').split('\t')[1] for l in open(org_output_file)]
org_conf_score = [float(l.strip('\n')) for l in open(org_conf_score_file)]
org_data = json.load(open(org_data_file))

assert len(data) == len(conf_score)
assert len(outputs) == len(conf_score)
assert len(org_data) == len(org_conf_score)
assert len(org_outputs) == len(org_conf_score)
assert len(conf_score) == len(org_conf_score)


all_conf_score = [('mod_%d'%i, conf_score[i]) for i in range(len(conf_score))] + [('org_%d'%i, org_conf_score[i]) for i in range(len(org_conf_score))]
assert len(all_conf_score) == (2*len(conf_score))
sorted_score = sorted(all_conf_score, key=lambda x:x[1], reverse=True)

cut_off = int(0.2*len(all_conf_score))
mod_portion = 0
total = 0
for s in sorted_score[:cut_off]:
  if s[0][:3] == 'mod':
    mod_portion += 1
  total += 1
print(mod_portion, total)
print(float(mod_portion)/cut_off)

# new_data = [(data[i], conf_score[i], outputs[i]) for i in range(len(data))]

# sorted_data = sorted(new_data, key=lambda x:x[1], reverse=True)
# print(sorted_data[:2])

# em = []
# for i in range(len(data)//2):
#   ans = sorted_data[i][2]
#   em.append(max(compute_exact(ans, gold) for gold in sorted_data[i][0]['answers']))
#   # if max(compute_exact(ans, gold) for gold in data[i]['answers']) != 0:
#       # fw.write(str(i)+'\n')
#   # em.append()

# print(float(sum(em))/len(em))

