import os
import random
import time
import torch
import torch.nn as nn
import numpy as np

from itertools import accumulate
from math import ceil

from pruner.utils.runs import Run
from pruner.utils.utils import print_message

from pruner.evaluation.metrics import RankingMetrics
from pruner.modeling.inference import ModelInference

from colbert.modeling.inference import ModelInference as ColbertModelInference

def evaluate_ranking(args):
    colbert_inference = ColbertModelInference(args.colbert, amp=args.amp)
    pruner_inference = ModelInference(args.pruner, amp=args.amp)
    
    qrels = args.qrels
    ranking = args.ranking
    # OrderedDict[ qid (int) -> Dict [ pid (int) -> score (float) ] ]

    data = []
    for qid, pids in qrels.items():
        for pid in pids:
            data.append((qid, pid))

    metrics = RankingMetrics(topks={10, 20, 24, 30, 40, 50, 60}, total_num_data=len(data))

    with torch.no_grad():
        
        n_samples = 0
        
        # Batch processing
        for offset in range(0, len(data), args.bsize):
        
            endpos = min(offset + args.bsize, len(data))
            
            qid_list, pid_list = zip(*data[offset:endpos])

            queries = [args.queries[qid] for qid in qid_list]
            passages = [args.collection[pid] for pid in pid_list]

            scores, tokens = pruner_inference.scoreFromText(docs=passages) 
            # scores: List[ List[float] ] = scores for each token

            Q = colbert_inference.queryFromText(queries)
            # (``bsize, query_maxlen, dim``)
            D = colbert_inference.docFromText([(_, []) for _ in passages])
            # (``bsize, doc_maxlen, dim``)

            qd_sims = (Q @ D.permute(0, 2, 1))
            # (``bsize, query_maxlen, doc_maxlen``)

            #?@ debugging
            # print(f':passages[0]:={passages[0]}')
            # print(f':passages[1]:={passages[1]}')
            # print(f':scores[0]:={scores[0]}')
            # print(f':scores[1]:={scores[1]}')
            # print(f':Q: {Q.size()}=\n{Q[0, :6, :6]}')
            # print(f':D: {D.size()}=\n{D[0, :6, :6]}')
            # print(f':qd_sims: {qd_sims.size()}=\n{qd_sims[0, :6, :6]}')
            
            # print('\n\n')
            for batch_i, (qid, pid, qd_sim, score) in enumerate(zip(qid_list, pid_list, qd_sims, scores)):
                
                qd_sim = qd_sim[:, :len(score)].cpu().data.numpy() # ``(query_maxlen, doc_maxlen)``

                rel_pids = qrels[qid] # List[ pid ]
                
                topk_pids = ranking[qid].copy() # Dict[ pid -> score ]
                if pid in topk_pids:
                    del topk_pids[pid]
                candidate_rels = list(topk_pids.items())

                metrics.add(pid=pid, qd_sim=qd_sim, score=score, gold_pids=set(rel_pids), candidate_rels=candidate_rels)
                n_samples += 1

                #?@ debugging
                # print(f'\n:qd_sim:{qd_sim.shape}=\n{qd_sim[:, :6]}')
                # print(f'\n:candidate_rels:{len(candidate_rels)}=\n{candidate_rels[:6]}')
                # print(f'\n:rel_pids:{rel_pids}')
                # metrics.print_metrics(n_samples)
                # print("\n\n")
                # exit()

            # exit()

    print('\n\n')
    assert n_samples == len(data)
    print_message("#> checkpoint['batch'] =", args.checkpoint['batch'])
    metrics.output_final_metrics(os.path.join(Run.path, 'ranking.metrics'), n_samples, len(data))
    print('\n\n')
