import sys
import os
import json
from tqdm import tqdm, trange
import random

rootdir='/data/timchen0618/open_domain_data/SituatedQA/DPR-finetune'
files = os.listdir(rootdir)
data = json.load(open('/data/timchen0618/open_domain_data/NQ/dev_wgold_indices.json'))  # l['gold_indices']

sit_data = []
for f in files:
    print('processing %s'%f)
    d = json.load(open(os.path.join(rootdir, f)))
    print(len(d))
    sit_data += d

print('situated data, len %d'%len(sit_data))
sit_data_dict = {}
for l in sit_data:
    sit_data_dict[l['question']] = l


nq_in_sit = []
nq_in_sit_inds = []
adv_ctxs = []  # len should be the same with selected nq open instances
adv_gold_indices = []

# check which instances of NQ dev is in AmbigQA
for i, inst in tqdm(enumerate(data), desc='check instances of NQ dev'):
    if inst['question'] in sit_data_dict:
        nq_in_sit_inds.append(i)
        nq_in_sit.append(inst)
        nq_in_sit[-1]['sit_answers'] = sit_data_dict[inst['question']]['answers']
        nq_in_sit[-1]['id'] = sit_data_dict[inst['question']]['id']
        adv_ctxs.append(sit_data_dict[inst['question']]['ctxs'])
        adv_gold_indices.append([])
        for j, c in enumerate(adv_ctxs[-1]):
            if c['has_answer'] == 'true':
                adv_gold_indices[-1].append(j)


print('processed data len %d (NQ in SituatedQA)'%len(nq_in_sit))
assert len(nq_in_sit) == len(nq_in_sit_inds)
assert len(nq_in_sit) == len(adv_ctxs)
assert len(nq_in_sit) == len(adv_gold_indices)

## start create counterfactual data ##
"""
    method 1: 
        50 NQ, 50 SituatedQA; in those 50 contexts, must contain all gold answer spans
    method 2:
        just replace the bottom contexts with situatedQA ones
"""
print('creating new data...')
new_data = []
HALF=50
for i in trange(len(nq_in_sit)):
    ctxs = []
    # collect situatedQA contexts
    if len(adv_gold_indices[i]) >= HALF:  # if have more than 50 gold passages
        ctxs += [adv_ctxs[i][j] for j in adv_gold_indices[i][:HALF]]
    else:
        ctxs += [adv_ctxs[i][j] for j in adv_gold_indices[i]]
        for j in range(2*HALF):
            if j not in adv_gold_indices[i]:
                ctxs.append(adv_ctxs[i][j])
            if len(ctxs) >= HALF:
                break
    assert len(ctxs) == HALF

    # collect NQ contexts
    if len(nq_in_sit[i]['gold_indices']) >= HALF:  # if have more than 50 gold passages
        ctxs += [nq_in_sit[i]['ctxs'][j] for j in nq_in_sit[i]['gold_indices'][:HALF]]
    else:
        ctxs += [nq_in_sit[i]['ctxs'][j] for j in nq_in_sit[i]['gold_indices']]
        for j in range(2*HALF):
            if j not in nq_in_sit[i]['gold_indices']:
                ctxs.append(nq_in_sit[i]['ctxs'][j])
            if len(ctxs) >= 2*HALF:
                break
    assert len(ctxs) == 2*HALF
    new_data.append({'question':nq_in_sit[i]['question'], 'answers':nq_in_sit[i]['answers'], 'sit_answers':nq_in_sit[i]['sit_answers'], 'id':nq_in_sit[i]['id'], 'ctxs':ctxs})


assert len(new_data) == len(nq_in_sit)


## write results ##
fw = open('counter_situated_data.json', 'w')
fw.write(json.dumps(new_data, indent=4))
fw.close()

fw = open('nq_dev_in_sit_inds.txt', 'w')
for l in nq_in_sit_inds:
    fw.write(str(l) + '\n')
fw.close()