import sys
import os
import json
from tqdm import tqdm, trange
import random


split='dev'
data_path = '/data/timchen0618/open_domain_data/AmbigNQ/%s_data'%split
print('loading DPR data from %s'%open(os.path.join(data_path, 'Ambig_%s_disambiguated_DPR.json'%(split[:1].upper()+split[1:]))))
print('loading qa pair from %s'%os.path.join(data_path, 'ambig_disambig_qa_pairs_%s_old_1.json'%split))
dpr_data = json.load(open(os.path.join(data_path, 'Ambig_%s_disambiguated_DPR.json'%(split[:1].upper()+split[1:]))))
qa_pair = json.load(open(os.path.join(data_path, 'ambig_disambig_qa_pairs_%s_old_1.json'%split)))

print('reading /data/timchen0618/open_domain_data/NQ/%s_wgold_indices.json'%split)
org_data = json.load(open('/data/timchen0618/open_domain_data/NQ/%s_wgold_indices.json'%split))  # l['gold_indices']

new_data = []
## start create counterfactual data ##
for i, inst in tqdm(enumerate(qa_pair), desc='check instances of Ambig %s'%split):
    new_inst = {"question": inst['question'], 'answers': inst['answers'], "ctxs":[], "disambiguated_questions": [], "new_answers": []}
    N = len(inst['dpr_inds'])

    # count the number of contexts for each dpr passage
    num_ctxs = [0 for _ in range(N+1)]
    for i in range(100):
        num_ctxs[i%(N+1)] += 1

    assert sum(num_ctxs) == 100

    for j, dpr_id in enumerate(inst['dpr_inds']):
        dpr_inst = dpr_data[dpr_id]
        new_inst["disambiguated_questions"].append(dpr_inst['question'])
        new_inst["new_answers"].append(dpr_inst['answers'])
        
        ordering = []
        # create an ordering for the contexts
        for k, c in enumerate(dpr_inst["ctxs"]):
            if c["has_answer"]:
                ordering.append(k)
        for k, c in enumerate(dpr_inst["ctxs"]):
            if k not in ordering:
                ordering.append(k)
        new_inst["ctxs"] += [dpr_inst["ctxs"][m] for m in ordering[:num_ctxs[j]]]

    org_inst = org_data[inst['dev_nq_id']]
    assert org_inst['question'] == inst['question']
    ordering = org_inst["gold_indices"]
    for k in range(100):
        if k not in ordering:
            ordering.append(k)
    assert len(ordering) == 100
    new_inst["ctxs"] += [org_inst["ctxs"][m] for m in ordering[:num_ctxs[-1]]]


    assert len(new_inst["ctxs"]) == 100
    assert len(new_inst["disambiguated_questions"]) == N
    assert len(new_inst["new_answers"]) == N
    new_data.append(new_inst)

print(len(new_data))
print(len(qa_pair))
assert len(new_data) == len(qa_pair)


## write results ##
fw = open('counter_ambigqa_data_%s_old.json'%split, 'w')
fw.write(json.dumps(new_data, indent=4))
fw.close()

