You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

train_utils.py 5.1 kB

2 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140
  1. """
  2. # -*- coding: utf-8 -*-
  3. -----------------------------------------------------------------------------------
  4. # Author: Nguyen Mau Dung
  5. # DoC: 2020.08.09
  6. # email: nguyenmaudung93.kstn@gmail.com
  7. -----------------------------------------------------------------------------------
  8. # Description: utils functions that use for training process
  9. """
  10. import copy
  11. import os
  12. import math
  13. import sys
  14. import torch
  15. from torch.optim.lr_scheduler import LambdaLR
  16. import matplotlib.pyplot as plt
  17. src_dir = os.path.dirname(os.path.realpath(__file__))
  18. # while not src_dir.endswith("sfa"):
  19. # src_dir = os.path.dirname(src_dir)
  20. if src_dir not in sys.path:
  21. sys.path.append(src_dir)
  22. from utils.lr_scheduler import OneCyclePolicy
  23. def create_optimizer(configs, model):
  24. """Create optimizer for training process
  25. """
  26. if hasattr(model, 'module'):
  27. train_params = [param for param in model.module.parameters() if param.requires_grad]
  28. else:
  29. train_params = [param for param in model.parameters() if param.requires_grad]
  30. if configs.optimizer_type == 'sgd':
  31. optimizer = torch.optim.SGD(train_params, lr=configs.lr, momentum=configs.momentum, nesterov=True)
  32. elif configs.optimizer_type == 'adam':
  33. optimizer = torch.optim.Adam(train_params, lr=configs.lr, weight_decay=configs.weight_decay)
  34. else:
  35. assert False, "Unknown optimizer type"
  36. return optimizer
  37. def create_lr_scheduler(optimizer, configs):
  38. """Create learning rate scheduler for training process"""
  39. if configs.lr_type == 'multi_step':
  40. def multi_step_scheduler(i):
  41. if i < configs.steps[0]:
  42. factor = 1.
  43. elif i < configs.steps[1]:
  44. factor = 0.1
  45. else:
  46. factor = 0.01
  47. return factor
  48. lr_scheduler = LambdaLR(optimizer, multi_step_scheduler)
  49. elif configs.lr_type == 'cosin':
  50. # Scheduler https://arxiv.org/pdf/1812.01187.pdf
  51. lf = lambda x: (((1 + math.cos(x * math.pi / configs.num_epochs)) / 2) ** 1.0) * 0.9 + 0.1 # cosine
  52. lr_scheduler = LambdaLR(optimizer, lr_lambda=lf)
  53. elif configs.lr_type == 'one_cycle':
  54. lr_scheduler = OneCyclePolicy(optimizer, configs.lr, configs.num_epochs, momentum_rng=[0.85, 0.95],
  55. phase_ratio=0.45)
  56. else:
  57. raise ValueError
  58. plot_lr_scheduler(optimizer, lr_scheduler, configs.num_epochs, save_dir=configs.logs_dir, lr_type=configs.lr_type)
  59. return lr_scheduler
  60. def get_saved_state(model, optimizer, lr_scheduler, epoch, configs):
  61. """Get the information to save with checkpoints"""
  62. if hasattr(model, 'module'):
  63. model_state_dict = model.module.state_dict()
  64. else:
  65. model_state_dict = model.state_dict()
  66. utils_state_dict = {
  67. 'epoch': epoch,
  68. 'configs': configs,
  69. 'optimizer': copy.deepcopy(optimizer.state_dict()),
  70. 'lr_scheduler': copy.deepcopy(lr_scheduler.state_dict())
  71. }
  72. return model_state_dict, utils_state_dict
  73. def save_checkpoint(checkpoints_dir, saved_fn, model_state_dict, utils_state_dict, epoch):
  74. """Save checkpoint every epoch only is best model or after every checkpoint_freq epoch"""
  75. model_save_path = os.path.join(checkpoints_dir, 'Model_{}_epoch_{}.pth'.format(saved_fn, epoch))
  76. utils_save_path = os.path.join(checkpoints_dir, 'Utils_{}_epoch_{}.pth'.format(saved_fn, epoch))
  77. torch.save(model_state_dict, model_save_path)
  78. torch.save(utils_state_dict, utils_save_path)
  79. print('save a checkpoint at {}'.format(model_save_path))
  80. def plot_lr_scheduler(optimizer, scheduler, num_epochs=300, save_dir='', lr_type=''):
  81. # Plot LR simulating training for full num_epochs
  82. optimizer, scheduler = copy.copy(optimizer), copy.copy(scheduler) # do not modify originals
  83. y = []
  84. for _ in range(num_epochs):
  85. scheduler.step()
  86. y.append(optimizer.param_groups[0]['lr'])
  87. plt.plot(y, '.-', label='LR')
  88. plt.xlabel('epoch')
  89. plt.ylabel('LR')
  90. plt.grid()
  91. plt.xlim(0, num_epochs)
  92. plt.ylim(0)
  93. plt.tight_layout()
  94. plt.savefig(os.path.join(save_dir, 'LR_{}.png'.format(lr_type)), dpi=200)
  95. if __name__ == '__main__':
  96. from easydict import EasyDict as edict
  97. from torchvision.models import resnet18
  98. configs = edict()
  99. configs.steps = [150, 180]
  100. configs.lr_type = 'one_cycle' # multi_step, cosin, one_csycle
  101. configs.logs_dir = '../../logs/'
  102. configs.num_epochs = 50
  103. configs.lr = 2.25e-3
  104. net = resnet18()
  105. optimizer = torch.optim.Adam(net.parameters(), 0.0002)
  106. # scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[3, 6, 9], gamma=0.1)
  107. # scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 3, gamma=0.1)
  108. scheduler = create_lr_scheduler(optimizer, configs)
  109. for i in range(configs.num_epochs):
  110. print(i, scheduler.get_lr())
  111. scheduler.step()

一站式算法开发平台、高性能分布式深度学习框架、先进算法模型库、视觉模型炼知平台、数据可视化分析平台等一系列平台及工具,在模型高效分布式训练、数据处理和可视分析、模型炼知和轻量化等技术上形成独特优势,目前已在产学研等各领域近千家单位及个人提供AI应用赋能