You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

lr_scheduler.py 2.1 kB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667
  1. # -*- coding: utf-8 -*-
  2. from abc import ABCMeta
  3. from .optimizer import Optimizer
  4. class LRScheduler(metaclass=ABCMeta):
  5. r"""Base class for all learning rate based schedulers.
  6. Args:
  7. optimizer: wrapped optimizer.
  8. current_epoch: the index of current epoch. Default: -1
  9. """
  10. def __init__( # pylint: disable=too-many-branches
  11. self, optimizer: Optimizer, current_epoch: int = -1
  12. ):
  13. if not isinstance(optimizer, Optimizer):
  14. raise TypeError(
  15. "optimizer argument given to the lr_scheduler should be Optimizer"
  16. )
  17. self.optimizer = optimizer
  18. self.current_epoch = current_epoch
  19. if current_epoch == -1:
  20. for group in self.optimizer.param_groups:
  21. group.setdefault("initial_lr", group["lr"])
  22. else:
  23. for i, group in enumerate(optimizer.param_groups):
  24. if "initial_lr" not in group:
  25. raise KeyError(
  26. "param 'initial_lr' is not specified in "
  27. "param_groups[{}] when resuming an optimizer".format(i)
  28. )
  29. self.base_lrs = list(
  30. map(lambda group: group["initial_lr"], self.optimizer.param_groups)
  31. )
  32. self.step()
  33. def state_dict(self):
  34. r"""Returns the state of the scheduler as a :class:`dict`.
  35. It contains an entry for every variable in self.__dict__ which
  36. is not the optimizer.
  37. """
  38. raise NotImplementedError
  39. def load_state_dict(self, state_dict):
  40. r"""Loads the schedulers state.
  41. Args:
  42. state_dict: scheduler state.
  43. """
  44. raise NotImplementedError
  45. def get_lr(self):
  46. r"""Compute current learning rate for the scheduler."""
  47. raise NotImplementedError
  48. def step(self, epoch=None):
  49. if epoch is None:
  50. self.current_epoch += 1
  51. else:
  52. self.current_epoch = epoch
  53. values = self.get_lr()
  54. for param_group, lr in zip(self.optimizer.param_groups, values):
  55. param_group["lr"] = lr