You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

lr_scheduler.py 2.5 kB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273
  1. # -*- coding: utf-8 -*-
  2. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  3. #
  4. # Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
  5. #
  6. # Unless required by applicable law or agreed to in writing,
  7. # software distributed under the License is distributed on an
  8. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  9. from abc import ABCMeta
  10. from .distributed_optimizer import DistributedOptimizer
  11. class LRScheduler(metaclass=ABCMeta):
  12. r"""Base class for all learning rate based schedulers.
  13. :param optimizer: Wrapped optimizer.
  14. :param current_epoch: The index of current epoch. Default: -1
  15. """
  16. def __init__( # pylint: disable=too-many-branches
  17. self, optimizer: DistributedOptimizer, current_epoch: int = -1
  18. ):
  19. if not isinstance(optimizer, DistributedOptimizer):
  20. raise TypeError(
  21. "optimizer argument given to the lr_scheduler should be Optimizer"
  22. )
  23. self.optimizer = optimizer
  24. self.current_epoch = current_epoch
  25. if current_epoch == -1:
  26. for group in self.optimizer.param_groups:
  27. group.setdefault("initial_lr", group["lr"])
  28. else:
  29. for i, group in enumerate(optimizer.param_groups):
  30. if "initial_lr" not in group:
  31. raise KeyError(
  32. "param 'initial_lr' is not specified in "
  33. "param_groups[{}] when resuming an optimizer".format(i)
  34. )
  35. self.base_lrs = list(
  36. map(lambda group: group["initial_lr"], self.optimizer.param_groups)
  37. )
  38. self.step()
  39. def state_dict(self):
  40. r"""Returns the state of the scheduler as a :class:`dict`.
  41. It contains an entry for every variable in self.__dict__ which
  42. is not the optimizer.
  43. """
  44. raise NotImplementedError
  45. def load_state_dict(self, state_dict):
  46. r"""Loads the schedulers state.
  47. :param state_dict (dict): scheduler state.
  48. """
  49. raise NotImplementedError
  50. def get_lr(self):
  51. r""" Compute current learning rate for the scheduler.
  52. """
  53. raise NotImplementedError
  54. def step(self, epoch=None):
  55. if epoch is None:
  56. self.current_epoch += 1
  57. else:
  58. self.current_epoch = epoch
  59. values = self.get_lr()
  60. for param_group, lr in zip(self.optimizer.param_groups, values):
  61. param_group["lr"] = lr

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台