

import os
import sys
import argparse
import torchvision
import torch
from time import time
from datetime import datetime
# import mytorch_utils

def model_summary(model):
  print("model_summary")
  print()
  print("Layer_name"+"\t"*7+"Number of Parameters")
  print("="*100)
  model_parameters = [layer for layer in model.parameters() if layer.requires_grad]
  layer_name = [child for child in model.children()]
  j = 0
  total_params = 0
  print("\t"*10)
  for i in layer_name:
    print()
    param = 0
    try:
      bias = (i.bias is not None)
    except:
      bias = False
    if not bias:
      param =model_parameters[j].numel()+model_parameters[j+1].numel()
      j = j+2
    else:
      param =model_parameters[j].numel()
      j = j+1
    print(str(i)+"\t"*3+str(param))
    total_params+=param
  print("="*100)
  print(f"Total Params:{total_params}")

if __name__ == "__main__":
    a = datetime.now().replace(microsecond=0)
    parser = argparse.ArgumentParser(description='ADAM Finetuning')
    parser.add_argument('--model', default='model_path', type=str, metavar='PATH',
                        help='path to model')
    parser.add_argument('--arch', '-a', default='resnet18', type=str, metavar='ARCH',
                        help='model architecture [resnet18, resnet50, resnet101, alexnet, vgg, squeezenet, densenet, inception] (default: resnet18)')

    global args
    args = parser.parse_args()
    use_gpu = "cuda:0" if torch.cuda.is_available() else "cpu"
    device = torch.device(use_gpu)
    model_name = args.arch

    best_state_path= args.model
    checkpoint=torch.load(best_state_path, map_location=device)
    # model_ft, img_resize, input_size = mytorch_utils.initialize_model(model_name, num_classes, feature_extract=False,use_pretrained=False)
    # model_ft.load_state_dict(checkpoint['state_dict'])
    # model_ft.to(device)
    # model_ft.eval()

    model_summary(checkpoint)

    b = datetime.now().replace(microsecond=0)
    print ("time taken:")
    print(b - a)

