You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

conv.py 20 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581
  1. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  2. #
  3. # Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
  4. #
  5. # Unless required by applicable law or agreed to in writing,
  6. # software distributed under the License is distributed on an
  7. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  8. from abc import abstractmethod
  9. from typing import Tuple, Union
  10. import numpy as np
  11. from ..functional import conv1d, conv2d, conv_transpose2d, local_conv2d, relu
  12. from ..functional.types import _pair, _pair_nonzero
  13. from ..tensor import Parameter
  14. from . import init
  15. from .module import Module
  16. class _ConvNd(Module):
  17. """base class for convolution modules, including transposed conv"""
  18. def __init__(
  19. self,
  20. in_channels: int,
  21. out_channels: int,
  22. kernel_size: Union[int, Tuple[int, int]],
  23. stride: Union[int, Tuple[int, int]],
  24. padding: Union[int, Tuple[int, int]],
  25. dilation: Union[int, Tuple[int, int]],
  26. groups: int,
  27. bias: bool = True,
  28. ):
  29. super().__init__()
  30. if in_channels % groups != 0:
  31. raise ValueError("in_channels must be divisible by groups")
  32. if out_channels % groups != 0:
  33. raise ValueError("out_channels must be divisible by groups")
  34. self.in_channels = in_channels
  35. self.out_channels = out_channels
  36. self.kernel_size = kernel_size
  37. self.stride = stride
  38. self.padding = padding
  39. self.dilation = dilation
  40. self.groups = groups
  41. self.weight = Parameter(np.zeros(self._infer_weight_shape(), dtype=np.float32))
  42. self.bias = None
  43. if bias:
  44. self.bias = Parameter(np.zeros(self._infer_bias_shape(), dtype=np.float32))
  45. self.reset_parameters()
  46. @abstractmethod
  47. def _get_fanin(self):
  48. pass
  49. def reset_parameters(self) -> None:
  50. fanin = self._get_fanin()
  51. std = np.sqrt(1 / fanin)
  52. init.normal_(self.weight, 0.0, std)
  53. if self.bias is not None:
  54. init.zeros_(self.bias)
  55. @abstractmethod
  56. def _infer_weight_shape(self):
  57. pass
  58. @abstractmethod
  59. def _infer_bias_shape(self):
  60. pass
  61. def _module_info_string(self):
  62. s = "{in_channels}, {out_channels}, kernel_size={kernel_size}"
  63. if self.stride != (1,) * len(self.stride):
  64. s += ", stride={stride}"
  65. if self.padding != (0,) * len(self.padding):
  66. s += ", padding={padding}"
  67. if self.dilation != (1,) * len(self.dilation):
  68. s += ", dilation={dilation}"
  69. if self.groups != 1:
  70. s += ", groups={groups}"
  71. if self.bias is None:
  72. s += ", bias=False"
  73. return s.format(**self.__dict__)
  74. class Conv1d(_ConvNd):
  75. r"""
  76. Applies a 1D convolution over an input tensor.
  77. For instance, given an input of the size :math:`(N, C_{\text{in}}, H)`,
  78. this layer generates an output of the size
  79. :math:`(N, C_{\text{out}}, H_{\text{out}}})` through the
  80. process described as below:
  81. .. math::
  82. \text{out}(N_i, C_{\text{out}_j}) = \text{bias}(C_{\text{out}_j}) +
  83. \sum_{k = 0}^{C_{\text{in}} - 1} \text{weight}(C_{\text{out}_j}, k) \star \text{input}(N_i, k)
  84. where :math:`\star` is the valid 1D cross-correlation operator,
  85. :math:`N` is batch size, :math:`C` denotes number of channels, and
  86. :math:`H` is length of 1D data element.
  87. When `groups == in_channels` and `out_channels == K * in_channels`,
  88. where K is a positive integer, this operation is also known as depthwise
  89. convolution.
  90. In other words, for an input of size :math:`(N, C_{in}, H_{in})`,
  91. a depthwise convolution with a depthwise multiplier `K`, can be constructed
  92. by arguments :math:`(in\_channels=C_{in}, out\_channels=C_{in} \times K, ..., groups=C_{in})`.
  93. :param in_channels: number of input channels.
  94. :param out_channels: number of output channels.
  95. :param kernel_size: size of weight on spatial dimensions. If kernel_size is
  96. an :class:`int`, the actual kernel size would be
  97. `(kernel_size, kernel_size)`. Default: 1
  98. :param stride: stride of the 1D convolution operation. Default: 1
  99. :param padding: size of the paddings added to the input on both sides of its
  100. spatial dimensions. Only zero-padding is supported. Default: 0
  101. :param dilation: dilation of the 1D convolution operation. Default: 1
  102. :param groups: number of groups into which the input and output channels are divided, so as to perform a "grouped convolution". When ``groups`` is not 1,
  103. ``in_channels`` and ``out_channels`` must be divisible by ``groups``,
  104. and there would be an extra dimension at the beginning of the weight's
  105. shape. Specifically, the shape of weight would be `(groups,
  106. out_channel // groups, in_channels // groups, *kernel_size)`.
  107. :param bias: whether to add a bias onto the result of convolution. Default:
  108. True
  109. :param conv_mode: Supports `CROSS_CORRELATION`. Default:
  110. `CROSS_CORRELATION`
  111. :param compute_mode: When set to "DEFAULT", no special requirements will be
  112. placed on the precision of intermediate results. When set to "FLOAT32",
  113. "Float32" would be used for accumulator and intermediate result, but only
  114. effective when input and output are of float16 dtype.
  115. Examples:
  116. .. testcode::
  117. import numpy as np
  118. import megengine as mge
  119. import megengine.module as M
  120. m = M.Conv1d(in_channels=3, out_channels=1, kernel_size=3)
  121. inp = mge.tensor(np.arange(0, 24).astype("float32").reshape(2, 3, 4))
  122. oup = m(inp)
  123. print(oup.numpy().shape)
  124. Outputs:
  125. .. testoutput::
  126. (2, 1, 2)
  127. """
  128. def __init__(
  129. self,
  130. in_channels: int,
  131. out_channels: int,
  132. kernel_size: int,
  133. stride: int = 1,
  134. padding: int = 0,
  135. dilation: int = 1,
  136. groups: int = 1,
  137. bias: bool = True,
  138. conv_mode: str = "CROSS_CORRELATION",
  139. compute_mode: str = "DEFAULT",
  140. ):
  141. kernel_size = kernel_size
  142. stride = stride
  143. padding = padding
  144. dilation = dilation
  145. self.conv_mode = conv_mode
  146. self.compute_mode = compute_mode
  147. super().__init__(
  148. in_channels,
  149. out_channels,
  150. kernel_size,
  151. stride,
  152. padding,
  153. dilation,
  154. groups,
  155. bias,
  156. )
  157. def _get_fanin(self):
  158. kh = self.kernel_size
  159. ic = self.in_channels
  160. return kh * ic
  161. def _infer_weight_shape(self):
  162. group = self.groups
  163. ichl = self.in_channels
  164. ochl = self.out_channels
  165. kh = self.kernel_size
  166. if group == 1:
  167. # Assume format is NCH(W=1)
  168. return (ochl, ichl, kh)
  169. assert (
  170. ichl % group == 0 and ochl % group == 0
  171. ), "invalid config: input_channels={} output_channels={} group={}".format(
  172. ichl, ochl, group
  173. )
  174. # Assume format is NCH(W=1)
  175. return (group, ochl // group, ichl // group, kh)
  176. def _infer_bias_shape(self):
  177. # Assume format is NCH(W=1)
  178. return (1, self.out_channels, 1)
  179. def calc_conv(self, inp, weight, bias):
  180. return conv1d(
  181. inp,
  182. weight,
  183. bias,
  184. self.stride,
  185. self.padding,
  186. self.dilation,
  187. self.groups,
  188. self.conv_mode,
  189. self.compute_mode,
  190. )
  191. def forward(self, inp):
  192. return self.calc_conv(inp, self.weight, self.bias)
  193. class Conv2d(_ConvNd):
  194. r"""
  195. Applies a 2D convolution over an input tensor.
  196. For instance, given an input of the size :math:`(N, C_{\text{in}}, H, W)`,
  197. this layer generates an output of the size
  198. :math:`(N, C_{\text{out}}, H_{\text{out}}, W_{\text{out}})` through the
  199. process described as below:
  200. .. math::
  201. \text{out}(N_i, C_{\text{out}_j}) = \text{bias}(C_{\text{out}_j}) +
  202. \sum_{k = 0}^{C_{\text{in}} - 1} \text{weight}(C_{\text{out}_j}, k) \star \text{input}(N_i, k)
  203. where :math:`\star` is the valid 2D cross-correlation operator,
  204. :math:`N` is batch size, :math:`C` denotes number of channels,
  205. :math:`H` is height of input planes in pixels, and :math:`W` is
  206. width in pixels.
  207. In general, output feature maps' shapes can be inferred as follows:
  208. input: :math:`(N, C_{\text{in}}, H_{\text{in}}, W_{\text{in}})`
  209. output: :math:`(N, C_{\text{out}}, H_{\text{out}}, W_{\text{out}})` where
  210. .. math::
  211. \text{H}_{out} = \lfloor \frac{\text{H}_{in} + 2 * \text{padding[0]} -
  212. \text{dilation[0]} * (\text{kernel_size[0]} - 1)}{\text{stride[0]}} + 1 \rfloor
  213. .. math::
  214. \text{W}_{out} = \lfloor \frac{\text{W}_{in} + 2 * \text{padding[1]} -
  215. \text{dilation[1]} * (\text{kernel_size[1]} - 1)}{\text{stride[1]}} + 1 \rfloor
  216. When `groups == in_channels` and `out_channels == K * in_channels`,
  217. where K is a positive integer, this operation is also known as depthwise
  218. convolution.
  219. In other words, for an input of size :math:`(N, C_{in}, H_{in}, W_{in})`,
  220. a depthwise convolution with a depthwise multiplier `K`, can be constructed
  221. by arguments :math:`(in\_channels=C_{in}, out\_channels=C_{in} \times K, ..., groups=C_{in})`.
  222. :param in_channels: number of input channels.
  223. :param out_channels: number of output channels.
  224. :param kernel_size: size of weight on spatial dimensions. If kernel_size is
  225. an :class:`int`, the actual kernel size would be
  226. `(kernel_size, kernel_size)`. Default: 1
  227. :param stride: stride of the 2D convolution operation. Default: 1
  228. :param padding: size of the paddings added to the input on both sides of its
  229. spatial dimensions. Only zero-padding is supported. Default: 0
  230. :param dilation: dilation of the 2D convolution operation. Default: 1
  231. :param groups: number of groups into which the input and output channels are divided, so as to perform a "grouped convolution". When ``groups`` is not 1,
  232. ``in_channels`` and ``out_channels`` must be divisible by ``groups``,
  233. and there would be an extra dimension at the beginning of the weight's
  234. shape. Specifically, the shape of weight would be `(groups,
  235. out_channel // groups, in_channels // groups, *kernel_size)`.
  236. :param bias: whether to add a bias onto the result of convolution. Default:
  237. True
  238. :param conv_mode: Supports `CROSS_CORRELATION`. Default:
  239. `CROSS_CORRELATION`
  240. :param compute_mode: When set to "DEFAULT", no special requirements will be
  241. placed on the precision of intermediate results. When set to "FLOAT32",
  242. "Float32" would be used for accumulator and intermediate result, but only
  243. effective when input and output are of float16 dtype.
  244. Examples:
  245. .. testcode::
  246. import numpy as np
  247. import megengine as mge
  248. import megengine.module as M
  249. m = M.Conv2d(in_channels=3, out_channels=1, kernel_size=3)
  250. inp = mge.tensor(np.arange(0, 96).astype("float32").reshape(2, 3, 4, 4))
  251. oup = m(inp)
  252. print(oup.numpy().shape)
  253. Outputs:
  254. .. testoutput::
  255. (2, 1, 2, 2)
  256. """
  257. def __init__(
  258. self,
  259. in_channels: int,
  260. out_channels: int,
  261. kernel_size: Union[int, Tuple[int, int]],
  262. stride: Union[int, Tuple[int, int]] = 1,
  263. padding: Union[int, Tuple[int, int]] = 0,
  264. dilation: Union[int, Tuple[int, int]] = 1,
  265. groups: int = 1,
  266. bias: bool = True,
  267. conv_mode: str = "CROSS_CORRELATION",
  268. compute_mode: str = "DEFAULT",
  269. ):
  270. kernel_size = _pair_nonzero(kernel_size)
  271. stride = _pair_nonzero(stride)
  272. padding = _pair(padding)
  273. dilation = _pair_nonzero(dilation)
  274. self.conv_mode = conv_mode
  275. self.compute_mode = compute_mode
  276. super().__init__(
  277. in_channels,
  278. out_channels,
  279. kernel_size,
  280. stride,
  281. padding,
  282. dilation,
  283. groups,
  284. bias,
  285. )
  286. def _get_fanin(self):
  287. kh, kw = self.kernel_size
  288. ic = self.in_channels
  289. return kh * kw * ic
  290. def _infer_weight_shape(self):
  291. group = self.groups
  292. ichl = self.in_channels
  293. ochl = self.out_channels
  294. kh, kw = self.kernel_size
  295. if group == 1:
  296. # Assume format is NCHW
  297. return (ochl, ichl, kh, kw)
  298. assert (
  299. ichl % group == 0 and ochl % group == 0
  300. ), "invalid config: input_channels={} output_channels={} group={}".format(
  301. ichl, ochl, group
  302. )
  303. # Assume format is NCHW
  304. return (group, ochl // group, ichl // group, kh, kw)
  305. def _infer_bias_shape(self):
  306. # Assume format is NCHW
  307. return (1, self.out_channels, 1, 1)
  308. def calc_conv(self, inp, weight, bias):
  309. return conv2d(
  310. inp,
  311. weight,
  312. bias,
  313. self.stride,
  314. self.padding,
  315. self.dilation,
  316. self.groups,
  317. self.conv_mode,
  318. self.compute_mode,
  319. )
  320. def forward(self, inp):
  321. return self.calc_conv(inp, self.weight, self.bias)
  322. class ConvTranspose2d(_ConvNd):
  323. r"""
  324. Applies a 2D transposed convolution over an input tensor.
  325. This module is also known as a deconvolution or a fractionally-strided convolution.
  326. :class:`ConvTranspose2d` can be seen as the gradient of :class:`Conv2d` operation
  327. with respect to its input.
  328. Convolution usually reduces the size of input, while transposed convolution works
  329. the opposite way, transforming a smaller input to a larger output while preserving the
  330. connectivity pattern.
  331. :param in_channels: number of input channels.
  332. :param out_channels: number of output channels.
  333. :param kernel_size: size of weight on spatial dimensions. If ``kernel_size`` is
  334. an :class:`int`, the actual kernel size would be
  335. ``(kernel_size, kernel_size)``. Default: 1
  336. :param stride: stride of the 2D convolution operation. Default: 1
  337. :param padding: size of the paddings added to the input on both sides of its
  338. spatial dimensions. Only zero-padding is supported. Default: 0
  339. :param dilation: dilation of the 2D convolution operation. Default: 1
  340. :param groups: number of groups into which the input and output channels are divided, so as to perform a "grouped convolution". When ``groups`` is not 1,
  341. ``in_channels`` and ``out_channels`` must be divisible by ``groups``,
  342. and there would be an extra dimension at the beginning of the weight's
  343. shape. Specifically, the shape of weight would be ``(groups,
  344. out_channels // groups, in_channels // groups, *kernel_size)``. Default: 1
  345. :param bias: wether to add a bias onto the result of convolution. Default:
  346. True
  347. :param conv_mode: Supports `CROSS_CORRELATION`. Default:
  348. `CROSS_CORRELATION`
  349. :param compute_mode: When set to "DEFAULT", no special requirements will be
  350. placed on the precision of intermediate results. When set to "FLOAT32",
  351. "Float32" would be used for accumulator and intermediate result, but only
  352. effective when input and output are of float16 dtype.
  353. """
  354. def __init__(
  355. self,
  356. in_channels: int,
  357. out_channels: int,
  358. kernel_size: Union[int, Tuple[int, int]],
  359. stride: Union[int, Tuple[int, int]] = 1,
  360. padding: Union[int, Tuple[int, int]] = 0,
  361. dilation: Union[int, Tuple[int, int]] = 1,
  362. groups: int = 1,
  363. bias: bool = True,
  364. conv_mode: str = "CROSS_CORRELATION",
  365. compute_mode: str = "DEFAULT",
  366. ):
  367. kernel_size = _pair_nonzero(kernel_size)
  368. stride = _pair_nonzero(stride)
  369. padding = _pair(padding)
  370. dilation = _pair_nonzero(dilation)
  371. self.conv_mode = conv_mode
  372. self.compute_mode = compute_mode
  373. super().__init__(
  374. in_channels,
  375. out_channels,
  376. kernel_size,
  377. stride,
  378. padding,
  379. dilation,
  380. groups,
  381. bias,
  382. )
  383. def _get_fanin(self):
  384. kh, kw = self.kernel_size
  385. oc = self.out_channels
  386. return kh * kw * oc
  387. def _infer_weight_shape(self):
  388. group = self.groups
  389. ichl = self.in_channels
  390. ochl = self.out_channels
  391. kh, kw = self.kernel_size
  392. if group == 1:
  393. # Assume format is NCHW
  394. return (ichl, ochl, kh, kw)
  395. assert (
  396. ichl % group == 0 and ochl % group == 0
  397. ), "invalid config: input_channels={} output_channels={} group={}".format(
  398. ichl, ochl, group
  399. )
  400. # Assume format is NCHW
  401. return (group, ichl // group, ochl // group, kh, kw)
  402. def _infer_bias_shape(self):
  403. # Assume format is NCHW
  404. return (1, self.out_channels, 1, 1)
  405. def forward(self, inp):
  406. return conv_transpose2d(
  407. inp,
  408. self.weight,
  409. self.bias,
  410. self.stride,
  411. self.padding,
  412. self.dilation,
  413. self.groups,
  414. self.conv_mode,
  415. self.compute_mode,
  416. )
  417. class LocalConv2d(Conv2d):
  418. r"""
  419. Applies a spatial convolution with untied kernels over an groupped channeled input 4D tensor.
  420. It is also known as the locally connected layer.
  421. :param in_channels: number of input channels.
  422. :param out_channels: number of output channels.
  423. :param input_height: the height of the input images.
  424. :param input_width: the width of the input images.
  425. :param kernel_size: size of weight on spatial dimensions. If kernel_size is
  426. an :class:`int`, the actual kernel size would be
  427. `(kernel_size, kernel_size)`. Default: 1
  428. :param stride: stride of the 2D convolution operation. Default: 1
  429. :param padding: size of the paddings added to the input on both sides of its
  430. spatial dimensions. Only zero-padding is supported. Default: 0
  431. :param groups: number of groups into which the input and output channels are divided,
  432. so as to perform a "grouped convolution". When ``groups`` is not 1,
  433. ``in_channels`` and ``out_channels`` must be divisible by ``groups``.
  434. The shape of weight is `(groups, output_height, output_width,
  435. in_channels // groups, *kernel_size, out_channels // groups)`.
  436. """
  437. def __init__(
  438. self,
  439. in_channels: int,
  440. out_channels: int,
  441. input_height: int,
  442. input_width: int,
  443. kernel_size: Union[int, Tuple[int, int]],
  444. stride: Union[int, Tuple[int, int]] = 1,
  445. padding: Union[int, Tuple[int, int]] = 0,
  446. dilation: Union[int, Tuple[int, int]] = 1,
  447. groups: int = 1,
  448. conv_mode: str = "CROSS_CORRELATION",
  449. ):
  450. self.input_height = input_height
  451. self.input_width = input_width
  452. super().__init__(
  453. in_channels,
  454. out_channels,
  455. kernel_size,
  456. stride,
  457. padding,
  458. dilation,
  459. groups,
  460. bias=False,
  461. )
  462. def _infer_weight_shape(self):
  463. group = self.groups
  464. output_height = (
  465. self.input_height + self.padding[0] * 2 - self.kernel_size[0]
  466. ) // self.stride[0] + 1
  467. output_width = (
  468. self.input_width + self.padding[1] * 2 - self.kernel_size[1]
  469. ) // self.stride[1] + 1
  470. # Assume format is NCHW
  471. return (
  472. group,
  473. output_height,
  474. output_width,
  475. self.in_channels // group,
  476. self.kernel_size[0],
  477. self.kernel_size[1],
  478. self.out_channels // group,
  479. )
  480. def forward(self, inp):
  481. return local_conv2d(
  482. inp,
  483. self.weight,
  484. None,
  485. self.stride,
  486. self.padding,
  487. self.dilation,
  488. self.conv_mode,
  489. )
  490. class ConvRelu2d(Conv2d):
  491. r"""
  492. A fused :class:`~.Module` including Conv2d and relu. Could be replaced
  493. with :class:`~.QATModule` version :class:`~.qat.conv.ConvRelu2d` using
  494. :func:`~.quantize.quantize_qat`.
  495. """
  496. def forward(self, inp):
  497. return relu(self.calc_conv(inp, self.weight, self.bias))

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台