import os
import ujson
import torch
import random

from collections import defaultdict, OrderedDict

from pruner.parameters import DEVICE
from pruner.modeling.colbert import ColBERTPruner
from pruner.utils.utils import print_message, load_checkpoint


def load_model(args, do_print=True):
    pruner = ColBERTPruner.from_pretrained('bert-base-uncased',
                                    doc_maxlen=args.doc_maxlen,
                                    dim=args.dim,
                                    mask_punctuation=args.mask_punctuation,
                                    )
    pruner = pruner.to(DEVICE)

    print_message("#> Loading model checkpoint.", condition=do_print)

    checkpoint = load_checkpoint(args.checkpoint, pruner, do_print=do_print)

    pruner.eval()

    return pruner, checkpoint
