You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

multi_step_lr.py 2.2 kB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667
  1. from bisect import bisect_right
  2. from typing import Iterable as Iter
  3. from .lr_scheduler import LRScheduler
  4. from .optimizer import Optimizer
  5. class MultiStepLR(LRScheduler):
  6. r"""Decays the learning rate of each parameter group by gamma once the
  7. number of epoch reaches one of the milestones.
  8. :param optimizer: Wrapped optimizer.
  9. :param milestones (list): List of epoch indices. Must be increasing.
  10. :param gamma (float): Multiplicative factor of learning rate decay. Default: 0.1.
  11. :param current_epoch: The index of current epoch. Default: -1.
  12. """
  13. def __init__(
  14. self,
  15. optimizer: Optimizer,
  16. milestones: Iter[int],
  17. gamma: float = 0.1,
  18. current_epoch: int = -1,
  19. ):
  20. if not list(milestones) == sorted(milestones):
  21. raise ValueError(
  22. "Milestones should be a list of increasing integers. Got {}".format(
  23. milestones
  24. )
  25. )
  26. self.milestones = milestones
  27. self.gamma = gamma
  28. super().__init__(optimizer, current_epoch)
  29. def state_dict(self):
  30. r"""Returns the state of the scheduler as a :class:`dict`.
  31. It contains an entry for every variable in self.__dict__ which
  32. is not the optimizer.
  33. """
  34. return {
  35. key: value
  36. for key, value in self.__dict__.items()
  37. if key in ["milestones", "gamma", "current_epoch"]
  38. }
  39. def load_state_dict(self, state_dict):
  40. r"""Loads the schedulers state.
  41. :param state_dict (dict): scheduler state.
  42. """
  43. tmp_dict = {}
  44. for key in ["milestones", "gamma", "current_epoch"]:
  45. if not key in state_dict.keys():
  46. raise KeyError(
  47. "key '{}'' is not specified in "
  48. "state_dict when loading state dict".format(key)
  49. )
  50. tmp_dict[key] = state_dict[key]
  51. self.__dict__.update(tmp_dict)
  52. def get_lr(self):
  53. return [
  54. base_lr * self.gamma ** bisect_right(self.milestones, self.current_epoch)
  55. for base_lr in self.base_lrs
  56. ]

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台