You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

train.py 12 kB

2 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290
  1. """
  2. # -*- coding: utf-8 -*-
  3. -----------------------------------------------------------------------------------
  4. # Author: Nguyen Mau Dung
  5. # DoC: 2020.08.17
  6. # email: nguyenmaudung93.kstn@gmail.com
  7. -----------------------------------------------------------------------------------
  8. # Description: This script for training
  9. """
  10. import time
  11. import numpy as np
  12. import sys
  13. import random
  14. import os
  15. import warnings
  16. warnings.filterwarnings("ignore", category=UserWarning)
  17. import torch
  18. from torch.utils.tensorboard import SummaryWriter
  19. import torch.distributed as dist
  20. import torch.multiprocessing as mp
  21. import torch.utils.data.distributed
  22. from tqdm import tqdm
  23. src_dir = os.path.dirname(os.path.realpath(__file__))
  24. # while not src_dir.endswith("sfa"):
  25. # src_dir = os.path.dirname(src_dir)
  26. if src_dir not in sys.path:
  27. sys.path.append(src_dir)
  28. from data_process.kitti_dataloader import create_train_dataloader, create_val_dataloader
  29. from models.model_utils import create_model, make_data_parallel, get_num_parameters
  30. from utils.train_utils import create_optimizer, create_lr_scheduler, get_saved_state, save_checkpoint
  31. from utils.torch_utils import reduce_tensor, to_python_float
  32. from utils.misc import AverageMeter, ProgressMeter
  33. from utils.logger import Logger
  34. from config.train_config import parse_train_configs
  35. from losses.losses import Compute_Loss
  36. def main():
  37. configs = parse_train_configs()
  38. # Re-produce results
  39. if configs.seed is not None:
  40. random.seed(configs.seed)
  41. np.random.seed(configs.seed)
  42. torch.manual_seed(configs.seed)
  43. torch.backends.cudnn.deterministic = True
  44. torch.backends.cudnn.benchmark = False
  45. if configs.gpu_idx is not None:
  46. print('You have chosen a specific GPU. This will completely disable data parallelism.')
  47. if configs.dist_url == "env://" and configs.world_size == -1:
  48. configs.world_size = int(os.environ["WORLD_SIZE"])
  49. configs.distributed = configs.world_size > 1 or configs.multiprocessing_distributed
  50. if configs.multiprocessing_distributed:
  51. configs.world_size = configs.ngpus_per_node * configs.world_size
  52. mp.spawn(main_worker, nprocs=configs.ngpus_per_node, args=(configs,))
  53. else:
  54. main_worker(configs.gpu_idx, configs)
  55. def main_worker(gpu_idx, configs):
  56. configs.gpu_idx = gpu_idx
  57. # configs.device = torch.device('cpu' if configs.gpu_idx is None else 'cuda:{}'.format(configs.gpu_idx))
  58. if configs.distributed:
  59. if configs.dist_url == "env://" and configs.rank == -1:
  60. configs.rank = int(os.environ["RANK"])
  61. if configs.multiprocessing_distributed:
  62. # For multiprocessing distributed training, rank needs to be the
  63. # global rank among all the processes
  64. configs.rank = configs.rank * configs.ngpus_per_node + gpu_idx
  65. dist.init_process_group(backend=configs.dist_backend, init_method=configs.dist_url,
  66. world_size=configs.world_size, rank=configs.rank)
  67. configs.subdivisions = int(64 / configs.batch_size / configs.ngpus_per_node)
  68. else:
  69. configs.subdivisions = int(64 / configs.batch_size)
  70. configs.is_master_node = (not configs.distributed) or (
  71. configs.distributed and (configs.rank % configs.ngpus_per_node == 0))
  72. if configs.is_master_node:
  73. logger = Logger(configs.logs_dir, configs.saved_fn)
  74. logger.info('>>> Created a new logger')
  75. logger.info('>>> configs: {}'.format(configs))
  76. tb_writer = SummaryWriter(log_dir=os.path.join(configs.logs_dir, 'tensorboard'))
  77. else:
  78. logger = None
  79. tb_writer = None
  80. # model
  81. model = create_model(configs)
  82. # load weight from a checkpoint
  83. if configs.pretrained_path is not None:
  84. # assert os.path.isfile(configs.pretrained_path), "=> no checkpoint found at '{}'".format(configs.pretrained_path)
  85. if os.path.isfile(configs.pretrained_path):
  86. model_path = configs.pretrained_path
  87. else:
  88. # 取最后一个模型
  89. model_path = os.path.join(configs.pretrained_path, os.listdir(configs.pretrained_path)[-1])
  90. model.load_state_dict(torch.load(model_path, map_location=configs.device))
  91. if logger is not None:
  92. logger.info('loaded pretrained model at {}'.format(configs.pretrained_path))
  93. # resume weights of model from a checkpoint
  94. if configs.resume_path is not None:
  95. assert os.path.isfile(configs.resume_path), "=> no checkpoint found at '{}'".format(configs.resume_path)
  96. model.load_state_dict(torch.load(configs.resume_path, map_location='cpu'))
  97. if logger is not None:
  98. logger.info('resume training model from checkpoint {}'.format(configs.resume_path))
  99. # Data Parallel
  100. model = make_data_parallel(model, configs)
  101. # Make sure to create optimizer after moving the model to cuda
  102. optimizer = create_optimizer(configs, model)
  103. lr_scheduler = create_lr_scheduler(optimizer, configs)
  104. configs.step_lr_in_epoch = False if configs.lr_type in ['multi_step', 'cosin', 'one_cycle'] else True
  105. # resume optimizer, lr_scheduler from a checkpoint
  106. if configs.resume_path is not None:
  107. utils_path = configs.resume_path.replace('Model_', 'Utils_')
  108. assert os.path.isfile(utils_path), "=> no checkpoint found at '{}'".format(utils_path)
  109. utils_state_dict = torch.load(utils_path, map_location='cuda:{}'.format(configs.gpu_idx))
  110. optimizer.load_state_dict(utils_state_dict['optimizer'])
  111. lr_scheduler.load_state_dict(utils_state_dict['lr_scheduler'])
  112. configs.start_epoch = utils_state_dict['epoch'] + 1
  113. if configs.is_master_node:
  114. num_parameters = get_num_parameters(model)
  115. logger.info('number of trained parameters of the model: {}'.format(num_parameters))
  116. if logger is not None:
  117. logger.info(">>> Loading dataset & getting dataloader...")
  118. # Create dataloader
  119. train_dataloader, train_sampler = create_train_dataloader(configs)
  120. if logger is not None:
  121. logger.info('number of batches in training set: {}'.format(len(train_dataloader)))
  122. if configs.evaluate:
  123. val_dataloader = create_val_dataloader(configs)
  124. val_loss = validate(val_dataloader, model, configs)
  125. print('val_loss: {:.4e}'.format(val_loss))
  126. return
  127. for epoch in range(configs.start_epoch, configs.num_epochs + 1):
  128. if logger is not None:
  129. logger.info('{}'.format('*-' * 40))
  130. logger.info('{} {}/{} {}'.format('=' * 35, epoch, configs.num_epochs, '=' * 35))
  131. logger.info('{}'.format('*-' * 40))
  132. logger.info('>>> Epoch: [{}/{}]'.format(epoch, configs.num_epochs))
  133. if configs.distributed:
  134. train_sampler.set_epoch(epoch)
  135. # train for one epoch
  136. train_one_epoch(train_dataloader, model, optimizer, lr_scheduler, epoch, configs, logger, tb_writer)
  137. if (not configs.no_val) and (epoch % configs.checkpoint_freq == 0):
  138. val_dataloader = create_val_dataloader(configs)
  139. print('number of batches in val_dataloader: {}'.format(len(val_dataloader)))
  140. val_loss = validate(val_dataloader, model, configs)
  141. print('val_loss: {:.4e}'.format(val_loss))
  142. if tb_writer is not None:
  143. tb_writer.add_scalar('Val_loss', val_loss, epoch)
  144. # Save checkpoint
  145. if configs.is_master_node and ((epoch % configs.checkpoint_freq) == 0):
  146. model_state_dict, utils_state_dict = get_saved_state(model, optimizer, lr_scheduler, epoch, configs)
  147. save_checkpoint(configs.checkpoints_dir, configs.saved_fn, model_state_dict, utils_state_dict, epoch)
  148. if not configs.step_lr_in_epoch:
  149. lr_scheduler.step()
  150. if tb_writer is not None:
  151. tb_writer.add_scalar('LR', lr_scheduler.get_lr()[0], epoch)
  152. if tb_writer is not None:
  153. tb_writer.close()
  154. if configs.distributed:
  155. cleanup()
  156. def cleanup():
  157. dist.destroy_process_group()
  158. def train_one_epoch(train_dataloader, model, optimizer, lr_scheduler, epoch, configs, logger, tb_writer):
  159. batch_time = AverageMeter('Time', ':6.3f')
  160. data_time = AverageMeter('Data', ':6.3f')
  161. losses = AverageMeter('Loss', ':.4e')
  162. progress = ProgressMeter(len(train_dataloader), [batch_time, data_time, losses],
  163. prefix="Train - Epoch: [{}/{}]".format(epoch, configs.num_epochs))
  164. criterion = Compute_Loss(device=configs.device)
  165. num_iters_per_epoch = len(train_dataloader)
  166. # switch to train mode
  167. model.train()
  168. start_time = time.time()
  169. for batch_idx, batch_data in enumerate(tqdm(train_dataloader)):
  170. data_time.update(time.time() - start_time)
  171. imgs, targets = batch_data
  172. batch_size = imgs.size(0)
  173. global_step = num_iters_per_epoch * (epoch - 1) + batch_idx + 1
  174. for k in targets.keys():
  175. targets[k] = targets[k].to(configs.device, non_blocking=True)
  176. imgs = imgs.to(configs.device, non_blocking=True).float()
  177. outputs = model(imgs)
  178. total_loss, loss_stats = criterion(outputs, targets)
  179. # For torch.nn.DataParallel case
  180. if (not configs.distributed) and (configs.gpu_idx is None):
  181. total_loss = torch.mean(total_loss)
  182. # compute gradient and perform backpropagation
  183. total_loss.backward()
  184. if global_step % configs.subdivisions == 0:
  185. optimizer.step()
  186. # zero the parameter gradients
  187. optimizer.zero_grad()
  188. # Adjust learning rate
  189. if configs.step_lr_in_epoch:
  190. lr_scheduler.step()
  191. if tb_writer is not None:
  192. tb_writer.add_scalar('LR', lr_scheduler.get_lr()[0], global_step)
  193. if configs.distributed:
  194. reduced_loss = reduce_tensor(total_loss.data, configs.world_size)
  195. else:
  196. reduced_loss = total_loss.data
  197. losses.update(to_python_float(reduced_loss), batch_size)
  198. # measure elapsed time
  199. # torch.cuda.synchronize()
  200. batch_time.update(time.time() - start_time)
  201. if tb_writer is not None:
  202. if (global_step % configs.tensorboard_freq) == 0:
  203. loss_stats['avg_loss'] = losses.avg
  204. tb_writer.add_scalars('Train', loss_stats, global_step)
  205. # Log message
  206. if logger is not None:
  207. if (global_step % configs.print_freq) == 0:
  208. logger.info(progress.get_message(batch_idx))
  209. start_time = time.time()
  210. def validate(val_dataloader, model, configs):
  211. losses = AverageMeter('Loss', ':.4e')
  212. criterion = Compute_Loss(device=configs.device)
  213. # switch to train mode
  214. model.eval()
  215. with torch.no_grad():
  216. for batch_idx, batch_data in enumerate(tqdm(val_dataloader)):
  217. imgs, targets = batch_data
  218. batch_size = imgs.size(0)
  219. for k in targets.keys():
  220. targets[k] = targets[k].to(configs.device, non_blocking=True)
  221. imgs = imgs.to(configs.device, non_blocking=True).float()
  222. outputs = model(imgs)
  223. total_loss, loss_stats = criterion(outputs, targets)
  224. # For torch.nn.DataParallel case
  225. if (not configs.distributed) and (configs.gpu_idx is None):
  226. total_loss = torch.mean(total_loss)
  227. if configs.distributed:
  228. reduced_loss = reduce_tensor(total_loss.data, configs.world_size)
  229. else:
  230. reduced_loss = total_loss.data
  231. losses.update(to_python_float(reduced_loss), batch_size)
  232. return losses.avg
  233. if __name__ == '__main__':
  234. try:
  235. main()
  236. except KeyboardInterrupt:
  237. try:
  238. cleanup()
  239. sys.exit(0)
  240. except SystemExit:
  241. os._exit(0)

一站式算法开发平台、高性能分布式深度学习框架、先进算法模型库、视觉模型炼知平台、数据可视化分析平台等一系列平台及工具,在模型高效分布式训练、数据处理和可视分析、模型炼知和轻量化等技术上形成独特优势,目前已在产学研等各领域近千家单位及个人提供AI应用赋能