You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

optimizer.py 12 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347
  1. # -*- coding: utf-8 -*-
  2. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  3. #
  4. # Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
  5. #
  6. # Unless required by applicable law or agreed to in writing,
  7. # software distributed under the License is distributed on an
  8. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  9. from abc import ABCMeta, abstractmethod
  10. from collections import Iterable
  11. from contextlib import contextmanager
  12. from typing import Dict
  13. from typing import Iterable as Iter
  14. from typing import Set, Union
  15. import numpy as np
  16. from ..core.autodiff.grad import Grad
  17. from ..device import get_default_device
  18. from ..distributed.group import get_client, is_distributed
  19. from ..functional import add_update
  20. from ..functional.distributed import all_reduce_sum, broadcast
  21. from ..functional.utils import copy
  22. from ..logger import get_logger
  23. from ..tensor import Tensor, TensorDict
  24. from ..tensor_nn import Buffer, Parameter
  25. logger = get_logger(__name__)
  26. class _RequiredParameter:
  27. def __repr__(self):
  28. return "<required parameter>"
  29. required = _RequiredParameter()
  30. class Optimizer(metaclass=ABCMeta):
  31. r"""Base class for all optimizers.
  32. :param params: specifies what Tensors should be optimized.
  33. :param defaults: a dict of default parameters of Optimizer, like learning rate or momentum.
  34. """
  35. _recording = None
  36. _grad = None
  37. _gradients = None
  38. def __init__( # pylint: disable=too-many-branches
  39. self, params: Union[Iter[Parameter], dict], defaults: dict,
  40. ):
  41. self._state = TensorDict()
  42. self._defaults = defaults
  43. if isinstance(params, (Parameter, dict)):
  44. params = [params]
  45. else:
  46. if not isinstance(params, Iterable):
  47. raise TypeError(
  48. "params argument given to the optimizer should be "
  49. "Parameter or dict, or Iterable of them"
  50. )
  51. self.param_groups = [] # type: list
  52. self.save_load_state_ignore_keys = set()
  53. param_groups = list(params)
  54. if len(param_groups) == 0:
  55. raise ValueError("optimizer got an empty parameter list")
  56. param_type = type(param_groups[0])
  57. for param in param_groups:
  58. if not isinstance(param, param_type):
  59. raise TypeError(
  60. "types of params argument given to the optimizer shoud be same"
  61. )
  62. if not isinstance(param_groups[0], dict):
  63. param_groups = [{"params": param_groups}]
  64. for group in param_groups:
  65. self.add_param_group(group)
  66. for group in self.param_groups:
  67. self._create_state(group)
  68. def add_param_group(self, param_group: dict):
  69. r"""Add a param group to ``param_groups`` of the :class:`~megengine.optim.optimizer.Optimizer`.
  70. This can be useful when fine tuning a pre-trained network as frozen layers can be made
  71. trainable and added to the :class:`~megengine.optim.optimizer.Optimizer` as training progresses.
  72. :param param_group: specifies what tensors should be optimized along with group.
  73. """
  74. assert isinstance(param_group, dict), "param group must be a dict"
  75. if isinstance(param_group["params"], Parameter):
  76. param_group["params"] = [param_group["params"]]
  77. else:
  78. param_group["params"] = list(param_group["params"])
  79. for param in param_group["params"]:
  80. if not isinstance(param, Parameter):
  81. raise TypeError(
  82. "optimizer can only optimize Parameters, but one of the params is "
  83. + type(param)
  84. )
  85. if not param.requires_grad:
  86. raise ValueError(
  87. "optimizer can only optimize Parameters with requires_grad=True"
  88. )
  89. for name, default in self._defaults.items():
  90. if default is required and name not in param_group:
  91. raise ValueError(
  92. "parameter group didn't specify a value of "
  93. "required optimization parameter " + name
  94. )
  95. param_group.setdefault(name, default)
  96. param_set = set()
  97. for group in self.param_groups:
  98. param_set.update(set(map(id, group["params"])))
  99. assert param_set.isdisjoint(
  100. set(map(id, param_group["params"]))
  101. ), "some parameters appear in more than one parameter group"
  102. self.param_groups.append(param_group)
  103. def _add_state(self, param, state_name, initializer=None):
  104. if initializer is None:
  105. initializer = np.zeros(param.shape, dtype=np.float32)
  106. state_dict = self._state.setdefault(param, {})
  107. assert state_name not in state_dict
  108. state = Buffer(initializer)
  109. state_dict[state_name] = state
  110. @abstractmethod
  111. def _create_state(self, param_group):
  112. pass
  113. @abstractmethod
  114. def _updates(self, param_group):
  115. pass
  116. def _get_params(self):
  117. params = []
  118. for group in self.param_groups:
  119. for param in group["params"]:
  120. params.append(param)
  121. return params
  122. def grad_callback(self, grad, i, group):
  123. pass
  124. def record(self):
  125. @contextmanager
  126. def recorder():
  127. params = self._get_params()
  128. grad = Grad()
  129. gradients = [None] * len(params)
  130. if self._recording:
  131. raise RuntimeError("already recording!")
  132. try:
  133. self._recording = True
  134. self._grad = grad
  135. for group in self.param_groups:
  136. group["grads"] = [None] * len(group["params"])
  137. for i, param in enumerate(group["params"]):
  138. def callback(tensor, grad, i=i, group=group, self=self):
  139. group["grads"][i] = grad
  140. self.grad_callback(grad, i, group)
  141. grad.wrt(param, callback=callback)
  142. with grad:
  143. yield
  144. finally:
  145. self._recording = False
  146. self._grad = None
  147. for group in self.param_groups:
  148. group["grads"] = []
  149. return recorder()
  150. def _calculate_gradients(self, loss: Tensor):
  151. if not self._recording:
  152. raise RuntimeError(
  153. "no computation history. "
  154. "did you forget record() or "
  155. "call a method that clears the history?"
  156. )
  157. assert self._grad is not None
  158. if len(loss.__wrapped__._extra_data) == 0: # in case loss depends on no tensor
  159. self._grad = None
  160. return
  161. one = Tensor([1.0], dtype=loss.dtype, device=loss.device)
  162. one = one.reshape(loss.shape)
  163. try:
  164. self._grad(loss, one)
  165. finally:
  166. self._grad = None
  167. def minimize(self, loss: Tensor):
  168. self.backward(loss)
  169. self.step()
  170. def backward(self, loss: Tensor):
  171. """Computes the back-propagation of the network given loss.
  172. :param loss: The obtained loss tensor
  173. """
  174. rst = []
  175. self._calculate_gradients(loss)
  176. # _grad_skip records the parameters which are not in the path of backward
  177. self._grad_skip = set()
  178. for group in self.param_groups:
  179. # _grad_skip is consumed in optimizer.step()
  180. # XXX: assumptions
  181. # 1. Assume the same execution sequence for all GPUs in data parallel
  182. # 2. If backward is called by multiple times to accumulate grad,
  183. # it's also assumed same _grad_skip for all backward() calls
  184. # Please change the code if any assumption is invalid
  185. for param, grad in zip(group["params"], group["grads"]):
  186. if grad is None:
  187. self._grad_skip.add(param.__wrapped__)
  188. continue
  189. grad = Buffer(grad)
  190. if getattr(param, "grad", None) is None:
  191. param.grad = grad
  192. else:
  193. assert isinstance(param.grad, Buffer)
  194. param.grad += grad
  195. rst.append(param.grad)
  196. if len(self._grad_skip) > 0:
  197. get_logger(__name__).warning(
  198. "{} parameters have no grad! "
  199. "Make sure you pass the right parameters list".format(
  200. len(self._grad_skip)
  201. )
  202. )
  203. return rst
  204. def step(self):
  205. r"""Performs a single optimization step.
  206. """
  207. for group in self.param_groups:
  208. if isinstance(group["params"], set):
  209. raise TypeError(
  210. "optimized parameters need to be organized in ordered collections, "
  211. "but the ordering of parameters in sets will change between runs. "
  212. "Please use a list instead."
  213. )
  214. self._updates(group)
  215. def zero_grad(self):
  216. r"""Reset the grad to zeros.
  217. """
  218. for param_group in self.param_groups:
  219. for param in param_group["params"]:
  220. if getattr(param, "grad", None) is not None:
  221. param.grad = None
  222. def add_save_load_state_ignore_keys(self, keys: Set[str]):
  223. self.save_load_state_ignore_keys |= keys
  224. def state_dict(self) -> Dict:
  225. r"""Export the optimizer state.
  226. :return: optimizer state. Can be loaded by :meth:`load_state_dict`.
  227. """
  228. param_groups = []
  229. state = dict()
  230. param2id = TensorDict()
  231. cur_id = 0
  232. for group in self.param_groups:
  233. for param in group["params"]:
  234. if param not in param2id:
  235. param2id[param] = cur_id
  236. cur_id += 1
  237. for param, st in self._state.items():
  238. state[param2id[param]] = st
  239. for group in self.param_groups:
  240. param_group = {
  241. k: v
  242. for k, v in group.items()
  243. if k != "params" and k not in self.save_load_state_ignore_keys
  244. }
  245. param_group["params"] = [param2id[param] for param in group["params"]]
  246. param_groups.append(param_group)
  247. return {"param_groups": param_groups, "state": state}
  248. def load_state_dict(self, state: dict):
  249. r"""Loads the optimizer state.
  250. :param state: optimizer state. Should be an object returned
  251. from a call to :meth:`state_dict`.
  252. """
  253. if len(self.param_groups) != len(state["param_groups"]):
  254. raise ValueError(
  255. "loaded state dict has a different number of parameter groups"
  256. )
  257. parameter_map = dict() # type: Dict
  258. for group_new, group_saved in zip(self.param_groups, state["param_groups"]):
  259. if len(group_new["params"]) != len(group_saved["params"]):
  260. raise ValueError(
  261. "loaded state dict contains a parameter group that "
  262. "doesn't match the size of optimizer's group"
  263. )
  264. for param_new, param_saved in zip(
  265. group_new["params"], group_saved["params"]
  266. ):
  267. p = param_new
  268. self._state[p] = state["state"][param_saved].copy()
  269. for k, v in self._state[p].items():
  270. if isinstance(v, Buffer):
  271. self._state[p][k] = Buffer(v.numpy())
  272. new_keys = set(group_new.keys()) - self.save_load_state_ignore_keys
  273. saved_keys = set(group_saved.keys()) - self.save_load_state_ignore_keys
  274. if new_keys != saved_keys:
  275. raise ValueError(
  276. "loaded state dict contains a parameter group that "
  277. "doesn't match the keys of optimizer's group"
  278. )
  279. for key in saved_keys:
  280. if key != "params":
  281. group_new[key] = group_saved[key]
  282. if len(self._state.keys()) != len(state["state"].keys()):
  283. raise ValueError(
  284. "loaded state dict contains a state that doesn't match "
  285. "the size of optimizer's state"
  286. )

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台