You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

batchnorm.py 11 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302
  1. # -*- coding: utf-8 -*-
  2. from typing import Optional
  3. import numpy as np
  4. from ..distributed.group import WORLD, Group
  5. from ..functional.nn import batch_norm, sync_batch_norm
  6. from ..tensor import Parameter, Tensor
  7. from . import init
  8. from .module import Module
  9. class _BatchNorm(Module):
  10. def __init__(
  11. self,
  12. num_features,
  13. eps=1e-5,
  14. momentum=0.9,
  15. affine=True,
  16. track_running_stats=True,
  17. freeze=False,
  18. **kwargs
  19. ):
  20. super(_BatchNorm, self).__init__(**kwargs)
  21. self.num_features = num_features
  22. self.eps = eps
  23. self.momentum = momentum
  24. self.affine = affine
  25. self.track_running_stats = track_running_stats
  26. self._track_running_stats_saved = track_running_stats
  27. self.freeze = freeze
  28. if self.freeze:
  29. assert (
  30. self._track_running_stats_saved
  31. ), "track_running_stats must be initilized to True if freeze is True"
  32. tshape = (1, self.num_features, 1, 1)
  33. if self.affine:
  34. self.weight = Parameter(np.ones(tshape, dtype=np.float32))
  35. self.bias = Parameter(np.zeros(tshape, dtype=np.float32))
  36. else:
  37. self.weight = None
  38. self.bias = None
  39. if self.track_running_stats:
  40. self.running_mean = Tensor(np.zeros(tshape, dtype=np.float32))
  41. self.running_var = Tensor(np.ones(tshape, dtype=np.float32))
  42. else:
  43. self.running_mean = None
  44. self.running_var = None
  45. def reset_running_stats(self) -> None:
  46. if self.track_running_stats:
  47. init.zeros_(self.running_mean)
  48. init.ones_(self.running_var)
  49. def reset_parameters(self) -> None:
  50. self.reset_running_stats()
  51. if self.affine:
  52. init.ones_(self.weight)
  53. init.zeros_(self.bias)
  54. def _check_input_ndim(self, inp):
  55. raise NotImplementedError
  56. def forward(self, inp):
  57. self._check_input_ndim(inp)
  58. if self._track_running_stats_saved == False:
  59. assert (
  60. self.track_running_stats == False
  61. ), "track_running_stats can not be initilized to False and changed to True later"
  62. _weight = self.weight
  63. _bias = self.bias
  64. if self.freeze:
  65. if _weight is not None:
  66. _weight = _weight.detach()
  67. if _bias is not None:
  68. _bias = _bias.detach()
  69. # fastpath excution for freeze
  70. scale = (self.running_var + self.eps) ** (-0.5)
  71. if _weight is not None:
  72. scale *= _weight
  73. bias = -self.running_mean * scale
  74. if _bias is not None:
  75. bias += _bias
  76. return inp * scale + bias
  77. if self.training and self.track_running_stats:
  78. exponential_average_factor = self.momentum
  79. else:
  80. exponential_average_factor = 0.0 # useless
  81. output = batch_norm(
  82. inp,
  83. self.running_mean if self.track_running_stats else None,
  84. self.running_var if self.track_running_stats else None,
  85. _weight,
  86. _bias,
  87. training=self.training
  88. or ((self.running_mean is None) and (self.running_var is None)),
  89. momentum=exponential_average_factor,
  90. eps=self.eps,
  91. )
  92. return output
  93. def _module_info_string(self) -> str:
  94. s = (
  95. "{num_features}, eps={eps}, momentum={momentum}, affine={affine}, "
  96. "track_running_stats={track_running_stats}"
  97. )
  98. return s.format(**self.__dict__)
  99. class SyncBatchNorm(_BatchNorm):
  100. r"""Applies Synchronized Batch Normalization for distributed training.
  101. Args:
  102. num_features: usually :math:`C` from an input of shape
  103. :math:`(N, C, H, W)` or the highest ranked dimension of an input
  104. less than 4D.
  105. eps: a value added to the denominator for numerical stability.
  106. Default: 1e-5
  107. momentum: the value used for the ``running_mean`` and ``running_var`` computation.
  108. Default: 0.9
  109. affine: a boolean value that when set to True, this module has
  110. learnable affine parameters. Default: True
  111. track_running_stats: when set to True, this module tracks the
  112. running mean and variance. When set to False, this module does not
  113. track such statistics and always uses batch statistics in both training
  114. and eval modes. Default: True
  115. freeze: when set to True, this module does not update the
  116. running mean and variance, and uses the running mean and variance instead of
  117. the batch mean and batch variance to normalize the input. The parameter takes effect
  118. only when the module is initilized with track_running_stats as True.
  119. Default: False
  120. group: communication group, caculate mean and variance between this group.
  121. Default: :obj:`~.distributed.WORLD`
  122. """
  123. def __init__(
  124. self,
  125. num_features,
  126. eps=1e-5,
  127. momentum=0.9,
  128. affine=True,
  129. track_running_stats=True,
  130. freeze=False,
  131. group: Optional[Group] = WORLD,
  132. **kwargs
  133. ) -> None:
  134. super().__init__(
  135. num_features, eps, momentum, affine, track_running_stats, freeze, **kwargs
  136. )
  137. self.group = group
  138. def _check_input_ndim(self, inp):
  139. if len(inp.shape) not in {2, 3, 4}:
  140. raise ValueError(
  141. "expected 2D, 3D or 4D input (got {}D input)".format(len(inp.shape))
  142. )
  143. def forward(self, inp):
  144. self._check_input_ndim(inp)
  145. inp_shape = inp.shape
  146. _ndims = len(inp_shape)
  147. if _ndims != 4:
  148. new_shape = Tensor([1, 1, 1, 1], device=inp.device)
  149. origin_shape = inp_shape
  150. if _ndims == 2:
  151. new_shape[:2] = origin_shape[:2]
  152. elif _ndims == 3:
  153. new_shape[:3] = origin_shape[:3]
  154. else:
  155. raise ValueError(
  156. "expected 2D, 3D or 4D input (got {}D input)".format(len(inp_shape))
  157. )
  158. inp = inp.reshape(new_shape)
  159. if self.training and self.track_running_stats:
  160. exponential_average_factor = self.momentum
  161. else:
  162. exponential_average_factor = 0.0 # useless
  163. _weight = self.weight
  164. _bias = self.bias
  165. if self.freeze:
  166. if _weight is not None:
  167. _weight = _weight.detach()
  168. if _bias is not None:
  169. _bias = _bias.detach()
  170. output = sync_batch_norm(
  171. inp,
  172. self.running_mean,
  173. self.running_var,
  174. _weight,
  175. _bias,
  176. training=(self.training and not self.freeze)
  177. or ((self.running_mean is None) and (self.running_var is None)),
  178. momentum=exponential_average_factor,
  179. eps=self.eps,
  180. group=self.group,
  181. )
  182. if _ndims != 4:
  183. output = output.reshape(origin_shape)
  184. return output
  185. class BatchNorm1d(_BatchNorm):
  186. r"""Applies Batch Normalization over a 2D/3D tensor.
  187. Refer to :class:`~.BatchNorm2d` for more information.
  188. """
  189. def _check_input_ndim(self, inp):
  190. if len(inp.shape) not in {2, 3}:
  191. raise ValueError(
  192. "expected 2D or 3D input (got {}D input)".format(len(inp.shape))
  193. )
  194. class BatchNorm2d(_BatchNorm):
  195. r"""Applies Batch Normalization over a 4D tensor.
  196. .. math::
  197. y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta
  198. The mean and standard-deviation are calculated per-dimension over
  199. the mini-batches and :math:`\gamma` and :math:`\beta` are learnable
  200. parameter vectors.
  201. By default, during training this layer keeps running estimates of its
  202. computed mean and variance, which are then used for normalization during
  203. evaluation. The running estimates are kept with a default :attr:`momentum`
  204. of 0.9.
  205. If :attr:`track_running_stats` is set to ``False``, this layer will not
  206. keep running estimates, batch statistics is used during
  207. evaluation time instead.
  208. Because the Batch Normalization is done over the `C` dimension, computing
  209. statistics on `(N, H, W)` slices, it's common terminology to call this
  210. Spatial Batch Normalization.
  211. .. note::
  212. The update formula for ``running_mean`` and ``running_var`` (taking ``running_mean`` as an example) is
  213. .. math::
  214. \textrm{running_mean} = \textrm{momentum} \times \textrm{running_mean} + (1 - \textrm{momentum}) \times \textrm{batch_mean}
  215. which could be defined differently in other frameworks. Most notably, ``momentum`` of 0.1 in PyTorch
  216. is equivalent to ``mementum`` of 0.9 here.
  217. Args:
  218. num_features: usually :math:`C` from an input of shape
  219. :math:`(N, C, H, W)` or the highest ranked dimension of an input
  220. less than 4D.
  221. eps: a value added to the denominator for numerical stability.
  222. Default: 1e-5
  223. momentum: the value used for the ``running_mean`` and ``running_var`` computation.
  224. Default: 0.9
  225. affine: a boolean value that when set to True, this module has
  226. learnable affine parameters. Default: True
  227. track_running_stats: when set to True, this module tracks the
  228. running mean and variance. When set to False, this module does not
  229. track such statistics and always uses batch statistics in both training
  230. and eval modes. Default: True
  231. freeze: when set to True, this module does not update the
  232. running mean and variance, and uses the running mean and variance instead of
  233. the batch mean and batch variance to normalize the input. The parameter takes effect
  234. only when the module is initilized with track_running_stats as True.
  235. Default: False
  236. Examples:
  237. >>> import numpy as np
  238. >>> # With Learnable Parameters
  239. >>> m = M.BatchNorm2d(4)
  240. >>> inp = mge.tensor(np.random.rand(1, 4, 3, 3).astype("float32"))
  241. >>> oup = m(inp)
  242. >>> print(m.weight.numpy().flatten(), m.bias.numpy().flatten())
  243. [1. 1. 1. 1.] [0. 0. 0. 0.]
  244. >>> # Without L`e`arnable Parameters
  245. >>> m = M.BatchNorm2d(4, affine=False)
  246. >>> oup = m(inp)
  247. >>> print(m.weight, m.bias)
  248. None None
  249. """
  250. def _check_input_ndim(self, inp):
  251. if len(inp.shape) != 4:
  252. raise ValueError("expected 4D input (got {}D input)".format(len(inp.shape)))