You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

batchnorm.py 12 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345
  1. # -*- coding: utf-8 -*-
  2. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  3. #
  4. # Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  5. #
  6. # Unless required by applicable law or agreed to in writing,
  7. # software distributed under the License is distributed on an
  8. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  9. from typing import Optional
  10. import numpy as np
  11. from ..distributed.group import WORLD, Group
  12. from ..functional.nn import batch_norm, sync_batch_norm
  13. from ..tensor import Parameter, Tensor
  14. from . import init
  15. from .module import Module
  16. class _BatchNorm(Module):
  17. def __init__(
  18. self,
  19. num_features,
  20. eps=1e-5,
  21. momentum=0.9,
  22. affine=True,
  23. track_running_stats=True,
  24. freeze=False,
  25. compute_mode="default",
  26. **kwargs
  27. ):
  28. super(_BatchNorm, self).__init__(**kwargs)
  29. self.num_features = num_features
  30. self.eps = eps
  31. self.momentum = momentum
  32. self.affine = affine
  33. self.track_running_stats = track_running_stats
  34. self._track_running_stats_saved = track_running_stats
  35. self.freeze = freeze
  36. self.compute_mode = compute_mode
  37. if self.freeze:
  38. assert (
  39. self._track_running_stats_saved
  40. ), "track_running_stats must be initilized to True if freeze is True"
  41. tshape = (1, self.num_features, 1, 1)
  42. if self.affine:
  43. self.weight = Parameter(np.ones(tshape, dtype=np.float32))
  44. self.bias = Parameter(np.zeros(tshape, dtype=np.float32))
  45. else:
  46. self.weight = None
  47. self.bias = None
  48. if self.track_running_stats:
  49. self.running_mean = Tensor(np.zeros(tshape, dtype=np.float32))
  50. self.running_var = Tensor(np.ones(tshape, dtype=np.float32))
  51. else:
  52. self.running_mean = None
  53. self.running_var = None
  54. def reset_running_stats(self) -> None:
  55. if self.track_running_stats:
  56. init.zeros_(self.running_mean)
  57. init.ones_(self.running_var)
  58. def reset_parameters(self) -> None:
  59. self.reset_running_stats()
  60. if self.affine:
  61. init.ones_(self.weight)
  62. init.zeros_(self.bias)
  63. def _check_input_ndim(self, inp):
  64. raise NotImplementedError
  65. def forward(self, inp):
  66. self._check_input_ndim(inp)
  67. if self._track_running_stats_saved == False:
  68. assert (
  69. self.track_running_stats == False
  70. ), "track_running_stats can not be initilized to False and changed to True later"
  71. inp_shape = inp.shape
  72. _ndims = len(inp_shape)
  73. if _ndims != 4:
  74. origin_shape = inp_shape
  75. if _ndims == 2:
  76. n, c = inp_shape[0], inp_shape[1]
  77. new_shape = (n, c, 1, 1)
  78. elif _ndims == 3:
  79. n, c, h = inp_shape[0], inp_shape[1], inp_shape[2]
  80. new_shape = (n, c, h, 1)
  81. inp = inp.reshape(new_shape)
  82. _weight = self.weight
  83. _bias = self.bias
  84. if self.freeze:
  85. if _weight is not None:
  86. _weight = _weight.detach()
  87. if _bias is not None:
  88. _bias = _bias.detach()
  89. # Need to expand to elementwise operations here
  90. # see MGB_IMPL_OPR_GRAD(BatchNormForward) in src/opr/impl/dnn/batch_norm.cpp
  91. scale = (self.running_var + self.eps) ** (-0.5)
  92. if _weight is not None:
  93. scale *= _weight
  94. bias = -self.running_mean * scale
  95. if _bias is not None:
  96. bias += _bias
  97. return inp * scale + bias
  98. if self.training and self.track_running_stats:
  99. exponential_average_factor = self.momentum
  100. else:
  101. exponential_average_factor = 0.0 # useless
  102. output = batch_norm(
  103. inp,
  104. self.running_mean if self.track_running_stats else None,
  105. self.running_var if self.track_running_stats else None,
  106. _weight,
  107. _bias,
  108. training=self.training
  109. or ((self.running_mean is None) and (self.running_var is None)),
  110. momentum=exponential_average_factor,
  111. eps=self.eps,
  112. compute_mode=self.compute_mode,
  113. )
  114. if _ndims != 4:
  115. output = output.reshape(origin_shape)
  116. return output
  117. def _module_info_string(self) -> str:
  118. s = (
  119. "{num_features}, eps={eps}, momentum={momentum}, affine={affine}, "
  120. "track_running_stats={track_running_stats}"
  121. )
  122. return s.format(**self.__dict__)
  123. class SyncBatchNorm(_BatchNorm):
  124. r"""
  125. Applies Synchronized Batch Normalization for distributed training.
  126. :type num_features: int
  127. :param num_features: usually :math:`C` from an input of shape
  128. :math:`(N, C, H, W)` or the highest ranked dimension of an input
  129. less than 4D.
  130. :type eps: float
  131. :param eps: a value added to the denominator for numerical stability.
  132. Default: 1e-5
  133. :type momentum: float
  134. :param momentum: the value used for the ``running_mean`` and ``running_var`` computation.
  135. Default: 0.9
  136. :type affine: bool
  137. :param affine: a boolean value that when set to True, this module has
  138. learnable affine parameters. Default: True
  139. :type track_running_stats: bool
  140. :param track_running_stats: when set to True, this module tracks the
  141. running mean and variance. When set to False, this module does not
  142. track such statistics and always uses batch statistics in both training
  143. and eval modes. Default: True
  144. :type freeze: bool
  145. :param freeze: when set to True, this module does not update the
  146. running mean and variance, and uses the running mean and variance instead of
  147. the batch mean and batch variance to normalize the input. The parameter takes effect
  148. only when the module is initilized with track_running_stats as True.
  149. Default: False
  150. :type group: :class:`~megengine.distributed.Group`
  151. :param group: communication group, caculate mean and variance between this group.
  152. Default: :obj:`~megengine.distributed.WORLD`
  153. :return: output tensor.
  154. """
  155. def __init__(
  156. self,
  157. num_features,
  158. eps=1e-5,
  159. momentum=0.9,
  160. affine=True,
  161. track_running_stats=True,
  162. freeze=False,
  163. group: Optional[Group] = WORLD,
  164. **kwargs
  165. ) -> None:
  166. super().__init__(
  167. num_features, eps, momentum, affine, track_running_stats, freeze, **kwargs
  168. )
  169. self.group = group
  170. def _check_input_ndim(self, inp):
  171. if len(inp.shape) not in {2, 3, 4}:
  172. raise ValueError(
  173. "expected 2D, 3D or 4D input (got {}D input)".format(len(inp.shape))
  174. )
  175. def forward(self, inp):
  176. self._check_input_ndim(inp)
  177. inp_shape = inp.shape
  178. _ndims = len(inp_shape)
  179. if _ndims != 4:
  180. new_shape = Tensor([1, 1, 1, 1], device=inp.device)
  181. origin_shape = inp_shape
  182. if _ndims == 2:
  183. new_shape[:2] = origin_shape[:2]
  184. elif _ndims == 3:
  185. new_shape[:3] = origin_shape[:3]
  186. else:
  187. raise ValueError(
  188. "expected 2D, 3D or 4D input (got {}D input)".format(len(inp_shape))
  189. )
  190. inp = inp.reshape(new_shape)
  191. if self.training and self.track_running_stats:
  192. exponential_average_factor = self.momentum
  193. else:
  194. exponential_average_factor = 0.0 # useless
  195. _weight = self.weight
  196. _bias = self.bias
  197. if self.freeze:
  198. if _weight is not None:
  199. _weight = _weight.detach()
  200. if _bias is not None:
  201. _bias = _bias.detach()
  202. output = sync_batch_norm(
  203. inp,
  204. self.running_mean,
  205. self.running_var,
  206. _weight,
  207. _bias,
  208. training=(self.training and not self.freeze)
  209. or ((self.running_mean is None) and (self.running_var is None)),
  210. momentum=exponential_average_factor,
  211. eps=self.eps,
  212. group=self.group,
  213. )
  214. if _ndims != 4:
  215. output = output.reshape(origin_shape)
  216. return output
  217. class BatchNorm1d(_BatchNorm):
  218. r"""
  219. Applies Batch Normalization over a 2D/3D tensor.
  220. Refer to :class:`~.BatchNorm2d` for more information.
  221. """
  222. def _check_input_ndim(self, inp):
  223. if len(inp.shape) not in {2, 3}:
  224. raise ValueError(
  225. "expected 2D or 3D input (got {}D input)".format(len(inp.shape))
  226. )
  227. class BatchNorm2d(_BatchNorm):
  228. r"""
  229. Applies Batch Normalization over a 4D tensor.
  230. .. math::
  231. y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta
  232. The mean and standard-deviation are calculated per-dimension over
  233. the mini-batches and :math:`\gamma` and :math:`\beta` are learnable
  234. parameter vectors.
  235. By default, during training this layer keeps running estimates of its
  236. computed mean and variance, which are then used for normalization during
  237. evaluation. The running estimates are kept with a default :attr:`momentum`
  238. of 0.9.
  239. If :attr:`track_running_stats` is set to ``False``, this layer will not
  240. keep running estimates, batch statistics is used during
  241. evaluation time instead.
  242. Because the Batch Normalization is done over the `C` dimension, computing
  243. statistics on `(N, H, W)` slices, it's common terminology to call this
  244. Spatial Batch Normalization.
  245. :type num_features: int
  246. :param num_features: usually :math:`C` from an input of shape
  247. :math:`(N, C, H, W)` or the highest ranked dimension of an input
  248. less than 4D.
  249. :type eps: float
  250. :param eps: a value added to the denominator for numerical stability.
  251. Default: 1e-5
  252. :type momentum: float
  253. :param momentum: the value used for the ``running_mean`` and ``running_var`` computation.
  254. Default: 0.9
  255. :type affine: bool
  256. :param affine: a boolean value that when set to True, this module has
  257. learnable affine parameters. Default: True
  258. :type track_running_stats: bool
  259. :param track_running_stats: when set to True, this module tracks the
  260. running mean and variance. When set to False, this module does not
  261. track such statistics and always uses batch statistics in both training
  262. and eval modes. Default: True
  263. :type freeze: bool
  264. :param freeze: when set to True, this module does not update the
  265. running mean and variance, and uses the running mean and variance instead of
  266. the batch mean and batch variance to normalize the input. The parameter takes effect
  267. only when the module is initilized with track_running_stats as True.
  268. Default: False
  269. Examples:
  270. .. testcode::
  271. import numpy as np
  272. import megengine as mge
  273. import megengine.module as M
  274. # With Learnable Parameters
  275. m = M.BatchNorm2d(4)
  276. inp = mge.tensor(np.random.rand(1, 4, 3, 3).astype("float32"))
  277. oup = m(inp)
  278. print(m.weight.numpy().flatten(), m.bias.numpy().flatten())
  279. # Without L`e`arnable Parameters
  280. m = M.BatchNorm2d(4, affine=False)
  281. oup = m(inp)
  282. print(m.weight, m.bias)
  283. Outputs:
  284. .. testoutput::
  285. [1. 1. 1. 1.] [0. 0. 0. 0.]
  286. None None
  287. """
  288. def _check_input_ndim(self, inp):
  289. if len(inp.shape) != 4:
  290. raise ValueError("expected 4D input (got {}D input)".format(len(inp.shape)))

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台