You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

conv.py 36 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984
  1. from abc import abstractmethod
  2. from typing import Tuple, Union
  3. import numpy as np
  4. from ..functional import (
  5. conv1d,
  6. conv2d,
  7. conv3d,
  8. conv_transpose2d,
  9. conv_transpose3d,
  10. deformable_conv2d,
  11. local_conv2d,
  12. pad,
  13. relu,
  14. )
  15. from ..tensor import Parameter
  16. from ..utils.tuple_function import _pair, _pair_nonzero, _triple, _triple_nonzero
  17. from . import init
  18. from .module import Module
  19. class _ConvNd(Module):
  20. """base class for convolution modules, including transposed conv"""
  21. def __init__(
  22. self,
  23. in_channels: int,
  24. out_channels: int,
  25. kernel_size: Union[int, Tuple[int, int]],
  26. stride: Union[int, Tuple[int, int]],
  27. padding: Union[int, Tuple[int, int]],
  28. output_padding: Union[int, Tuple[int, int]],
  29. dilation: Union[int, Tuple[int, int]],
  30. groups: int,
  31. bias: bool = True,
  32. **kwargs
  33. ):
  34. super().__init__(**kwargs)
  35. if in_channels % groups != 0:
  36. raise ValueError("in_channels must be divisible by groups")
  37. if out_channels % groups != 0:
  38. raise ValueError("out_channels must be divisible by groups")
  39. self.in_channels = in_channels
  40. self.out_channels = out_channels
  41. self.kernel_size = kernel_size
  42. self.stride = stride
  43. self.padding = padding
  44. self.output_padding = output_padding
  45. self.dilation = dilation
  46. self.groups = groups
  47. self.weight = Parameter(np.zeros(self._infer_weight_shape(), dtype=np.float32))
  48. self.bias = None
  49. if bias:
  50. self.bias = Parameter(np.zeros(self._infer_bias_shape(), dtype=np.float32))
  51. self.reset_parameters()
  52. @abstractmethod
  53. def _get_fanin(self):
  54. pass
  55. def reset_parameters(self) -> None:
  56. fanin = self._get_fanin()
  57. std = np.sqrt(1 / fanin)
  58. init.normal_(self.weight, 0.0, std)
  59. if self.bias is not None:
  60. init.zeros_(self.bias)
  61. @abstractmethod
  62. def _infer_weight_shape(self):
  63. pass
  64. @abstractmethod
  65. def _infer_bias_shape(self):
  66. pass
  67. def _module_info_string(self):
  68. s = "{in_channels}, {out_channels}, kernel_size={kernel_size}"
  69. if self.stride != (1,) * len(self.stride):
  70. s += ", stride={stride}"
  71. if self.padding != (0,) * len(self.padding):
  72. s += ", padding={padding}"
  73. if self.dilation != (1,) * len(self.dilation):
  74. s += ", dilation={dilation}"
  75. if self.groups != 1:
  76. s += ", groups={groups}"
  77. if self.bias is None:
  78. s += ", bias=False"
  79. return s.format(**self.__dict__)
  80. class Conv1d(_ConvNd):
  81. r"""Applies a 1D convolution over an input tensor.
  82. For instance, given an input of the size :math:`(N, C_{\text{in}}, H)`,
  83. this layer generates an output of the size
  84. :math:`(N, C_{\text{out}}, H_{\text{out}})` through the
  85. process described as below:
  86. .. math::
  87. \text{out}(N_i, C_{\text{out}_j}) = \text{bias}(C_{\text{out}_j}) +
  88. \sum_{k = 0}^{C_{\text{in}} - 1} \text{weight}(C_{\text{out}_j}, k) \star \text{input}(N_i, k)
  89. where :math:`\star` is the valid 1D cross-correlation operator,
  90. :math:`N` is batch size, :math:`C` denotes number of channels, and
  91. :math:`H` is length of 1D data element.
  92. When `groups == in_channels` and `out_channels == K * in_channels`,
  93. where K is a positive integer, this operation is also known as depthwise
  94. convolution.
  95. In other words, for an input of size :math:`(N, C_{in}, H_{in})`,
  96. a depthwise convolution with a depthwise multiplier `K`, can be constructed
  97. by arguments :math:`(in\_channels=C_{in}, out\_channels=C_{in} \times K, ..., groups=C_{in})`.
  98. Args:
  99. in_channels: number of input channels.
  100. out_channels: number of output channels.
  101. kernel_size: size of weight on spatial dimensions.
  102. stride: stride of the 1D convolution operation.
  103. padding: size of the paddings added to the input on both sides of its
  104. spatial dimensions. Default: 0
  105. dilation: dilation of the 1D convolution operation. Default: 1
  106. groups: number of groups to divide input and output channels into,
  107. so as to perform a "grouped convolution". When ``groups`` is not 1,
  108. ``in_channels`` and ``out_channels`` must be divisible by ``groups``,
  109. and the shape of weight should be ``(groups, out_channel // groups,
  110. in_channels // groups, kernel_size)``. Default: 1
  111. bias: whether to add a bias onto the result of convolution. Default: True
  112. conv_mode: Supports `cross_correlation`. Default: `cross_correlation`
  113. compute_mode: When set to "default", no special requirements will be
  114. placed on the precision of intermediate results. When set to "float32",
  115. "float32" would be used for accumulator and intermediate result, but only
  116. effective when input and output are of float16 dtype.
  117. padding_mode: "zeros", "reflect" or "replicate". Default: "zeros".
  118. Refer to :class:`~.module.padding.Pad` for more information.
  119. Note:
  120. * ``weight`` usually has shape ``(out_channels, in_channels, kernel_size)`` ,
  121. if groups is not 1, shape will be ``(groups, out_channels // groups, in_channels // groups, kernel_size)``
  122. * ``bias`` usually has shape ``(1, out_channels, 1)``
  123. Examples:
  124. >>> import numpy as np
  125. >>> m = M.Conv1d(in_channels=3, out_channels=1, kernel_size=3)
  126. >>> inp = mge.tensor(np.arange(0, 24).astype("float32").reshape(2, 3, 4))
  127. >>> oup = m(inp)
  128. >>> oup.numpy().shape
  129. (2, 1, 2)
  130. """
  131. def __init__(
  132. self,
  133. in_channels: int,
  134. out_channels: int,
  135. kernel_size: int,
  136. stride: int = 1,
  137. padding: int = 0,
  138. dilation: int = 1,
  139. groups: int = 1,
  140. bias: bool = True,
  141. conv_mode: str = "cross_correlation",
  142. compute_mode: str = "default",
  143. padding_mode: str = "zeros",
  144. **kwargs
  145. ):
  146. kernel_size = kernel_size
  147. stride = stride
  148. padding = padding
  149. dilation = dilation
  150. self.conv_mode = conv_mode
  151. self.compute_mode = compute_mode
  152. self.padding_mode = padding_mode
  153. super().__init__(
  154. in_channels,
  155. out_channels,
  156. kernel_size,
  157. stride,
  158. padding,
  159. 0,
  160. dilation,
  161. groups,
  162. bias,
  163. **kwargs,
  164. )
  165. def _get_fanin(self):
  166. kh = self.kernel_size
  167. ic = self.in_channels
  168. return kh * ic
  169. def _infer_weight_shape(self):
  170. group = self.groups
  171. ichl = self.in_channels
  172. ochl = self.out_channels
  173. kh = self.kernel_size
  174. if group == 1:
  175. # Assume format is NCH(W=1)
  176. return (ochl, ichl, kh)
  177. assert (
  178. ichl % group == 0 and ochl % group == 0
  179. ), "invalid config: in_channels={} out_channels={} group={}".format(
  180. ichl, ochl, group
  181. )
  182. # Assume format is NCH(W=1)
  183. return (group, ochl // group, ichl // group, kh)
  184. def _infer_bias_shape(self):
  185. # Assume format is NCH(W=1)
  186. return (1, self.out_channels, 1)
  187. def get_pad_witdth(self):
  188. return ((0, 0), (0, 0), (self.padding, self.padding))
  189. def calc_conv(self, inp, weight, bias):
  190. assert self.padding_mode in [
  191. "zeros",
  192. "reflect",
  193. "replicate",
  194. ]
  195. if self.padding_mode != "zeros":
  196. return conv1d(
  197. pad(inp, self.get_pad_witdth(), self.padding_mode),
  198. weight,
  199. bias,
  200. self.stride,
  201. 0,
  202. self.dilation,
  203. self.groups,
  204. self.conv_mode,
  205. self.compute_mode,
  206. )
  207. return conv1d(
  208. inp,
  209. weight,
  210. bias,
  211. self.stride,
  212. self.padding,
  213. self.dilation,
  214. self.groups,
  215. self.conv_mode,
  216. self.compute_mode,
  217. )
  218. def forward(self, inp):
  219. return self.calc_conv(inp, self.weight, self.bias)
  220. class Conv2d(_ConvNd):
  221. r"""Applies a 2D convolution over an input tensor.
  222. For instance, given an input of the size :math:`(N, C_{\text{in}}, H, W)`,
  223. this layer generates an output of the size
  224. :math:`(N, C_{\text{out}}, H_{\text{out}}, W_{\text{out}})` through the
  225. process described as below:
  226. .. math::
  227. \text{out}(N_i, C_{\text{out}_j}) = \text{bias}(C_{\text{out}_j}) +
  228. \sum_{k = 0}^{C_{\text{in}} - 1} \text{weight}(C_{\text{out}_j}, k) \star \text{input}(N_i, k)
  229. where :math:`\star` is the valid 2D cross-correlation operator,
  230. :math:`N` is batch size, :math:`C` denotes number of channels,
  231. :math:`H` is height of input planes in pixels, and :math:`W` is
  232. width in pixels.
  233. In general, output feature maps' shapes can be inferred as follows:
  234. input: :math:`(N, C_{\text{in}}, H_{\text{in}}, W_{\text{in}})`
  235. output: :math:`(N, C_{\text{out}}, H_{\text{out}}, W_{\text{out}})` where
  236. .. math::
  237. \text{H}_{out} = \lfloor \frac{\text{H}_{in} + 2 * \text{padding[0]} -
  238. \text{dilation[0]} * (\text{kernel_size[0]} - 1) - 1}{\text{stride[0]}} + 1 \rfloor
  239. .. math::
  240. \text{W}_{out} = \lfloor \frac{\text{W}_{in} + 2 * \text{padding[1]} -
  241. \text{dilation[1]} * (\text{kernel_size[1]} - 1) - 1}{\text{stride[1]}} + 1 \rfloor
  242. When `groups == in_channels` and `out_channels == K * in_channels`,
  243. where K is a positive integer, this operation is also known as depthwise
  244. convolution.
  245. In other words, for an input of size :math:`(N, C_{in}, H_{in}, W_{in})`,
  246. a depthwise convolution with a depthwise multiplier `K`, can be constructed
  247. by arguments :math:`(in\_channels=C_{in}, out\_channels=C_{in} \times K, ..., groups=C_{in})`.
  248. Args:
  249. in_channels: number of input channels.
  250. out_channels: number of output channels.
  251. kernel_size: size of weight on spatial dimensions. If kernel_size is
  252. an :class:`int`, the actual kernel size would be
  253. ``(kernel_size, kernel_size)``.
  254. stride: stride of the 2D convolution operation. Default: 1
  255. padding: size of the paddings added to the input on both sides of its
  256. spatial dimensions. Default: 0
  257. dilation: dilation of the 2D convolution operation. Default: 1
  258. groups: number of groups into which the input and output channels are divided,
  259. so as to perform a ``grouped convolution``. When ``groups`` is not 1,
  260. ``in_channels`` and ``out_channels`` must be divisible by ``groups``,
  261. and the shape of weight should be ``(groups, out_channel // groups,
  262. in_channels // groups, height, width)``. Default: 1
  263. bias: whether to add a bias onto the result of convolution. Default: True
  264. conv_mode: Supports `cross_correlation`. Default: `cross_correlation`
  265. compute_mode: When set to "default", no special requirements will be
  266. placed on the precision of intermediate results. When set to "float32",
  267. "float32" would be used for accumulator and intermediate result, but only
  268. effective when input and output are of float16 dtype.
  269. padding_mode: "zeros", "reflect" or "replicate". Default: "zeros".
  270. Refer to :class:`~.module.padding.Pad` for more information.
  271. Note:
  272. * ``weight`` usually has shape ``(out_channels, in_channels, height, width)`` ,
  273. if groups is not 1, shape will be ``(groups, out_channels // groups, in_channels // groups, height, width)``
  274. * ``bias`` usually has shape ``(1, out_channels, *1)``
  275. Examples:
  276. >>> import numpy as np
  277. >>> m = M.Conv2d(in_channels=3, out_channels=1, kernel_size=3)
  278. >>> inp = mge.tensor(np.arange(0, 96).astype("float32").reshape(2, 3, 4, 4))
  279. >>> oup = m(inp)
  280. >>> oup.numpy().shape
  281. (2, 1, 2, 2)
  282. """
  283. def __init__(
  284. self,
  285. in_channels: int,
  286. out_channels: int,
  287. kernel_size: Union[int, Tuple[int, int]],
  288. stride: Union[int, Tuple[int, int]] = 1,
  289. padding: Union[int, Tuple[int, int]] = 0,
  290. dilation: Union[int, Tuple[int, int]] = 1,
  291. groups: int = 1,
  292. bias: bool = True,
  293. conv_mode: str = "cross_correlation",
  294. compute_mode: str = "default",
  295. padding_mode: str = "zeros",
  296. **kwargs
  297. ):
  298. kernel_size = _pair_nonzero(kernel_size)
  299. stride = _pair_nonzero(stride)
  300. padding = _pair(padding)
  301. dilation = _pair_nonzero(dilation)
  302. self.conv_mode = conv_mode
  303. self.compute_mode = compute_mode
  304. self.padding_mode = padding_mode
  305. super().__init__(
  306. in_channels,
  307. out_channels,
  308. kernel_size,
  309. stride,
  310. padding,
  311. 0,
  312. dilation,
  313. groups,
  314. bias,
  315. **kwargs,
  316. )
  317. def _get_fanin(self):
  318. kh, kw = self.kernel_size
  319. ic = self.in_channels
  320. return kh * kw * ic
  321. def _infer_weight_shape(self):
  322. group = self.groups
  323. ichl = self.in_channels
  324. ochl = self.out_channels
  325. kh, kw = self.kernel_size
  326. if group == 1:
  327. # Assume format is NCHW
  328. return (ochl, ichl, kh, kw)
  329. assert (
  330. ichl % group == 0 and ochl % group == 0
  331. ), "invalid config: in_channels={} out_channels={} group={}".format(
  332. ichl, ochl, group
  333. )
  334. # Assume format is NCHW
  335. return (group, ochl // group, ichl // group, kh, kw)
  336. def _infer_bias_shape(self):
  337. # Assume format is NCHW
  338. return (1, self.out_channels, 1, 1)
  339. def get_pad_witdth(self):
  340. return (
  341. (0, 0),
  342. (0, 0),
  343. (self.padding[0], self.padding[0]),
  344. (self.padding[1], self.padding[1]),
  345. )
  346. def calc_conv(self, inp, weight, bias):
  347. assert self.padding_mode in [
  348. "zeros",
  349. "reflect",
  350. "replicate",
  351. ]
  352. if self.padding_mode != "zeros":
  353. return conv2d(
  354. pad(inp, self.get_pad_witdth(), self.padding_mode),
  355. weight,
  356. bias,
  357. self.stride,
  358. 0,
  359. self.dilation,
  360. self.groups,
  361. self.conv_mode,
  362. self.compute_mode,
  363. )
  364. return conv2d(
  365. inp,
  366. weight,
  367. bias,
  368. self.stride,
  369. self.padding,
  370. self.dilation,
  371. self.groups,
  372. self.conv_mode,
  373. self.compute_mode,
  374. )
  375. def forward(self, inp):
  376. return self.calc_conv(inp, self.weight, self.bias)
  377. class Conv3d(_ConvNd):
  378. r"""Applies a 3D convolution over an input tensor.
  379. For instance, given an input of the size :math:`(N, C_{\text{in}}, T, H, W)`,
  380. this layer generates an output of the size
  381. :math:`(N, C_{\text{out}}, T_{\text{out}}, H_{\text{out}}, W_{\text{out}})` through the
  382. process described as below:
  383. .. math::
  384. \text{out}(N_i, C_{\text{out}_j}) = \text{bias}(C_{\text{out}_j}) +
  385. \sum_{k = 0}^{C_{\text{in}} - 1} \text{weight}(C_{\text{out}_j}, k) \star \text{input}(N_i, k)
  386. where :math:`\star` is the valid 3D cross-correlation operator,
  387. :math:`N` is batch size, :math:`C` denotes number of channels.
  388. When `groups == in_channels` and `out_channels == K * in_channels`,
  389. where K is a positive integer, this operation is also known as depthwise
  390. convolution.
  391. In other words, for an input of size :math:`(N, C_{in}, T_{int}, H_{in}, W_{in})`,
  392. a depthwise convolution with a depthwise multiplier `K`, can be constructed
  393. by arguments :math:`(in\_channels=C_{in}, out\_channels=C_{in} \times K, ..., groups=C_{in})`.
  394. Args:
  395. in_channels: number of input channels.
  396. out_channels: number of output channels.
  397. kernel_size: size of weight on spatial dimensions. If kernel_size is
  398. an :class:`int`, the actual kernel size would be
  399. `(kernel_size, kernel_size, kernel_size)`.
  400. stride: stride of the 3D convolution operation. Default: 1
  401. padding: size of the paddings added to the input on both sides of its
  402. spatial dimensions. Only zero-padding is supported. Default: 0
  403. dilation: dilation of the 3D convolution operation. Default: 1
  404. groups: number of groups into which the input and output channels are divided,
  405. so as to perform a ``grouped convolution``. When ``groups`` is not 1,
  406. ``in_channels`` and ``out_channels`` must be divisible by ``groups``,
  407. and the shape of weight should be ``(groups, out_channel // groups,
  408. in_channels // groups, depth, height, width)``. Default: 1
  409. bias: whether to add a bias onto the result of convolution. Default: True
  410. conv_mode: Supports `cross_correlation`. Default: `cross_correlation`
  411. Note:
  412. * ``weight`` usually has shape ``(out_channels, in_channels, depth, height, width)`` ,
  413. if groups is not 1, shape will be ``(groups, out_channels // groups, in_channels // groups, depth, height, width)``
  414. * ``bias`` usually has shape ``(1, out_channels, *1)``
  415. Examples:
  416. >>> import numpy as np
  417. >>> m = M.Conv3d(in_channels=3, out_channels=1, kernel_size=3)
  418. >>> inp = mge.tensor(np.arange(0, 384).astype("float32").reshape(2, 3, 4, 4, 4))
  419. >>> oup = m(inp)
  420. >>> oup.numpy().shape
  421. (2, 1, 2, 2, 2)
  422. """
  423. def __init__(
  424. self,
  425. in_channels: int,
  426. out_channels: int,
  427. kernel_size: Union[int, Tuple[int, int, int]],
  428. stride: Union[int, Tuple[int, int, int]] = 1,
  429. padding: Union[int, Tuple[int, int, int]] = 0,
  430. dilation: Union[int, Tuple[int, int, int]] = 1,
  431. groups: int = 1,
  432. bias: bool = True,
  433. conv_mode: str = "cross_correlation",
  434. ):
  435. kernel_size = _triple_nonzero(kernel_size)
  436. stride = _triple_nonzero(stride)
  437. padding = _triple(padding)
  438. dilation = _triple_nonzero(dilation)
  439. self.conv_mode = conv_mode
  440. super().__init__(
  441. in_channels,
  442. out_channels,
  443. kernel_size,
  444. stride,
  445. padding,
  446. 0,
  447. dilation,
  448. groups,
  449. bias,
  450. )
  451. def _get_fanin(self):
  452. kt, kh, kw = self.kernel_size
  453. ic = self.in_channels
  454. return kt * kh * kw * ic
  455. def _infer_weight_shape(self):
  456. group = self.groups
  457. ichl = self.in_channels
  458. ochl = self.out_channels
  459. kt, kh, kw = self.kernel_size
  460. if group == 1:
  461. # Assume format is NCTHW
  462. return (ochl, ichl, kt, kh, kw)
  463. assert (
  464. ichl % group == 0 and ochl % group == 0
  465. ), "invalid config: in_channels={} out_channels={} group={}".format(
  466. ichl, ochl, group
  467. )
  468. # Assume format is NCTHW
  469. return (group, ochl // group, ichl // group, kt, kh, kw)
  470. def _infer_bias_shape(self):
  471. # Assume format is NCTHW
  472. return (1, self.out_channels, 1, 1, 1)
  473. def calc_conv(self, inp, weight, bias):
  474. return conv3d(
  475. inp,
  476. weight,
  477. bias,
  478. self.stride,
  479. self.padding,
  480. self.dilation,
  481. self.groups,
  482. self.conv_mode,
  483. )
  484. def forward(self, inp):
  485. return self.calc_conv(inp, self.weight, self.bias)
  486. class ConvTranspose2d(_ConvNd):
  487. r"""Applies a 2D transposed convolution over an input tensor.
  488. This module is also known as a deconvolution or a fractionally-strided convolution.
  489. :class:`ConvTranspose2d` can be seen as the gradient of :class:`Conv2d` operation
  490. with respect to its input.
  491. Convolution usually reduces the size of input, while transposed convolution works
  492. the opposite way, transforming a smaller input to a larger output while preserving the
  493. connectivity pattern.
  494. Args:
  495. in_channels: number of input channels.
  496. out_channels: number of output channels.
  497. kernel_size: size of weight on spatial dimensions. If ``kernel_size`` is
  498. an :class:`int`, the actual kernel size would be
  499. ``(kernel_size, kernel_size)``.
  500. stride: stride of the 2D convolution operation. Default: 1
  501. padding: size of the paddings added to the input on both sides of its
  502. spatial dimensions. Only zero-padding is supported. Default: 0
  503. output_padding: size of paddings appended to output. Default: 0
  504. dilation: dilation of the 2D convolution operation. Default: 1
  505. groups: number of groups into which the input and output channels are divided,
  506. so as to perform a ``grouped convolution``. When ``groups`` is not 1,
  507. ``in_channels`` and ``out_channels`` must be divisible by groups,
  508. and the shape of weight should be ``(groups, in_channels // groups,
  509. out_channels // groups, height, width)``. Default: 1
  510. bias: wether to add a bias onto the result of convolution. Default: True
  511. conv_mode: Supports `cross_correlation`. Default: `cross_correlation`
  512. compute_mode: When set to "default", no special requirements will be
  513. placed on the precision of intermediate results. When set to "float32",
  514. "float32" would be used for accumulator and intermediate result, but only
  515. effective when input and output are of float16 dtype.
  516. Note:
  517. * ``weight`` usually has shape ``(in_channels, out_channels, height, width)`` ,
  518. if groups is not 1, shape will be ``(groups, in_channels // groups, out_channels // groups, height, width)``
  519. * ``bias`` usually has shape ``(1, out_channels, *1)``
  520. """
  521. output_padding = 0
  522. def __init__(
  523. self,
  524. in_channels: int,
  525. out_channels: int,
  526. kernel_size: Union[int, Tuple[int, int]],
  527. stride: Union[int, Tuple[int, int]] = 1,
  528. padding: Union[int, Tuple[int, int]] = 0,
  529. output_padding: Union[int, Tuple[int, int]] = 0,
  530. dilation: Union[int, Tuple[int, int]] = 1,
  531. groups: int = 1,
  532. bias: bool = True,
  533. conv_mode: str = "cross_correlation",
  534. compute_mode: str = "default",
  535. **kwargs
  536. ):
  537. kernel_size = _pair_nonzero(kernel_size)
  538. stride = _pair_nonzero(stride)
  539. padding = _pair(padding)
  540. output_padding = _pair(output_padding)
  541. dilation = _pair_nonzero(dilation)
  542. self.conv_mode = conv_mode
  543. self.compute_mode = compute_mode
  544. super().__init__(
  545. in_channels,
  546. out_channels,
  547. kernel_size,
  548. stride,
  549. padding,
  550. output_padding,
  551. dilation,
  552. groups,
  553. bias,
  554. **kwargs,
  555. )
  556. def _get_fanin(self):
  557. kh, kw = self.kernel_size
  558. oc = self.out_channels
  559. return kh * kw * oc
  560. def _infer_weight_shape(self):
  561. group = self.groups
  562. ichl = self.in_channels
  563. ochl = self.out_channels
  564. kh, kw = self.kernel_size
  565. if group == 1:
  566. # Assume format is NCHW
  567. return (ichl, ochl, kh, kw)
  568. assert (
  569. ichl % group == 0 and ochl % group == 0
  570. ), "invalid config: in_channels={} out_channels={} group={}".format(
  571. ichl, ochl, group
  572. )
  573. # Assume format is NCHW
  574. return (group, ichl // group, ochl // group, kh, kw)
  575. def _infer_bias_shape(self):
  576. # Assume format is NCHW
  577. return (1, self.out_channels, 1, 1)
  578. def calc_conv_transpose2d(self, inp, weight, bias):
  579. return conv_transpose2d(
  580. inp,
  581. weight,
  582. bias,
  583. self.stride,
  584. self.padding,
  585. self.output_padding,
  586. self.dilation,
  587. self.groups,
  588. self.conv_mode,
  589. self.compute_mode,
  590. )
  591. def forward(self, inp):
  592. return self.calc_conv_transpose2d(inp, self.weight, self.bias)
  593. class LocalConv2d(Conv2d):
  594. r"""Applies a spatial convolution with untied kernels over an groupped channeled input 4D tensor.
  595. It is also known as the locally connected layer.
  596. Args:
  597. in_channels: number of input channels.
  598. out_channels: number of output channels.
  599. input_height: the height of the input images.
  600. input_width: the width of the input images.
  601. kernel_size: size of weight on spatial dimensions. If kernel_size is
  602. an :class:`int`, the actual kernel size would be
  603. ``(kernel_size, kernel_size)``.
  604. stride: stride of the 2D convolution operation. Default: 1
  605. padding: size of the paddings added to the input on both sides of its
  606. spatial dimensions. Only zero-padding is supported. Default: 0
  607. dilation: dilation of the 2D convolution operation. Default: 1
  608. groups: number of groups into which the input and output channels are divided,
  609. so as to perform a "grouped convolution". When ``groups`` is not 1,
  610. ``in_channels`` and ``out_channels`` must be divisible by ``groups``. Default: 1
  611. Note:
  612. * ``weight`` usually has shape ``(out_height, out_width, in_channels, height, width, in_channels)`` ,
  613. if groups is not 1, shape will be ``(groups, out_height, out_width, in_channels // groups, height, width, out_channels // groups)``
  614. * ``bias`` usually has shape ``(1, out_channels, *1)``
  615. """
  616. def __init__(
  617. self,
  618. in_channels: int,
  619. out_channels: int,
  620. input_height: int,
  621. input_width: int,
  622. kernel_size: Union[int, Tuple[int, int]],
  623. stride: Union[int, Tuple[int, int]] = 1,
  624. padding: Union[int, Tuple[int, int]] = 0,
  625. dilation: Union[int, Tuple[int, int]] = 1,
  626. groups: int = 1,
  627. conv_mode: str = "cross_correlation",
  628. **kwargs
  629. ):
  630. self.input_height = input_height
  631. self.input_width = input_width
  632. super().__init__(
  633. in_channels,
  634. out_channels,
  635. kernel_size,
  636. stride,
  637. padding,
  638. dilation,
  639. groups,
  640. bias=False,
  641. **kwargs,
  642. )
  643. def _infer_weight_shape(self):
  644. group = self.groups
  645. out_height = (
  646. self.input_height + self.padding[0] * 2 - self.kernel_size[0]
  647. ) // self.stride[0] + 1
  648. out_width = (
  649. self.input_width + self.padding[1] * 2 - self.kernel_size[1]
  650. ) // self.stride[1] + 1
  651. # Assume format is NCHW
  652. return (
  653. group,
  654. out_height,
  655. out_width,
  656. self.in_channels // group,
  657. self.kernel_size[0],
  658. self.kernel_size[1],
  659. self.out_channels // group,
  660. )
  661. def forward(self, inp):
  662. return local_conv2d(
  663. inp,
  664. self.weight,
  665. None,
  666. self.stride,
  667. self.padding,
  668. self.dilation,
  669. self.conv_mode,
  670. )
  671. class ConvRelu2d(Conv2d):
  672. r"""A fused :class:`~.Module` including :class:`~.module.Conv2d` and :func:`~.relu`.
  673. Could be replaced with :class:`~.QATModule` version :class:`~.qat.ConvRelu2d` using :func:`~.quantize.quantize_qat`.
  674. """
  675. def forward(self, inp):
  676. return relu(self.calc_conv(inp, self.weight, self.bias))
  677. class DeformableConv2d(_ConvNd):
  678. r"""Deformable Convolution.
  679. Args:
  680. in_channels: number of input channels.
  681. out_channels: number of output channels.
  682. kernel_size: size of weight on spatial dimensions. If kernel_size is
  683. an :class:`int`, the actual kernel size would be
  684. ``(kernel_size, kernel_size)``.
  685. stride: stride of the 2D convolution operation. Default: 1
  686. padding: size of the paddings added to the input on both sides of its
  687. spatial dimensions. Only zero-padding is supported. Default: 0
  688. dilation: dilation of the 2D convolution operation. Default: 1
  689. groups: number of groups into which the input and output channels are divided,
  690. so as to perform a ``grouped convolution``. When ``groups`` is not 1,
  691. ``in_channels`` and ``out_channels`` must be divisible by groups,
  692. and the shape of weight should be ``(groups, out_channel // groups,
  693. in_channels // groups, height, width)``. Default: 1
  694. bias: whether to add a bias onto the result of convolution. Default: True
  695. conv_mode: Supports `cross_correlation`. Default: `cross_correlation`
  696. compute_mode: When set to "default", no special requirements will be
  697. placed on the precision of intermediate results. When set to "float32",
  698. "float32" would be used for accumulator and intermediate result, but only
  699. effective when input and output are of float16 dtype.
  700. Note:
  701. * ``weight`` usually has shape ``(out_channels, in_channels, height, width)`` ,
  702. if groups is not 1, shape will be ``(groups, out_channels // groups, in_channels // groups, height, width)``
  703. * ``bias`` usually has shape ``(1, out_channels, *1)``
  704. """
  705. def __init__(
  706. self,
  707. in_channels: int,
  708. out_channels: int,
  709. kernel_size: Union[int, Tuple[int, int]],
  710. stride: Union[int, Tuple[int, int]] = 1,
  711. padding: Union[int, Tuple[int, int]] = 0,
  712. dilation: Union[int, Tuple[int, int]] = 1,
  713. groups: int = 1,
  714. bias: bool = True,
  715. conv_mode: str = "cross_correlation",
  716. compute_mode: str = "default",
  717. **kwargs
  718. ):
  719. kernel_size = _pair_nonzero(kernel_size)
  720. stride = _pair_nonzero(stride)
  721. padding = _pair(padding)
  722. dilation = _pair_nonzero(dilation)
  723. self.conv_mode = conv_mode
  724. self.compute_mode = compute_mode
  725. super().__init__(
  726. in_channels,
  727. out_channels,
  728. kernel_size,
  729. stride,
  730. padding,
  731. 0,
  732. dilation,
  733. groups,
  734. bias,
  735. **kwargs,
  736. )
  737. def _get_fanin(self):
  738. kh, kw = self.kernel_size
  739. ic = self.in_channels
  740. return kh * kw * ic
  741. def _infer_weight_shape(self):
  742. group = self.groups
  743. ichl = self.in_channels
  744. ochl = self.out_channels
  745. kh, kw = self.kernel_size
  746. if group == 1:
  747. # Assume format is NCHW
  748. return (ochl, ichl, kh, kw)
  749. assert (
  750. ichl % group == 0 and ochl % group == 0
  751. ), "invalid config: in_channels={} out_channels={} group={}".format(
  752. ichl, ochl, group
  753. )
  754. # Assume format is NCHW
  755. return (group, ochl // group, ichl // group, kh, kw)
  756. def _infer_bias_shape(self):
  757. # Assume format is NCHW
  758. return (1, self.out_channels, 1, 1)
  759. def calc_conv(self, inp, weight, offset, mask, bias):
  760. return deformable_conv2d(
  761. inp,
  762. weight,
  763. offset,
  764. mask,
  765. bias,
  766. self.stride,
  767. self.padding,
  768. self.dilation,
  769. self.groups,
  770. self.conv_mode,
  771. self.compute_mode,
  772. )
  773. def forward(self, inp, offset, mask):
  774. return self.calc_conv(inp, self.weight, offset, mask, self.bias)
  775. class ConvTranspose3d(_ConvNd):
  776. r"""Applies a 3D transposed convolution over an input tensor.
  777. Only support the case that groups = 1 and conv_mode = "cross_correlation".
  778. :class:`ConvTranspose3d` can be seen as the gradient of :class:`Conv3d` operation
  779. with respect to its input.
  780. Convolution3D usually reduces the size of input, while transposed convolution3d
  781. works the opposite way, transforming a smaller input to a larger output while
  782. preserving the connectivity pattern.
  783. Args:
  784. in_channels: number of input channels.
  785. out_channels: number of output channels.
  786. kernel_size: size of weight on spatial dimensions. If ``kernel_size`` is
  787. an :class:`int`, the actual kernel size would be
  788. ``(kernel_size, kernel_size, kernel_size)``.
  789. stride: stride of the 3D convolution operation. Default: 1
  790. padding: size of the paddings added to the input on all sides of its
  791. spatial dimensions. Only zero-padding is supported. Default: 0
  792. output_padding: size of paddings appended to output. Default: 0
  793. dilation: dilation of the 3D convolution operation. Default: 1
  794. groups: number of groups into which the input and output channels are divided,
  795. so as to perform a ``grouped convolution``. When ``groups`` is not 1,
  796. ``in_channels`` and ``out_channels`` must be divisible by groups,
  797. and the shape of weight should be ``(groups, in_channels // groups,
  798. out_channels // groups, depth, height, width)``. Default: 1
  799. bias: wether to add a bias onto the result of convolution. Default: True
  800. Note:
  801. * ``weight`` usually has shape ``(in_channels, out_channels, depth, height, width)`` .
  802. * ``bias`` usually has shape ``(1, out_channels, *1)``
  803. """
  804. output_padding = 0
  805. def __init__(
  806. self,
  807. in_channels: int,
  808. out_channels: int,
  809. kernel_size: Union[int, Tuple[int, int, int]],
  810. stride: Union[int, Tuple[int, int, int]] = 1,
  811. padding: Union[int, Tuple[int, int, int]] = 0,
  812. output_padding: Union[int, Tuple[int, int, int]] = 0,
  813. dilation: Union[int, Tuple[int, int, int]] = 1,
  814. groups: int = 1,
  815. bias: bool = True,
  816. ):
  817. kernel_size = _triple_nonzero(kernel_size)
  818. stride = _triple_nonzero(stride)
  819. padding = _triple(padding)
  820. dilation = _triple_nonzero(dilation)
  821. super().__init__(
  822. in_channels=in_channels,
  823. out_channels=out_channels,
  824. kernel_size=kernel_size,
  825. stride=stride,
  826. padding=padding,
  827. output_padding=output_padding,
  828. dilation=dilation,
  829. groups=groups,
  830. bias=bias,
  831. )
  832. def _get_fanin(self):
  833. kt, kh, kw = self.kernel_size
  834. ic = self.in_channels
  835. return kt * kh * kw * ic
  836. def _infer_weight_shape(self):
  837. group = self.groups
  838. ichl = self.in_channels
  839. ochl = self.out_channels
  840. kt, kh, kw = self.kernel_size
  841. if group == 1:
  842. # Assume format is NCHW
  843. return (ichl, ochl, kt, kh, kw)
  844. assert (
  845. ichl % group == 0 and ochl % group == 0
  846. ), "invalid config: in_channels={} out_channels={} group={}".format(
  847. ichl, ochl, group
  848. )
  849. # Assume format is NCHW
  850. return (group, ichl // group, ochl // group, kt, kh, kw)
  851. def _infer_bias_shape(self):
  852. # Assume format is NCTHW
  853. return (1, self.out_channels, 1, 1, 1)
  854. def forward(self, inp):
  855. return conv_transpose3d(
  856. inp,
  857. self.weight,
  858. self.bias,
  859. self.stride,
  860. self.padding,
  861. self.output_padding,
  862. self.dilation,
  863. )