import os
import random
import ujson
from collections import OrderedDict

from pruner.utils.parser import Arguments
from pruner.utils.runs import Run

from pruner.evaluation.loaders import load_pruner
from pruner.training.batcher import load_data

from pruner.evaluation.ranking import evaluate_ranking

from colbert.evaluation.loaders import load_queries, load_qrels, load_collection

from colbert.evaluation.load_model import load_model
def load_colbert(args, do_print=True):
    colbert, checkpoint = load_model(args, do_print)

    # TODO: If the parameters below were not specified on the command line, their *checkpoint* values should be used.
    # I.e., not their purely (i.e., training) default values.

    # for k in ['query_maxlen', 'doc_maxlen', 'dim', 'similarity', 'amp']: #!@ original
    for k in ['query_maxlen', 'doc_maxlen', 'dim', 'similarity', 'amp', 
            'pruner_filepath', 'pseudo_query_indicator', 'pruned_index_size']: #!@ custom
        if 'arguments' in checkpoint and hasattr(args, k):
            if k in checkpoint['arguments'] and checkpoint['arguments'][k] != getattr(args, k):
                a, b = checkpoint['arguments'][k], getattr(args, k)
                Run.warn(f"Got checkpoint['arguments']['{k}'] != args.{k} (i.e., {a} != {b})")

    if 'arguments' in checkpoint:
        if args.rank < 1:
            print(ujson.dumps(checkpoint['arguments'], indent=4))

    if do_print:
        print('\n')

    return colbert, checkpoint


def load_ranking(path):
    print("#> Loading ranking from", path, "...")

    qid_to_topk = OrderedDict()
    
    with open(path, mode='r', encoding="utf-8") as f:
        for line_idx, line in enumerate(f):
            qid, pid, rank, score = line.strip().split('\t')
            
            qid, pid, rank = map(int, (qid, pid, rank))
            score = float(score)
            
            qid_to_topk[qid] = qid_to_topk.get(qid, {})
            qid_to_topk[qid][pid] = score

    return qid_to_topk

def main():
    random.seed(12345)

    parser = Arguments(description='Evaluation on pruner.')

    parser.add_model_parameters()
    parser.add_model_inference_parameters()

    parser.add_argument('--ranking', dest='ranking', default=None, help="ranking.tsv from ColBERT")
    parser.add_argument('--queries')
    parser.add_argument('--collection')
    parser.add_argument('--qrels')

    parser.add_colbert_model_parameters()
    parser.add_argument('--colbert_checkpoint',)

    args = parser.parse()

    with Run.context():
        
        args.ranking = load_ranking(args.ranking)
        # OrderedDict[ qid (int) -> Dict [ pid (int) -> score (float) ] ]
        
        args.queries = load_queries(args.queries)
        args.qrels = load_qrels(args.qrels)
        args.collection = load_collection(args.collection)

        args.pruner, args.checkpoint = load_pruner(args)
        
        args.checkpoint = args.colbert_checkpoint
        args.colbert, args.checkpoint = load_colbert(args)
        
        evaluate_ranking(args)


if __name__ == "__main__":
    main()
