You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

conv.py 24 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705
  1. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  2. #
  3. # Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  4. #
  5. # Unless required by applicable law or agreed to in writing,
  6. # software distributed under the License is distributed on an
  7. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  8. from abc import abstractmethod
  9. from typing import Tuple, Union
  10. import numpy as np
  11. from ..functional import (
  12. conv1d,
  13. conv2d,
  14. conv_transpose2d,
  15. deformable_conv2d,
  16. local_conv2d,
  17. relu,
  18. )
  19. from ..tensor import Parameter
  20. from ..utils.tuple_function import _pair, _pair_nonzero
  21. from . import init
  22. from .module import Module
  23. class _ConvNd(Module):
  24. """base class for convolution modules, including transposed conv"""
  25. def __init__(
  26. self,
  27. in_channels: int,
  28. out_channels: int,
  29. kernel_size: Union[int, Tuple[int, int]],
  30. stride: Union[int, Tuple[int, int]],
  31. padding: Union[int, Tuple[int, int]],
  32. dilation: Union[int, Tuple[int, int]],
  33. groups: int,
  34. bias: bool = True,
  35. **kwargs
  36. ):
  37. super().__init__(**kwargs)
  38. if in_channels % groups != 0:
  39. raise ValueError("in_channels must be divisible by groups")
  40. if out_channels % groups != 0:
  41. raise ValueError("out_channels must be divisible by groups")
  42. self.in_channels = in_channels
  43. self.out_channels = out_channels
  44. self.kernel_size = kernel_size
  45. self.stride = stride
  46. self.padding = padding
  47. self.dilation = dilation
  48. self.groups = groups
  49. self.weight = Parameter(np.zeros(self._infer_weight_shape(), dtype=np.float32))
  50. self.bias = None
  51. if bias:
  52. self.bias = Parameter(np.zeros(self._infer_bias_shape(), dtype=np.float32))
  53. self.reset_parameters()
  54. @abstractmethod
  55. def _get_fanin(self):
  56. pass
  57. def reset_parameters(self) -> None:
  58. fanin = self._get_fanin()
  59. std = np.sqrt(1 / fanin)
  60. init.normal_(self.weight, 0.0, std)
  61. if self.bias is not None:
  62. init.zeros_(self.bias)
  63. @abstractmethod
  64. def _infer_weight_shape(self):
  65. pass
  66. @abstractmethod
  67. def _infer_bias_shape(self):
  68. pass
  69. def _module_info_string(self):
  70. s = "{in_channels}, {out_channels}, kernel_size={kernel_size}"
  71. if self.stride != (1,) * len(self.stride):
  72. s += ", stride={stride}"
  73. if self.padding != (0,) * len(self.padding):
  74. s += ", padding={padding}"
  75. if self.dilation != (1,) * len(self.dilation):
  76. s += ", dilation={dilation}"
  77. if self.groups != 1:
  78. s += ", groups={groups}"
  79. if self.bias is None:
  80. s += ", bias=False"
  81. return s.format(**self.__dict__)
  82. class Conv1d(_ConvNd):
  83. r"""
  84. Applies a 1D convolution over an input tensor.
  85. For instance, given an input of the size :math:`(N, C_{\text{in}}, H)`,
  86. this layer generates an output of the size
  87. :math:`(N, C_{\text{out}}, H_{\text{out}}})` through the
  88. process described as below:
  89. .. math::
  90. \text{out}(N_i, C_{\text{out}_j}) = \text{bias}(C_{\text{out}_j}) +
  91. \sum_{k = 0}^{C_{\text{in}} - 1} \text{weight}(C_{\text{out}_j}, k) \star \text{input}(N_i, k)
  92. where :math:`\star` is the valid 1D cross-correlation operator,
  93. :math:`N` is batch size, :math:`C` denotes number of channels, and
  94. :math:`H` is length of 1D data element.
  95. When `groups == in_channels` and `out_channels == K * in_channels`,
  96. where K is a positive integer, this operation is also known as depthwise
  97. convolution.
  98. In other words, for an input of size :math:`(N, C_{in}, H_{in})`,
  99. a depthwise convolution with a depthwise multiplier `K`, can be constructed
  100. by arguments :math:`(in\_channels=C_{in}, out\_channels=C_{in} \times K, ..., groups=C_{in})`.
  101. :param in_channels: number of input channels.
  102. :param out_channels: number of output channels.
  103. :param kernel_size: size of weight on spatial dimensions. If kernel_size is
  104. an :class:`int`, the actual kernel size would be
  105. `(kernel_size, kernel_size)`. Default: 1
  106. :param stride: stride of the 1D convolution operation. Default: 1
  107. :param padding: size of the paddings added to the input on both sides of its
  108. spatial dimensions. Only zero-padding is supported. Default: 0
  109. :param dilation: dilation of the 1D convolution operation. Default: 1
  110. :param groups: number of groups into which the input and output channels are divided,
  111. so as to perform a "grouped convolution". When ``groups`` is not 1,
  112. ``in_channels`` and ``out_channels`` must be divisible by ``groups``,
  113. and there would be an extra dimension at the beginning of the weight's
  114. shape. Specifically, the shape of weight would be `(groups,
  115. out_channel // groups, in_channels // groups, *kernel_size)`.
  116. :param bias: whether to add a bias onto the result of convolution. Default:
  117. True
  118. :param conv_mode: Supports `CROSS_CORRELATION`. Default:
  119. `CROSS_CORRELATION`
  120. :param compute_mode: When set to "DEFAULT", no special requirements will be
  121. placed on the precision of intermediate results. When set to "FLOAT32",
  122. "Float32" would be used for accumulator and intermediate result, but only
  123. effective when input and output are of float16 dtype.
  124. Examples:
  125. .. testcode::
  126. import numpy as np
  127. import megengine as mge
  128. import megengine.module as M
  129. m = M.Conv1d(in_channels=3, out_channels=1, kernel_size=3)
  130. inp = mge.tensor(np.arange(0, 24).astype("float32").reshape(2, 3, 4))
  131. oup = m(inp)
  132. print(oup.numpy().shape)
  133. Outputs:
  134. .. testoutput::
  135. (2, 1, 2)
  136. """
  137. def __init__(
  138. self,
  139. in_channels: int,
  140. out_channels: int,
  141. kernel_size: int,
  142. stride: int = 1,
  143. padding: int = 0,
  144. dilation: int = 1,
  145. groups: int = 1,
  146. bias: bool = True,
  147. conv_mode: str = "CROSS_CORRELATION",
  148. compute_mode: str = "DEFAULT",
  149. **kwargs
  150. ):
  151. kernel_size = kernel_size
  152. stride = stride
  153. padding = padding
  154. dilation = dilation
  155. self.conv_mode = conv_mode
  156. self.compute_mode = compute_mode
  157. super().__init__(
  158. in_channels,
  159. out_channels,
  160. kernel_size,
  161. stride,
  162. padding,
  163. dilation,
  164. groups,
  165. bias,
  166. **kwargs,
  167. )
  168. def _get_fanin(self):
  169. kh = self.kernel_size
  170. ic = self.in_channels
  171. return kh * ic
  172. def _infer_weight_shape(self):
  173. group = self.groups
  174. ichl = self.in_channels
  175. ochl = self.out_channels
  176. kh = self.kernel_size
  177. if group == 1:
  178. # Assume format is NCH(W=1)
  179. return (ochl, ichl, kh)
  180. assert (
  181. ichl % group == 0 and ochl % group == 0
  182. ), "invalid config: input_channels={} output_channels={} group={}".format(
  183. ichl, ochl, group
  184. )
  185. # Assume format is NCH(W=1)
  186. return (group, ochl // group, ichl // group, kh)
  187. def _infer_bias_shape(self):
  188. # Assume format is NCH(W=1)
  189. return (1, self.out_channels, 1)
  190. def calc_conv(self, inp, weight, bias):
  191. return conv1d(
  192. inp,
  193. weight,
  194. bias,
  195. self.stride,
  196. self.padding,
  197. self.dilation,
  198. self.groups,
  199. self.conv_mode,
  200. self.compute_mode,
  201. )
  202. def forward(self, inp):
  203. return self.calc_conv(inp, self.weight, self.bias)
  204. class Conv2d(_ConvNd):
  205. r"""
  206. Applies a 2D convolution over an input tensor.
  207. For instance, given an input of the size :math:`(N, C_{\text{in}}, H, W)`,
  208. this layer generates an output of the size
  209. :math:`(N, C_{\text{out}}, H_{\text{out}}, W_{\text{out}})` through the
  210. process described as below:
  211. .. math::
  212. \text{out}(N_i, C_{\text{out}_j}) = \text{bias}(C_{\text{out}_j}) +
  213. \sum_{k = 0}^{C_{\text{in}} - 1} \text{weight}(C_{\text{out}_j}, k) \star \text{input}(N_i, k)
  214. where :math:`\star` is the valid 2D cross-correlation operator,
  215. :math:`N` is batch size, :math:`C` denotes number of channels,
  216. :math:`H` is height of input planes in pixels, and :math:`W` is
  217. width in pixels.
  218. In general, output feature maps' shapes can be inferred as follows:
  219. input: :math:`(N, C_{\text{in}}, H_{\text{in}}, W_{\text{in}})`
  220. output: :math:`(N, C_{\text{out}}, H_{\text{out}}, W_{\text{out}})` where
  221. .. math::
  222. \text{H}_{out} = \lfloor \frac{\text{H}_{in} + 2 * \text{padding[0]} -
  223. \text{dilation[0]} * (\text{kernel_size[0]} - 1) - 1}{\text{stride[0]}} + 1 \rfloor
  224. .. math::
  225. \text{W}_{out} = \lfloor \frac{\text{W}_{in} + 2 * \text{padding[1]} -
  226. \text{dilation[1]} * (\text{kernel_size[1]} - 1) - 1}{\text{stride[1]}} + 1 \rfloor
  227. When `groups == in_channels` and `out_channels == K * in_channels`,
  228. where K is a positive integer, this operation is also known as depthwise
  229. convolution.
  230. In other words, for an input of size :math:`(N, C_{in}, H_{in}, W_{in})`,
  231. a depthwise convolution with a depthwise multiplier `K`, can be constructed
  232. by arguments :math:`(in\_channels=C_{in}, out\_channels=C_{in} \times K, ..., groups=C_{in})`.
  233. :param in_channels: number of input channels.
  234. :param out_channels: number of output channels.
  235. :param kernel_size: size of weight on spatial dimensions. If kernel_size is
  236. an :class:`int`, the actual kernel size would be
  237. `(kernel_size, kernel_size)`. Default: 1
  238. :param stride: stride of the 2D convolution operation. Default: 1
  239. :param padding: size of the paddings added to the input on both sides of its
  240. spatial dimensions. Only zero-padding is supported. Default: 0
  241. :param dilation: dilation of the 2D convolution operation. Default: 1
  242. :param groups: number of groups into which the input and output channels are divided,
  243. so as to perform a "grouped convolution". When ``groups`` is not 1,
  244. ``in_channels`` and ``out_channels`` must be divisible by ``groups``,
  245. and there would be an extra dimension at the beginning of the weight's
  246. shape. Specifically, the shape of weight would be `(groups,
  247. out_channel // groups, in_channels // groups, *kernel_size)`.
  248. :param bias: whether to add a bias onto the result of convolution. Default:
  249. True
  250. :param conv_mode: Supports `CROSS_CORRELATION`. Default:
  251. `CROSS_CORRELATION`
  252. :param compute_mode: When set to "DEFAULT", no special requirements will be
  253. placed on the precision of intermediate results. When set to "FLOAT32",
  254. "Float32" would be used for accumulator and intermediate result, but only
  255. effective when input and output are of float16 dtype.
  256. Examples:
  257. .. testcode::
  258. import numpy as np
  259. import megengine as mge
  260. import megengine.module as M
  261. m = M.Conv2d(in_channels=3, out_channels=1, kernel_size=3)
  262. inp = mge.tensor(np.arange(0, 96).astype("float32").reshape(2, 3, 4, 4))
  263. oup = m(inp)
  264. print(oup.numpy().shape)
  265. Outputs:
  266. .. testoutput::
  267. (2, 1, 2, 2)
  268. """
  269. def __init__(
  270. self,
  271. in_channels: int,
  272. out_channels: int,
  273. kernel_size: Union[int, Tuple[int, int]],
  274. stride: Union[int, Tuple[int, int]] = 1,
  275. padding: Union[int, Tuple[int, int]] = 0,
  276. dilation: Union[int, Tuple[int, int]] = 1,
  277. groups: int = 1,
  278. bias: bool = True,
  279. conv_mode: str = "CROSS_CORRELATION",
  280. compute_mode: str = "DEFAULT",
  281. **kwargs
  282. ):
  283. kernel_size = _pair_nonzero(kernel_size)
  284. stride = _pair_nonzero(stride)
  285. padding = _pair(padding)
  286. dilation = _pair_nonzero(dilation)
  287. self.conv_mode = conv_mode
  288. self.compute_mode = compute_mode
  289. super().__init__(
  290. in_channels,
  291. out_channels,
  292. kernel_size,
  293. stride,
  294. padding,
  295. dilation,
  296. groups,
  297. bias,
  298. **kwargs,
  299. )
  300. def _get_fanin(self):
  301. kh, kw = self.kernel_size
  302. ic = self.in_channels
  303. return kh * kw * ic
  304. def _infer_weight_shape(self):
  305. group = self.groups
  306. ichl = self.in_channels
  307. ochl = self.out_channels
  308. kh, kw = self.kernel_size
  309. if group == 1:
  310. # Assume format is NCHW
  311. return (ochl, ichl, kh, kw)
  312. assert (
  313. ichl % group == 0 and ochl % group == 0
  314. ), "invalid config: input_channels={} output_channels={} group={}".format(
  315. ichl, ochl, group
  316. )
  317. # Assume format is NCHW
  318. return (group, ochl // group, ichl // group, kh, kw)
  319. def _infer_bias_shape(self):
  320. # Assume format is NCHW
  321. return (1, self.out_channels, 1, 1)
  322. def calc_conv(self, inp, weight, bias):
  323. return conv2d(
  324. inp,
  325. weight,
  326. bias,
  327. self.stride,
  328. self.padding,
  329. self.dilation,
  330. self.groups,
  331. self.conv_mode,
  332. self.compute_mode,
  333. )
  334. def forward(self, inp):
  335. return self.calc_conv(inp, self.weight, self.bias)
  336. class ConvTranspose2d(_ConvNd):
  337. r"""
  338. Applies a 2D transposed convolution over an input tensor.
  339. This module is also known as a deconvolution or a fractionally-strided convolution.
  340. :class:`ConvTranspose2d` can be seen as the gradient of :class:`Conv2d` operation
  341. with respect to its input.
  342. Convolution usually reduces the size of input, while transposed convolution works
  343. the opposite way, transforming a smaller input to a larger output while preserving the
  344. connectivity pattern.
  345. :param in_channels: number of input channels.
  346. :param out_channels: number of output channels.
  347. :param kernel_size: size of weight on spatial dimensions. If ``kernel_size`` is
  348. an :class:`int`, the actual kernel size would be
  349. ``(kernel_size, kernel_size)``. Default: 1
  350. :param stride: stride of the 2D convolution operation. Default: 1
  351. :param padding: size of the paddings added to the input on both sides of its
  352. spatial dimensions. Only zero-padding is supported. Default: 0
  353. :param dilation: dilation of the 2D convolution operation. Default: 1
  354. :param groups: number of groups into which the input and output channels are divided,
  355. so as to perform a "grouped convolution". When ``groups`` is not 1,
  356. ``in_channels`` and ``out_channels`` must be divisible by ``groups``,
  357. and there would be an extra dimension at the beginning of the weight's
  358. shape. Specifically, the shape of weight would be ``(groups,
  359. out_channels // groups, in_channels // groups, *kernel_size)``. Default: 1
  360. :param bias: wether to add a bias onto the result of convolution. Default:
  361. True
  362. :param conv_mode: Supports `CROSS_CORRELATION`. Default:
  363. `CROSS_CORRELATION`
  364. :param compute_mode: When set to "DEFAULT", no special requirements will be
  365. placed on the precision of intermediate results. When set to "FLOAT32",
  366. "Float32" would be used for accumulator and intermediate result, but only
  367. effective when input and output are of float16 dtype.
  368. """
  369. def __init__(
  370. self,
  371. in_channels: int,
  372. out_channels: int,
  373. kernel_size: Union[int, Tuple[int, int]],
  374. stride: Union[int, Tuple[int, int]] = 1,
  375. padding: Union[int, Tuple[int, int]] = 0,
  376. dilation: Union[int, Tuple[int, int]] = 1,
  377. groups: int = 1,
  378. bias: bool = True,
  379. conv_mode: str = "CROSS_CORRELATION",
  380. compute_mode: str = "DEFAULT",
  381. **kwargs
  382. ):
  383. kernel_size = _pair_nonzero(kernel_size)
  384. stride = _pair_nonzero(stride)
  385. padding = _pair(padding)
  386. dilation = _pair_nonzero(dilation)
  387. self.conv_mode = conv_mode
  388. self.compute_mode = compute_mode
  389. super().__init__(
  390. in_channels,
  391. out_channels,
  392. kernel_size,
  393. stride,
  394. padding,
  395. dilation,
  396. groups,
  397. bias,
  398. **kwargs,
  399. )
  400. def _get_fanin(self):
  401. kh, kw = self.kernel_size
  402. oc = self.out_channels
  403. return kh * kw * oc
  404. def _infer_weight_shape(self):
  405. group = self.groups
  406. ichl = self.in_channels
  407. ochl = self.out_channels
  408. kh, kw = self.kernel_size
  409. if group == 1:
  410. # Assume format is NCHW
  411. return (ichl, ochl, kh, kw)
  412. assert (
  413. ichl % group == 0 and ochl % group == 0
  414. ), "invalid config: input_channels={} output_channels={} group={}".format(
  415. ichl, ochl, group
  416. )
  417. # Assume format is NCHW
  418. return (group, ichl // group, ochl // group, kh, kw)
  419. def _infer_bias_shape(self):
  420. # Assume format is NCHW
  421. return (1, self.out_channels, 1, 1)
  422. def forward(self, inp):
  423. return conv_transpose2d(
  424. inp,
  425. self.weight,
  426. self.bias,
  427. self.stride,
  428. self.padding,
  429. self.dilation,
  430. self.groups,
  431. self.conv_mode,
  432. self.compute_mode,
  433. )
  434. class LocalConv2d(Conv2d):
  435. r"""
  436. Applies a spatial convolution with untied kernels over an groupped channeled input 4D tensor.
  437. It is also known as the locally connected layer.
  438. :param in_channels: number of input channels.
  439. :param out_channels: number of output channels.
  440. :param input_height: the height of the input images.
  441. :param input_width: the width of the input images.
  442. :param kernel_size: size of weight on spatial dimensions. If kernel_size is
  443. an :class:`int`, the actual kernel size would be
  444. `(kernel_size, kernel_size)`. Default: 1
  445. :param stride: stride of the 2D convolution operation. Default: 1
  446. :param padding: size of the paddings added to the input on both sides of its
  447. spatial dimensions. Only zero-padding is supported. Default: 0
  448. :param groups: number of groups into which the input and output channels are divided,
  449. so as to perform a "grouped convolution". When ``groups`` is not 1,
  450. ``in_channels`` and ``out_channels`` must be divisible by ``groups``.
  451. The shape of weight is `(groups, output_height, output_width,
  452. in_channels // groups, *kernel_size, out_channels // groups)`.
  453. """
  454. def __init__(
  455. self,
  456. in_channels: int,
  457. out_channels: int,
  458. input_height: int,
  459. input_width: int,
  460. kernel_size: Union[int, Tuple[int, int]],
  461. stride: Union[int, Tuple[int, int]] = 1,
  462. padding: Union[int, Tuple[int, int]] = 0,
  463. dilation: Union[int, Tuple[int, int]] = 1,
  464. groups: int = 1,
  465. conv_mode: str = "CROSS_CORRELATION",
  466. **kwargs
  467. ):
  468. self.input_height = input_height
  469. self.input_width = input_width
  470. super().__init__(
  471. in_channels,
  472. out_channels,
  473. kernel_size,
  474. stride,
  475. padding,
  476. dilation,
  477. groups,
  478. bias=False,
  479. **kwargs,
  480. )
  481. def _infer_weight_shape(self):
  482. group = self.groups
  483. output_height = (
  484. self.input_height + self.padding[0] * 2 - self.kernel_size[0]
  485. ) // self.stride[0] + 1
  486. output_width = (
  487. self.input_width + self.padding[1] * 2 - self.kernel_size[1]
  488. ) // self.stride[1] + 1
  489. # Assume format is NCHW
  490. return (
  491. group,
  492. output_height,
  493. output_width,
  494. self.in_channels // group,
  495. self.kernel_size[0],
  496. self.kernel_size[1],
  497. self.out_channels // group,
  498. )
  499. def forward(self, inp):
  500. return local_conv2d(
  501. inp,
  502. self.weight,
  503. None,
  504. self.stride,
  505. self.padding,
  506. self.dilation,
  507. self.conv_mode,
  508. )
  509. class ConvRelu2d(Conv2d):
  510. r"""
  511. A fused :class:`~.Module` including Conv2d and relu. Could be replaced
  512. with :class:`~.QATModule` version :class:`~.qat.conv.ConvRelu2d` using
  513. :func:`~.quantize.quantize_qat`.
  514. """
  515. def forward(self, inp):
  516. return relu(self.calc_conv(inp, self.weight, self.bias))
  517. class DeformableConv2d(_ConvNd):
  518. """
  519. Deformable Convolution.
  520. :param in_channels: number of input channels.
  521. :param out_channels: number of output channels.
  522. :param kernel_size: size of weight on spatial dimensions. If kernel_size is
  523. an :class:`int`, the actual kernel size would be
  524. `(kernel_size, kernel_size)`. Default: 1
  525. :param stride: stride of the 2D convolution operation. Default: 1
  526. :param padding: size of the paddings added to the input on both sides of its
  527. spatial dimensions. Only zero-padding is supported. Default: 0
  528. :param dilation: dilation of the 2D convolution operation. Default: 1
  529. :param groups: number of groups into which the input and output channels are divided,
  530. so as to perform a "grouped convolution". When ``groups`` is not 1,
  531. ``in_channels`` and ``out_channels`` must be divisible by ``groups``,
  532. and there would be an extra dimension at the beginning of the weight's
  533. shape. Specifically, the shape of weight would be `(groups,
  534. out_channel // groups, in_channels // groups, *kernel_size)`.
  535. :param bias: whether to add a bias onto the result of convolution. Default:
  536. True
  537. :param conv_mode: Supports `CROSS_CORRELATION`. Default:
  538. `CROSS_CORRELATION`
  539. :param compute_mode: When set to "DEFAULT", no special requirements will be
  540. placed on the precision of intermediate results. When set to "FLOAT32",
  541. "Float32" would be used for accumulator and intermediate result, but only
  542. effective when input and output are of float16 dtype.
  543. """
  544. def __init__(
  545. self,
  546. in_channels: int,
  547. out_channels: int,
  548. kernel_size: Union[int, Tuple[int, int]],
  549. stride: Union[int, Tuple[int, int]] = 1,
  550. padding: Union[int, Tuple[int, int]] = 0,
  551. dilation: Union[int, Tuple[int, int]] = 1,
  552. groups: int = 1,
  553. bias: bool = True,
  554. conv_mode: str = "CROSS_CORRELATION",
  555. compute_mode: str = "DEFAULT",
  556. ):
  557. kernel_size = _pair_nonzero(kernel_size)
  558. stride = _pair_nonzero(stride)
  559. padding = _pair(padding)
  560. dilation = _pair_nonzero(dilation)
  561. self.conv_mode = conv_mode
  562. self.compute_mode = compute_mode
  563. super().__init__(
  564. in_channels,
  565. out_channels,
  566. kernel_size,
  567. stride,
  568. padding,
  569. dilation,
  570. groups,
  571. bias,
  572. )
  573. def _get_fanin(self):
  574. kh, kw = self.kernel_size
  575. ic = self.in_channels
  576. return kh * kw * ic
  577. def _infer_weight_shape(self):
  578. group = self.groups
  579. ichl = self.in_channels
  580. ochl = self.out_channels
  581. kh, kw = self.kernel_size
  582. if group == 1:
  583. # Assume format is NCHW
  584. return (ochl, ichl, kh, kw)
  585. assert (
  586. ichl % group == 0 and ochl % group == 0
  587. ), "invalid config: input_channels={} output_channels={} group={}".format(
  588. ichl, ochl, group
  589. )
  590. # Assume format is NCHW
  591. return (group, ochl // group, ichl // group, kh, kw)
  592. def _infer_bias_shape(self):
  593. # Assume format is NCHW
  594. return (1, self.out_channels, 1, 1)
  595. def calc_conv(self, inp, weight, offset, mask, bias):
  596. return deformable_conv2d(
  597. inp,
  598. weight,
  599. offset,
  600. mask,
  601. bias,
  602. self.stride,
  603. self.padding,
  604. self.dilation,
  605. self.groups,
  606. self.conv_mode,
  607. self.compute_mode,
  608. )
  609. def forward(self, inp, offset, mask):
  610. return self.calc_conv(inp, self.weight, offset, mask, self.bias)

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台