import json
import sys
from match import has_answer, SimpleTokenizer
from tqdm import tqdm

tok = SimpleTokenizer()
train=False
dataset="TQA"
if train:
    data = json.load(open('/data/timchen0618/open_domain_data/%s/train.json'%dataset))
else:
    data = json.load(open('/data/timchen0618/open_domain_data/%s/dev.json'%dataset))

recall = 0

num_golds = []
for i in tqdm(range(len(data))):
    ans = data[i]['answers']
    ctxs = data[i]['ctxs']
    has_ans = False
    golds = []
    for j, c in enumerate(ctxs):
        if has_answer(ans, c['text'], tok):
            golds.append(j)

    data[i]['gold_indices'] = golds
    num_golds.append(len(golds))

print('num gold: %f'%(sum(num_golds)/float(len(num_golds))))
#fw = open('/data/timchen0618/open_domain_data/%s/train_wgold_indices.json'%dataset, 'w')
#fw.write(json.dumps(data, indent=4))
