import json
import sys
import os
import torch
from tqdm import trange
import random

if sys.argv[2][-1] == '%':
    K = float(sys.argv[2][:-1])/100.0
    topk_percent = True
else:
    K = int(sys.argv[2])
    topk_percent = False
#P = int(sys.argv[2])/100.0

#topP = True if sys.argv[3] == 'topP' else False
topP = False
sub_top = True

# first select examples that are correct and can be selected
# then sort passages indices according to score, and match corresponding substitute passages that are topk 
data = json.load(open(sys.argv[1]))
path = '/data/timchen0618/open_domain_data/NQ/entity_data/NQOpenPart'
inds = [int(l.strip('\n')) for l in open(os.path.join(path, 'only_corr_inds.txt'))]

data = [data[i] for i in inds]
sub_data = json.load(open(os.path.join(path, "NQOpen_Part_only_corr_TOP100_mod_ent.json")))
assert len(sub_data) == len(data)

cnt = 0 
for i in trange(len(data)):
    scores = torch.Tensor([c['score'] for c in data[i]['ctxs']])
    assert len(scores) == 100
    scores_sorted, inds_sorted = scores.sort(descending=True)
    
    # calculate topP
    if topP:
        j = 0
        prob = 0
        while True:
            index = inds_sorted[j].item()
            if not data[i]['ctxs'][index]['text'] == sub_data[i]['ctxs'][index]['text']:
                data[i]['ctxs'][index] = sub_data[i]['ctxs'][index]
                prob += scores_sorted[j].item()
                cnt += 1
            j += 1
            if prob >= P:
                break
            if j >= 100:
                break
    # calculate topK
    else:
        j = 0
        sub_cnt = 0
        inds = []
        if topk_percent:
            th = float(sub_data[i]['num_gold']) * K
            if th - int(th) > 0.5:
                threshold = int(th) + 1
            elif th - int(th) < 0.5:
                threshold = int(th)
            else:
                threshold = int(th) + random.choice([0,1])
            print(sub_data[i]['num_gold'], threshold)
        else:
            threshold = K
        while True:
            index = inds_sorted[j].item()
            if not data[i]['ctxs'][index]['text'] == sub_data[i]['ctxs'][index]['text']:
                #print('ffffffff', index)
        
                #print(index, data[i]['ctxs'][index])
                #data[i]['ctxs'][index] = sub_data[i]['ctxs'][index]
                #print(data[i]['ctxs'][index])
                inds.append(index)
                sub_cnt += 1

            j += 1
            if sub_cnt == threshold:
                break
            if j >= 100:
                break
        
        if sub_top:
            sub_inds = inds
        else:
            sub_inds = [k for k in range(100) if k not in set(inds)]
        
        for index in sub_inds:
            data[i]['ctxs'][index] = sub_data[i]['ctxs'][index]
            

if topP:
    fw = open("NQOpen_Part_only_corr_TOP%2.2f_mod_ent.json"%P, 'w')
    print(cnt/float(len(data)))
else:
    if sub_top:
        if topk_percent:
            K = int(K*100)
        fw = open("NQOpen_Part_only_corr_attn_TOP%d_softmax_mod_ent.json"%K, 'w')
    else:
        fw = open("NQOpen_Part_only_corr_NOT_TOP%d_softmax_mod_ent.json"%K, 'w')

fw.write(json.dumps(data, indent=4))
fw.close()
