You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

model_utils.py 5.3 kB

2 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134
  1. """
  2. # -*- coding: utf-8 -*-
  3. -----------------------------------------------------------------------------------
  4. # Author: Nguyen Mau Dung
  5. # DoC: 2020.08.09
  6. # email: nguyenmaudung93.kstn@gmail.com
  7. -----------------------------------------------------------------------------------
  8. # Description: utils functions that use for model
  9. """
  10. import os
  11. import sys
  12. import torch
  13. src_dir = os.path.dirname(os.path.realpath(__file__))
  14. # while not src_dir.endswith("sfa"):
  15. # src_dir = os.path.dirname(src_dir)
  16. if src_dir not in sys.path:
  17. sys.path.append(src_dir)
  18. from models import resnet, fpn_resnet
  19. def create_model(configs):
  20. """Create model based on architecture name"""
  21. try:
  22. arch_parts = configs.arch.split('_')
  23. num_layers = int(arch_parts[-1])
  24. except:
  25. raise ValueError
  26. if 'fpn_resnet' in configs.arch:
  27. print('using ResNet architecture with feature pyramid')
  28. model = fpn_resnet.get_pose_net(num_layers=num_layers, heads=configs.heads, head_conv=configs.head_conv,
  29. imagenet_pretrained=configs.imagenet_pretrained)
  30. elif 'resnet' in configs.arch:
  31. print('using ResNet architecture')
  32. model = resnet.get_pose_net(num_layers=num_layers, heads=configs.heads, head_conv=configs.head_conv,
  33. imagenet_pretrained=configs.imagenet_pretrained)
  34. else:
  35. assert False, 'Undefined model backbone'
  36. return model
  37. def get_num_parameters(model):
  38. """Count number of trained parameters of the model"""
  39. if hasattr(model, 'module'):
  40. num_parameters = sum(p.numel() for p in model.module.parameters() if p.requires_grad)
  41. else:
  42. num_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
  43. return num_parameters
  44. def make_data_parallel(model, configs):
  45. if configs.distributed:
  46. # For multiprocessing distributed, DistributedDataParallel constructor
  47. # should always set the single device scope, otherwise,
  48. # DistributedDataParallel will use all available devices.
  49. if configs.gpu_idx is not None:
  50. torch.cuda.set_device(configs.gpu_idx)
  51. model.cuda(configs.gpu_idx)
  52. # When using a single GPU per process and per
  53. # DistributedDataParallel, we need to divide the batch size
  54. # ourselves based on the total number of GPUs we have
  55. configs.batch_size = int(configs.batch_size / configs.ngpus_per_node)
  56. configs.num_workers = int((configs.num_workers + configs.ngpus_per_node - 1) / configs.ngpus_per_node)
  57. model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[configs.gpu_idx])
  58. else:
  59. model.cuda()
  60. # DistributedDataParallel will divide and allocate batch_size to all
  61. # available GPUs if device_ids are not set
  62. model = torch.nn.parallel.DistributedDataParallel(model)
  63. elif configs.gpu_idx is not None:
  64. torch.cuda.set_device(configs.gpu_idx)
  65. model = model.cuda(configs.gpu_idx)
  66. else:
  67. # DataParallel will divide and allocate batch_size to all available GPUs
  68. model = torch.nn.DataParallel(model).cuda()
  69. return model
  70. if __name__ == '__main__':
  71. import argparse
  72. from torchsummary import summary
  73. from easydict import EasyDict as edict
  74. parser = argparse.ArgumentParser(description='RTM3D Implementation')
  75. parser.add_argument('-a', '--arch', type=str, default='resnet_18', metavar='ARCH',
  76. help='The name of the model architecture')
  77. parser.add_argument('--head_conv', type=int, default=-1,
  78. help='conv layer channels for output head'
  79. '0 for no conv layer'
  80. '-1 for default setting: '
  81. '64 for resnets and 256 for dla.')
  82. configs = edict(vars(parser.parse_args()))
  83. if configs.head_conv == -1: # init default head_conv
  84. configs.head_conv = 256 if 'dla' in configs.arch else 64
  85. configs.num_classes = 3
  86. configs.num_vertexes = 8
  87. configs.num_center_offset = 2
  88. configs.num_vertexes_offset = 2
  89. configs.num_dimension = 3
  90. configs.num_rot = 8
  91. configs.num_depth = 1
  92. configs.num_wh = 2
  93. configs.heads = {
  94. 'hm_mc': configs.num_classes,
  95. 'hm_ver': configs.num_vertexes,
  96. 'vercoor': configs.num_vertexes * 2,
  97. 'cenoff': configs.num_center_offset,
  98. 'veroff': configs.num_vertexes_offset,
  99. 'dim': configs.num_dimension,
  100. 'rot': configs.num_rot,
  101. 'depth': configs.num_depth,
  102. 'wh': configs.num_wh
  103. }
  104. configs.device = torch.device('cuda:1')
  105. # configs.device = torch.device('cpu')
  106. model = create_model(configs).to(device=configs.device)
  107. sample_input = torch.randn((1, 3, 224, 224)).to(device=configs.device)
  108. # summary(model.cuda(1), (3, 224, 224))
  109. output = model(sample_input)
  110. for hm_name, hm_out in output.items():
  111. print('hm_name: {}, hm_out size: {}'.format(hm_name, hm_out.size()))
  112. print('number of parameters: {}'.format(get_num_parameters(model)))

一站式算法开发平台、高性能分布式深度学习框架、先进算法模型库、视觉模型炼知平台、数据可视化分析平台等一系列平台及工具,在模型高效分布式训练、数据处理和可视分析、模型炼知和轻量化等技术上形成独特优势,目前已在产学研等各领域近千家单位及个人提供AI应用赋能