import os
import sys
import json

f = '/data/timchen0618/open_domain_data/NQ/dev.json'
f1 = '/home/timchen0618/data/open_domain_data/NQ/entity_data/NQOpenPart/NQOpen_TOP100_ent_auto.json'

indices = [int(l.strip('\n')) for l in open('indices_dev_ent_ans.txt')]
data = json.load(open(f))
mod_data = json.load(open(f1))

ent_data = [data[i] for i in indices]
assert len(ent_data) == len(mod_data)

for i in range(len(ent_data)):
    assert ent_data[i]['question'] == mod_data[i]['question']
    diff_ctx_indices = []
    for c in range(len(ent_data[i]['ctxs'])):
        if ent_data[i]['ctxs'][c]['text'] != mod_data[i]['ctxs'][c]['text']:
            diff_ctx_indices.append(c)
    ent_data[i]['diff_indices'] = diff_ctx_indices
    ent_data[i]['num_gold'] = len(diff_ctx_indices)

fw = open('new_data.json', 'w')
fw.write(json.dumps(ent_data, indent=4))
