import json
from auto_eval import compute_exact
def main(data_file, pred_file, mod_file):
    data = json.load(open(data_file))
    pred_data = json.load(open(pred_file))
    mod_data = json.load(open(mod_file))

    assert len(pred_data) == len(mod_data)

    total_pred_incorr_mod_corr = 0
    same_title = 0
    for i in range(len(pred_data)):
        gold_answers = pred_data[i]['gold_answers']
        question = pred_data[i]['question']
        inst_mod = None
        for inst in mod_data:
            if inst['question'] == question:
                inst_mod = inst
        assert pred_data[i]['question'] == inst_mod['question']
        # #this calculates when prediction is incorrect and modified is correct
        # if (not is_correct(pred_data[i], gold_answers)) and is_correct(inst_mod, gold_answers): 
        # this calculates when orig retrieved do not have ans and modified passage has ans
        pred_ret_id = int(pred_data[i]['predictions'][-1]['prediction']['passage_idx'])
        mod_ret_id = int(inst_mod['predictions'][-1]['prediction']['passage_idx'])
        if (not data[i]['ctxs'][pred_ret_id]['has_answer']) and data[i]['ctxs'][mod_ret_id]['has_answer']:
            if (not is_correct(pred_data[i], gold_answers)) and is_correct(inst_mod, gold_answers):
                assert data[i]['ctxs'][mod_ret_id]['has_answer']
                print('gggggg', i)
                total_pred_incorr_mod_corr += 1
                pred_ret_id = int(pred_data[i]['predictions'][-1]['prediction']['passage_idx'])
                pred_title = data[i]['ctxs'][pred_ret_id]['title']
                mod_ret_id = int(inst_mod['predictions'][-1]['prediction']['passage_idx'])
                mod_title = data[i]['ctxs'][mod_ret_id]['title']
                print(pred_title, '\t', mod_title)
                assert pred_ret_id != mod_ret_id
                if pred_title == mod_title:
                    same_title += 1
                    print("same_title", i)

    print('='*50)
    print('total_pred_incorr_mod_corr', total_pred_incorr_mod_corr)
    print('same_title', same_title)

def is_correct(output, gold_answers):
    for ans in gold_answers:
        if compute_exact(ans, output['predictions'][-1]['prediction']['text']):
            return True
    return False
if __name__ == '__main__':
    data_file = 'DPR/selected_dev_data/retriever_results/dev_top20_not5.json'
    pred_file = 'DPR/dpr/pred_dir/pred_top20_not5.json'
    mod_file = 'DPR/dpr/pred_dir/mod_top20_not5.json'

    main(data_file, pred_file, mod_file)


    # where DPR prediction incorrect and correct after manipulation
    # where top ranked passage does not contain the answer