You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

multi_step_lr.py 2.5 kB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677
  1. # -*- coding: utf-8 -*-
  2. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  3. #
  4. # Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  5. #
  6. # Unless required by applicable law or agreed to in writing,
  7. # software distributed under the License is distributed on an
  8. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  9. from bisect import bisect_right
  10. from typing import Iterable as Iter
  11. from .lr_scheduler import LRScheduler
  12. from .optimizer import Optimizer
  13. class MultiStepLR(LRScheduler):
  14. r"""Decays the learning rate of each parameter group by gamma once the
  15. number of epoch reaches one of the milestones.
  16. Args:
  17. optimizer: wrapped optimizer.
  18. milestones: list of epoch indices which should be increasing.
  19. gamma: multiplicative factor of learning rate decay. Default: 0.1
  20. current_epoch: the index of current epoch. Default: -1
  21. """
  22. def __init__(
  23. self,
  24. optimizer: Optimizer,
  25. milestones: Iter[int],
  26. gamma: float = 0.1,
  27. current_epoch: int = -1,
  28. ):
  29. if not list(milestones) == sorted(milestones):
  30. raise ValueError(
  31. "Milestones should be a list of increasing integers. Got {}".format(
  32. milestones
  33. )
  34. )
  35. self.milestones = milestones
  36. self.gamma = gamma
  37. super().__init__(optimizer, current_epoch)
  38. def state_dict(self):
  39. r"""Returns the state of the scheduler as a :class:`dict`.
  40. It contains an entry for every variable in self.__dict__ which
  41. is not the optimizer.
  42. """
  43. return {
  44. key: value
  45. for key, value in self.__dict__.items()
  46. if key in ["milestones", "gamma", "current_epoch"]
  47. }
  48. def load_state_dict(self, state_dict):
  49. r"""Loads the schedulers state.
  50. Args:
  51. state_dict: scheduler state.
  52. """
  53. tmp_dict = {}
  54. for key in ["milestones", "gamma", "current_epoch"]:
  55. if not key in state_dict.keys():
  56. raise KeyError(
  57. "key '{}'' is not specified in "
  58. "state_dict when loading state dict".format(key)
  59. )
  60. tmp_dict[key] = state_dict[key]
  61. self.__dict__.update(tmp_dict)
  62. def get_lr(self):
  63. return [
  64. base_lr * self.gamma ** bisect_right(self.milestones, self.current_epoch)
  65. for base_lr in self.base_lrs
  66. ]