You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

multi_step_lr.py 2.6 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778
  1. # -*- coding: utf-8 -*-
  2. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  3. #
  4. # Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
  5. #
  6. # Unless required by applicable law or agreed to in writing,
  7. # software distributed under the License is distributed on an
  8. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  9. from bisect import bisect_right
  10. from typing import Iterable as Iter
  11. from .lr_scheduler import LRScheduler
  12. from .optimizer import Optimizer
  13. class MultiStepLR(LRScheduler):
  14. r"""Decays the learning rate of each parameter group by gamma once the
  15. number of epoch reaches one of the milestones.
  16. :param optimizer: wrapped optimizer.
  17. :type milestones: list
  18. :param milestones: list of epoch indices which should be increasing.
  19. :type gamma: float
  20. :param gamma: multiplicative factor of learning rate decay. Default: 0.1
  21. :param current_epoch: the index of current epoch. Default: -1
  22. """
  23. def __init__(
  24. self,
  25. optimizer: Optimizer,
  26. milestones: Iter[int],
  27. gamma: float = 0.1,
  28. current_epoch: int = -1,
  29. ):
  30. if not list(milestones) == sorted(milestones):
  31. raise ValueError(
  32. "Milestones should be a list of increasing integers. Got {}".format(
  33. milestones
  34. )
  35. )
  36. self.milestones = milestones
  37. self.gamma = gamma
  38. super().__init__(optimizer, current_epoch)
  39. def state_dict(self):
  40. r"""Returns the state of the scheduler as a :class:`dict`.
  41. It contains an entry for every variable in self.__dict__ which
  42. is not the optimizer.
  43. """
  44. return {
  45. key: value
  46. for key, value in self.__dict__.items()
  47. if key in ["milestones", "gamma", "current_epoch"]
  48. }
  49. def load_state_dict(self, state_dict):
  50. r"""Loads the schedulers state.
  51. :type state_dict: dict
  52. :param state_dict: scheduler state.
  53. """
  54. tmp_dict = {}
  55. for key in ["milestones", "gamma", "current_epoch"]:
  56. if not key in state_dict.keys():
  57. raise KeyError(
  58. "key '{}'' is not specified in "
  59. "state_dict when loading state dict".format(key)
  60. )
  61. tmp_dict[key] = state_dict[key]
  62. self.__dict__.update(tmp_dict)
  63. def get_lr(self):
  64. return [
  65. base_lr * self.gamma ** bisect_right(self.milestones, self.current_epoch)
  66. for base_lr in self.base_lrs
  67. ]

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台