import argparse
import os

from colbert.evaluation.loaders import load_qrels

if __name__=='__main__':

    parser = argparse.ArgumentParser()
    parser.add_argument('--qrels')
    parser.add_argument('--pruner_score_path')
    
    args = parser.parse_args()

    qrels = load_qrels(args.qrels)
    pids = set()
    for qid, pid_list in qrels.items():
        pids.update(pid_list)

    print(f'# of unique passages = {len(pids)}, e.g., ', list(pids)[:5], '\n\n')
    
    outfile_path = os.path.join(os.path.dirname(args.pruner_score_path), 'collection.annotated.pruner_score.tsv')
    print(f'#> \"[pid]\\t[tokens]\\t[scores]\" \non annotated relevant passages in {args.qrels}\nwill be saved into \"{outfile_path}\"')

    with open(args.pruner_score_path, 'r', encoding='utf-8') as infile, \
        open(outfile_path, 'w', encoding='utf-8') as outfile:
            for line_idx, line in enumerate(infile):
                pid, tokens, scores = line.strip().split('\t')
                if int(pid) in pids:
                    outline = f'{pid}\t{tokens}\t{scores}\n'
                    outfile.write(outline)
    
    print(f'#> Done!')