You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

batchnorm.py 9.8 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295
  1. # -*- coding: utf-8 -*-
  2. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  3. #
  4. # Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
  5. #
  6. # Unless required by applicable law or agreed to in writing,
  7. # software distributed under the License is distributed on an
  8. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  9. from typing import Optional
  10. import numpy as np
  11. from ..distributed.group import WORLD, Group
  12. from ..functional.nn import batch_norm, sync_batch_norm
  13. from ..tensor import Parameter, Tensor
  14. from . import init
  15. from .module import Module
  16. class _BatchNorm(Module):
  17. def __init__(
  18. self,
  19. num_features,
  20. eps=1e-5,
  21. momentum=0.9,
  22. affine=True,
  23. track_running_stats=True,
  24. freeze=False,
  25. ):
  26. super(_BatchNorm, self).__init__()
  27. self.num_features = num_features
  28. self.eps = eps
  29. self.momentum = momentum
  30. self.affine = affine
  31. self.track_running_stats = track_running_stats
  32. self._track_running_stats_saved = track_running_stats
  33. self.freeze = freeze
  34. if self.affine:
  35. self.weight = Parameter(np.ones(num_features, dtype=np.float32))
  36. self.bias = Parameter(np.zeros(num_features, dtype=np.float32))
  37. else:
  38. self.weight = None
  39. self.bias = None
  40. tshape = (1, self.num_features, 1, 1)
  41. if self.track_running_stats:
  42. self.running_mean = Tensor(np.zeros(tshape, dtype=np.float32))
  43. self.running_var = Tensor(np.ones(tshape, dtype=np.float32))
  44. else:
  45. self.running_mean = None
  46. self.running_var = None
  47. def reset_running_stats(self) -> None:
  48. if self.track_running_stats:
  49. init.zeros_(self.running_mean)
  50. init.ones_(self.running_var)
  51. def reset_parameters(self) -> None:
  52. self.reset_running_stats()
  53. if self.affine:
  54. init.ones_(self.weight)
  55. init.zeros_(self.bias)
  56. def _check_input_ndim(self, inp):
  57. raise NotImplementedError
  58. def forward(self, inp):
  59. self._check_input_ndim(inp)
  60. if self._track_running_stats_saved == False:
  61. assert (
  62. self.track_running_stats == False
  63. ), "track_running_stats can not be initilized to False and changed to True later"
  64. inp_shape = inp.shape
  65. _ndims = len(inp_shape)
  66. if _ndims != 4:
  67. origin_shape = inp_shape
  68. if _ndims == 2:
  69. n, c = inp_shape[0], inp_shape[1]
  70. new_shape = (n, c, 1, 1)
  71. elif _ndims == 3:
  72. n, c, h = inp_shape[0], inp_shape[1], inp_shape[2]
  73. new_shape = (n, c, h, 1)
  74. inp = inp.reshape(new_shape)
  75. if self.freeze and self.training and self._track_running_stats_saved:
  76. scale = self.weight.reshape(1, -1, 1, 1) * (
  77. self.running_var + self.eps
  78. ) ** (-0.5)
  79. bias = self.bias.reshape(1, -1, 1, 1) - self.running_mean * scale
  80. return inp * scale.detach() + bias.detach()
  81. if self.training and self.track_running_stats:
  82. exponential_average_factor = self.momentum
  83. else:
  84. exponential_average_factor = 0.0 # useless
  85. output = batch_norm(
  86. inp,
  87. self.running_mean if self.track_running_stats else None,
  88. self.running_var if self.track_running_stats else None,
  89. self.weight,
  90. self.bias,
  91. training=self.training
  92. or ((self.running_mean is None) and (self.running_var is None)),
  93. momentum=exponential_average_factor,
  94. eps=self.eps,
  95. )
  96. if _ndims != 4:
  97. output = output.reshape(origin_shape)
  98. return output
  99. def _module_info_string(self) -> str:
  100. s = (
  101. "{num_features}, eps={eps}, momentum={momentum}, affine={affine}, "
  102. "track_running_stats={track_running_stats}"
  103. )
  104. return s.format(**self.__dict__)
  105. class SyncBatchNorm(_BatchNorm):
  106. r"""
  107. Applies Synchronization Batch Normalization.
  108. """
  109. def __init__(
  110. self,
  111. num_features,
  112. eps=1e-5,
  113. momentum=0.9,
  114. affine=True,
  115. track_running_stats=True,
  116. freeze=False,
  117. group: Optional[Group] = WORLD,
  118. ) -> None:
  119. super().__init__(
  120. num_features, eps, momentum, affine, track_running_stats, freeze
  121. )
  122. self.group = group
  123. def _check_input_ndim(self, inp):
  124. if len(inp.shape) not in {2, 3, 4}:
  125. raise ValueError(
  126. "expected 2D, 3D or 4D input (got {}D input)".format(len(inp.shape))
  127. )
  128. def forward(self, inp):
  129. self._check_input_ndim(inp)
  130. inp_shape = inp.shape
  131. _ndims = len(inp_shape)
  132. if _ndims != 4:
  133. new_shape = Tensor([1, 1, 1, 1], device=inp.device)
  134. origin_shape = inp_shape
  135. if _ndims == 2:
  136. new_shape[:2] = origin_shape[:2]
  137. elif _ndims == 3:
  138. new_shape[:3] = origin_shape[:3]
  139. else:
  140. raise ValueError(
  141. "expected 2D, 3D or 4D input (got {}D input)".format(len(inp_shape))
  142. )
  143. inp = inp.reshape(new_shape)
  144. if self.training and self.track_running_stats:
  145. exponential_average_factor = self.momentum
  146. else:
  147. exponential_average_factor = 0.0 # useless
  148. output = sync_batch_norm(
  149. inp,
  150. self.running_mean,
  151. self.running_var,
  152. self.weight,
  153. self.bias,
  154. self.training or not self.track_running_stats,
  155. exponential_average_factor,
  156. self.eps,
  157. group=self.group,
  158. )
  159. if _ndims != 4:
  160. output = output.reshape(origin_shape)
  161. return output
  162. class BatchNorm1d(_BatchNorm):
  163. r"""
  164. Applies Batch Normalization over a 2D/3D tensor.
  165. Refer to :class:`~.BatchNorm2d` for more information.
  166. """
  167. def _check_input_ndim(self, inp):
  168. if len(inp.shape) not in {2, 3}:
  169. raise ValueError(
  170. "expected 2D or 3D input (got {}D input)".format(len(inp.shape))
  171. )
  172. class BatchNorm2d(_BatchNorm):
  173. r"""
  174. Applies Batch Normalization over a 4D tensor.
  175. .. math::
  176. y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta
  177. The mean and standard-deviation are calculated per-dimension over
  178. the mini-batches and :math:`\gamma` and :math:`\beta` are learnable
  179. parameter vectors.
  180. By default, during training this layer keeps running estimates of its
  181. computed mean and variance, which are then used for normalization during
  182. evaluation. The running estimates are kept with a default :attr:`momentum`
  183. of 0.9.
  184. If :attr:`track_running_stats` is set to ``False``, this layer will not
  185. keep running estimates, batch statistics is used during
  186. evaluation time instead.
  187. .. note::
  188. This :attr:`momentum` argument is different from one used in optimizer
  189. classes and the conventional notion of momentum. Mathematically, the
  190. update rule for running statistics here is
  191. :math:`\hat{x}_\text{new} = \text{momentum} \times \hat{x} + (1 - \text{momentum}) \times x_t`,
  192. where :math:`\hat{x}` is the estimated statistic and :math:`x_t` is the
  193. new observed value.
  194. Because the Batch Normalization is done over the `C` dimension, computing
  195. statistics on `(N, H, W)` slices, it's common terminology to call this
  196. Spatial Batch Normalization.
  197. :type num_features: int
  198. :param num_features: usually :math:`C` from an input of shape
  199. :math:`(N, C, H, W)` or the highest ranked dimension of an input
  200. less than 4D.
  201. :type eps: float
  202. :param eps: a value added to the denominator for numerical stability.
  203. Default: 1e-5
  204. :type momentum: float
  205. :param momentum: the value used for the ``running_mean`` and ``running_var`` computation.
  206. Default: 0.9
  207. :type affine: bool
  208. :param affine: a boolean value that when set to True, this module has
  209. learnable affine parameters. Default: True
  210. :type track_running_stats: bool
  211. :param track_running_stats: when set to True, this module tracks the
  212. running mean and variance. When set to False, this module does not
  213. track such statistics and always uses batch statistics in both training
  214. and eval modes. Default: True
  215. :type freeze: bool
  216. :param freeze: when set to True, this module does not update the
  217. running mean and variance, and uses the running mean and variance instead of
  218. the batch mean and batch variance to normalize the input. The parameter takes effect
  219. only when the module is initilized with track_running_stats as True and
  220. the module is in training mode.
  221. Default: False
  222. Examples:
  223. .. testcode::
  224. import numpy as np
  225. import megengine as mge
  226. import megengine.module as M
  227. # With Learnable Parameters
  228. m = M.BatchNorm2d(4)
  229. inp = mge.tensor(np.random.rand(1, 4, 3, 3).astype("float32"))
  230. oup = m(inp)
  231. print(m.weight.numpy(), m.bias.numpy())
  232. # Without L`e`arnable Parameters
  233. m = M.BatchNorm2d(4, affine=False)
  234. oup = m(inp)
  235. print(m.weight, m.bias)
  236. Outputs:
  237. .. testoutput::
  238. [1. 1. 1. 1.] [0. 0. 0. 0.]
  239. None None
  240. """
  241. def _check_input_ndim(self, inp):
  242. if len(inp.shape) != 4:
  243. raise ValueError("expected 4D input (got {}D input)".format(len(inp.shape)))

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台