import os
import ujson
import torch
import random

from collections import defaultdict, OrderedDict

from pruner.parameters import DEVICE
from pruner.modeling.colbert import ColBERTPruner
from pruner.utils.utils import print_message, load_checkpoint
from pruner.evaluation.load_model import load_model
from pruner.utils.runs import Run

def load_pruner(args, do_print=True):
    pruner, checkpoint = load_model(args, do_print)

    # TODO: If the parameters below were not specified on the command line, their *checkpoint* values should be used.
    # I.e., not their purely (i.e., training) default values.

    for k in ['doc_maxlen', 'dim', 'amp', ]: #!@ custom
        if 'arguments' in checkpoint and hasattr(args, k):
            if k in checkpoint['arguments'] and checkpoint['arguments'][k] != getattr(args, k):
                a, b = checkpoint['arguments'][k], getattr(args, k)
                Run.warn(f"Got checkpoint['arguments']['{k}'] != args.{k} (i.e., {a} != {b})")

    if 'arguments' in checkpoint:
        if args.rank < 1:
            print(ujson.dumps(checkpoint['arguments'], indent=4))

    if do_print:
        print('\n')

    return pruner, checkpoint
