You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

conv.py 15 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430
  1. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  2. #
  3. # Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
  4. #
  5. # Unless required by applicable law or agreed to in writing,
  6. # software distributed under the License is distributed on an
  7. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  8. from abc import abstractmethod
  9. from typing import Tuple, Union
  10. import numpy as np
  11. from ..core.ops._internal import param_defs as P
  12. from ..functional import conv2d, conv_transpose2d, local_conv2d, relu
  13. from ..functional.types import _pair, _pair_nonzero
  14. from ..tensor import Parameter
  15. from . import init
  16. from .module import Module
  17. class _ConvNd(Module):
  18. """base class for convolution modules, including transposed conv"""
  19. def __init__(
  20. self,
  21. in_channels: int,
  22. out_channels: int,
  23. kernel_size: Union[int, Tuple[int, int]],
  24. stride: Union[int, Tuple[int, int]],
  25. padding: Union[int, Tuple[int, int]],
  26. dilation: Union[int, Tuple[int, int]],
  27. groups: int,
  28. bias: bool = True,
  29. ):
  30. super().__init__()
  31. if in_channels % groups != 0:
  32. raise ValueError("in_channels must be divisible by groups")
  33. if out_channels % groups != 0:
  34. raise ValueError("out_channels must be divisible by groups")
  35. self.in_channels = in_channels
  36. self.out_channels = out_channels
  37. self.kernel_size = kernel_size
  38. self.stride = stride
  39. self.padding = padding
  40. self.dilation = dilation
  41. self.groups = groups
  42. self.weight = Parameter(np.zeros(self._infer_weight_shape(), dtype=np.float32))
  43. self.bias = None
  44. if bias:
  45. self.bias = Parameter(np.zeros(self._infer_bias_shape(), dtype=np.float32))
  46. self.reset_parameters()
  47. @abstractmethod
  48. def _get_fanin(self):
  49. pass
  50. def reset_parameters(self) -> None:
  51. fanin = self._get_fanin()
  52. std = np.sqrt(1 / fanin)
  53. init.normal_(self.weight, 0.0, std)
  54. if self.bias is not None:
  55. init.zeros_(self.bias)
  56. @abstractmethod
  57. def _infer_weight_shape(self):
  58. pass
  59. @abstractmethod
  60. def _infer_bias_shape(self):
  61. pass
  62. def _module_info_string(self):
  63. s = "{in_channels}, {out_channels}, kernel_size={kernel_size}"
  64. if self.stride != (1,) * len(self.stride):
  65. s += ", stride={stride}"
  66. if self.padding != (0,) * len(self.padding):
  67. s += ", padding={padding}"
  68. if self.dilation != (1,) * len(self.dilation):
  69. s += ", dilation={dilation}"
  70. if self.groups != 1:
  71. s += ", groups={groups}"
  72. if self.bias is None:
  73. s += ", bias=False"
  74. return s.format(**self.__dict__)
  75. class Conv2d(_ConvNd):
  76. r"""Applies a 2D convolution over an input tensor.
  77. For instance, given an input of the size :math:`(N, C_{\text{in}}, H, W)`,
  78. this layer generates an output of the size
  79. :math:`(N, C_{\text{out}}, H_{\text{out}}, W_{\text{out}})` through the
  80. process described as below:
  81. .. math::
  82. \text{out}(N_i, C_{\text{out}_j}) = \text{bias}(C_{\text{out}_j}) +
  83. \sum_{k = 0}^{C_{\text{in}} - 1} \text{weight}(C_{\text{out}_j}, k) \star \text{input}(N_i, k)
  84. where :math:`\star` is the valid 2D cross-correlation operator,
  85. :math:`N` is batch size, :math:`C` denotes number of channels,
  86. :math:`H` is height of input planes in pixels, and :math:`W` is
  87. width in pixels.
  88. When `groups == in_channels` and `out_channels == K * in_channels`,
  89. where K is a positive integer, this operation is also known as depthwise
  90. convolution.
  91. In other words, for an input of size :math:`(N, C_{in}, H_{in}, W_{in})`,
  92. a depthwise convolution with a depthwise multiplier `K`, can be constructed
  93. by arguments :math:`(in\_channels=C_{in}, out\_channels=C_{in} \times K, ..., groups=C_{in})`.
  94. :param in_channels: number of input channels.
  95. :param out_channels: number of output channels.
  96. :param kernel_size: size of weight on spatial dimensions. If kernel_size is
  97. an :class:`int`, the actual kernel size would be
  98. `(kernel_size, kernel_size)`. Default: 1
  99. :param stride: stride of the 2D convolution operation. Default: 1
  100. :param padding: size of the paddings added to the input on both sides of its
  101. spatial dimensions. Only zero-padding is supported. Default: 0
  102. :param dilation: dilation of the 2D convolution operation. Default: 1
  103. :param groups: number of groups into which the input and output channels are divided, so as to perform a "grouped convolution". When ``groups`` is not 1,
  104. ``in_channels`` and ``out_channels`` must be divisible by ``groups``,
  105. and there would be an extra dimension at the beginning of the weight's
  106. shape. Specifically, the shape of weight would be `(groups,
  107. out_channel // groups, in_channels // groups, *kernel_size)`.
  108. :param bias: whether to add a bias onto the result of convolution. Default:
  109. True
  110. :param conv_mode: Supports `CROSS_CORRELATION` or `CONVOLUTION`. Default:
  111. `CROSS_CORRELATION`
  112. :param compute_mode: When set to "DEFAULT", no special requirements will be
  113. placed on the precision of intermediate results. When set to "FLOAT32",
  114. "Float32" would be used for accumulator and intermediate result, but only
  115. effective when input and output are of float16 dtype.
  116. Examples:
  117. .. testcode::
  118. import numpy as np
  119. import megengine as mge
  120. import megengine.module as M
  121. m = M.Conv2d(in_channels=3, out_channels=1, kernel_size=3)
  122. inp = mge.tensor(np.arange(0, 96).astype("float32").reshape(2, 3, 4, 4))
  123. oup = m(inp)
  124. print(oup.shape)
  125. Outputs:
  126. .. testoutput::
  127. (2, 1, 2, 2)
  128. """
  129. _conv_mode_type = P.Convolution.Mode
  130. _compute_mode_type = P.Convolution.ComputeMode
  131. def __init__(
  132. self,
  133. in_channels: int,
  134. out_channels: int,
  135. kernel_size: Union[int, Tuple[int, int]],
  136. stride: Union[int, Tuple[int, int]] = 1,
  137. padding: Union[int, Tuple[int, int]] = 0,
  138. dilation: Union[int, Tuple[int, int]] = 1,
  139. groups: int = 1,
  140. bias: bool = True,
  141. conv_mode: str = "CROSS_CORRELATION",
  142. compute_mode: str = "DEFAULT",
  143. ):
  144. kernel_size = _pair_nonzero(kernel_size)
  145. stride = _pair_nonzero(stride)
  146. padding = _pair(padding)
  147. dilation = _pair_nonzero(dilation)
  148. self.conv_mode = self._conv_mode_type.convert(conv_mode)
  149. self.compute_mode = self._compute_mode_type.convert(compute_mode)
  150. super().__init__(
  151. in_channels,
  152. out_channels,
  153. kernel_size,
  154. stride,
  155. padding,
  156. dilation,
  157. groups,
  158. bias,
  159. )
  160. def _get_fanin(self):
  161. kh, kw = self.kernel_size
  162. ic = self.in_channels
  163. return kh * kw * ic
  164. def _infer_weight_shape(self):
  165. group = self.groups
  166. ichl = self.in_channels
  167. ochl = self.out_channels
  168. kh, kw = self.kernel_size
  169. if group == 1:
  170. # Assume format is NCHW
  171. return (ochl, ichl, kh, kw)
  172. assert (
  173. ichl % group == 0 and ochl % group == 0
  174. ), "invalid config: input_channels={} output_channels={} group={}".format(
  175. ichl, ochl, group
  176. )
  177. # Assume format is NCHW
  178. return (group, ochl // group, ichl // group, kh, kw)
  179. def _infer_bias_shape(self):
  180. # Assume format is NCHW
  181. return (1, self.out_channels, 1, 1)
  182. def calc_conv(self, inp, weight, bias):
  183. return conv2d(
  184. inp,
  185. weight,
  186. bias,
  187. self.stride,
  188. self.padding,
  189. self.dilation,
  190. self.groups,
  191. self.conv_mode,
  192. self.compute_mode,
  193. )
  194. def forward(self, inp):
  195. return self.calc_conv(inp, self.weight, self.bias)
  196. class ConvTranspose2d(_ConvNd):
  197. r"""Applies a 2D transposed convolution over an input tensor.
  198. This module is also known as a deconvolution or a fractionally-strided convolution.
  199. :class:`ConvTranspose2d` can be seen as the gradient of :class:`Conv2d` operation
  200. with respect to its input.
  201. Convolution usually reduces the size of input, while transposed convolution works
  202. the opposite way, transforming a smaller input to a larger output while preserving the
  203. connectivity pattern.
  204. :param in_channels: number of input channels.
  205. :param out_channels: number of output channels.
  206. :param kernel_size: size of weight on spatial dimensions. If ``kernel_size`` is
  207. an :class:`int`, the actual kernel size would be
  208. ``(kernel_size, kernel_size)``. Default: 1
  209. :param stride: stride of the 2D convolution operation. Default: 1
  210. :param padding: size of the paddings added to the input on both sides of its
  211. spatial dimensions. Only zero-padding is supported. Default: 0
  212. :param dilation: dilation of the 2D convolution operation. Default: 1
  213. :param groups: number of groups into which the input and output channels are divided, so as to perform a "grouped convolution". When ``groups`` is not 1,
  214. ``in_channels`` and ``out_channels`` must be divisible by ``groups``,
  215. and there would be an extra dimension at the beginning of the weight's
  216. shape. Specifically, the shape of weight would be ``(groups,
  217. out_channels // groups, in_channels // groups, *kernel_size)``. Default: 1
  218. :param bias: wether to add a bias onto the result of convolution. Default:
  219. True
  220. :param conv_mode: Supports `CROSS_CORRELATION` or `CONVOLUTION`. Default:
  221. `CROSS_CORRELATION`
  222. :param compute_mode: When set to "DEFAULT", no special requirements will be
  223. placed on the precision of intermediate results. When set to "FLOAT32",
  224. "Float32" would be used for accumulator and intermediate result, but only
  225. effective when input and output are of float16 dtype.
  226. """
  227. _conv_mode_type = P.Convolution.Mode
  228. _compute_mode_type = P.Convolution.ComputeMode
  229. def __init__(
  230. self,
  231. in_channels: int,
  232. out_channels: int,
  233. kernel_size: Union[int, Tuple[int, int]],
  234. stride: Union[int, Tuple[int, int]] = 1,
  235. padding: Union[int, Tuple[int, int]] = 0,
  236. dilation: Union[int, Tuple[int, int]] = 1,
  237. groups: int = 1,
  238. bias: bool = True,
  239. conv_mode: str = "CROSS_CORRELATION",
  240. compute_mode: str = "DEFAULT",
  241. ):
  242. kernel_size = _pair_nonzero(kernel_size)
  243. stride = _pair_nonzero(stride)
  244. padding = _pair(padding)
  245. dilation = _pair_nonzero(dilation)
  246. self.conv_mode = self._conv_mode_type.convert(conv_mode)
  247. self.compute_mode = self._compute_mode_type.convert(compute_mode)
  248. super().__init__(
  249. in_channels,
  250. out_channels,
  251. kernel_size,
  252. stride,
  253. padding,
  254. dilation,
  255. groups,
  256. bias,
  257. )
  258. def _get_fanin(self):
  259. kh, kw = self.kernel_size
  260. oc = self.out_channels
  261. return kh * kw * oc
  262. def _infer_weight_shape(self):
  263. group = self.groups
  264. ichl = self.in_channels
  265. ochl = self.out_channels
  266. kh, kw = self.kernel_size
  267. if group == 1:
  268. # Assume format is NCHW
  269. return (ichl, ochl, kh, kw)
  270. assert (
  271. ichl % group == 0 and ochl % group == 0
  272. ), "invalid config: input_channels={} output_channels={} group={}".format(
  273. ichl, ochl, group
  274. )
  275. # Assume format is NCHW
  276. return (group, ichl // group, ochl // group, kh, kw)
  277. def _infer_bias_shape(self):
  278. # Assume format is NCHW
  279. return (1, self.out_channels, 1, 1)
  280. def forward(self, inp):
  281. return conv_transpose2d(
  282. inp,
  283. self.weight,
  284. self.bias,
  285. self.stride,
  286. self.padding,
  287. self.dilation,
  288. self.groups,
  289. self.conv_mode,
  290. self.compute_mode,
  291. )
  292. class LocalConv2d(Conv2d):
  293. r"""Applies a spatial convolution with untied kernels over an groupped channeled input 4D tensor.
  294. It is also known as the locally connected layer.
  295. :param in_channels: number of input channels.
  296. :param out_channels: number of output channels.
  297. :param input_height: the height of the input images.
  298. :param input_width: the width of the input images.
  299. :param kernel_size: size of weight on spatial dimensions. If kernel_size is
  300. an :class:`int`, the actual kernel size would be
  301. `(kernel_size, kernel_size)`. Default: 1
  302. :param stride: stride of the 2D convolution operation. Default: 1
  303. :param padding: size of the paddings added to the input on both sides of its
  304. spatial dimensions. Only zero-padding is supported. Default: 0
  305. :param groups: number of groups into which the input and output channels are divided,
  306. so as to perform a "grouped convolution". When ``groups`` is not 1,
  307. ``in_channels`` and ``out_channels`` must be divisible by ``groups``.
  308. The shape of weight is `(groups, output_height, output_width,
  309. in_channels // groups, *kernel_size, out_channels // groups)`.
  310. """
  311. _conv_mode_type = P.Convolution.Mode
  312. def __init__(
  313. self,
  314. in_channels: int,
  315. out_channels: int,
  316. input_height: int,
  317. input_width: int,
  318. kernel_size: Union[int, Tuple[int, int]],
  319. stride: Union[int, Tuple[int, int]] = 1,
  320. padding: Union[int, Tuple[int, int]] = 0,
  321. dilation: Union[int, Tuple[int, int]] = 1,
  322. groups: int = 1,
  323. conv_mode: str = "CROSS_CORRELATION",
  324. ):
  325. self.input_height = input_height
  326. self.input_width = input_width
  327. super().__init__(
  328. in_channels,
  329. out_channels,
  330. kernel_size,
  331. stride,
  332. padding,
  333. dilation,
  334. groups,
  335. bias=False,
  336. )
  337. def _infer_weight_shape(self):
  338. group = self.groups
  339. output_height = (
  340. self.input_height + self.padding[0] * 2 - self.kernel_size[0]
  341. ) // self.stride[0] + 1
  342. output_width = (
  343. self.input_width + self.padding[1] * 2 - self.kernel_size[1]
  344. ) // self.stride[1] + 1
  345. # Assume format is NCHW
  346. return (
  347. group,
  348. output_height,
  349. output_width,
  350. self.in_channels // group,
  351. self.kernel_size[0],
  352. self.kernel_size[1],
  353. self.out_channels // group,
  354. )
  355. def forward(self, inp):
  356. return local_conv2d(
  357. inp,
  358. self.weight,
  359. None,
  360. self.stride,
  361. self.padding,
  362. self.dilation,
  363. self.conv_mode,
  364. )
  365. class ConvRelu2d(Conv2d):
  366. r"""
  367. A fused :class:`~.Module` including Conv2d and relu. Could be replaced
  368. with :class:`~.QATModule` version :class:`~.qat.conv.ConvRelu2d` using
  369. :func:`~.quantize.quantize_qat`.
  370. """
  371. def forward(self, inp):
  372. return relu(self.calc_conv(inp, self.weight, self.bias))

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台