import os
import random
import ujson

from colbert.evaluation.loaders import load_model, load_topK, load_qrels
from colbert.evaluation.loaders import load_queries, load_topK_pids #!@ custom

from colbert.training.lazy_batcher import load_collection #!@ custom

from pruner.utils.parser import Arguments
from pruner.utils.runs import Run
from pruner.evaluation.loaders import load_pruner
from pruner.evaluation.reranking import evaluate

# Re-define to prevent: AttributeError: '_RunManager' object has no attribute 'warn'
def load_colbert(args, do_print=True):
    colbert, checkpoint = load_model(args, do_print)

    # for k in ['query_maxlen', 'doc_maxlen', 'dim', 'similarity', 'amp']: #!@ original
    for k in ['query_maxlen', 'doc_maxlen', 'dim', 'similarity', 'amp', 
            'pruner_filepath', 'pseudo_query_indicator', 'pruned_index_size']: #!@ custom
        if 'arguments' in checkpoint and hasattr(args, k):
            if k in checkpoint['arguments'] and checkpoint['arguments'][k] != getattr(args, k):
                a, b = checkpoint['arguments'][k], getattr(args, k)
                Run.warn(f"Got checkpoint['arguments']['{k}'] != args.{k} (i.e., {a} != {b})")

    if 'arguments' in checkpoint:
        if args.rank < 1:
            print(ujson.dumps(checkpoint['arguments'], indent=4))

    if do_print:
        print('\n')

    return colbert, checkpoint

def main():
    random.seed(12345)

    parser = Arguments(description='Exhaustive (slow, not index-based) evaluation of re-ranking with ColBERT.')

    parser.add_model_parameters()
    parser.add_model_inference_parameters()
    parser.add_reranking_input()

    parser.add_colbert_model_parameters()
    parser.add_argument('--colbert_checkpoint',)

    parser.add_argument('--pruned_index_size', type=int, default=24)

    parser.add_argument('--depth', dest='depth', required=False, default=None, type=int)

    args = parser.parse()

    with Run.context():
        args.pruner, args.checkpoint = load_pruner(args)

        args.checkpoint = args.colbert_checkpoint
        args.colbert, args.checkpoint = load_colbert(args)

        args.qrels = load_qrels(args.qrels)

        if args.collection or args.queries:
            assert args.collection and args.queries

            args.queries = load_queries(args.queries)
            args.collection = load_collection(args.collection)
            args.topK_pids, args.qrels = load_topK_pids(args.topK, args.qrels)

        else:
            args.queries, args.topK_docs, args.topK_pids = load_topK(args.topK)

        assert (not args.shortcircuit) or args.qrels, \
            "Short-circuiting (i.e., applying minimal computation to queries with no positives in the re-ranked set) " \
            "can only be applied if qrels is provided."

        evaluate(args)


if __name__ == "__main__":
    main()
