import argparse

import torch

if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        description="Tool to show details (e.g., args, parameters, etc.) of a checkpoint file",
    )
    # fmt: off
    parser.add_argument('input', metavar='FILE',
                        help='Input checkpoint file paths.')
    parser.add_argument('-a', '--args', action='store_true',
                        help='if set, will print the arguments.')
    parser.add_argument('-p', '--params', action='store_true',
                        help='if set, will print the parameters.')
    # fmt: on
    args = parser.parse_args()
    print(args)

    checkpoint = torch.load(args.input, map_location="cpu")
    print('| This checkpoint contains {}'.format(list(checkpoint.keys())))

    if args.args:
        print("\n| arguments:")
        key_width = max([len(k) for k in checkpoint['args'].__dict__.keys()])
        for key, value in checkpoint['args'].__dict__.items():
            print(f'{key: <{key_width}} = {value}')

    if args.params:
        print("\n| parameters:")
        key_width = max([len(k) for k in checkpoint['model'].keys()])
        for key, value in checkpoint['model'].items():
            print(f'{key: <{key_width}} = {list(value.size())}')
