You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

batchnorm.py 8.4 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257
  1. # -*- coding: utf-8 -*-
  2. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  3. #
  4. # Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
  5. #
  6. # Unless required by applicable law or agreed to in writing,
  7. # software distributed under the License is distributed on an
  8. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  9. import numpy as np
  10. from ..core import Buffer, Parameter
  11. from ..core.device import get_default_device
  12. from ..functional import batch_norm2d, sync_batch_norm
  13. from . import init
  14. from .module import Module
  15. class _BatchNorm(Module):
  16. def __init__(
  17. self,
  18. num_features,
  19. eps=1e-5,
  20. momentum=0.9,
  21. affine=True,
  22. track_running_stats=True,
  23. ):
  24. super(_BatchNorm, self).__init__()
  25. self.num_features = num_features
  26. self.eps = eps
  27. self.momentum = momentum
  28. self.affine = affine
  29. self.track_running_stats = track_running_stats
  30. if self.affine:
  31. self.weight = Parameter(np.ones(num_features, dtype=np.float32))
  32. self.bias = Parameter(np.zeros(num_features, dtype=np.float32))
  33. else:
  34. self.weight = None
  35. self.bias = None
  36. tshape = (1, self.num_features, 1, 1)
  37. if self.track_running_stats:
  38. self.running_mean = Buffer(np.zeros(tshape, dtype=np.float32))
  39. self.running_var = Buffer(np.ones(tshape, dtype=np.float32))
  40. else:
  41. self.running_mean = None
  42. self.running_var = None
  43. def reset_running_stats(self) -> None:
  44. if self.track_running_stats:
  45. init.zeros_(self.running_mean)
  46. init.ones_(self.running_var)
  47. def reset_parameters(self) -> None:
  48. self.reset_running_stats()
  49. if self.affine:
  50. init.ones_(self.weight)
  51. init.zeros_(self.bias)
  52. def _check_input_ndim(self, inp):
  53. raise NotImplementedError
  54. def forward(self, inp):
  55. self._check_input_ndim(inp)
  56. _ndims = len(inp.shape)
  57. if _ndims != 4:
  58. origin_shape = inp.shapeof()
  59. if _ndims == 2:
  60. n, c = inp.shapeof(0), inp.shapeof(1)
  61. new_shape = (n, c, 1, 1)
  62. elif _ndims == 3:
  63. n, c, h = inp.shapeof(0), inp.shapeof(1), inp.shapeof(2)
  64. new_shape = (n, c, h, 1)
  65. inp = inp.reshape(new_shape)
  66. if self.training and self.track_running_stats:
  67. exponential_average_factor = self.momentum
  68. else:
  69. exponential_average_factor = 0.0 # useless
  70. # FIXME currently rocm does not support real bn opr so we just use
  71. # sync_batch_norm(as implemented by elemwise) here,
  72. # we will fix it in the next version
  73. if get_default_device() == "rocmx":
  74. output = sync_batch_norm(
  75. inp,
  76. self.running_mean,
  77. self.running_var,
  78. self.weight,
  79. self.bias,
  80. self.training or not self.track_running_stats,
  81. exponential_average_factor,
  82. self.eps,
  83. )
  84. else:
  85. output = batch_norm2d(
  86. inp,
  87. self.running_mean,
  88. self.running_var,
  89. self.weight,
  90. self.bias,
  91. self.training or not self.track_running_stats,
  92. exponential_average_factor,
  93. self.eps,
  94. )
  95. if _ndims != 4:
  96. output = output.reshape(origin_shape)
  97. return output
  98. class SyncBatchNorm(_BatchNorm):
  99. r"""
  100. Applies Synchronization Batch Normalization.
  101. """
  102. def _check_input_ndim(self, inp):
  103. if len(inp.shape) not in {2, 3, 4}:
  104. raise ValueError(
  105. "expected 2D, 3D or 4D input (got {}D input)".format(len(inp.shape))
  106. )
  107. def forward(self, inp):
  108. self._check_input_ndim(inp)
  109. _ndims = len(inp.shape)
  110. if _ndims != 4:
  111. origin_shape = inp.shapeof()
  112. if _ndims == 2:
  113. n, c = inp.shapeof(0), inp.shapeof(1)
  114. new_shape = (n, c, 1, 1)
  115. elif _ndims == 3:
  116. n, c, h = inp.shapeof(0), inp.shapeof(1), inp.shapeof(2)
  117. new_shape = (n, c, h, 1)
  118. inp = inp.reshape(new_shape)
  119. if self.training and self.track_running_stats:
  120. exponential_average_factor = self.momentum
  121. else:
  122. exponential_average_factor = 0.0 # useless
  123. output = sync_batch_norm(
  124. inp,
  125. self.running_mean,
  126. self.running_var,
  127. self.weight,
  128. self.bias,
  129. self.training or not self.track_running_stats,
  130. exponential_average_factor,
  131. self.eps,
  132. )
  133. if _ndims != 4:
  134. output = output.reshape(origin_shape)
  135. return output
  136. class BatchNorm1d(_BatchNorm):
  137. r"""
  138. Applies Batch Normalization over a 2D/3D tensor.
  139. Refer to :class:`~.BatchNorm2d` for more information.
  140. """
  141. def _check_input_ndim(self, inp):
  142. if len(inp.shape) not in {2, 3}:
  143. raise ValueError(
  144. "expected 2D or 3D input (got {}D input)".format(len(inp.shape))
  145. )
  146. class BatchNorm2d(_BatchNorm):
  147. r"""
  148. Applies Batch Normalization over a 4D tensor.
  149. .. math::
  150. y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta
  151. The mean and standard-deviation are calculated per-dimension over
  152. the mini-batches and :math:`\gamma` and :math:`\beta` are learnable
  153. parameter vectors.
  154. By default, during training this layer keeps running estimates of its
  155. computed mean and variance, which are then used for normalization during
  156. evaluation. The running estimates are kept with a default :attr:`momentum`
  157. of 0.9.
  158. If :attr:`track_running_stats` is set to ``False``, this layer will not
  159. keep running estimates, and batch statistics are instead used during
  160. evaluation time.
  161. .. note::
  162. This :attr:`momentum` argument is different from one used in optimizer
  163. classes and the conventional notion of momentum. Mathematically, the
  164. update rule for running statistics here is
  165. :math:`\hat{x}_\text{new} = \text{momentum} \times \hat{x} + (1 - \text{momentum}) \times x_t`,
  166. where :math:`\hat{x}` is the estimated statistic and :math:`x_t` is the
  167. new observed value.
  168. Because the Batch Normalization is done over the `C` dimension, computing
  169. statistics on `(N, H, W)` slices, it's common terminology to call this
  170. Spatial Batch Normalization.
  171. :type num_features: int
  172. :param num_features: usually the :math:`C` from an input of size
  173. :math:`(N, C, H, W)` or the highest ranked dimension of an input with
  174. less than 4D.
  175. :type eps: float
  176. :param eps: a value added to the denominator for numerical stability.
  177. Default: 1e-5.
  178. :type momentum: float
  179. :param momentum: the value used for the `running_mean` and `running_var`
  180. computation.
  181. Default: 0.9
  182. :type affine: bool
  183. :param affine: a boolean value that when set to ``True``, this module has
  184. learnable affine parameters. Default: ``True``
  185. :type track_running_stats: bool
  186. :param track_running_stats: when set to ``True``, this module tracks the
  187. running mean and variance. When set to ``False``, this module does not
  188. track such statistics and always uses batch statistics in both training
  189. and eval modes. Default: ``True``.
  190. Examples:
  191. .. testcode::
  192. import numpy as np
  193. import megengine as mge
  194. import megengine.module as M
  195. # With Learnable Parameters
  196. m = M.BatchNorm2d(4)
  197. inp = mge.tensor(np.random.rand(1, 4, 3, 3).astype("float32"))
  198. oup = m(inp)
  199. print(m.weight, m.bias)
  200. # Without Learnable Parameters
  201. m = M.BatchNorm2d(4, affine=False)
  202. oup = m(inp)
  203. print(m.weight, m.bias)
  204. .. testoutput::
  205. Tensor([1. 1. 1. 1.]) Tensor([0. 0. 0. 0.])
  206. None None
  207. """
  208. def _check_input_ndim(self, inp):
  209. if len(inp.shape) != 4:
  210. raise ValueError("expected 4D input (got {}D input)".format(len(inp.shape)))

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台