You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

multi_step_lr.py 2.2 kB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970
  1. # -*- coding: utf-8 -*-
  2. from bisect import bisect_right
  3. from typing import Iterable as Iter
  4. from .lr_scheduler import LRScheduler
  5. from .optimizer import Optimizer
  6. class MultiStepLR(LRScheduler):
  7. r"""Decays the learning rate of each parameter group by gamma once the
  8. number of epoch reaches one of the milestones.
  9. Args:
  10. optimizer: wrapped optimizer.
  11. milestones: list of epoch indices which should be increasing.
  12. gamma: multiplicative factor of learning rate decay. Default: 0.1
  13. current_epoch: the index of current epoch. Default: -1
  14. """
  15. def __init__(
  16. self,
  17. optimizer: Optimizer,
  18. milestones: Iter[int],
  19. gamma: float = 0.1,
  20. current_epoch: int = -1,
  21. ):
  22. if not list(milestones) == sorted(milestones):
  23. raise ValueError(
  24. "Milestones should be a list of increasing integers. Got {}".format(
  25. milestones
  26. )
  27. )
  28. self.milestones = milestones
  29. self.gamma = gamma
  30. super().__init__(optimizer, current_epoch)
  31. def state_dict(self):
  32. r"""Returns the state of the scheduler as a :class:`dict`.
  33. It contains an entry for every variable in self.__dict__ which
  34. is not the optimizer.
  35. """
  36. return {
  37. key: value
  38. for key, value in self.__dict__.items()
  39. if key in ["milestones", "gamma", "current_epoch"]
  40. }
  41. def load_state_dict(self, state_dict):
  42. r"""Loads the schedulers state.
  43. Args:
  44. state_dict: scheduler state.
  45. """
  46. tmp_dict = {}
  47. for key in ["milestones", "gamma", "current_epoch"]:
  48. if not key in state_dict.keys():
  49. raise KeyError(
  50. "key '{}'' is not specified in "
  51. "state_dict when loading state dict".format(key)
  52. )
  53. tmp_dict[key] = state_dict[key]
  54. self.__dict__.update(tmp_dict)
  55. def get_lr(self):
  56. return [
  57. base_lr * self.gamma ** bisect_right(self.milestones, self.current_epoch)
  58. for base_lr in self.base_lrs
  59. ]