You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

batchnorm.py 12 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342
  1. # -*- coding: utf-8 -*-
  2. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  3. #
  4. # Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  5. #
  6. # Unless required by applicable law or agreed to in writing,
  7. # software distributed under the License is distributed on an
  8. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  9. from typing import Optional
  10. import numpy as np
  11. from ..distributed.group import WORLD, Group
  12. from ..functional.nn import batch_norm, sync_batch_norm
  13. from ..tensor import Parameter, Tensor
  14. from . import init
  15. from .module import Module
  16. class _BatchNorm(Module):
  17. def __init__(
  18. self,
  19. num_features,
  20. eps=1e-5,
  21. momentum=0.9,
  22. affine=True,
  23. track_running_stats=True,
  24. freeze=False,
  25. compute_mode="default",
  26. param_dim="dim_1c11",
  27. **kwargs
  28. ):
  29. super(_BatchNorm, self).__init__(**kwargs)
  30. self.num_features = num_features
  31. self.eps = eps
  32. self.momentum = momentum
  33. self.affine = affine
  34. self.track_running_stats = track_running_stats
  35. self._track_running_stats_saved = track_running_stats
  36. self.freeze = freeze
  37. self.compute_mode = compute_mode
  38. self.param_dim = param_dim
  39. if self.freeze:
  40. assert (
  41. self._track_running_stats_saved
  42. ), "track_running_stats must be initilized to True if freeze is True"
  43. tshape = (1, self.num_features, 1, 1)
  44. if self.affine:
  45. self.weight = Parameter(np.ones(tshape, dtype=np.float32))
  46. self.bias = Parameter(np.zeros(tshape, dtype=np.float32))
  47. else:
  48. self.weight = None
  49. self.bias = None
  50. if self.track_running_stats:
  51. self.running_mean = Tensor(np.zeros(tshape, dtype=np.float32))
  52. self.running_var = Tensor(np.ones(tshape, dtype=np.float32))
  53. else:
  54. self.running_mean = None
  55. self.running_var = None
  56. def reset_running_stats(self) -> None:
  57. if self.track_running_stats:
  58. init.zeros_(self.running_mean)
  59. init.ones_(self.running_var)
  60. def reset_parameters(self) -> None:
  61. self.reset_running_stats()
  62. if self.affine:
  63. init.ones_(self.weight)
  64. init.zeros_(self.bias)
  65. def _check_input_ndim(self, inp):
  66. raise NotImplementedError
  67. def forward(self, inp):
  68. self._check_input_ndim(inp)
  69. if self._track_running_stats_saved == False:
  70. assert (
  71. self.track_running_stats == False
  72. ), "track_running_stats can not be initilized to False and changed to True later"
  73. inp_shape = inp.shape
  74. _ndims = len(inp_shape)
  75. if _ndims != 4:
  76. origin_shape = inp_shape
  77. if _ndims == 2:
  78. n, c = inp_shape[0], inp_shape[1]
  79. new_shape = (n, c, 1, 1)
  80. elif _ndims == 3:
  81. n, c, h = inp_shape[0], inp_shape[1], inp_shape[2]
  82. new_shape = (n, c, h, 1)
  83. inp = inp.reshape(new_shape)
  84. _weight = self.weight
  85. _bias = self.bias
  86. if self.freeze:
  87. if _weight is not None:
  88. _weight = _weight.detach()
  89. if _bias is not None:
  90. _bias = _bias.detach()
  91. # fastpath excution for freeze
  92. scale = (self.running_var + self.eps) ** (-0.5)
  93. if _weight is not None:
  94. scale *= _weight
  95. bias = -self.running_mean * scale
  96. if _bias is not None:
  97. bias += _bias
  98. return inp * scale + bias
  99. if self.training and self.track_running_stats:
  100. exponential_average_factor = self.momentum
  101. else:
  102. exponential_average_factor = 0.0 # useless
  103. output = batch_norm(
  104. inp,
  105. self.running_mean if self.track_running_stats else None,
  106. self.running_var if self.track_running_stats else None,
  107. _weight,
  108. _bias,
  109. training=self.training
  110. or ((self.running_mean is None) and (self.running_var is None)),
  111. momentum=exponential_average_factor,
  112. eps=self.eps,
  113. compute_mode=self.compute_mode,
  114. param_dim=self.param_dim,
  115. )
  116. if _ndims != 4:
  117. output = output.reshape(origin_shape)
  118. return output
  119. def _module_info_string(self) -> str:
  120. s = (
  121. "{num_features}, eps={eps}, momentum={momentum}, affine={affine}, "
  122. "track_running_stats={track_running_stats}"
  123. )
  124. return s.format(**self.__dict__)
  125. class SyncBatchNorm(_BatchNorm):
  126. r"""Applies Synchronized Batch Normalization for distributed training.
  127. Args:
  128. num_features: usually :math:`C` from an input of shape
  129. :math:`(N, C, H, W)` or the highest ranked dimension of an input
  130. less than 4D.
  131. eps: a value added to the denominator for numerical stability.
  132. Default: 1e-5
  133. momentum: the value used for the ``running_mean`` and ``running_var`` computation.
  134. Default: 0.9
  135. affine: a boolean value that when set to True, this module has
  136. learnable affine parameters. Default: True
  137. track_running_stats: when set to True, this module tracks the
  138. running mean and variance. When set to False, this module does not
  139. track such statistics and always uses batch statistics in both training
  140. and eval modes. Default: True
  141. freeze: when set to True, this module does not update the
  142. running mean and variance, and uses the running mean and variance instead of
  143. the batch mean and batch variance to normalize the input. The parameter takes effect
  144. only when the module is initilized with track_running_stats as True.
  145. Default: False
  146. group: communication group, caculate mean and variance between this group.
  147. Default: :obj:`~.distributed.WORLD`
  148. """
  149. def __init__(
  150. self,
  151. num_features,
  152. eps=1e-5,
  153. momentum=0.9,
  154. affine=True,
  155. track_running_stats=True,
  156. freeze=False,
  157. group: Optional[Group] = WORLD,
  158. **kwargs
  159. ) -> None:
  160. super().__init__(
  161. num_features, eps, momentum, affine, track_running_stats, freeze, **kwargs
  162. )
  163. self.group = group
  164. def _check_input_ndim(self, inp):
  165. if len(inp.shape) not in {2, 3, 4}:
  166. raise ValueError(
  167. "expected 2D, 3D or 4D input (got {}D input)".format(len(inp.shape))
  168. )
  169. def forward(self, inp):
  170. self._check_input_ndim(inp)
  171. inp_shape = inp.shape
  172. _ndims = len(inp_shape)
  173. if _ndims != 4:
  174. new_shape = Tensor([1, 1, 1, 1], device=inp.device)
  175. origin_shape = inp_shape
  176. if _ndims == 2:
  177. new_shape[:2] = origin_shape[:2]
  178. elif _ndims == 3:
  179. new_shape[:3] = origin_shape[:3]
  180. else:
  181. raise ValueError(
  182. "expected 2D, 3D or 4D input (got {}D input)".format(len(inp_shape))
  183. )
  184. inp = inp.reshape(new_shape)
  185. if self.training and self.track_running_stats:
  186. exponential_average_factor = self.momentum
  187. else:
  188. exponential_average_factor = 0.0 # useless
  189. _weight = self.weight
  190. _bias = self.bias
  191. if self.freeze:
  192. if _weight is not None:
  193. _weight = _weight.detach()
  194. if _bias is not None:
  195. _bias = _bias.detach()
  196. output = sync_batch_norm(
  197. inp,
  198. self.running_mean,
  199. self.running_var,
  200. _weight,
  201. _bias,
  202. training=(self.training and not self.freeze)
  203. or ((self.running_mean is None) and (self.running_var is None)),
  204. momentum=exponential_average_factor,
  205. eps=self.eps,
  206. group=self.group,
  207. )
  208. if _ndims != 4:
  209. output = output.reshape(origin_shape)
  210. return output
  211. class BatchNorm1d(_BatchNorm):
  212. r"""Applies Batch Normalization over a 2D/3D tensor.
  213. Refer to :class:`~.BatchNorm2d` for more information.
  214. """
  215. def _check_input_ndim(self, inp):
  216. if len(inp.shape) not in {2, 3}:
  217. raise ValueError(
  218. "expected 2D or 3D input (got {}D input)".format(len(inp.shape))
  219. )
  220. class BatchNorm2d(_BatchNorm):
  221. r"""Applies Batch Normalization over a 4D tensor.
  222. .. math::
  223. y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta
  224. The mean and standard-deviation are calculated per-dimension over
  225. the mini-batches and :math:`\gamma` and :math:`\beta` are learnable
  226. parameter vectors.
  227. By default, during training this layer keeps running estimates of its
  228. computed mean and variance, which are then used for normalization during
  229. evaluation. The running estimates are kept with a default :attr:`momentum`
  230. of 0.9.
  231. If :attr:`track_running_stats` is set to ``False``, this layer will not
  232. keep running estimates, batch statistics is used during
  233. evaluation time instead.
  234. Because the Batch Normalization is done over the `C` dimension, computing
  235. statistics on `(N, H, W)` slices, it's common terminology to call this
  236. Spatial Batch Normalization.
  237. .. note::
  238. The update formula for ``running_mean`` and ``running_var`` (taking ``running_mean`` as an example) is
  239. .. math::
  240. \textrm{running_mean} = \textrm{momentum} \times \textrm{running_mean} + (1 - \textrm{momentum}) \times \textrm{batch_mean}
  241. which could be defined differently in other frameworks. Most notably, ``momentum`` of 0.1 in PyTorch
  242. is equivalent to ``mementum`` of 0.9 here.
  243. Args:
  244. num_features: usually :math:`C` from an input of shape
  245. :math:`(N, C, H, W)` or the highest ranked dimension of an input
  246. less than 4D.
  247. eps: a value added to the denominator for numerical stability.
  248. Default: 1e-5
  249. momentum: the value used for the ``running_mean`` and ``running_var`` computation.
  250. Default: 0.9
  251. affine: a boolean value that when set to True, this module has
  252. learnable affine parameters. Default: True
  253. track_running_stats: when set to True, this module tracks the
  254. running mean and variance. When set to False, this module does not
  255. track such statistics and always uses batch statistics in both training
  256. and eval modes. Default: True
  257. freeze: when set to True, this module does not update the
  258. running mean and variance, and uses the running mean and variance instead of
  259. the batch mean and batch variance to normalize the input. The parameter takes effect
  260. only when the module is initilized with track_running_stats as True.
  261. Default: False
  262. Examples:
  263. .. testcode::
  264. import numpy as np
  265. import megengine as mge
  266. import megengine.module as M
  267. # With Learnable Parameters
  268. m = M.BatchNorm2d(4)
  269. inp = mge.tensor(np.random.rand(1, 4, 3, 3).astype("float32"))
  270. oup = m(inp)
  271. print(m.weight.numpy().flatten(), m.bias.numpy().flatten())
  272. # Without L`e`arnable Parameters
  273. m = M.BatchNorm2d(4, affine=False)
  274. oup = m(inp)
  275. print(m.weight, m.bias)
  276. Outputs:
  277. .. testoutput::
  278. [1. 1. 1. 1.] [0. 0. 0. 0.]
  279. None None
  280. """
  281. def _check_input_ndim(self, inp):
  282. if len(inp.shape) != 4:
  283. raise ValueError("expected 4D input (got {}D input)".format(len(inp.shape)))