You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

batchnorm.py 11 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305
  1. # -*- coding: utf-8 -*-
  2. from typing import Optional
  3. import numpy as np
  4. from ..distributed.group import WORLD, Group
  5. from ..functional.nn import batch_norm, sync_batch_norm
  6. from ..tensor import Parameter, Tensor
  7. from . import init
  8. from .module import Module
  9. class _BatchNorm(Module):
  10. def __init__(
  11. self,
  12. num_features,
  13. eps=1e-5,
  14. momentum=0.9,
  15. affine=True,
  16. track_running_stats=True,
  17. freeze=False,
  18. param_dim="dim_1c11",
  19. **kwargs
  20. ):
  21. super(_BatchNorm, self).__init__(**kwargs)
  22. self.num_features = num_features
  23. self.eps = eps
  24. self.momentum = momentum
  25. self.affine = affine
  26. self.track_running_stats = track_running_stats
  27. self._track_running_stats_saved = track_running_stats
  28. self.freeze = freeze
  29. self.param_dim = param_dim
  30. if self.freeze:
  31. assert (
  32. self._track_running_stats_saved
  33. ), "track_running_stats must be initilized to True if freeze is True"
  34. tshape = (1, self.num_features, 1, 1)
  35. if self.affine:
  36. self.weight = Parameter(np.ones(tshape, dtype=np.float32))
  37. self.bias = Parameter(np.zeros(tshape, dtype=np.float32))
  38. else:
  39. self.weight = None
  40. self.bias = None
  41. if self.track_running_stats:
  42. self.running_mean = Tensor(np.zeros(tshape, dtype=np.float32))
  43. self.running_var = Tensor(np.ones(tshape, dtype=np.float32))
  44. else:
  45. self.running_mean = None
  46. self.running_var = None
  47. def reset_running_stats(self) -> None:
  48. if self.track_running_stats:
  49. init.zeros_(self.running_mean)
  50. init.ones_(self.running_var)
  51. def reset_parameters(self) -> None:
  52. self.reset_running_stats()
  53. if self.affine:
  54. init.ones_(self.weight)
  55. init.zeros_(self.bias)
  56. def _check_input_ndim(self, inp):
  57. raise NotImplementedError
  58. def forward(self, inp):
  59. self._check_input_ndim(inp)
  60. if self._track_running_stats_saved == False:
  61. assert (
  62. self.track_running_stats == False
  63. ), "track_running_stats can not be initilized to False and changed to True later"
  64. _weight = self.weight
  65. _bias = self.bias
  66. if self.freeze:
  67. if _weight is not None:
  68. _weight = _weight.detach()
  69. if _bias is not None:
  70. _bias = _bias.detach()
  71. # fastpath excution for freeze
  72. scale = (self.running_var + self.eps) ** (-0.5)
  73. if _weight is not None:
  74. scale *= _weight
  75. bias = -self.running_mean * scale
  76. if _bias is not None:
  77. bias += _bias
  78. return inp * scale + bias
  79. if self.training and self.track_running_stats:
  80. exponential_average_factor = self.momentum
  81. else:
  82. exponential_average_factor = 0.0 # useless
  83. output = batch_norm(
  84. inp,
  85. self.running_mean if self.track_running_stats else None,
  86. self.running_var if self.track_running_stats else None,
  87. _weight,
  88. _bias,
  89. training=self.training
  90. or ((self.running_mean is None) and (self.running_var is None)),
  91. momentum=exponential_average_factor,
  92. eps=self.eps,
  93. param_dim=self.param_dim,
  94. )
  95. return output
  96. def _module_info_string(self) -> str:
  97. s = (
  98. "{num_features}, eps={eps}, momentum={momentum}, affine={affine}, "
  99. "track_running_stats={track_running_stats}"
  100. )
  101. return s.format(**self.__dict__)
  102. class SyncBatchNorm(_BatchNorm):
  103. r"""Applies Synchronized Batch Normalization for distributed training.
  104. Args:
  105. num_features: usually :math:`C` from an input of shape
  106. :math:`(N, C, H, W)` or the highest ranked dimension of an input
  107. less than 4D.
  108. eps: a value added to the denominator for numerical stability.
  109. Default: 1e-5
  110. momentum: the value used for the ``running_mean`` and ``running_var`` computation.
  111. Default: 0.9
  112. affine: a boolean value that when set to True, this module has
  113. learnable affine parameters. Default: True
  114. track_running_stats: when set to True, this module tracks the
  115. running mean and variance. When set to False, this module does not
  116. track such statistics and always uses batch statistics in both training
  117. and eval modes. Default: True
  118. freeze: when set to True, this module does not update the
  119. running mean and variance, and uses the running mean and variance instead of
  120. the batch mean and batch variance to normalize the input. The parameter takes effect
  121. only when the module is initilized with track_running_stats as True.
  122. Default: False
  123. group: communication group, caculate mean and variance between this group.
  124. Default: :obj:`~.distributed.WORLD`
  125. """
  126. def __init__(
  127. self,
  128. num_features,
  129. eps=1e-5,
  130. momentum=0.9,
  131. affine=True,
  132. track_running_stats=True,
  133. freeze=False,
  134. group: Optional[Group] = WORLD,
  135. **kwargs
  136. ) -> None:
  137. super().__init__(
  138. num_features, eps, momentum, affine, track_running_stats, freeze, **kwargs
  139. )
  140. self.group = group
  141. def _check_input_ndim(self, inp):
  142. if len(inp.shape) not in {2, 3, 4}:
  143. raise ValueError(
  144. "expected 2D, 3D or 4D input (got {}D input)".format(len(inp.shape))
  145. )
  146. def forward(self, inp):
  147. self._check_input_ndim(inp)
  148. inp_shape = inp.shape
  149. _ndims = len(inp_shape)
  150. if _ndims != 4:
  151. new_shape = Tensor([1, 1, 1, 1], device=inp.device)
  152. origin_shape = inp_shape
  153. if _ndims == 2:
  154. new_shape[:2] = origin_shape[:2]
  155. elif _ndims == 3:
  156. new_shape[:3] = origin_shape[:3]
  157. else:
  158. raise ValueError(
  159. "expected 2D, 3D or 4D input (got {}D input)".format(len(inp_shape))
  160. )
  161. inp = inp.reshape(new_shape)
  162. if self.training and self.track_running_stats:
  163. exponential_average_factor = self.momentum
  164. else:
  165. exponential_average_factor = 0.0 # useless
  166. _weight = self.weight
  167. _bias = self.bias
  168. if self.freeze:
  169. if _weight is not None:
  170. _weight = _weight.detach()
  171. if _bias is not None:
  172. _bias = _bias.detach()
  173. output = sync_batch_norm(
  174. inp,
  175. self.running_mean,
  176. self.running_var,
  177. _weight,
  178. _bias,
  179. training=(self.training and not self.freeze)
  180. or ((self.running_mean is None) and (self.running_var is None)),
  181. momentum=exponential_average_factor,
  182. eps=self.eps,
  183. group=self.group,
  184. )
  185. if _ndims != 4:
  186. output = output.reshape(origin_shape)
  187. return output
  188. class BatchNorm1d(_BatchNorm):
  189. r"""Applies Batch Normalization over a 2D/3D tensor.
  190. Refer to :class:`~.BatchNorm2d` for more information.
  191. """
  192. def _check_input_ndim(self, inp):
  193. if len(inp.shape) not in {2, 3}:
  194. raise ValueError(
  195. "expected 2D or 3D input (got {}D input)".format(len(inp.shape))
  196. )
  197. class BatchNorm2d(_BatchNorm):
  198. r"""Applies Batch Normalization over a 4D tensor.
  199. .. math::
  200. y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta
  201. The mean and standard-deviation are calculated per-dimension over
  202. the mini-batches and :math:`\gamma` and :math:`\beta` are learnable
  203. parameter vectors.
  204. By default, during training this layer keeps running estimates of its
  205. computed mean and variance, which are then used for normalization during
  206. evaluation. The running estimates are kept with a default :attr:`momentum`
  207. of 0.9.
  208. If :attr:`track_running_stats` is set to ``False``, this layer will not
  209. keep running estimates, batch statistics is used during
  210. evaluation time instead.
  211. Because the Batch Normalization is done over the `C` dimension, computing
  212. statistics on `(N, H, W)` slices, it's common terminology to call this
  213. Spatial Batch Normalization.
  214. .. note::
  215. The update formula for ``running_mean`` and ``running_var`` (taking ``running_mean`` as an example) is
  216. .. math::
  217. \textrm{running_mean} = \textrm{momentum} \times \textrm{running_mean} + (1 - \textrm{momentum}) \times \textrm{batch_mean}
  218. which could be defined differently in other frameworks. Most notably, ``momentum`` of 0.1 in PyTorch
  219. is equivalent to ``mementum`` of 0.9 here.
  220. Args:
  221. num_features: usually :math:`C` from an input of shape
  222. :math:`(N, C, H, W)` or the highest ranked dimension of an input
  223. less than 4D.
  224. eps: a value added to the denominator for numerical stability.
  225. Default: 1e-5
  226. momentum: the value used for the ``running_mean`` and ``running_var`` computation.
  227. Default: 0.9
  228. affine: a boolean value that when set to True, this module has
  229. learnable affine parameters. Default: True
  230. track_running_stats: when set to True, this module tracks the
  231. running mean and variance. When set to False, this module does not
  232. track such statistics and always uses batch statistics in both training
  233. and eval modes. Default: True
  234. freeze: when set to True, this module does not update the
  235. running mean and variance, and uses the running mean and variance instead of
  236. the batch mean and batch variance to normalize the input. The parameter takes effect
  237. only when the module is initilized with track_running_stats as True.
  238. Default: False
  239. Examples:
  240. >>> import numpy as np
  241. >>> # With Learnable Parameters
  242. >>> m = M.BatchNorm2d(4)
  243. >>> inp = mge.tensor(np.random.rand(1, 4, 3, 3).astype("float32"))
  244. >>> oup = m(inp)
  245. >>> print(m.weight.numpy().flatten(), m.bias.numpy().flatten())
  246. [1. 1. 1. 1.] [0. 0. 0. 0.]
  247. >>> # Without L`e`arnable Parameters
  248. >>> m = M.BatchNorm2d(4, affine=False)
  249. >>> oup = m(inp)
  250. >>> print(m.weight, m.bias)
  251. None None
  252. """
  253. def _check_input_ndim(self, inp):
  254. if len(inp.shape) != 4:
  255. raise ValueError("expected 4D input (got {}D input)".format(len(inp.shape)))