You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

conv.py 35 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960
  1. from abc import abstractmethod
  2. from typing import Tuple, Union
  3. import numpy as np
  4. from ..functional import (
  5. conv1d,
  6. conv2d,
  7. conv3d,
  8. conv_transpose2d,
  9. conv_transpose3d,
  10. deformable_conv2d,
  11. local_conv2d,
  12. pad,
  13. relu,
  14. )
  15. from ..tensor import Parameter
  16. from ..utils.tuple_function import _pair, _pair_nonzero, _triple, _triple_nonzero
  17. from . import init
  18. from .module import Module
  19. class _ConvNd(Module):
  20. """base class for convolution modules, including transposed conv"""
  21. def __init__(
  22. self,
  23. in_channels: int,
  24. out_channels: int,
  25. kernel_size: Union[int, Tuple[int, int]],
  26. stride: Union[int, Tuple[int, int]],
  27. padding: Union[int, Tuple[int, int]],
  28. dilation: Union[int, Tuple[int, int]],
  29. groups: int,
  30. bias: bool = True,
  31. **kwargs
  32. ):
  33. super().__init__(**kwargs)
  34. if in_channels % groups != 0:
  35. raise ValueError("in_channels must be divisible by groups")
  36. if out_channels % groups != 0:
  37. raise ValueError("out_channels must be divisible by groups")
  38. self.in_channels = in_channels
  39. self.out_channels = out_channels
  40. self.kernel_size = kernel_size
  41. self.stride = stride
  42. self.padding = padding
  43. self.dilation = dilation
  44. self.groups = groups
  45. self.weight = Parameter(np.zeros(self._infer_weight_shape(), dtype=np.float32))
  46. self.bias = None
  47. if bias:
  48. self.bias = Parameter(np.zeros(self._infer_bias_shape(), dtype=np.float32))
  49. self.reset_parameters()
  50. @abstractmethod
  51. def _get_fanin(self):
  52. pass
  53. def reset_parameters(self) -> None:
  54. fanin = self._get_fanin()
  55. std = np.sqrt(1 / fanin)
  56. init.normal_(self.weight, 0.0, std)
  57. if self.bias is not None:
  58. init.zeros_(self.bias)
  59. @abstractmethod
  60. def _infer_weight_shape(self):
  61. pass
  62. @abstractmethod
  63. def _infer_bias_shape(self):
  64. pass
  65. def _module_info_string(self):
  66. s = "{in_channels}, {out_channels}, kernel_size={kernel_size}"
  67. if self.stride != (1,) * len(self.stride):
  68. s += ", stride={stride}"
  69. if self.padding != (0,) * len(self.padding):
  70. s += ", padding={padding}"
  71. if self.dilation != (1,) * len(self.dilation):
  72. s += ", dilation={dilation}"
  73. if self.groups != 1:
  74. s += ", groups={groups}"
  75. if self.bias is None:
  76. s += ", bias=False"
  77. return s.format(**self.__dict__)
  78. class Conv1d(_ConvNd):
  79. r"""Applies a 1D convolution over an input tensor.
  80. For instance, given an input of the size :math:`(N, C_{\text{in}}, H)`,
  81. this layer generates an output of the size
  82. :math:`(N, C_{\text{out}}, H_{\text{out}})` through the
  83. process described as below:
  84. .. math::
  85. \text{out}(N_i, C_{\text{out}_j}) = \text{bias}(C_{\text{out}_j}) +
  86. \sum_{k = 0}^{C_{\text{in}} - 1} \text{weight}(C_{\text{out}_j}, k) \star \text{input}(N_i, k)
  87. where :math:`\star` is the valid 1D cross-correlation operator,
  88. :math:`N` is batch size, :math:`C` denotes number of channels, and
  89. :math:`H` is length of 1D data element.
  90. When `groups == in_channels` and `out_channels == K * in_channels`,
  91. where K is a positive integer, this operation is also known as depthwise
  92. convolution.
  93. In other words, for an input of size :math:`(N, C_{in}, H_{in})`,
  94. a depthwise convolution with a depthwise multiplier `K`, can be constructed
  95. by arguments :math:`(in\_channels=C_{in}, out\_channels=C_{in} \times K, ..., groups=C_{in})`.
  96. Args:
  97. in_channels: number of input channels.
  98. out_channels: number of output channels.
  99. kernel_size: size of weight on spatial dimensions.
  100. stride: stride of the 1D convolution operation.
  101. padding: size of the paddings added to the input on both sides of its
  102. spatial dimensions. Default: 0
  103. dilation: dilation of the 1D convolution operation. Default: 1
  104. groups: number of groups to divide input and output channels into,
  105. so as to perform a "grouped convolution". When ``groups`` is not 1,
  106. ``in_channels`` and ``out_channels`` must be divisible by ``groups``,
  107. and the shape of weight should be ``(groups, out_channel // groups,
  108. in_channels // groups, kernel_size)``. Default: 1
  109. bias: whether to add a bias onto the result of convolution. Default: True
  110. conv_mode: Supports `cross_correlation`. Default: `cross_correlation`
  111. compute_mode: When set to "default", no special requirements will be
  112. placed on the precision of intermediate results. When set to "float32",
  113. "float32" would be used for accumulator and intermediate result, but only
  114. effective when input and output are of float16 dtype.
  115. padding_mode: "zeros", "reflect" or "replicate". Default: "zeros".
  116. Refer to :class:`~.module.padding.Pad` for more information.
  117. Note:
  118. * ``weight`` usually has shape ``(out_channels, in_channels, kernel_size)`` ,
  119. if groups is not 1, shape will be ``(groups, out_channels // groups, in_channels // groups, kernel_size)``
  120. * ``bias`` usually has shape ``(1, out_channels, 1)``
  121. Examples:
  122. >>> import numpy as np
  123. >>> m = M.Conv1d(in_channels=3, out_channels=1, kernel_size=3)
  124. >>> inp = mge.tensor(np.arange(0, 24).astype("float32").reshape(2, 3, 4))
  125. >>> oup = m(inp)
  126. >>> oup.numpy().shape
  127. (2, 1, 2)
  128. """
  129. def __init__(
  130. self,
  131. in_channels: int,
  132. out_channels: int,
  133. kernel_size: int,
  134. stride: int = 1,
  135. padding: int = 0,
  136. dilation: int = 1,
  137. groups: int = 1,
  138. bias: bool = True,
  139. conv_mode: str = "cross_correlation",
  140. compute_mode: str = "default",
  141. padding_mode: str = "zeros",
  142. **kwargs
  143. ):
  144. kernel_size = kernel_size
  145. stride = stride
  146. padding = padding
  147. dilation = dilation
  148. self.conv_mode = conv_mode
  149. self.compute_mode = compute_mode
  150. self.padding_mode = padding_mode
  151. super().__init__(
  152. in_channels,
  153. out_channels,
  154. kernel_size,
  155. stride,
  156. padding,
  157. dilation,
  158. groups,
  159. bias,
  160. **kwargs,
  161. )
  162. def _get_fanin(self):
  163. kh = self.kernel_size
  164. ic = self.in_channels
  165. return kh * ic
  166. def _infer_weight_shape(self):
  167. group = self.groups
  168. ichl = self.in_channels
  169. ochl = self.out_channels
  170. kh = self.kernel_size
  171. if group == 1:
  172. # Assume format is NCH(W=1)
  173. return (ochl, ichl, kh)
  174. assert (
  175. ichl % group == 0 and ochl % group == 0
  176. ), "invalid config: in_channels={} out_channels={} group={}".format(
  177. ichl, ochl, group
  178. )
  179. # Assume format is NCH(W=1)
  180. return (group, ochl // group, ichl // group, kh)
  181. def _infer_bias_shape(self):
  182. # Assume format is NCH(W=1)
  183. return (1, self.out_channels, 1)
  184. def get_pad_witdth(self):
  185. return ((0, 0), (0, 0), (self.padding, self.padding))
  186. def calc_conv(self, inp, weight, bias):
  187. assert self.padding_mode in [
  188. "zeros",
  189. "reflect",
  190. "replicate",
  191. ]
  192. if self.padding_mode != "zeros":
  193. return conv1d(
  194. pad(inp, self.get_pad_witdth(), self.padding_mode),
  195. weight,
  196. bias,
  197. self.stride,
  198. 0,
  199. self.dilation,
  200. self.groups,
  201. self.conv_mode,
  202. self.compute_mode,
  203. )
  204. return conv1d(
  205. inp,
  206. weight,
  207. bias,
  208. self.stride,
  209. self.padding,
  210. self.dilation,
  211. self.groups,
  212. self.conv_mode,
  213. self.compute_mode,
  214. )
  215. def forward(self, inp):
  216. return self.calc_conv(inp, self.weight, self.bias)
  217. class Conv2d(_ConvNd):
  218. r"""Applies a 2D convolution over an input tensor.
  219. For instance, given an input of the size :math:`(N, C_{\text{in}}, H, W)`,
  220. this layer generates an output of the size
  221. :math:`(N, C_{\text{out}}, H_{\text{out}}, W_{\text{out}})` through the
  222. process described as below:
  223. .. math::
  224. \text{out}(N_i, C_{\text{out}_j}) = \text{bias}(C_{\text{out}_j}) +
  225. \sum_{k = 0}^{C_{\text{in}} - 1} \text{weight}(C_{\text{out}_j}, k) \star \text{input}(N_i, k)
  226. where :math:`\star` is the valid 2D cross-correlation operator,
  227. :math:`N` is batch size, :math:`C` denotes number of channels,
  228. :math:`H` is height of input planes in pixels, and :math:`W` is
  229. width in pixels.
  230. In general, output feature maps' shapes can be inferred as follows:
  231. input: :math:`(N, C_{\text{in}}, H_{\text{in}}, W_{\text{in}})`
  232. output: :math:`(N, C_{\text{out}}, H_{\text{out}}, W_{\text{out}})` where
  233. .. math::
  234. \text{H}_{out} = \lfloor \frac{\text{H}_{in} + 2 * \text{padding[0]} -
  235. \text{dilation[0]} * (\text{kernel_size[0]} - 1) - 1}{\text{stride[0]}} + 1 \rfloor
  236. .. math::
  237. \text{W}_{out} = \lfloor \frac{\text{W}_{in} + 2 * \text{padding[1]} -
  238. \text{dilation[1]} * (\text{kernel_size[1]} - 1) - 1}{\text{stride[1]}} + 1 \rfloor
  239. When `groups == in_channels` and `out_channels == K * in_channels`,
  240. where K is a positive integer, this operation is also known as depthwise
  241. convolution.
  242. In other words, for an input of size :math:`(N, C_{in}, H_{in}, W_{in})`,
  243. a depthwise convolution with a depthwise multiplier `K`, can be constructed
  244. by arguments :math:`(in\_channels=C_{in}, out\_channels=C_{in} \times K, ..., groups=C_{in})`.
  245. Args:
  246. in_channels: number of input channels.
  247. out_channels: number of output channels.
  248. kernel_size: size of weight on spatial dimensions. If kernel_size is
  249. an :class:`int`, the actual kernel size would be
  250. ``(kernel_size, kernel_size)``.
  251. stride: stride of the 2D convolution operation. Default: 1
  252. padding: size of the paddings added to the input on both sides of its
  253. spatial dimensions. Default: 0
  254. dilation: dilation of the 2D convolution operation. Default: 1
  255. groups: number of groups into which the input and output channels are divided,
  256. so as to perform a ``grouped convolution``. When ``groups`` is not 1,
  257. ``in_channels`` and ``out_channels`` must be divisible by ``groups``,
  258. and the shape of weight should be ``(groups, out_channel // groups,
  259. in_channels // groups, height, width)``. Default: 1
  260. bias: whether to add a bias onto the result of convolution. Default: True
  261. conv_mode: Supports `cross_correlation`. Default: `cross_correlation`
  262. compute_mode: When set to "default", no special requirements will be
  263. placed on the precision of intermediate results. When set to "float32",
  264. "float32" would be used for accumulator and intermediate result, but only
  265. effective when input and output are of float16 dtype.
  266. padding_mode: "zeros", "reflect" or "replicate". Default: "zeros".
  267. Refer to :class:`~.module.padding.Pad` for more information.
  268. Note:
  269. * ``weight`` usually has shape ``(out_channels, in_channels, height, width)`` ,
  270. if groups is not 1, shape will be ``(groups, out_channels // groups, in_channels // groups, height, width)``
  271. * ``bias`` usually has shape ``(1, out_channels, *1)``
  272. Examples:
  273. >>> import numpy as np
  274. >>> m = M.Conv2d(in_channels=3, out_channels=1, kernel_size=3)
  275. >>> inp = mge.tensor(np.arange(0, 96).astype("float32").reshape(2, 3, 4, 4))
  276. >>> oup = m(inp)
  277. >>> oup.numpy().shape
  278. (2, 1, 2, 2)
  279. """
  280. def __init__(
  281. self,
  282. in_channels: int,
  283. out_channels: int,
  284. kernel_size: Union[int, Tuple[int, int]],
  285. stride: Union[int, Tuple[int, int]] = 1,
  286. padding: Union[int, Tuple[int, int]] = 0,
  287. dilation: Union[int, Tuple[int, int]] = 1,
  288. groups: int = 1,
  289. bias: bool = True,
  290. conv_mode: str = "cross_correlation",
  291. compute_mode: str = "default",
  292. padding_mode: str = "zeros",
  293. **kwargs
  294. ):
  295. kernel_size = _pair_nonzero(kernel_size)
  296. stride = _pair_nonzero(stride)
  297. padding = _pair(padding)
  298. dilation = _pair_nonzero(dilation)
  299. self.conv_mode = conv_mode
  300. self.compute_mode = compute_mode
  301. self.padding_mode = padding_mode
  302. super().__init__(
  303. in_channels,
  304. out_channels,
  305. kernel_size,
  306. stride,
  307. padding,
  308. dilation,
  309. groups,
  310. bias,
  311. **kwargs,
  312. )
  313. def _get_fanin(self):
  314. kh, kw = self.kernel_size
  315. ic = self.in_channels
  316. return kh * kw * ic
  317. def _infer_weight_shape(self):
  318. group = self.groups
  319. ichl = self.in_channels
  320. ochl = self.out_channels
  321. kh, kw = self.kernel_size
  322. if group == 1:
  323. # Assume format is NCHW
  324. return (ochl, ichl, kh, kw)
  325. assert (
  326. ichl % group == 0 and ochl % group == 0
  327. ), "invalid config: in_channels={} out_channels={} group={}".format(
  328. ichl, ochl, group
  329. )
  330. # Assume format is NCHW
  331. return (group, ochl // group, ichl // group, kh, kw)
  332. def _infer_bias_shape(self):
  333. # Assume format is NCHW
  334. return (1, self.out_channels, 1, 1)
  335. def get_pad_witdth(self):
  336. return (
  337. (0, 0),
  338. (0, 0),
  339. (self.padding[0], self.padding[0]),
  340. (self.padding[1], self.padding[1]),
  341. )
  342. def calc_conv(self, inp, weight, bias):
  343. assert self.padding_mode in [
  344. "zeros",
  345. "reflect",
  346. "replicate",
  347. ]
  348. if self.padding_mode != "zeros":
  349. return conv2d(
  350. pad(inp, self.get_pad_witdth(), self.padding_mode),
  351. weight,
  352. bias,
  353. self.stride,
  354. 0,
  355. self.dilation,
  356. self.groups,
  357. self.conv_mode,
  358. self.compute_mode,
  359. )
  360. return conv2d(
  361. inp,
  362. weight,
  363. bias,
  364. self.stride,
  365. self.padding,
  366. self.dilation,
  367. self.groups,
  368. self.conv_mode,
  369. self.compute_mode,
  370. )
  371. def forward(self, inp):
  372. return self.calc_conv(inp, self.weight, self.bias)
  373. class Conv3d(_ConvNd):
  374. r"""Applies a 3D convolution over an input tensor.
  375. For instance, given an input of the size :math:`(N, C_{\text{in}}, T, H, W)`,
  376. this layer generates an output of the size
  377. :math:`(N, C_{\text{out}}, T_{\text{out}}, H_{\text{out}}, W_{\text{out}})` through the
  378. process described as below:
  379. .. math::
  380. \text{out}(N_i, C_{\text{out}_j}) = \text{bias}(C_{\text{out}_j}) +
  381. \sum_{k = 0}^{C_{\text{in}} - 1} \text{weight}(C_{\text{out}_j}, k) \star \text{input}(N_i, k)
  382. where :math:`\star` is the valid 3D cross-correlation operator,
  383. :math:`N` is batch size, :math:`C` denotes number of channels.
  384. When `groups == in_channels` and `out_channels == K * in_channels`,
  385. where K is a positive integer, this operation is also known as depthwise
  386. convolution.
  387. In other words, for an input of size :math:`(N, C_{in}, T_{int}, H_{in}, W_{in})`,
  388. a depthwise convolution with a depthwise multiplier `K`, can be constructed
  389. by arguments :math:`(in\_channels=C_{in}, out\_channels=C_{in} \times K, ..., groups=C_{in})`.
  390. Args:
  391. in_channels: number of input channels.
  392. out_channels: number of output channels.
  393. kernel_size: size of weight on spatial dimensions. If kernel_size is
  394. an :class:`int`, the actual kernel size would be
  395. `(kernel_size, kernel_size, kernel_size)`.
  396. stride: stride of the 3D convolution operation. Default: 1
  397. padding: size of the paddings added to the input on both sides of its
  398. spatial dimensions. Only zero-padding is supported. Default: 0
  399. dilation: dilation of the 3D convolution operation. Default: 1
  400. groups: number of groups into which the input and output channels are divided,
  401. so as to perform a ``grouped convolution``. When ``groups`` is not 1,
  402. ``in_channels`` and ``out_channels`` must be divisible by ``groups``,
  403. and the shape of weight should be ``(groups, out_channel // groups,
  404. in_channels // groups, depth, height, width)``. Default: 1
  405. bias: whether to add a bias onto the result of convolution. Default: True
  406. conv_mode: Supports `cross_correlation`. Default: `cross_correlation`
  407. Note:
  408. * ``weight`` usually has shape ``(out_channels, in_channels, depth, height, width)`` ,
  409. if groups is not 1, shape will be ``(groups, out_channels // groups, in_channels // groups, depth, height, width)``
  410. * ``bias`` usually has shape ``(1, out_channels, *1)``
  411. Examples:
  412. >>> import numpy as np
  413. >>> m = M.Conv3d(in_channels=3, out_channels=1, kernel_size=3)
  414. >>> inp = mge.tensor(np.arange(0, 384).astype("float32").reshape(2, 3, 4, 4, 4))
  415. >>> oup = m(inp)
  416. >>> oup.numpy().shape
  417. (2, 1, 2, 2, 2)
  418. """
  419. def __init__(
  420. self,
  421. in_channels: int,
  422. out_channels: int,
  423. kernel_size: Union[int, Tuple[int, int, int]],
  424. stride: Union[int, Tuple[int, int, int]] = 1,
  425. padding: Union[int, Tuple[int, int, int]] = 0,
  426. dilation: Union[int, Tuple[int, int, int]] = 1,
  427. groups: int = 1,
  428. bias: bool = True,
  429. conv_mode: str = "cross_correlation",
  430. ):
  431. kernel_size = _triple_nonzero(kernel_size)
  432. stride = _triple_nonzero(stride)
  433. padding = _triple(padding)
  434. dilation = _triple_nonzero(dilation)
  435. self.conv_mode = conv_mode
  436. super().__init__(
  437. in_channels,
  438. out_channels,
  439. kernel_size,
  440. stride,
  441. padding,
  442. dilation,
  443. groups,
  444. bias,
  445. )
  446. def _get_fanin(self):
  447. kt, kh, kw = self.kernel_size
  448. ic = self.in_channels
  449. return kt * kh * kw * ic
  450. def _infer_weight_shape(self):
  451. group = self.groups
  452. ichl = self.in_channels
  453. ochl = self.out_channels
  454. kt, kh, kw = self.kernel_size
  455. if group == 1:
  456. # Assume format is NCTHW
  457. return (ochl, ichl, kt, kh, kw)
  458. assert (
  459. ichl % group == 0 and ochl % group == 0
  460. ), "invalid config: in_channels={} out_channels={} group={}".format(
  461. ichl, ochl, group
  462. )
  463. # Assume format is NCTHW
  464. return (group, ochl // group, ichl // group, kt, kh, kw)
  465. def _infer_bias_shape(self):
  466. # Assume format is NCTHW
  467. return (1, self.out_channels, 1, 1, 1)
  468. def calc_conv(self, inp, weight, bias):
  469. return conv3d(
  470. inp,
  471. weight,
  472. bias,
  473. self.stride,
  474. self.padding,
  475. self.dilation,
  476. self.groups,
  477. self.conv_mode,
  478. )
  479. def forward(self, inp):
  480. return self.calc_conv(inp, self.weight, self.bias)
  481. class ConvTranspose2d(_ConvNd):
  482. r"""Applies a 2D transposed convolution over an input tensor.
  483. This module is also known as a deconvolution or a fractionally-strided convolution.
  484. :class:`ConvTranspose2d` can be seen as the gradient of :class:`Conv2d` operation
  485. with respect to its input.
  486. Convolution usually reduces the size of input, while transposed convolution works
  487. the opposite way, transforming a smaller input to a larger output while preserving the
  488. connectivity pattern.
  489. Args:
  490. in_channels: number of input channels.
  491. out_channels: number of output channels.
  492. kernel_size: size of weight on spatial dimensions. If ``kernel_size`` is
  493. an :class:`int`, the actual kernel size would be
  494. ``(kernel_size, kernel_size)``.
  495. stride: stride of the 2D convolution operation. Default: 1
  496. padding: size of the paddings added to the input on both sides of its
  497. spatial dimensions. Only zero-padding is supported. Default: 0
  498. dilation: dilation of the 2D convolution operation. Default: 1
  499. groups: number of groups into which the input and output channels are divided,
  500. so as to perform a ``grouped convolution``. When ``groups`` is not 1,
  501. ``in_channels`` and ``out_channels`` must be divisible by groups,
  502. and the shape of weight should be ``(groups, in_channels // groups,
  503. out_channels // groups, height, width)``. Default: 1
  504. bias: wether to add a bias onto the result of convolution. Default: True
  505. conv_mode: Supports `cross_correlation`. Default: `cross_correlation`
  506. compute_mode: When set to "default", no special requirements will be
  507. placed on the precision of intermediate results. When set to "float32",
  508. "float32" would be used for accumulator and intermediate result, but only
  509. effective when input and output are of float16 dtype.
  510. Note:
  511. * ``weight`` usually has shape ``(in_channels, out_channels, height, width)`` ,
  512. if groups is not 1, shape will be ``(groups, in_channels // groups, out_channels // groups, height, width)``
  513. * ``bias`` usually has shape ``(1, out_channels, *1)``
  514. """
  515. def __init__(
  516. self,
  517. in_channels: int,
  518. out_channels: int,
  519. kernel_size: Union[int, Tuple[int, int]],
  520. stride: Union[int, Tuple[int, int]] = 1,
  521. padding: Union[int, Tuple[int, int]] = 0,
  522. dilation: Union[int, Tuple[int, int]] = 1,
  523. groups: int = 1,
  524. bias: bool = True,
  525. conv_mode: str = "cross_correlation",
  526. compute_mode: str = "default",
  527. **kwargs
  528. ):
  529. kernel_size = _pair_nonzero(kernel_size)
  530. stride = _pair_nonzero(stride)
  531. padding = _pair(padding)
  532. dilation = _pair_nonzero(dilation)
  533. self.conv_mode = conv_mode
  534. self.compute_mode = compute_mode
  535. super().__init__(
  536. in_channels,
  537. out_channels,
  538. kernel_size,
  539. stride,
  540. padding,
  541. dilation,
  542. groups,
  543. bias,
  544. **kwargs,
  545. )
  546. def _get_fanin(self):
  547. kh, kw = self.kernel_size
  548. oc = self.out_channels
  549. return kh * kw * oc
  550. def _infer_weight_shape(self):
  551. group = self.groups
  552. ichl = self.in_channels
  553. ochl = self.out_channels
  554. kh, kw = self.kernel_size
  555. if group == 1:
  556. # Assume format is NCHW
  557. return (ichl, ochl, kh, kw)
  558. assert (
  559. ichl % group == 0 and ochl % group == 0
  560. ), "invalid config: in_channels={} out_channels={} group={}".format(
  561. ichl, ochl, group
  562. )
  563. # Assume format is NCHW
  564. return (group, ichl // group, ochl // group, kh, kw)
  565. def _infer_bias_shape(self):
  566. # Assume format is NCHW
  567. return (1, self.out_channels, 1, 1)
  568. def calc_conv_transpose2d(self, inp, weight, bias):
  569. return conv_transpose2d(
  570. inp,
  571. weight,
  572. bias,
  573. self.stride,
  574. self.padding,
  575. self.dilation,
  576. self.groups,
  577. self.conv_mode,
  578. self.compute_mode,
  579. )
  580. def forward(self, inp):
  581. return self.calc_conv_transpose2d(inp, self.weight, self.bias)
  582. class LocalConv2d(Conv2d):
  583. r"""Applies a spatial convolution with untied kernels over an groupped channeled input 4D tensor.
  584. It is also known as the locally connected layer.
  585. Args:
  586. in_channels: number of input channels.
  587. out_channels: number of output channels.
  588. input_height: the height of the input images.
  589. input_width: the width of the input images.
  590. kernel_size: size of weight on spatial dimensions. If kernel_size is
  591. an :class:`int`, the actual kernel size would be
  592. ``(kernel_size, kernel_size)``.
  593. stride: stride of the 2D convolution operation. Default: 1
  594. padding: size of the paddings added to the input on both sides of its
  595. spatial dimensions. Only zero-padding is supported. Default: 0
  596. dilation: dilation of the 2D convolution operation. Default: 1
  597. groups: number of groups into which the input and output channels are divided,
  598. so as to perform a "grouped convolution". When ``groups`` is not 1,
  599. ``in_channels`` and ``out_channels`` must be divisible by ``groups``. Default: 1
  600. Note:
  601. * ``weight`` usually has shape ``(out_height, out_width, in_channels, height, width, in_channels)`` ,
  602. if groups is not 1, shape will be ``(groups, out_height, out_width, in_channels // groups, height, width, out_channels // groups)``
  603. * ``bias`` usually has shape ``(1, out_channels, *1)``
  604. """
  605. def __init__(
  606. self,
  607. in_channels: int,
  608. out_channels: int,
  609. input_height: int,
  610. input_width: int,
  611. kernel_size: Union[int, Tuple[int, int]],
  612. stride: Union[int, Tuple[int, int]] = 1,
  613. padding: Union[int, Tuple[int, int]] = 0,
  614. dilation: Union[int, Tuple[int, int]] = 1,
  615. groups: int = 1,
  616. conv_mode: str = "cross_correlation",
  617. **kwargs
  618. ):
  619. self.input_height = input_height
  620. self.input_width = input_width
  621. super().__init__(
  622. in_channels,
  623. out_channels,
  624. kernel_size,
  625. stride,
  626. padding,
  627. dilation,
  628. groups,
  629. bias=False,
  630. **kwargs,
  631. )
  632. def _infer_weight_shape(self):
  633. group = self.groups
  634. out_height = (
  635. self.input_height + self.padding[0] * 2 - self.kernel_size[0]
  636. ) // self.stride[0] + 1
  637. out_width = (
  638. self.input_width + self.padding[1] * 2 - self.kernel_size[1]
  639. ) // self.stride[1] + 1
  640. # Assume format is NCHW
  641. return (
  642. group,
  643. out_height,
  644. out_width,
  645. self.in_channels // group,
  646. self.kernel_size[0],
  647. self.kernel_size[1],
  648. self.out_channels // group,
  649. )
  650. def forward(self, inp):
  651. return local_conv2d(
  652. inp,
  653. self.weight,
  654. None,
  655. self.stride,
  656. self.padding,
  657. self.dilation,
  658. self.conv_mode,
  659. )
  660. class ConvRelu2d(Conv2d):
  661. r"""A fused :class:`~.Module` including :class:`~.module.Conv2d` and :func:`~.relu`.
  662. Could be replaced with :class:`~.QATModule` version :class:`~.qat.ConvRelu2d` using :func:`~.quantize.quantize_qat`.
  663. """
  664. def forward(self, inp):
  665. return relu(self.calc_conv(inp, self.weight, self.bias))
  666. class DeformableConv2d(_ConvNd):
  667. r"""Deformable Convolution.
  668. Args:
  669. in_channels: number of input channels.
  670. out_channels: number of output channels.
  671. kernel_size: size of weight on spatial dimensions. If kernel_size is
  672. an :class:`int`, the actual kernel size would be
  673. ``(kernel_size, kernel_size)``.
  674. stride: stride of the 2D convolution operation. Default: 1
  675. padding: size of the paddings added to the input on both sides of its
  676. spatial dimensions. Only zero-padding is supported. Default: 0
  677. dilation: dilation of the 2D convolution operation. Default: 1
  678. groups: number of groups into which the input and output channels are divided,
  679. so as to perform a ``grouped convolution``. When ``groups`` is not 1,
  680. ``in_channels`` and ``out_channels`` must be divisible by groups,
  681. and the shape of weight should be ``(groups, out_channel // groups,
  682. in_channels // groups, height, width)``. Default: 1
  683. bias: whether to add a bias onto the result of convolution. Default: True
  684. conv_mode: Supports `cross_correlation`. Default: `cross_correlation`
  685. compute_mode: When set to "default", no special requirements will be
  686. placed on the precision of intermediate results. When set to "float32",
  687. "float32" would be used for accumulator and intermediate result, but only
  688. effective when input and output are of float16 dtype.
  689. Note:
  690. * ``weight`` usually has shape ``(out_channels, in_channels, height, width)`` ,
  691. if groups is not 1, shape will be ``(groups, out_channels // groups, in_channels // groups, height, width)``
  692. * ``bias`` usually has shape ``(1, out_channels, *1)``
  693. """
  694. def __init__(
  695. self,
  696. in_channels: int,
  697. out_channels: int,
  698. kernel_size: Union[int, Tuple[int, int]],
  699. stride: Union[int, Tuple[int, int]] = 1,
  700. padding: Union[int, Tuple[int, int]] = 0,
  701. dilation: Union[int, Tuple[int, int]] = 1,
  702. groups: int = 1,
  703. bias: bool = True,
  704. conv_mode: str = "cross_correlation",
  705. compute_mode: str = "default",
  706. **kwargs
  707. ):
  708. kernel_size = _pair_nonzero(kernel_size)
  709. stride = _pair_nonzero(stride)
  710. padding = _pair(padding)
  711. dilation = _pair_nonzero(dilation)
  712. self.conv_mode = conv_mode
  713. self.compute_mode = compute_mode
  714. super().__init__(
  715. in_channels,
  716. out_channels,
  717. kernel_size,
  718. stride,
  719. padding,
  720. dilation,
  721. groups,
  722. bias,
  723. **kwargs,
  724. )
  725. def _get_fanin(self):
  726. kh, kw = self.kernel_size
  727. ic = self.in_channels
  728. return kh * kw * ic
  729. def _infer_weight_shape(self):
  730. group = self.groups
  731. ichl = self.in_channels
  732. ochl = self.out_channels
  733. kh, kw = self.kernel_size
  734. if group == 1:
  735. # Assume format is NCHW
  736. return (ochl, ichl, kh, kw)
  737. assert (
  738. ichl % group == 0 and ochl % group == 0
  739. ), "invalid config: in_channels={} out_channels={} group={}".format(
  740. ichl, ochl, group
  741. )
  742. # Assume format is NCHW
  743. return (group, ochl // group, ichl // group, kh, kw)
  744. def _infer_bias_shape(self):
  745. # Assume format is NCHW
  746. return (1, self.out_channels, 1, 1)
  747. def calc_conv(self, inp, weight, offset, mask, bias):
  748. return deformable_conv2d(
  749. inp,
  750. weight,
  751. offset,
  752. mask,
  753. bias,
  754. self.stride,
  755. self.padding,
  756. self.dilation,
  757. self.groups,
  758. self.conv_mode,
  759. self.compute_mode,
  760. )
  761. def forward(self, inp, offset, mask):
  762. return self.calc_conv(inp, self.weight, offset, mask, self.bias)
  763. class ConvTranspose3d(_ConvNd):
  764. r"""Applies a 3D transposed convolution over an input tensor.
  765. Only support the case that groups = 1 and conv_mode = "cross_correlation".
  766. :class:`ConvTranspose3d` can be seen as the gradient of :class:`Conv3d` operation
  767. with respect to its input.
  768. Convolution3D usually reduces the size of input, while transposed convolution3d
  769. works the opposite way, transforming a smaller input to a larger output while
  770. preserving the connectivity pattern.
  771. Args:
  772. in_channels: number of input channels.
  773. out_channels: number of output channels.
  774. kernel_size: size of weight on spatial dimensions. If ``kernel_size`` is
  775. an :class:`int`, the actual kernel size would be
  776. ``(kernel_size, kernel_size, kernel_size)``.
  777. stride: stride of the 3D convolution operation. Default: 1
  778. padding: size of the paddings added to the input on all sides of its
  779. spatial dimensions. Only zero-padding is supported. Default: 0
  780. dilation: dilation of the 3D convolution operation. Default: 1
  781. groups: number of groups into which the input and output channels are divided,
  782. so as to perform a ``grouped convolution``. When ``groups`` is not 1,
  783. ``in_channels`` and ``out_channels`` must be divisible by groups,
  784. and the shape of weight should be ``(groups, in_channels // groups,
  785. out_channels // groups, depth, height, width)``. Default: 1
  786. bias: wether to add a bias onto the result of convolution. Default: True
  787. Note:
  788. * ``weight`` usually has shape ``(in_channels, out_channels, depth, height, width)`` .
  789. * ``bias`` usually has shape ``(1, out_channels, *1)``
  790. """
  791. def __init__(
  792. self,
  793. in_channels: int,
  794. out_channels: int,
  795. kernel_size: Union[int, Tuple[int, int, int]],
  796. stride: Union[int, Tuple[int, int, int]] = 1,
  797. padding: Union[int, Tuple[int, int, int]] = 0,
  798. dilation: Union[int, Tuple[int, int, int]] = 1,
  799. groups: int = 1,
  800. bias: bool = True,
  801. ):
  802. kernel_size = _triple_nonzero(kernel_size)
  803. stride = _triple_nonzero(stride)
  804. padding = _triple(padding)
  805. dilation = _triple_nonzero(dilation)
  806. super().__init__(
  807. in_channels=in_channels,
  808. out_channels=out_channels,
  809. kernel_size=kernel_size,
  810. stride=stride,
  811. padding=padding,
  812. dilation=dilation,
  813. groups=groups,
  814. bias=bias,
  815. )
  816. def _get_fanin(self):
  817. kt, kh, kw = self.kernel_size
  818. ic = self.in_channels
  819. return kt * kh * kw * ic
  820. def _infer_weight_shape(self):
  821. group = self.groups
  822. ichl = self.in_channels
  823. ochl = self.out_channels
  824. kt, kh, kw = self.kernel_size
  825. if group == 1:
  826. # Assume format is NCHW
  827. return (ichl, ochl, kt, kh, kw)
  828. assert (
  829. ichl % group == 0 and ochl % group == 0
  830. ), "invalid config: in_channels={} out_channels={} group={}".format(
  831. ichl, ochl, group
  832. )
  833. # Assume format is NCHW
  834. return (group, ichl // group, ochl // group, kt, kh, kw)
  835. def _infer_bias_shape(self):
  836. # Assume format is NCTHW
  837. return (1, self.out_channels, 1, 1, 1)
  838. def forward(self, inp):
  839. return conv_transpose3d(
  840. inp, self.weight, self.bias, self.stride, self.padding, self.dilation,
  841. )