You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

lr_scheduler.py 14 kB

2 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312
  1. import torch
  2. from torch.optim import SGD, lr_scheduler
  3. import numpy as np
  4. class _LRMomentumScheduler(lr_scheduler._LRScheduler):
  5. def __init__(self, optimizer, last_epoch=-1):
  6. if last_epoch == -1:
  7. for group in optimizer.param_groups:
  8. group.setdefault('initial_momentum', group['momentum'])
  9. else:
  10. for i, group in enumerate(optimizer.param_groups):
  11. if 'initial_momentum' not in group:
  12. raise KeyError("param 'initial_momentum' is not specified "
  13. "in param_groups[{}] when resuming an optimizer".format(i))
  14. self.base_momentums = list(map(lambda group: group['initial_momentum'], optimizer.param_groups))
  15. super().__init__(optimizer, last_epoch)
  16. def get_lr(self):
  17. raise NotImplementedError
  18. def get_momentum(self):
  19. raise NotImplementedError
  20. def step(self, epoch=None):
  21. if epoch is None:
  22. epoch = self.last_epoch + 1
  23. self.last_epoch = epoch
  24. for param_group, lr, momentum in zip(self.optimizer.param_groups, self.get_lr(), self.get_momentum()):
  25. param_group['lr'] = lr
  26. param_group['momentum'] = momentum
  27. class ParameterUpdate(object):
  28. """A callable class used to define an arbitrary schedule defined by a list.
  29. This object is designed to be passed to the LambdaLR or LambdaScheduler scheduler to apply
  30. the given schedule.
  31. Arguments:
  32. params {list or numpy.array} -- List or numpy array defining parameter schedule.
  33. base_param {float} -- Parameter value used to initialize the optimizer.
  34. """
  35. def __init__(self, params, base_param):
  36. self.params = np.hstack([params, 0])
  37. self.base_param = base_param
  38. def __call__(self, epoch):
  39. return self.params[epoch] / self.base_param
  40. def apply_lambda(last_epoch, bases, lambdas):
  41. return [base * lmbda(last_epoch) for lmbda, base in zip(lambdas, bases)]
  42. class LambdaScheduler(_LRMomentumScheduler):
  43. """Sets the learning rate and momentum of each parameter group to the initial lr and momentum
  44. times a given function. When last_epoch=-1, sets initial lr and momentum to the optimizer
  45. values.
  46. Args:
  47. optimizer (Optimizer): Wrapped optimizer.
  48. lr_lambda (function or list): A function which computes a multiplicative
  49. factor given an integer parameter epoch, or a list of such
  50. functions, one for each group in optimizer.param_groups.
  51. Default: lambda x:x.
  52. momentum_lambda (function or list): As for lr_lambda but applied to momentum.
  53. Default: lambda x:x.
  54. last_epoch (int): The index of last epoch. Default: -1.
  55. Example:
  56. >>> # Assuming optimizer has two groups.
  57. >>> lr_lambda = [
  58. ... lambda epoch: epoch // 30,
  59. ... lambda epoch: 0.95 ** epoch
  60. ... ]
  61. >>> mom_lambda = [
  62. ... lambda epoch: max(0, (50 - epoch) // 50),
  63. ... lambda epoch: 0.99 ** epoch
  64. ... ]
  65. >>> scheduler = LambdaScheduler(optimizer, lr_lambda, mom_lambda)
  66. >>> for epoch in range(100):
  67. >>> train(...)
  68. >>> validate(...)
  69. >>> scheduler.step()
  70. """
  71. def __init__(self, optimizer, lr_lambda=lambda x: x, momentum_lambda=lambda x: x, last_epoch=-1):
  72. self.optimizer = optimizer
  73. if not isinstance(lr_lambda, (list, tuple)):
  74. self.lr_lambdas = [lr_lambda] * len(optimizer.param_groups)
  75. else:
  76. if len(lr_lambda) != len(optimizer.param_groups):
  77. raise ValueError("Expected {} lr_lambdas, but got {}".format(
  78. len(optimizer.param_groups), len(lr_lambda)))
  79. self.lr_lambdas = list(lr_lambda)
  80. if not isinstance(momentum_lambda, (list, tuple)):
  81. self.momentum_lambdas = [momentum_lambda] * len(optimizer.param_groups)
  82. else:
  83. if len(momentum_lambda) != len(optimizer.param_groups):
  84. raise ValueError("Expected {} momentum_lambdas, but got {}".format(
  85. len(optimizer.param_groups), len(momentum_lambda)))
  86. self.momentum_lambdas = list(momentum_lambda)
  87. self.last_epoch = last_epoch
  88. super().__init__(optimizer, last_epoch)
  89. def state_dict(self):
  90. """Returns the state of the scheduler as a :class:`dict`.
  91. It contains an entry for every variable in self.__dict__ which
  92. is not the optimizer.
  93. The learning rate and momentum lambda functions will only be saved if they are
  94. callable objects and not if they are functions or lambdas.
  95. """
  96. state_dict = {key: value for key, value in self.__dict__.items()
  97. if key not in ('optimizer', 'lr_lambdas', 'momentum_lambdas')}
  98. state_dict['lr_lambdas'] = [None] * len(self.lr_lambdas)
  99. state_dict['momentum_lambdas'] = [None] * len(self.momentum_lambdas)
  100. for idx, (lr_fn, mom_fn) in enumerate(zip(self.lr_lambdas, self.momentum_lambdas)):
  101. if not isinstance(lr_fn, types.FunctionType):
  102. state_dict['lr_lambdas'][idx] = lr_fn.__dict__.copy()
  103. if not isinstance(mom_fn, types.FunctionType):
  104. state_dict['momentum_lambdas'][idx] = mom_fn.__dict__.copy()
  105. return state_dict
  106. def load_state_dict(self, state_dict):
  107. """Loads the schedulers state.
  108. Arguments:
  109. state_dict (dict): scheduler state. Should be an object returned
  110. from a call to :meth:`state_dict`.
  111. """
  112. lr_lambdas = state_dict.pop('lr_lambdas')
  113. momentum_lambdas = state_dict.pop('momentum_lambdas')
  114. self.__dict__.update(state_dict)
  115. for idx, fn in enumerate(lr_lambdas):
  116. if fn is not None:
  117. self.lr_lambdas[idx].__dict__.update(fn)
  118. for idx, fn in enumerate(momentum_lambdas):
  119. if fn is not None:
  120. self.momentum_lambdas[idx].__dict__.update(fn)
  121. def get_lr(self):
  122. return apply_lambda(self.last_epoch, self.base_lrs, self.lr_lambdas)
  123. def get_momentum(self):
  124. return apply_lambda(self.last_epoch, self.base_momentums, self.momentum_lambdas)
  125. class ParameterUpdate(object):
  126. """A callable class used to define an arbitrary schedule defined by a list.
  127. This object is designed to be passed to the LambdaLR or LambdaScheduler scheduler to apply
  128. the given schedule. If a base_param is zero, no updates are applied.
  129. Arguments:
  130. params {list or numpy.array} -- List or numpy array defining parameter schedule.
  131. base_param {float} -- Parameter value used to initialize the optimizer.
  132. """
  133. def __init__(self, params, base_param):
  134. self.params = np.hstack([params, 0])
  135. self.base_param = base_param
  136. if base_param < 1e-12:
  137. self.base_param = 1
  138. self.params = self.params * 0.0 + 1.0
  139. def __call__(self, epoch):
  140. return self.params[epoch] / self.base_param
  141. class ListScheduler(LambdaScheduler):
  142. """Sets the learning rate and momentum of each parameter group to values defined by lists.
  143. When last_epoch=-1, sets initial lr and momentum to the optimizer values. One of both of lr
  144. and momentum schedules may be specified.
  145. Note that the parameters used to initialize the optimizer are overriden by those defined by
  146. this scheduler.
  147. Args:
  148. optimizer (Optimizer): Wrapped optimizer.
  149. lrs (list or numpy.ndarray): A list of learning rates, or a list of lists, one for each
  150. parameter group. One- or two-dimensional numpy arrays may also be passed.
  151. momentum (list or numpy.ndarray): A list of momentums, or a list of lists, one for each
  152. parameter group. One- or two-dimensional numpy arrays may also be passed.
  153. last_epoch (int): The index of last epoch. Default: -1.
  154. Example:
  155. >>> # Assuming optimizer has two groups.
  156. >>> lrs = [
  157. ... np.linspace(0.01, 0.1, 100),
  158. ... np.logspace(-2, 0, 100)
  159. ... ]
  160. >>> momentums = [
  161. ... np.linspace(0.85, 0.95, 100),
  162. ... np.linspace(0.8, 0.99, 100)
  163. ... ]
  164. >>> scheduler = ListScheduler(optimizer, lrs, momentums)
  165. >>> for epoch in range(100):
  166. >>> train(...)
  167. >>> validate(...)
  168. >>> scheduler.step()
  169. """
  170. def __init__(self, optimizer, lrs=None, momentums=None, last_epoch=-1):
  171. groups = optimizer.param_groups
  172. if lrs is None:
  173. lr_lambda = lambda x: x
  174. else:
  175. lrs = np.array(lrs) if isinstance(lrs, (list, tuple)) else lrs
  176. if len(lrs.shape) == 1:
  177. lr_lambda = [ParameterUpdate(lrs, g['lr']) for g in groups]
  178. else:
  179. lr_lambda = [ParameterUpdate(l, g['lr']) for l, g in zip(lrs, groups)]
  180. if momentums is None:
  181. momentum_lambda = lambda x: x
  182. else:
  183. momentums = np.array(momentums) if isinstance(momentums, (list, tuple)) else momentums
  184. if len(momentums.shape) == 1:
  185. momentum_lambda = [ParameterUpdate(momentums, g['momentum']) for g in groups]
  186. else:
  187. momentum_lambda = [ParameterUpdate(l, g['momentum']) for l, g in zip(momentums, groups)]
  188. super().__init__(optimizer, lr_lambda, momentum_lambda)
  189. class RangeFinder(ListScheduler):
  190. """Scheduler class that implements the LR range search specified in:
  191. A disciplined approach to neural network hyper-parameters: Part 1 -- learning rate, batch
  192. size, momentum, and weight decay. Leslie N. Smith, 2018, arXiv:1803.09820.
  193. Logarithmically spaced learning rates from 1e-7 to 1 are searched. The number of increments in
  194. that range is determined by 'epochs'.
  195. Note that the parameters used to initialize the optimizer are overriden by those defined by
  196. this scheduler.
  197. Args:
  198. optimizer (Optimizer): Wrapped optimizer.
  199. epochs (int): Number of epochs over which to run test.
  200. Example:
  201. >>> scheduler = RangeFinder(optimizer, 100)
  202. >>> for epoch in range(100):
  203. >>> train(...)
  204. >>> validate(...)
  205. >>> scheduler.step()
  206. """
  207. def __init__(self, optimizer, epochs):
  208. lrs = np.logspace(-7, 0, epochs)
  209. super().__init__(optimizer, lrs)
  210. class OneCyclePolicy(ListScheduler):
  211. """Scheduler class that implements the 1cycle policy search specified in:
  212. A disciplined approach to neural network hyper-parameters: Part 1 -- learning rate, batch
  213. size, momentum, and weight decay. Leslie N. Smith, 2018, arXiv:1803.09820.
  214. Args:
  215. optimizer (Optimizer): Wrapped optimizer.
  216. lr (float or list). Maximum learning rate in range. If a list of values is passed, they
  217. should correspond to parameter groups.
  218. epochs (int): The number of epochs to use during search.
  219. momentum_rng (list). Optional upper and lower momentum values (may be both equal). Set to
  220. None to run without momentum. Default: [0.85, 0.95]. If a list of lists is passed, they
  221. should correspond to parameter groups.
  222. phase_ratio (float): Fraction of epochs used for the increasing and decreasing phase of
  223. the schedule. For example, if phase_ratio=0.45 and epochs=100, the learning rate will
  224. increase from lr/10 to lr over 45 epochs, then decrease back to lr/10 over 45 epochs,
  225. then decrease to lr/100 over the remaining 10 epochs. Default: 0.45.
  226. """
  227. def __init__(self, optimizer, lr, epochs, momentum_rng=[0.85, 0.95], phase_ratio=0.45):
  228. phase_epochs = int(phase_ratio * epochs)
  229. if isinstance(lr, (list, tuple)):
  230. lrs = [
  231. np.hstack([
  232. np.linspace(l * 1e-1, l, phase_epochs),
  233. np.linspace(l, l * 1e-1, phase_epochs),
  234. np.linspace(l * 1e-1, l * 1e-2, epochs - 2 * phase_epochs),
  235. ]) for l in lr
  236. ]
  237. else:
  238. lrs = np.hstack([
  239. np.linspace(lr * 1e-1, lr, phase_epochs),
  240. np.linspace(lr, lr * 1e-1, phase_epochs),
  241. np.linspace(lr * 1e-1, lr * 1e-2, epochs - 2 * phase_epochs),
  242. ])
  243. if momentum_rng is not None:
  244. momentum_rng = np.array(momentum_rng)
  245. if len(momentum_rng.shape) == 2:
  246. for i, g in enumerate(optimizer.param_groups):
  247. g['momentum'] = momentum_rng[i][1]
  248. momentums = [
  249. np.hstack([
  250. np.linspace(m[1], m[0], phase_epochs),
  251. np.linspace(m[0], m[1], phase_epochs),
  252. np.linspace(m[1], m[1], epochs - 2 * phase_epochs),
  253. ]) for m in momentum_rng
  254. ]
  255. else:
  256. for i, g in enumerate(optimizer.param_groups):
  257. g['momentum'] = momentum_rng[1]
  258. momentums = np.hstack([
  259. np.linspace(momentum_rng[1], momentum_rng[0], phase_epochs),
  260. np.linspace(momentum_rng[0], momentum_rng[1], phase_epochs),
  261. np.linspace(momentum_rng[1], momentum_rng[1], epochs - 2 * phase_epochs),
  262. ])
  263. else:
  264. momentums = None
  265. super().__init__(optimizer, lrs, momentums)

一站式算法开发平台、高性能分布式深度学习框架、先进算法模型库、视觉模型炼知平台、数据可视化分析平台等一系列平台及工具,在模型高效分布式训练、数据处理和可视分析、模型炼知和轻量化等技术上形成独特优势,目前已在产学研等各领域近千家单位及个人提供AI应用赋能