import collections
import json
import numpy as np
import os
import re
import string
import sys
import argparse
from tqdm import tqdm, trange

def normalize_answer(s):
  """Lower text and remove punctuation, articles and extra whitespace."""
  def remove_articles(text):
    regex = re.compile(r'\b(a|an|the)\b', re.UNICODE)
    return re.sub(regex, ' ', text)
  def white_space_fix(text):
    return ' '.join(text.split())
  def remove_punc(text):
    exclude = set(string.punctuation)
    return ''.join(ch for ch in text if ch not in exclude)
  def lower(text):
    return text.lower()
  return white_space_fix(remove_articles(remove_punc(lower(s))))


def check_overlap():
    dev_data = json.load(open('/data/timchen0618/open_domain_data/TQA/dev.json'))
    train_data = json.load(open('/data/timchen0618/open_domain_data/TQA/train.json'))
    train_answers = [[normalize_answer(l) for l in train_data[j]['answers']] for j in range(len(train_data))]
    assert len(train_answers) == len(train_data)
    overlap_inds = []
    for i in trange(len(dev_data)):
        answers = [normalize_answer(l) for l in dev_data[i]['answers']]
        overlap = False
        for j in range(len(train_data)):
            #train_answers = [normalize_answer(l) for l in train_data[j]['answers']]
            if len(set(answers).intersection(set(train_answers[j]))) > 0:
                overlap=True
                break
        if overlap:
            overlap_inds.append(i)
    fw = open('overlap_inds.txt', 'w')
    for l in overlap_inds:
        fw.write(str(l) + '\n')
    fw.close()

if __name__ == '__main__':
    check_overlap()
