import os
import sys
import json
from tqdm import tqdm

############### normalize answer ######################
import regex
import unicodedata
import re
import string
import argparse

def normalize_answer(s):
  """Lower text and remove punctuation, articles and extra whitespace."""
  def remove_articles(text):
    regex = re.compile(r'\b(a|an|the)\b', re.UNICODE)
    return re.sub(regex, ' ', text)
  def white_space_fix(text):
    return ' '.join(text.split())
  def remove_punc(text):
    exclude = set(string.punctuation)
    return ''.join(ch for ch in text if ch not in exclude)
  def lower(text):
    return text.lower()
  return white_space_fix(remove_articles(remove_punc(lower(s))))

def get_tokens(s):
  if not s: return []
  return normalize_answer(s).split()

def compute_exact(a_gold, a_pred):
  return int(normalize_answer(a_gold) == normalize_answer(a_pred))


class SimpleTokenizer(object):
    ALPHA_NUM = r'[\p{L}\p{N}\p{M}]+'
    NON_WS = r'[^\p{Z}\p{C}]'

    def __init__(self):
        """
        Args:
            annotators: None or empty set (only tokenizes).
        """
        self._regexp = regex.compile(
            '(%s)|(%s)' % (self.ALPHA_NUM, self.NON_WS),
            flags=regex.IGNORECASE + regex.UNICODE + regex.MULTILINE
        )

    def tokenize(self, text, uncased=False):
        matches = [m for m in self._regexp.finditer(text)]
        if uncased:
            tokens = [m.group().lower() for m in matches]
        else:
            tokens = [m.group() for m in matches]
        return tokens
############### normalize answer ######################


# first find all the answers in nq train
# train_data = json.load(open('/data/timchen0618/open_domain_data/NQ/train.json'))
# all_ans = set()
# for l in tqdm(train_data):
#     for ans in l['answers']:
#         all_ans.add(ans)

# fw = open('train_answers.txt', 'w')
# for l in all_ans:
#     fw.write(l+'\n')


train_answers = [l.strip('\n') for l in open('train_answers.txt')]
print(train_answers[0])
train_answers = list(map(normalize_answer, train_answers))
print(train_answers[0])
train_answers = set(train_answers)

valid_data = json.load(open('/data/timchen0618/open_domain_data/NQ/dev.json'))

nao_indices = []
for i, inst in enumerate(valid_data):
    overlap = False
    for ans in inst['answers']:
        if normalize_answer(ans) in train_answers:
            overlap = True
            break
    if not overlap:
        nao_indices.append(i)

fw = open('nao_indices.txt', 'w')
for l in nao_indices:
    fw.write(str(l) + '\n')

