import argparse
from typing import DefaultDict
from tqdm import tqdm
import numpy as np
import ujson
import pandas as pd

from collections import defaultdict, OrderedDict

def load_qrels(qrels_path):
    if qrels_path is None:
        return None

    print("#> Loading qrels from", qrels_path, "...")

    qid_list = []
    pid_list = []
    rel_list = []
    with open(qrels_path, mode='r', encoding="utf-8") as f:
        for line_idx, line in enumerate(f):
            qid, _, pid, rel = line.strip().split()
            qid, pid, rel = map(int, (qid, pid, rel))
            qid_list.append(str(qid))
            pid_list.append(str(pid))
            rel_list.append(rel)
    qrels = pd.DataFrame({
        'qid':qid_list,
        'docno':pid_list,
        'label':np.array(rel_list, dtype=np.int64),
    })
    return qrels

def load_ranking(path, qrels_exclude):
    print("#> Loading ranking from", path, "...")

    qid_list = []
    pid_list = []
    rank_list = []
    score_list = []

    with open(path, mode='r', encoding="utf-8") as f:
        
        if path.endswith('.jsonl'):
            for line_idx, line in enumerate(f):
                qid, pids = line.strip().split('\t')
                pids = ujson.loads(pids)
                _rank = 1
                for rank, pid in enumerate(pids):
                    pid = str(pid)
                    if pid not in qrels_exclude[qid]:
                        qid_list.append(qid)
                        pid_list.append(pid)
                        rank_list.append(_rank)
                        score_list.append(1000-float(_rank))
                        _rank += 1

        elif path.endswith('.tsv'):
            qid_rank = defaultdict(int)
            for line_idx, line in enumerate(f):
                qid, pid, rank, score = line.strip().split('\t')
                if pid not in qrels_exclude[qid]:
                    qid_rank[qid] += 1
                    _rank = qid_rank[qid]

                    qid_list.append(qid)
                    pid_list.append(pid)
                    rank_list.append(_rank)
                    score_list.append(1000-float(_rank))

    ranking = pd.DataFrame({
        'qid':qid_list,
        'docno':pid_list,
        'rank':np.array(rank_list, dtype=np.int64),
        'score':np.array(score_list, dtype=np.float64),
    })
    return ranking



import pyterrier as pt

if __name__=='__main__':
        
    parser = argparse.ArgumentParser()
    parser.add_argument('--qrels', dest='qrels')
    parser.add_argument('--qrels_exclude', type=str)
    parser.add_argument('--ranking', dest='ranking',)

    args = parser.parse_args()

    qrels = load_qrels(args.qrels)
    print(f'#> The # of samples in qrels = {len(qrels)}')
    print(qrels.head())
    
    if args.qrels_exclude:
        qrels_exclude = defaultdict(set)
        with open(args.qrels_exclude, mode='r', encoding="utf-8") as f:
            for line_idx, line in enumerate(f):
                qid, _, pid, rel = line.strip().split()
                # qrels_exclude[int(qid)].add(int(pid))
                qrels_exclude[qid.strip()].add(pid.strip())
    else:
        qrels_exclude = defaultdict(set)

    ranking = load_ranking(args.ranking, qrels_exclude)

    print('\n\n')

    if not pt.started():
        pt.init()
    
    from pyterrier.measures import RR, nDCG, AP, NumRet, R, P
    # RR: [Mean] Reciprocal Rank ([M]RR)
    # nDCG: The normalized Discounted Cumulative Gain (nDCG).
    # AP: The [Mean] Average Precision ([M]AP).
    # R: Recall@k (R@k).
    from pandas import DataFrame
    """ (from ColBERT-PRF)
     we report mean reciprocal rank (MRR) and normalised
    discounted cumulative gain (NDCG) calculated at rank 10, as well
    as Recall and Mean Average Precision (MAP) at rank 1000 [8]. For
    the MRR, MAP and Recall metrics, we treat passages with label
    grade 1 as non-relevant, following [7, 8].
    """
    eval = pt.Utils.evaluate(ranking, qrels, 
        metrics=[
            nDCG@10, nDCG@25, nDCG@50, nDCG@100, nDCG@200, nDCG@500, nDCG@1000, 
            R(rel=2)@3, R(rel=2)@5, R(rel=2)@10, R(rel=2)@25, R(rel=2)@50, R(rel=2)@100, R(rel=2)@200, R(rel=2)@1000,
            P(rel=2)@1, P(rel=2)@3, P(rel=2)@5, P(rel=2)@10, P(rel=2)@25, P(rel=2)@50, P(rel=2)@100, P(rel=2)@200, P(rel=2)@1000,
            AP(rel=2)@1000, RR(rel=2)@10, 
            NumRet, "num_q",
        ],
        # These measures are from "https://github.com/terrierteam/ir_measures/tree/f6b5dc62fd80f9e4ca5678e7fc82f6e8173a800d/ir_measures/measures"
    )
    print(eval)
