You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

optimizer.py 9.4 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262
  1. # -*- coding: utf-8 -*-
  2. import copy
  3. from abc import ABCMeta, abstractmethod
  4. from collections.abc import Iterable
  5. from typing import Dict
  6. from typing import Iterable as Iter
  7. from typing import Union
  8. import numpy as np
  9. from ..core._imperative_rt.core2 import (
  10. get_auto_format_convert,
  11. pop_scope,
  12. push_scope,
  13. set_auto_format_convert,
  14. set_option,
  15. )
  16. from ..core.tensor.utils import set_convert_inputs
  17. from ..tensor import Parameter, Tensor
  18. from ..utils.deprecation import deprecated
  19. class _RequiredParameter:
  20. def __repr__(self):
  21. return "<required parameter>"
  22. required = _RequiredParameter()
  23. class Optimizer(metaclass=ABCMeta):
  24. r"""Base class for all optimizers.
  25. Args:
  26. params: specifies what Tensors should be optimized.
  27. defaults: a dict of default parameters of Optimizer, like learning rate or momentum.
  28. """
  29. def __init__( # pylint: disable=too-many-branches
  30. self, params: Union[Iter[Parameter], dict], defaults: dict,
  31. ):
  32. self._state = dict()
  33. self._defaults = defaults
  34. self._disable_type_convert = False
  35. if isinstance(params, (Parameter, dict)):
  36. params = [params]
  37. else:
  38. if not isinstance(params, Iterable):
  39. raise TypeError(
  40. "params argument given to the optimizer should be "
  41. "Parameter or dict, or Iterable of them"
  42. )
  43. self.param_groups = [] # type: list
  44. param_groups = list(params)
  45. if len(param_groups) == 0:
  46. raise ValueError("optimizer got an empty parameter list")
  47. param_type = type(param_groups[0])
  48. for param in param_groups:
  49. if not isinstance(param, param_type):
  50. raise TypeError(
  51. "types of params argument given to the optimizer shoud be same"
  52. )
  53. if not isinstance(param_groups[0], dict):
  54. param_groups = [{"params": param_groups}]
  55. for group in param_groups:
  56. self.add_param_group(group)
  57. for group in self.param_groups:
  58. self._create_state(group)
  59. def add_param_group(self, param_group: dict):
  60. r"""Add a param group to ``param_groups`` of the :class:`~megengine.optim.optimizer.Optimizer`.
  61. This can be useful when fine tuning a pre-trained network as frozen layers can be made
  62. trainable and added to the :class:`~megengine.optim.optimizer.Optimizer` as training progresses.
  63. Args:
  64. param_group: specifies what tensors should be optimized along with group.
  65. """
  66. assert isinstance(param_group, dict), "param group must be a dict"
  67. if isinstance(param_group["params"], Parameter):
  68. param_group["params"] = [param_group["params"]]
  69. else:
  70. param_group["params"] = list(param_group["params"])
  71. for param in param_group["params"]:
  72. if not isinstance(param, Parameter):
  73. raise TypeError(
  74. "optimizer can only optimize Parameters, but one of the params is "
  75. + str(type(param))
  76. )
  77. param._reset(Tensor(param.numpy(), no_cache=True, format=param.format))
  78. for name, default in self._defaults.items():
  79. if default is required and name not in param_group:
  80. raise ValueError(
  81. "parameter group didn't specify a value of "
  82. "required optimization parameter " + name
  83. )
  84. param_group.setdefault(name, default)
  85. param_set = set()
  86. for group in self.param_groups:
  87. param_set.update(set(map(id, group["params"])))
  88. assert param_set.isdisjoint(
  89. set(map(id, param_group["params"]))
  90. ), "some parameters appear in more than one parameter group"
  91. self.param_groups.append(param_group)
  92. def _add_state(self, param, state_name, initializer=None):
  93. if initializer is None:
  94. initializer = np.zeros(param.shape, dtype=np.float32)
  95. state_dict = self._state.setdefault(param, {})
  96. assert state_name not in state_dict
  97. state = Tensor(initializer, no_cache=True)
  98. state_dict[state_name] = state
  99. @abstractmethod
  100. def _create_state(self, param_group):
  101. pass
  102. @abstractmethod
  103. def _updates(self, param_group):
  104. pass
  105. def _get_params(self):
  106. params = []
  107. for group in self.param_groups:
  108. for param in group["params"]:
  109. params.append(param)
  110. return params
  111. def step(self):
  112. r"""Performs a single optimization step."""
  113. # set the globle state `_enable_convert_inputs` to `False` to disable
  114. # the `convert_inputs` for param updates
  115. set_option("record_computing_path", 0)
  116. _origin_auto_format = get_auto_format_convert()
  117. set_auto_format_convert(False)
  118. if self._disable_type_convert:
  119. backup = set_convert_inputs(False)
  120. for group in self.param_groups:
  121. if isinstance(group["params"], set):
  122. raise TypeError(
  123. "optimized parameters need to be organized in ordered collections, "
  124. "but the ordering of parameters in sets will change between runs. "
  125. "Please use a list instead."
  126. )
  127. push_scope("step")
  128. self._updates(group)
  129. pop_scope("step")
  130. if self._disable_type_convert:
  131. # restore the globle state `_enable_convert_inputs`
  132. set_convert_inputs(backup)
  133. set_option("record_computing_path", 1)
  134. set_auto_format_convert(_origin_auto_format)
  135. return self
  136. @deprecated(version="1.0", reason="use clear_grad instead")
  137. def zero_grad(self):
  138. for param_group in self.param_groups:
  139. for param in param_group["params"]:
  140. if param.grad is not None:
  141. param.grad.reset_zero()
  142. def clear_grad(self):
  143. r"""Set the grad attribute to None for all parameters."""
  144. for param_group in self.param_groups:
  145. push_scope("clear_grad")
  146. for param in param_group["params"]:
  147. param.grad = None
  148. pop_scope("clear_grad")
  149. def state_dict(self, keep_var=False) -> Dict:
  150. r"""Export the optimizer state.
  151. Return:
  152. optimizer state. Can be loaded by :meth:`load_state_dict`.
  153. """
  154. param_groups = []
  155. state = dict()
  156. param2id = dict()
  157. cur_id = 0
  158. for group in self.param_groups:
  159. for param in group["params"]:
  160. if param not in param2id:
  161. param2id[param] = cur_id
  162. cur_id += 1
  163. for param, st in self._state.items():
  164. _st = copy.copy(st)
  165. if not keep_var:
  166. for k, v in st.items():
  167. _st[k] = v.numpy()
  168. state[param2id[param]] = _st
  169. for group in self.param_groups:
  170. param_group = {k: v for k, v in group.items() if k != "params"}
  171. param_group["params"] = [param2id[param] for param in group["params"]]
  172. param_groups.append(param_group)
  173. return {"param_groups": param_groups, "state": state}
  174. def load_state_dict(self, state: dict):
  175. r"""Loads the optimizer state.
  176. Args:
  177. state: optimizer state. Should be an object returned
  178. from a call to :meth:`state_dict`.
  179. """
  180. if len(self.param_groups) != len(state["param_groups"]):
  181. raise ValueError(
  182. "loaded state dict has a different number of parameter groups"
  183. )
  184. for group_new, group_saved in zip(self.param_groups, state["param_groups"]):
  185. if len(group_new["params"]) != len(group_saved["params"]):
  186. raise ValueError(
  187. "loaded state dict contains a parameter group that "
  188. "doesn't match the size of optimizer's group"
  189. )
  190. for param_new, param_saved in zip(
  191. group_new["params"], group_saved["params"]
  192. ):
  193. p = param_new
  194. self._state[p] = state["state"][param_saved].copy()
  195. for k, v in self._state[p].items():
  196. if isinstance(v, Tensor):
  197. self._state[p][k] = v.detach()
  198. else:
  199. self._state[p][k] = Tensor(v)
  200. if set(group_new.keys()) != set(group_saved.keys()):
  201. raise ValueError(
  202. "loaded state dict contains a parameter group that "
  203. "doesn't match the keys of optimizer's group"
  204. )
  205. for key in group_new.keys():
  206. if key != "params":
  207. group_new[key] = group_saved[key]
  208. if len(self._state.keys()) != len(state["state"].keys()):
  209. raise ValueError(
  210. "loaded state dict contains a state that doesn't match "
  211. "the size of optimizer's state"
  212. )
  213. def backward(self, loss):
  214. raise NotImplementedError("use autodiff.GradManager instead")
  215. def bcast_param(self):
  216. raise NotImplementedError("use distributed.bcast_list_ instead")