import os
import sys
import json
import random

random.seed(42)
f = '/data/timchen0618/open_domain_data/NQ/dev_entity_examples_wnumgold.json'
f1 = '/home/timchen0618/data/open_domain_data/NQ/entity_data/NQOpenPart/NQOpen_TOP100_ent_auto.json'

data = json.load(open(f))
mod_data = json.load(open(f1))
assert len(data) == len(mod_data)

topk=float(sys.argv[1])
print(topk)

sample_top=True

offsets=[0, 1]
for i in range(len(data)):
    mod_ctxs = mod_data[i]['ctxs']
    indices = data[i]['diff_indices']

    end=len(indices) * topk 
    if end - int(end) > 0.5:
        end = int(end) + 1
    elif end - int(end) < 0.5:
        end = int(end) 
    else:
        end = int(end) + random.choice(offsets)
       
    if indices and end > 0:
        if sample_top:
            chosen_indices = sorted(indices)[:end]    
        else:
            chosen_indices = random.sample(indices, end)
    else:
        chosen_indices = []
    print('===')
    print(len(indices), indices)
    print(len(chosen_indices), chosen_indices)
    for idx in chosen_indices:
        data[i]['ctxs'][idx] = mod_data[i]['ctxs'][idx]
    for idx in range(100):
        data[i]['ctxs'][idx]['title'] = ""
        data[i]['ctxs'][idx]['id'] = "0"
    data[i]['answers'] = mod_data[i]['answers']

fw = open('NQOpen_Part_sample_top_TOP%d_mod_ent.json'%int(topk*100), 'w')
fw.write(json.dumps(data, indent=4))
