You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843
  1. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  2. #
  3. # Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  4. #
  5. # Unless required by applicable law or agreed to in writing,
  6. # software distributed under the License is distributed on an
  7. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  8. from abc import abstractmethod
  9. from typing import Tuple, Union
  10. import numpy as np
  11. from ..functional import (
  12. conv1d,
  13. conv2d,
  14. conv3d,
  15. conv_transpose2d,
  16. deformable_conv2d,
  17. local_conv2d,
  18. relu,
  19. )
  20. from ..tensor import Parameter
  21. from ..utils.tuple_function import _pair, _pair_nonzero, _triple, _triple_nonzero
  22. from . import init
  23. from .module import Module
  24. class _ConvNd(Module):
  25. """base class for convolution modules, including transposed conv"""
  26. def __init__(
  27. self,
  28. in_channels: int,
  29. out_channels: int,
  30. kernel_size: Union[int, Tuple[int, int]],
  31. stride: Union[int, Tuple[int, int]],
  32. padding: Union[int, Tuple[int, int]],
  33. dilation: Union[int, Tuple[int, int]],
  34. groups: int,
  35. bias: bool = True,
  36. **kwargs
  37. ):
  38. super().__init__(**kwargs)
  39. if in_channels % groups != 0:
  40. raise ValueError("in_channels must be divisible by groups")
  41. if out_channels % groups != 0:
  42. raise ValueError("out_channels must be divisible by groups")
  43. self.in_channels = in_channels
  44. self.out_channels = out_channels
  45. self.kernel_size = kernel_size
  46. self.stride = stride
  47. self.padding = padding
  48. self.dilation = dilation
  49. self.groups = groups
  50. self.weight = Parameter(np.zeros(self._infer_weight_shape(), dtype=np.float32))
  51. self.bias = None
  52. if bias:
  53. self.bias = Parameter(np.zeros(self._infer_bias_shape(), dtype=np.float32))
  54. self.reset_parameters()
  55. @abstractmethod
  56. def _get_fanin(self):
  57. pass
  58. def reset_parameters(self) -> None:
  59. fanin = self._get_fanin()
  60. std = np.sqrt(1 / fanin)
  61. init.normal_(self.weight, 0.0, std)
  62. if self.bias is not None:
  63. init.zeros_(self.bias)
  64. @abstractmethod
  65. def _infer_weight_shape(self):
  66. pass
  67. @abstractmethod
  68. def _infer_bias_shape(self):
  69. pass
  70. def _module_info_string(self):
  71. s = "{in_channels}, {out_channels}, kernel_size={kernel_size}"
  72. if self.stride != (1,) * len(self.stride):
  73. s += ", stride={stride}"
  74. if self.padding != (0,) * len(self.padding):
  75. s += ", padding={padding}"
  76. if self.dilation != (1,) * len(self.dilation):
  77. s += ", dilation={dilation}"
  78. if self.groups != 1:
  79. s += ", groups={groups}"
  80. if self.bias is None:
  81. s += ", bias=False"
  82. return s.format(**self.__dict__)
  83. class Conv1d(_ConvNd):
  84. r"""
  85. Applies a 1D convolution over an input tensor.
  86. For instance, given an input of the size :math:`(N, C_{\text{in}}, H)`,
  87. this layer generates an output of the size
  88. :math:`(N, C_{\text{out}}, H_{\text{out}})` through the
  89. process described as below:
  90. .. math::
  91. \text{out}(N_i, C_{\text{out}_j}) = \text{bias}(C_{\text{out}_j}) +
  92. \sum_{k = 0}^{C_{\text{in}} - 1} \text{weight}(C_{\text{out}_j}, k) \star \text{input}(N_i, k)
  93. where :math:`\star` is the valid 1D cross-correlation operator,
  94. :math:`N` is batch size, :math:`C` denotes number of channels, and
  95. :math:`H` is length of 1D data element.
  96. When `groups == in_channels` and `out_channels == K * in_channels`,
  97. where K is a positive integer, this operation is also known as depthwise
  98. convolution.
  99. In other words, for an input of size :math:`(N, C_{in}, H_{in})`,
  100. a depthwise convolution with a depthwise multiplier `K`, can be constructed
  101. by arguments :math:`(in\_channels=C_{in}, out\_channels=C_{in} \times K, ..., groups=C_{in})`.
  102. :param in_channels: number of input channels.
  103. :param out_channels: number of output channels.
  104. :param kernel_size: size of weight on spatial dimensions. If kernel_size is
  105. an :class:`int`, the actual kernel size would be
  106. `(kernel_size, kernel_size)`. Default: 1
  107. :param stride: stride of the 1D convolution operation. Default: 1
  108. :param padding: size of the paddings added to the input on both sides of its
  109. spatial dimensions. Only zero-padding is supported. Default: 0
  110. :param dilation: dilation of the 1D convolution operation. Default: 1
  111. :param groups: number of groups into which the input and output channels are divided,
  112. so as to perform a "grouped convolution". When ``groups`` is not 1,
  113. ``in_channels`` and ``out_channels`` must be divisible by ``groups``,
  114. and there would be an extra dimension at the beginning of the weight's
  115. shape. Specifically, the shape of weight would be `(groups,
  116. out_channel // groups, in_channels // groups, *kernel_size)`.
  117. :param bias: whether to add a bias onto the result of convolution. Default:
  118. True
  119. :param conv_mode: Supports `CROSS_CORRELATION`. Default:
  120. `CROSS_CORRELATION`
  121. :param compute_mode: When set to "DEFAULT", no special requirements will be
  122. placed on the precision of intermediate results. When set to "FLOAT32",
  123. "Float32" would be used for accumulator and intermediate result, but only
  124. effective when input and output are of float16 dtype.
  125. Examples:
  126. .. testcode::
  127. import numpy as np
  128. import megengine as mge
  129. import megengine.module as M
  130. m = M.Conv1d(in_channels=3, out_channels=1, kernel_size=3)
  131. inp = mge.tensor(np.arange(0, 24).astype("float32").reshape(2, 3, 4))
  132. oup = m(inp)
  133. print(oup.numpy().shape)
  134. Outputs:
  135. .. testoutput::
  136. (2, 1, 2)
  137. """
  138. def __init__(
  139. self,
  140. in_channels: int,
  141. out_channels: int,
  142. kernel_size: int,
  143. stride: int = 1,
  144. padding: int = 0,
  145. dilation: int = 1,
  146. groups: int = 1,
  147. bias: bool = True,
  148. conv_mode: str = "CROSS_CORRELATION",
  149. compute_mode: str = "DEFAULT",
  150. **kwargs
  151. ):
  152. kernel_size = kernel_size
  153. stride = stride
  154. padding = padding
  155. dilation = dilation
  156. self.conv_mode = conv_mode
  157. self.compute_mode = compute_mode
  158. super().__init__(
  159. in_channels,
  160. out_channels,
  161. kernel_size,
  162. stride,
  163. padding,
  164. dilation,
  165. groups,
  166. bias,
  167. **kwargs,
  168. )
  169. def _get_fanin(self):
  170. kh = self.kernel_size
  171. ic = self.in_channels
  172. return kh * ic
  173. def _infer_weight_shape(self):
  174. group = self.groups
  175. ichl = self.in_channels
  176. ochl = self.out_channels
  177. kh = self.kernel_size
  178. if group == 1:
  179. # Assume format is NCH(W=1)
  180. return (ochl, ichl, kh)
  181. assert (
  182. ichl % group == 0 and ochl % group == 0
  183. ), "invalid config: input_channels={} output_channels={} group={}".format(
  184. ichl, ochl, group
  185. )
  186. # Assume format is NCH(W=1)
  187. return (group, ochl // group, ichl // group, kh)
  188. def _infer_bias_shape(self):
  189. # Assume format is NCH(W=1)
  190. return (1, self.out_channels, 1)
  191. def calc_conv(self, inp, weight, bias):
  192. return conv1d(
  193. inp,
  194. weight,
  195. bias,
  196. self.stride,
  197. self.padding,
  198. self.dilation,
  199. self.groups,
  200. self.conv_mode,
  201. self.compute_mode,
  202. )
  203. def forward(self, inp):
  204. return self.calc_conv(inp, self.weight, self.bias)
  205. class Conv2d(_ConvNd):
  206. r"""
  207. Applies a 2D convolution over an input tensor.
  208. For instance, given an input of the size :math:`(N, C_{\text{in}}, H, W)`,
  209. this layer generates an output of the size
  210. :math:`(N, C_{\text{out}}, H_{\text{out}}, W_{\text{out}})` through the
  211. process described as below:
  212. .. math::
  213. \text{out}(N_i, C_{\text{out}_j}) = \text{bias}(C_{\text{out}_j}) +
  214. \sum_{k = 0}^{C_{\text{in}} - 1} \text{weight}(C_{\text{out}_j}, k) \star \text{input}(N_i, k)
  215. where :math:`\star` is the valid 2D cross-correlation operator,
  216. :math:`N` is batch size, :math:`C` denotes number of channels,
  217. :math:`H` is height of input planes in pixels, and :math:`W` is
  218. width in pixels.
  219. In general, output feature maps' shapes can be inferred as follows:
  220. input: :math:`(N, C_{\text{in}}, H_{\text{in}}, W_{\text{in}})`
  221. output: :math:`(N, C_{\text{out}}, H_{\text{out}}, W_{\text{out}})` where
  222. .. math::
  223. \text{H}_{out} = \lfloor \frac{\text{H}_{in} + 2 * \text{padding[0]} -
  224. \text{dilation[0]} * (\text{kernel_size[0]} - 1) - 1}{\text{stride[0]}} + 1 \rfloor
  225. .. math::
  226. \text{W}_{out} = \lfloor \frac{\text{W}_{in} + 2 * \text{padding[1]} -
  227. \text{dilation[1]} * (\text{kernel_size[1]} - 1) - 1}{\text{stride[1]}} + 1 \rfloor
  228. When `groups == in_channels` and `out_channels == K * in_channels`,
  229. where K is a positive integer, this operation is also known as depthwise
  230. convolution.
  231. In other words, for an input of size :math:`(N, C_{in}, H_{in}, W_{in})`,
  232. a depthwise convolution with a depthwise multiplier `K`, can be constructed
  233. by arguments :math:`(in\_channels=C_{in}, out\_channels=C_{in} \times K, ..., groups=C_{in})`.
  234. :param in_channels: number of input channels.
  235. :param out_channels: number of output channels.
  236. :param kernel_size: size of weight on spatial dimensions. If kernel_size is
  237. an :class:`int`, the actual kernel size would be
  238. `(kernel_size, kernel_size)`. Default: 1
  239. :param stride: stride of the 2D convolution operation. Default: 1
  240. :param padding: size of the paddings added to the input on both sides of its
  241. spatial dimensions. Only zero-padding is supported. Default: 0
  242. :param dilation: dilation of the 2D convolution operation. Default: 1
  243. :param groups: number of groups into which the input and output channels are divided,
  244. so as to perform a "grouped convolution". When ``groups`` is not 1,
  245. ``in_channels`` and ``out_channels`` must be divisible by ``groups``,
  246. and there would be an extra dimension at the beginning of the weight's
  247. shape. Specifically, the shape of weight would be `(groups,
  248. out_channel // groups, in_channels // groups, *kernel_size)`.
  249. :param bias: whether to add a bias onto the result of convolution. Default:
  250. True
  251. :param conv_mode: Supports `CROSS_CORRELATION`. Default:
  252. `CROSS_CORRELATION`
  253. :param compute_mode: When set to "DEFAULT", no special requirements will be
  254. placed on the precision of intermediate results. When set to "FLOAT32",
  255. "Float32" would be used for accumulator and intermediate result, but only
  256. effective when input and output are of float16 dtype.
  257. Examples:
  258. .. testcode::
  259. import numpy as np
  260. import megengine as mge
  261. import megengine.module as M
  262. m = M.Conv2d(in_channels=3, out_channels=1, kernel_size=3)
  263. inp = mge.tensor(np.arange(0, 96).astype("float32").reshape(2, 3, 4, 4))
  264. oup = m(inp)
  265. print(oup.numpy().shape)
  266. Outputs:
  267. .. testoutput::
  268. (2, 1, 2, 2)
  269. """
  270. def __init__(
  271. self,
  272. in_channels: int,
  273. out_channels: int,
  274. kernel_size: Union[int, Tuple[int, int]],
  275. stride: Union[int, Tuple[int, int]] = 1,
  276. padding: Union[int, Tuple[int, int]] = 0,
  277. dilation: Union[int, Tuple[int, int]] = 1,
  278. groups: int = 1,
  279. bias: bool = True,
  280. conv_mode: str = "CROSS_CORRELATION",
  281. compute_mode: str = "DEFAULT",
  282. **kwargs
  283. ):
  284. kernel_size = _pair_nonzero(kernel_size)
  285. stride = _pair_nonzero(stride)
  286. padding = _pair(padding)
  287. dilation = _pair_nonzero(dilation)
  288. self.conv_mode = conv_mode
  289. self.compute_mode = compute_mode
  290. super().__init__(
  291. in_channels,
  292. out_channels,
  293. kernel_size,
  294. stride,
  295. padding,
  296. dilation,
  297. groups,
  298. bias,
  299. **kwargs,
  300. )
  301. def _get_fanin(self):
  302. kh, kw = self.kernel_size
  303. ic = self.in_channels
  304. return kh * kw * ic
  305. def _infer_weight_shape(self):
  306. group = self.groups
  307. ichl = self.in_channels
  308. ochl = self.out_channels
  309. kh, kw = self.kernel_size
  310. if group == 1:
  311. # Assume format is NCHW
  312. return (ochl, ichl, kh, kw)
  313. assert (
  314. ichl % group == 0 and ochl % group == 0
  315. ), "invalid config: input_channels={} output_channels={} group={}".format(
  316. ichl, ochl, group
  317. )
  318. # Assume format is NCHW
  319. return (group, ochl // group, ichl // group, kh, kw)
  320. def _infer_bias_shape(self):
  321. # Assume format is NCHW
  322. return (1, self.out_channels, 1, 1)
  323. def calc_conv(self, inp, weight, bias):
  324. return conv2d(
  325. inp,
  326. weight,
  327. bias,
  328. self.stride,
  329. self.padding,
  330. self.dilation,
  331. self.groups,
  332. self.conv_mode,
  333. self.compute_mode,
  334. )
  335. def forward(self, inp):
  336. return self.calc_conv(inp, self.weight, self.bias)
  337. class Conv3d(_ConvNd):
  338. r"""
  339. Applies a 3D convolution over an input tensor.
  340. For instance, given an input of the size :math:`(N, C_{\text{in}}, T, H, W)`,
  341. this layer generates an output of the size
  342. :math:`(N, C_{\text{out}}, T_{\text{out}}}, H_{\text{out}}}, W_{\text{out}}})` through the
  343. process described as below:
  344. .. math::
  345. \text{out}(N_i, C_{\text{out}_j}) = \text{bias}(C_{\text{out}_j}) +
  346. \sum_{k = 0}^{C_{\text{in}} - 1} \text{weight}(C_{\text{out}_j}, k) \star \text{input}(N_i, k)
  347. where :math:`\star` is the valid 3D cross-correlation operator,
  348. :math:`N` is batch size, :math:`C` denotes number of channels
  349. When `groups == in_channels` and `out_channels == K * in_channels`,
  350. where K is a positive integer, this operation is also known as depthwise
  351. convolution.
  352. In other words, for an input of size :math:`(N, C_{in}, T_{int}, H_{in}, W_{in})`,
  353. a depthwise convolution with a depthwise multiplier `K`, can be constructed
  354. by arguments :math:`(in\_channels=C_{in}, out\_channels=C_{in} \times K, ..., groups=C_{in})`.
  355. :param in_channels: number of input channels.
  356. :param out_channels: number of output channels.
  357. :param kernel_size: size of weight on spatial dimensions. If kernel_size is
  358. an :class:`int`, the actual kernel size would be
  359. `(kernel_size, kernel_size, kernel_size)`. Default: 1
  360. :param stride: stride of the 3D convolution operation. Default: 1
  361. :param padding: size of the paddings added to the input on both sides of its
  362. spatial dimensions. Only zero-padding is supported. Default: 0
  363. :param dilation: dilation of the 3D convolution operation. Default: 1
  364. :param groups: number of groups into which the input and output channels are divided, so as to perform a "grouped convolution". When ``groups`` is not 1,
  365. ``in_channels`` and ``out_channels`` must be divisible by ``groups``,
  366. and there would be an extra dimension at the beginning of the weight's
  367. shape. Specifically, the shape of weight would be `(groups,
  368. out_channel // groups, in_channels // groups, *kernel_size)`.
  369. :param bias: whether to add a bias onto the result of convolution. Default:
  370. True
  371. :param conv_mode: Supports `CROSS_CORRELATION`. Default:
  372. `CROSS_CORRELATION`
  373. Examples:
  374. .. testcode::
  375. import numpy as np
  376. import megengine as mge
  377. import megengine.module as M
  378. m = M.Conv3d(in_channels=3, out_channels=1, kernel_size=3)
  379. inp = mge.tensor(np.arange(0, 384).astype("float32").reshape(2, 3, 4, 4, 4))
  380. oup = m(inp)
  381. print(oup.numpy().shape)
  382. Outputs:
  383. .. testoutput::
  384. (2, 1, 2, 2, 2)
  385. """
  386. def __init__(
  387. self,
  388. in_channels: int,
  389. out_channels: int,
  390. kernel_size: Union[int, Tuple[int, int, int]],
  391. stride: Union[int, Tuple[int, int, int]] = 1,
  392. padding: Union[int, Tuple[int, int, int]] = 0,
  393. dilation: Union[int, Tuple[int, int, int]] = 1,
  394. groups: int = 1,
  395. bias: bool = True,
  396. conv_mode: str = "CROSS_CORRELATION",
  397. ):
  398. kernel_size = _triple_nonzero(kernel_size)
  399. stride = _triple_nonzero(stride)
  400. padding = _triple(padding)
  401. dilation = _triple_nonzero(dilation)
  402. self.conv_mode = conv_mode
  403. super().__init__(
  404. in_channels,
  405. out_channels,
  406. kernel_size,
  407. stride,
  408. padding,
  409. dilation,
  410. groups,
  411. bias,
  412. )
  413. def _get_fanin(self):
  414. kt, kh, kw = self.kernel_size
  415. ic = self.in_channels
  416. return kt * kh * kw * ic
  417. def _infer_weight_shape(self):
  418. group = self.groups
  419. ichl = self.in_channels
  420. ochl = self.out_channels
  421. kt, kh, kw = self.kernel_size
  422. if group == 1:
  423. # Assume format is NCTHW
  424. return (ochl, ichl, kt, kh, kw)
  425. assert (
  426. ichl % group == 0 and ochl % group == 0
  427. ), "invalid config: input_channels={} output_channels={} group={}".format(
  428. ichl, ochl, group
  429. )
  430. # Assume format is NCTHW
  431. return (group, ochl // group, ichl // group, kt, kh, kw)
  432. def _infer_bias_shape(self):
  433. # Assume format is NCTHW
  434. return (1, self.out_channels, 1, 1, 1)
  435. def calc_conv(self, inp, weight, bias):
  436. return conv3d(
  437. inp,
  438. weight,
  439. bias,
  440. self.stride,
  441. self.padding,
  442. self.dilation,
  443. self.groups,
  444. self.conv_mode,
  445. )
  446. def forward(self, inp):
  447. return self.calc_conv(inp, self.weight, self.bias)
  448. class ConvTranspose2d(_ConvNd):
  449. r"""
  450. Applies a 2D transposed convolution over an input tensor.
  451. This module is also known as a deconvolution or a fractionally-strided convolution.
  452. :class:`ConvTranspose2d` can be seen as the gradient of :class:`Conv2d` operation
  453. with respect to its input.
  454. Convolution usually reduces the size of input, while transposed convolution works
  455. the opposite way, transforming a smaller input to a larger output while preserving the
  456. connectivity pattern.
  457. :param in_channels: number of input channels.
  458. :param out_channels: number of output channels.
  459. :param kernel_size: size of weight on spatial dimensions. If ``kernel_size`` is
  460. an :class:`int`, the actual kernel size would be
  461. ``(kernel_size, kernel_size)``. Default: 1
  462. :param stride: stride of the 2D convolution operation. Default: 1
  463. :param padding: size of the paddings added to the input on both sides of its
  464. spatial dimensions. Only zero-padding is supported. Default: 0
  465. :param dilation: dilation of the 2D convolution operation. Default: 1
  466. :param groups: number of groups into which the input and output channels are divided,
  467. so as to perform a "grouped convolution". When ``groups`` is not 1,
  468. ``in_channels`` and ``out_channels`` must be divisible by ``groups``,
  469. and there would be an extra dimension at the beginning of the weight's
  470. shape. Specifically, the shape of weight would be ``(groups,
  471. out_channels // groups, in_channels // groups, *kernel_size)``. Default: 1
  472. :param bias: wether to add a bias onto the result of convolution. Default:
  473. True
  474. :param conv_mode: Supports `CROSS_CORRELATION`. Default:
  475. `CROSS_CORRELATION`
  476. :param compute_mode: When set to "DEFAULT", no special requirements will be
  477. placed on the precision of intermediate results. When set to "FLOAT32",
  478. "Float32" would be used for accumulator and intermediate result, but only
  479. effective when input and output are of float16 dtype.
  480. """
  481. def __init__(
  482. self,
  483. in_channels: int,
  484. out_channels: int,
  485. kernel_size: Union[int, Tuple[int, int]],
  486. stride: Union[int, Tuple[int, int]] = 1,
  487. padding: Union[int, Tuple[int, int]] = 0,
  488. dilation: Union[int, Tuple[int, int]] = 1,
  489. groups: int = 1,
  490. bias: bool = True,
  491. conv_mode: str = "CROSS_CORRELATION",
  492. compute_mode: str = "DEFAULT",
  493. **kwargs
  494. ):
  495. kernel_size = _pair_nonzero(kernel_size)
  496. stride = _pair_nonzero(stride)
  497. padding = _pair(padding)
  498. dilation = _pair_nonzero(dilation)
  499. self.conv_mode = conv_mode
  500. self.compute_mode = compute_mode
  501. super().__init__(
  502. in_channels,
  503. out_channels,
  504. kernel_size,
  505. stride,
  506. padding,
  507. dilation,
  508. groups,
  509. bias,
  510. **kwargs,
  511. )
  512. def _get_fanin(self):
  513. kh, kw = self.kernel_size
  514. oc = self.out_channels
  515. return kh * kw * oc
  516. def _infer_weight_shape(self):
  517. group = self.groups
  518. ichl = self.in_channels
  519. ochl = self.out_channels
  520. kh, kw = self.kernel_size
  521. if group == 1:
  522. # Assume format is NCHW
  523. return (ichl, ochl, kh, kw)
  524. assert (
  525. ichl % group == 0 and ochl % group == 0
  526. ), "invalid config: input_channels={} output_channels={} group={}".format(
  527. ichl, ochl, group
  528. )
  529. # Assume format is NCHW
  530. return (group, ichl // group, ochl // group, kh, kw)
  531. def _infer_bias_shape(self):
  532. # Assume format is NCHW
  533. return (1, self.out_channels, 1, 1)
  534. def forward(self, inp):
  535. return conv_transpose2d(
  536. inp,
  537. self.weight,
  538. self.bias,
  539. self.stride,
  540. self.padding,
  541. self.dilation,
  542. self.groups,
  543. self.conv_mode,
  544. self.compute_mode,
  545. )
  546. class LocalConv2d(Conv2d):
  547. r"""
  548. Applies a spatial convolution with untied kernels over an groupped channeled input 4D tensor.
  549. It is also known as the locally connected layer.
  550. :param in_channels: number of input channels.
  551. :param out_channels: number of output channels.
  552. :param input_height: the height of the input images.
  553. :param input_width: the width of the input images.
  554. :param kernel_size: size of weight on spatial dimensions. If kernel_size is
  555. an :class:`int`, the actual kernel size would be
  556. `(kernel_size, kernel_size)`. Default: 1
  557. :param stride: stride of the 2D convolution operation. Default: 1
  558. :param padding: size of the paddings added to the input on both sides of its
  559. spatial dimensions. Only zero-padding is supported. Default: 0
  560. :param groups: number of groups into which the input and output channels are divided,
  561. so as to perform a "grouped convolution". When ``groups`` is not 1,
  562. ``in_channels`` and ``out_channels`` must be divisible by ``groups``.
  563. The shape of weight is `(groups, output_height, output_width,
  564. in_channels // groups, *kernel_size, out_channels // groups)`.
  565. """
  566. def __init__(
  567. self,
  568. in_channels: int,
  569. out_channels: int,
  570. input_height: int,
  571. input_width: int,
  572. kernel_size: Union[int, Tuple[int, int]],
  573. stride: Union[int, Tuple[int, int]] = 1,
  574. padding: Union[int, Tuple[int, int]] = 0,
  575. dilation: Union[int, Tuple[int, int]] = 1,
  576. groups: int = 1,
  577. conv_mode: str = "CROSS_CORRELATION",
  578. **kwargs
  579. ):
  580. self.input_height = input_height
  581. self.input_width = input_width
  582. super().__init__(
  583. in_channels,
  584. out_channels,
  585. kernel_size,
  586. stride,
  587. padding,
  588. dilation,
  589. groups,
  590. bias=False,
  591. **kwargs,
  592. )
  593. def _infer_weight_shape(self):
  594. group = self.groups
  595. output_height = (
  596. self.input_height + self.padding[0] * 2 - self.kernel_size[0]
  597. ) // self.stride[0] + 1
  598. output_width = (
  599. self.input_width + self.padding[1] * 2 - self.kernel_size[1]
  600. ) // self.stride[1] + 1
  601. # Assume format is NCHW
  602. return (
  603. group,
  604. output_height,
  605. output_width,
  606. self.in_channels // group,
  607. self.kernel_size[0],
  608. self.kernel_size[1],
  609. self.out_channels // group,
  610. )
  611. def forward(self, inp):
  612. return local_conv2d(
  613. inp,
  614. self.weight,
  615. None,
  616. self.stride,
  617. self.padding,
  618. self.dilation,
  619. self.conv_mode,
  620. )
  621. class ConvRelu2d(Conv2d):
  622. r"""
  623. A fused :class:`~.Module` including :class:`~.module.Conv2d` and :func:`~.relu`.
  624. Could be replaced with :class:`~.QATModule` version :class:`~.qat.ConvRelu2d` using :func:`~.quantize.quantize_qat`.
  625. """
  626. def forward(self, inp):
  627. return relu(self.calc_conv(inp, self.weight, self.bias))
  628. class DeformableConv2d(_ConvNd):
  629. """
  630. Deformable Convolution.
  631. :param in_channels: number of input channels.
  632. :param out_channels: number of output channels.
  633. :param kernel_size: size of weight on spatial dimensions. If kernel_size is
  634. an :class:`int`, the actual kernel size would be
  635. `(kernel_size, kernel_size)`. Default: 1
  636. :param stride: stride of the 2D convolution operation. Default: 1
  637. :param padding: size of the paddings added to the input on both sides of its
  638. spatial dimensions. Only zero-padding is supported. Default: 0
  639. :param dilation: dilation of the 2D convolution operation. Default: 1
  640. :param groups: number of groups into which the input and output channels are divided,
  641. so as to perform a "grouped convolution". When ``groups`` is not 1,
  642. ``in_channels`` and ``out_channels`` must be divisible by ``groups``,
  643. and there would be an extra dimension at the beginning of the weight's
  644. shape. Specifically, the shape of weight would be `(groups,
  645. out_channel // groups, in_channels // groups, *kernel_size)`.
  646. :param bias: whether to add a bias onto the result of convolution. Default:
  647. True
  648. :param conv_mode: Supports `CROSS_CORRELATION`. Default:
  649. `CROSS_CORRELATION`
  650. :param compute_mode: When set to "DEFAULT", no special requirements will be
  651. placed on the precision of intermediate results. When set to "FLOAT32",
  652. "Float32" would be used for accumulator and intermediate result, but only
  653. effective when input and output are of float16 dtype.
  654. """
  655. def __init__(
  656. self,
  657. in_channels: int,
  658. out_channels: int,
  659. kernel_size: Union[int, Tuple[int, int]],
  660. stride: Union[int, Tuple[int, int]] = 1,
  661. padding: Union[int, Tuple[int, int]] = 0,
  662. dilation: Union[int, Tuple[int, int]] = 1,
  663. groups: int = 1,
  664. bias: bool = True,
  665. conv_mode: str = "CROSS_CORRELATION",
  666. compute_mode: str = "DEFAULT",
  667. **kwargs
  668. ):
  669. kernel_size = _pair_nonzero(kernel_size)
  670. stride = _pair_nonzero(stride)
  671. padding = _pair(padding)
  672. dilation = _pair_nonzero(dilation)
  673. self.conv_mode = conv_mode
  674. self.compute_mode = compute_mode
  675. super().__init__(
  676. in_channels,
  677. out_channels,
  678. kernel_size,
  679. stride,
  680. padding,
  681. dilation,
  682. groups,
  683. bias,
  684. **kwargs,
  685. )
  686. def _get_fanin(self):
  687. kh, kw = self.kernel_size
  688. ic = self.in_channels
  689. return kh * kw * ic
  690. def _infer_weight_shape(self):
  691. group = self.groups
  692. ichl = self.in_channels
  693. ochl = self.out_channels
  694. kh, kw = self.kernel_size
  695. if group == 1:
  696. # Assume format is NCHW
  697. return (ochl, ichl, kh, kw)
  698. assert (
  699. ichl % group == 0 and ochl % group == 0
  700. ), "invalid config: input_channels={} output_channels={} group={}".format(
  701. ichl, ochl, group
  702. )
  703. # Assume format is NCHW
  704. return (group, ochl // group, ichl // group, kh, kw)
  705. def _infer_bias_shape(self):
  706. # Assume format is NCHW
  707. return (1, self.out_channels, 1, 1)
  708. def calc_conv(self, inp, weight, offset, mask, bias):
  709. return deformable_conv2d(
  710. inp,
  711. weight,
  712. offset,
  713. mask,
  714. bias,
  715. self.stride,
  716. self.padding,
  717. self.dilation,
  718. self.groups,
  719. self.conv_mode,
  720. self.compute_mode,
  721. )
  722. def forward(self, inp, offset, mask):
  723. return self.calc_conv(inp, self.weight, offset, mask, self.bias)

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台