You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

lr_scheduler.py 2.1 kB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465
  1. from abc import ABCMeta
  2. from .optimizer import Optimizer
  3. class LRScheduler(metaclass=ABCMeta):
  4. r"""Base class for all learning rate based schedulers.
  5. :param optimizer: Wrapped optimizer.
  6. :param current_epoch: The index of current epoch. Default: -1
  7. """
  8. def __init__( # pylint: disable=too-many-branches
  9. self, optimizer: Optimizer, current_epoch: int = -1
  10. ):
  11. if not isinstance(optimizer, Optimizer):
  12. raise TypeError(
  13. "optimizer argument given to the lr_scheduler should be Optimizer"
  14. )
  15. self.optimizer = optimizer
  16. self.current_epoch = current_epoch
  17. if current_epoch == -1:
  18. for group in self.optimizer.param_groups:
  19. group.setdefault("initial_lr", group["lr"])
  20. else:
  21. for i, group in enumerate(optimizer.param_groups):
  22. if "initial_lr" not in group:
  23. raise KeyError(
  24. "param 'initial_lr' is not specified in "
  25. "param_groups[{}] when resuming an optimizer".format(i)
  26. )
  27. self.base_lrs = list(
  28. map(lambda group: group["initial_lr"], self.optimizer.param_groups)
  29. )
  30. self.step()
  31. def state_dict(self):
  32. r"""Returns the state of the scheduler as a :class:`dict`.
  33. It contains an entry for every variable in self.__dict__ which
  34. is not the optimizer.
  35. """
  36. raise NotImplementedError
  37. def load_state_dict(self, state_dict):
  38. r"""Loads the schedulers state.
  39. :param state_dict (dict): scheduler state.
  40. """
  41. raise NotImplementedError
  42. def get_lr(self):
  43. r""" Compute current learning rate for the scheduler.
  44. """
  45. raise NotImplementedError
  46. def step(self, epoch=None):
  47. if epoch is None:
  48. self.current_epoch += 1
  49. else:
  50. self.current_epoch = epoch
  51. values = self.get_lr()
  52. for param_group, lr in zip(self.optimizer.param_groups, values):
  53. param_group["lr"] = lr

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台