You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

train_config.py 9.6 kB

2 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172
  1. """
  2. # -*- coding: utf-8 -*-
  3. -----------------------------------------------------------------------------------
  4. # Author: Nguyen Mau Dung
  5. # DoC: 2020.08.17
  6. # email: nguyenmaudung93.kstn@gmail.com
  7. -----------------------------------------------------------------------------------
  8. # Description: The configurations of the project will be defined here
  9. """
  10. import os
  11. import argparse
  12. import torch
  13. from easydict import EasyDict as edict
  14. def parse_train_configs():
  15. parser = argparse.ArgumentParser(description='The Implementation using PyTorch')
  16. parser.add_argument('--seed', type=int, default=2020,
  17. help='re-produce the results with seed random')
  18. parser.add_argument('--saved_fn', type=str, default='fpn_resnet_18', metavar='FN',
  19. help='The name using for saving logs, models,...')
  20. parser.add_argument('--root_dir', type=str, default='../', metavar='PATH',
  21. help='The ROOT working directory')
  22. ####################################################################
  23. ############## Model configs ########################
  24. ####################################################################
  25. parser.add_argument('--arch', type=str, default='fpn_resnet_18', metavar='ARCH',
  26. help='The name of the model architecture')
  27. parser.add_argument('--model_load_dir', type=str, default=None, metavar='PATH',
  28. help='the path of the pretrained checkpoint')
  29. ####################################################################
  30. ############## Dataloader and Running configs #######
  31. ####################################################################
  32. parser.add_argument('--data_url', type=str, default='../dataset/apollo/training', metavar='PATH',
  33. help='the path of the dataset')
  34. parser.add_argument('--val_data_url', type=str, default='../dataset/apollo/val', metavar='PATH',
  35. help='the path of the dataset')
  36. parser.add_argument('--train_model_out', type=str, default='../checkpoints', metavar='PATH',
  37. help='the path of the model output')
  38. parser.add_argument('--train_out', type=str, default='../logs', metavar='PATH',
  39. help='the path of the logs output')
  40. parser.add_argument('--hflip_prob', type=float, default=0.5,
  41. help='The probability of horizontal flip')
  42. parser.add_argument('--no-val', action='store_true',
  43. help='If true, dont evaluate the model on the val set')
  44. parser.add_argument('--num_samples', type=int, default=None,
  45. help='Take a subset of the dataset to run and debug')
  46. parser.add_argument('--num_workers', type=int, default=4,
  47. help='Number of threads for loading data')
  48. parser.add_argument('--batch_size', type=int, default=8,
  49. help='mini-batch size (default: 16), this is the total'
  50. 'batch size of all GPUs on the current node when using'
  51. 'Data Parallel or Distributed Data Parallel')
  52. parser.add_argument('--print_freq', type=int, default=50, metavar='N',
  53. help='print frequency (default: 50)')
  54. parser.add_argument('--tensorboard_freq', type=int, default=50, metavar='N',
  55. help='frequency of saving tensorboard (default: 50)')
  56. parser.add_argument('--checkpoint_freq', type=int, default=2, metavar='N',
  57. help='frequency of saving checkpoints (default: 5)')
  58. parser.add_argument('--gpu_num_per_node', type=int, default=1,
  59. help='Number of GPU')
  60. ####################################################################
  61. ############## Training strategy ####################
  62. ####################################################################
  63. parser.add_argument('--start_epoch', type=int, default=1, metavar='N',
  64. help='the starting epoch')
  65. parser.add_argument('--num_epochs', type=int, default=300, metavar='N',
  66. help='number of total epochs to run')
  67. parser.add_argument('--lr_type', type=str, default='cosin',
  68. help='the type of learning rate scheduler (cosin or multi_step or one_cycle)')
  69. parser.add_argument('--lr', type=float, default=0.001, metavar='LR',
  70. help='initial learning rate')
  71. parser.add_argument('--minimum_lr', type=float, default=1e-7, metavar='MIN_LR',
  72. help='minimum learning rate during training')
  73. parser.add_argument('--momentum', type=float, default=0.949, metavar='M',
  74. help='momentum')
  75. parser.add_argument('-wd', '--weight_decay', type=float, default=0., metavar='WD',
  76. help='weight decay (default: 0.)')
  77. parser.add_argument('--optimizer_type', type=str, default='adam', metavar='OPTIMIZER',
  78. help='the type of optimizer, it can be sgd or adam')
  79. parser.add_argument('--steps', nargs='*', default=[150, 180],
  80. help='number of burn in step')
  81. ####################################################################
  82. ############## Loss weight ##########################
  83. ####################################################################
  84. ####################################################################
  85. ############## Distributed Data Parallel ############
  86. ####################################################################
  87. parser.add_argument('--world-size', default=-1, type=int, metavar='N',
  88. help='number of nodes for distributed training')
  89. parser.add_argument('--rank', default=-1, type=int, metavar='N',
  90. help='node rank for distributed training')
  91. parser.add_argument('--dist-url', default='tcp://127.0.0.1:29500', type=str,
  92. help='url used to set up distributed training')
  93. parser.add_argument('--dist-backend', default='nccl', type=str,
  94. help='distributed backend')
  95. parser.add_argument('--gpu_idx', default=0, type=int,
  96. help='GPU index to use.')
  97. parser.add_argument('--no_cuda', default= False,
  98. help='If true, cuda is not used.')
  99. parser.add_argument('--multiprocessing-distributed', action='store_true',
  100. help='Use multi-processing distributed training to launch '
  101. 'N processes per node, which has N GPUs. This is the '
  102. 'fastest way to use PyTorch for either single node or '
  103. 'multi node data parallel training')
  104. ####################################################################
  105. ############## Evaluation configurations ###################
  106. ####################################################################
  107. parser.add_argument('--evaluate', action='store_true',
  108. help='only evaluate the model, not training')
  109. parser.add_argument('--resume_path', type=str, default=None, metavar='PATH',
  110. help='the path of the resumed checkpoint')
  111. parser.add_argument('--K', type=int, default=50,
  112. help='the number of top K')
  113. configs = edict(vars(parser.parse_args()))
  114. ####################################################################
  115. ############## Hardware configurations #############################
  116. ####################################################################
  117. # configs.device = torch.device('cpu' if configs.no_cuda else 'cuda')
  118. configs.device = torch.device('cpu' if configs.no_cuda else 'cuda:{}'.format(configs.gpu_idx))
  119. configs.ngpus_per_node = torch.cuda.device_count()
  120. configs.pin_memory = True
  121. configs.input_size = (1216, 608)
  122. configs.hm_size = (304, 152)
  123. configs.down_ratio = 4
  124. configs.max_objects = 50
  125. configs.imagenet_pretrained = True
  126. configs.head_conv = 64
  127. configs.num_classes = 3
  128. configs.num_center_offset = 2
  129. configs.num_z = 1
  130. configs.num_dim = 3
  131. configs.num_direction = 2 # sin, cos
  132. configs.heads = {
  133. 'hm_cen': configs.num_classes,
  134. 'cen_offset': configs.num_center_offset,
  135. 'direction': configs.num_direction,
  136. 'z_coor': configs.num_z,
  137. 'dim': configs.num_dim
  138. }
  139. configs.num_input_features = 4
  140. ####################################################################
  141. ############## Dataset, logs, Checkpoints dir ######################
  142. ####################################################################
  143. configs.dataset = 'apollo' # or kitti
  144. configs.dataset_dir = configs.data_url
  145. # configs.checkpoints_dir = os.path.join(configs.train_model_out, configs.saved_fn)
  146. configs.checkpoints_dir = configs.train_model_out
  147. # configs.logs_dir = os.path.join(configs.train_out, configs.saved_fn)
  148. configs.logs_dir = configs.train_out
  149. configs.pretrained_path = configs.model_load_dir
  150. if not os.path.isdir(configs.checkpoints_dir):
  151. os.makedirs(configs.checkpoints_dir)
  152. if not os.path.isdir(configs.logs_dir):
  153. os.makedirs(configs.logs_dir)
  154. return configs

一站式算法开发平台、高性能分布式深度学习框架、先进算法模型库、视觉模型炼知平台、数据可视化分析平台等一系列平台及工具,在模型高效分布式训练、数据处理和可视分析、模型炼知和轻量化等技术上形成独特优势,目前已在产学研等各领域近千家单位及个人提供AI应用赋能