import os
import random
import time
import numpy as np
import torch
import torch.nn as nn

from itertools import accumulate
from math import ceil

from colbert.evaluation.ranking_logger import RankingLogger

from colbert.modeling.inference import ModelInference as ColbertModelInference

from pruner.utils.runs import Run
from pruner.utils.utils import print_message
from pruner.modeling.inference import ModelInference as PrunerModelInference
from pruner.evaluation.metrics import RerankingMetrics



def evaluate(args):
    
    pruner_inference = PrunerModelInference(args.pruner, amp=args.amp)
    colbert_inference = ColbertModelInference(args.colbert, amp=args.amp)
    
    qrels, queries, topK_pids = args.qrels, args.queries, args.topK_pids

    depth = args.depth
    collection = args.collection
    if collection is None:
        topK_docs = args.topK_docs

    def qid2passages(qid):
        if collection is not None:
            return [collection[pid] for pid in topK_pids[qid][:depth]] 
        else:
            return topK_docs[qid][:depth]

    metrics = RerankingMetrics(mrr_depths={10, 100}, recall_depths={50, 200, 1000},
                      success_depths={5, 10, 20, 50, 100, 1000},
                      total_queries=len(queries))

    ranking_logger = RankingLogger(Run.path, qrels=qrels)

    args.milliseconds = []

    with ranking_logger.context('ranking.tsv', also_save_annotations=(qrels is not None)) as rlogger:
        with torch.no_grad():
            keys = sorted(list(queries.keys()))
            random.shuffle(keys)

            for query_idx, qid in enumerate(keys):
                
                #?@ debugging
                # if query_idx == 3:break
                
                query = queries[qid]

                print_message(query_idx, qid, query, '\n')

                if qrels and args.shortcircuit and len(set.intersection(set(qrels[qid]), set(topK_pids[qid]))) == 0:
                    continue
                    
                pids = topK_pids[qid]
                passages = qid2passages(qid)

                # ---
                # Compute similarities between query x document tokens 
                # ---
                Q = colbert_inference.queryFromText([query])
                D_, D_mask = colbert_inference.docFromTextForPruner([(psg, []) for psg in passages])
                QD_sims = (Q @ D_.permute(0, 2, 1)).cpu()
                # print(f':QD_sims: {QD_sims.size()}=\n{QD_sims[:, :6, :6]}')

                # ---
                # Aggregate scores for selected document tokens
                # ---
                pruner_scores, tokens = pruner_inference.scoreFromText(docs=passages) 
                # pruner_scores: List[ List[float] ] = scores for each token
                # print(f':pruner_scores: {len(pruner_scores)}=\n{pruner_scores[0]}')

                selected_positions = [np.argsort(pruner_score)[::-1][:args.pruned_index_size] for pruner_score in pruner_scores] 
                # print(f':selected_positions: {len(selected_positions)}, len(selected_positions[0]: {len(selected_positions[0])}=\n{selected_positions[0]}')
                # List [ List[int] ]
                scores = [
                    x[:, y.copy()].max(1).values.sum(0).item()
                    for x, y in zip(QD_sims, selected_positions)
                ]
                scores = torch.tensor(scores)
                # print(f':scores: {scores.size()}=\n{scores}')

                # ---
                # Rank by scores
                # ---
                scores = scores.sort(descending=True)
                ranked = scores.indices.tolist()

                ranked_scores = scores.values.tolist()
                ranked_pids = [pids[position] for position in ranked]
                ranked_passages = [passages[position] for position in ranked]
                assert len(ranked_pids) == len(set(ranked_pids))

                ranking = list(zip(ranked_scores, ranked_pids, ranked_passages))

                rlogger.log(qid, ranking, [0, 1])

                if qrels:
                    metrics.add(query_idx, qid, ranking, qrels[qid])

                    for i, (score, pid, passage) in enumerate(ranking):
                        if pid in qrels[qid]:
                            print("\n#> Found", pid, "at position", i+1, "with score", score)

                            print(passage) 
                            break

                    metrics.print_metrics(query_idx)
                    metrics.log(query_idx)

                print_message("#> checkpoint['batch'] =", args.checkpoint['batch'], '\n')
                print("rlogger.filename =", rlogger.filename)

                if len(args.milliseconds) > 1:
                    print('Slow-Ranking Avg Latency =', sum(args.milliseconds[1:]) / len(args.milliseconds[1:]))

                print("\n\n")

        print("\n\n")
        # print('Avg Latency =', sum(args.milliseconds[1:]) / len(args.milliseconds[1:]))
        print("\n\n")

    print('\n\n')
    if qrels:
        assert query_idx + 1 == len(keys) == len(set(keys))
        metrics.output_final_metrics(os.path.join(Run.path, 'ranking.metrics'), query_idx, len(queries))
    print('\n\n')
