import os
import random
import torch
import copy

import pruner.utils.distributed as distributed

from pruner.utils.parser import Arguments
from pruner.utils.runs import Run

from pruner.training.training import train


def main():
    parser = Arguments(description='Training ColBERTPruner with <passage, label> tuples.')

    parser.add_model_parameters()
    parser.add_model_training_parameters()
    parser.add_training_input()

    parser.add_argument('--checkpoint', type=str, default="")

    args = parser.parse()

    assert args.bsize % args.accumsteps == 0, ((args.bsize, args.accumsteps),
                                               "The batch size must be divisible by the number of gradient accumulation steps.")
    assert args.doc_maxlen <= 512

    with Run.context(consider_failed_if_interrupted=False):
        train(args) 


if __name__ == "__main__":
    main()
