You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

conv.py 15 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432
  1. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  2. #
  3. # Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
  4. #
  5. # Unless required by applicable law or agreed to in writing,
  6. # software distributed under the License is distributed on an
  7. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  8. from abc import abstractmethod
  9. from typing import Tuple, Union
  10. import numpy as np
  11. from ..core.ops._internal import param_defs as P
  12. from ..functional import conv2d, conv_transpose2d, local_conv2d, relu
  13. from ..functional.types import _pair, _pair_nonzero
  14. from ..tensor import Parameter
  15. from . import init
  16. from .module import Module
  17. class _ConvNd(Module):
  18. """base class for convolution modules, including transposed conv"""
  19. def __init__(
  20. self,
  21. in_channels: int,
  22. out_channels: int,
  23. kernel_size: Union[int, Tuple[int, int]],
  24. stride: Union[int, Tuple[int, int]],
  25. padding: Union[int, Tuple[int, int]],
  26. dilation: Union[int, Tuple[int, int]],
  27. groups: int,
  28. bias: bool = True,
  29. ):
  30. super().__init__()
  31. if in_channels % groups != 0:
  32. raise ValueError("in_channels must be divisible by groups")
  33. if out_channels % groups != 0:
  34. raise ValueError("out_channels must be divisible by groups")
  35. self.in_channels = in_channels
  36. self.out_channels = out_channels
  37. self.kernel_size = kernel_size
  38. self.stride = stride
  39. self.padding = padding
  40. self.dilation = dilation
  41. self.groups = groups
  42. self.weight = Parameter(np.zeros(self._infer_weight_shape(), dtype=np.float32))
  43. self.bias = None
  44. if bias:
  45. self.bias = Parameter(np.zeros(self._infer_bias_shape(), dtype=np.float32))
  46. self.reset_parameters()
  47. @abstractmethod
  48. def _get_fanin(self):
  49. pass
  50. def reset_parameters(self) -> None:
  51. fanin = self._get_fanin()
  52. std = np.sqrt(1 / fanin)
  53. init.normal_(self.weight, 0.0, std)
  54. if self.bias is not None:
  55. init.zeros_(self.bias)
  56. @abstractmethod
  57. def _infer_weight_shape(self):
  58. pass
  59. @abstractmethod
  60. def _infer_bias_shape(self):
  61. pass
  62. def _module_info_string(self):
  63. s = "{in_channels}, {out_channels}, kernel_size={kernel_size}"
  64. if self.stride != (1,) * len(self.stride):
  65. s += ", stride={stride}"
  66. if self.padding != (0,) * len(self.padding):
  67. s += ", padding={padding}"
  68. if self.dilation != (1,) * len(self.dilation):
  69. s += ", dilation={dilation}"
  70. if self.groups != 1:
  71. s += ", groups={groups}"
  72. if self.bias is None:
  73. s += ", bias=False"
  74. return s.format(**self.__dict__)
  75. class Conv2d(_ConvNd):
  76. r"""Applies a 2D convolution over an input tensor.
  77. For instance, given an input of the size :math:`(N, C_{\text{in}}, H, W)`,
  78. this layer generates an output of the size
  79. :math:`(N, C_{\text{out}}, H_{\text{out}}, W_{\text{out}})` through the
  80. process described as below:
  81. .. math::
  82. \text{out}(N_i, C_{\text{out}_j}) = \text{bias}(C_{\text{out}_j}) +
  83. \sum_{k = 0}^{C_{\text{in}} - 1} \text{weight}(C_{\text{out}_j}, k) \star \text{input}(N_i, k)
  84. where :math:`\star` is the valid 2D cross-correlation operator,
  85. :math:`N` is a batch size, :math:`C` denotes a number of channels,
  86. :math:`H` is a height of input planes in pixels, and :math:`W` is
  87. width in pixels.
  88. When `groups == in_channels` and `out_channels == K * in_channels`,
  89. where K is a positive integer, this operation is also known as depthwise
  90. convolution.
  91. In other words, for an input of size :math:`(N, C_{in}, H_{in}, W_{in})`,
  92. a depthwise convolution with a depthwise multiplier `K`, can be constructed
  93. by arguments :math:`(in\_channels=C_{in}, out\_channels=C_{in} \times K, ..., groups=C_{in})`.
  94. :param in_channels: number of input channels.
  95. :param out_channels: number of output channels.
  96. :param kernel_size: size of weight on spatial dimensions. If kernel_size is
  97. an :class:`int`, the actual kernel size would be
  98. `(kernel_size, kernel_size)`. Default: 1
  99. :param stride: stride of the 2D convolution operation. Default: 1
  100. :param padding: size of the paddings added to the input on both sides of its
  101. spatial dimensions. Only zero-padding is supported. Default: 0
  102. :param dilation: dilation of the 2D convolution operation. Default: 1
  103. :param groups: number of groups to divide input and output channels into,
  104. so as to perform a "grouped convolution". When groups is not 1,
  105. in_channels and out_channels must be divisible by groups,
  106. and there would be an extra dimension at the beginning of the weight's
  107. shape. Specifically, the shape of weight would be `(groups,
  108. out_channel // groups, in_channels // groups, *kernel_size)`.
  109. :param bias: whether to add a bias onto the result of convolution. Default:
  110. True
  111. :param conv_mode: Supports `CROSS_CORRELATION` or `CONVOLUTION`. Default:
  112. `CROSS_CORRELATION`
  113. :param compute_mode: When set to `DEFAULT`, no special requirements will be
  114. placed on the precision of intermediate results. When set to `FLOAT32`,
  115. float32 would be used for accumulator and intermediate result, but only
  116. effective when input and output are of float16 dtype.
  117. Examples:
  118. .. testcode::
  119. import numpy as np
  120. import megengine as mge
  121. import megengine.module as M
  122. m = M.Conv2d(in_channels=3, out_channels=1, kernel_size=3)
  123. inp = mge.tensor(np.arange(0, 96).astype("float32").reshape(2, 3, 4, 4))
  124. oup = m(inp)
  125. print(oup.shape)
  126. Outputs:
  127. .. testoutput::
  128. (2, 1, 2, 2)
  129. """
  130. _conv_mode_type = P.Convolution.Mode
  131. _compute_mode_type = P.Convolution.ComputeMode
  132. def __init__(
  133. self,
  134. in_channels: int,
  135. out_channels: int,
  136. kernel_size: Union[int, Tuple[int, int]],
  137. stride: Union[int, Tuple[int, int]] = 1,
  138. padding: Union[int, Tuple[int, int]] = 0,
  139. dilation: Union[int, Tuple[int, int]] = 1,
  140. groups: int = 1,
  141. bias: bool = True,
  142. conv_mode: str = "CROSS_CORRELATION",
  143. compute_mode: str = "DEFAULT",
  144. ):
  145. kernel_size = _pair_nonzero(kernel_size)
  146. stride = _pair_nonzero(stride)
  147. padding = _pair(padding)
  148. dilation = _pair_nonzero(dilation)
  149. self.conv_mode = self._conv_mode_type.convert(conv_mode)
  150. self.compute_mode = self._compute_mode_type.convert(compute_mode)
  151. super().__init__(
  152. in_channels,
  153. out_channels,
  154. kernel_size,
  155. stride,
  156. padding,
  157. dilation,
  158. groups,
  159. bias,
  160. )
  161. def _get_fanin(self):
  162. kh, kw = self.kernel_size
  163. ic = self.in_channels
  164. return kh * kw * ic
  165. def _infer_weight_shape(self):
  166. group = self.groups
  167. ichl = self.in_channels
  168. ochl = self.out_channels
  169. kh, kw = self.kernel_size
  170. if group == 1:
  171. # Assume format is NCHW
  172. return (ochl, ichl, kh, kw)
  173. assert (
  174. ichl % group == 0 and ochl % group == 0
  175. ), "invalid config: input_channels={} output_channels={} group={}".format(
  176. ichl, ochl, group
  177. )
  178. # Assume format is NCHW
  179. return (group, ochl // group, ichl // group, kh, kw)
  180. def _infer_bias_shape(self):
  181. # Assume format is NCHW
  182. return (1, self.out_channels, 1, 1)
  183. def calc_conv(self, inp, weight, bias):
  184. return conv2d(
  185. inp,
  186. weight,
  187. bias,
  188. self.stride,
  189. self.padding,
  190. self.dilation,
  191. self.groups,
  192. self.conv_mode,
  193. self.compute_mode,
  194. )
  195. def forward(self, inp):
  196. return self.calc_conv(inp, self.weight, self.bias)
  197. class ConvTranspose2d(_ConvNd):
  198. r"""Applies a 2D transposed convolution over an input tensor.
  199. This module is also known as a deconvolution or a fractionally-strided convolution.
  200. :class:`ConvTranspose2d` can ben seen as the gradient of :class:`Conv2d` operation
  201. with respect to its input.
  202. Convolution usually reduces the size of input, while transposed convolution works
  203. the opposite way, transforming a smaller input to a larger output while preserving the
  204. connectivity pattern.
  205. :param in_channels: number of input channels.
  206. :param out_channels: number of output channels.
  207. :param kernel_size: size of weight on spatial dimensions. If ``kernel_size`` is
  208. an :class:`int`, the actual kernel size would be
  209. ``(kernel_size, kernel_size)``. Default: 1
  210. :param stride: stride of the 2D convolution operation. Default: 1
  211. :param padding: size of the paddings added to the input on both sides of its
  212. spatial dimensions. Only zero-padding is supported. Default: 0
  213. :param dilation: dilation of the 2D convolution operation. Default: 1
  214. :param groups: number of groups to divide input and output channels into,
  215. so as to perform a "grouped convolution". When ``groups`` is not 1,
  216. ``in_channels`` and ``out_channels`` must be divisible by ``groups``,
  217. and there would be an extra dimension at the beginning of the weight's
  218. shape. Specifically, the shape of weight would be ``(groups,
  219. out_channels // groups, in_channels // groups, *kernel_size)``. Default: 1
  220. :param bias: wether to add a bias onto the result of convolution. Default:
  221. True
  222. :param conv_mode: Supports `CROSS_CORRELATION` or `CONVOLUTION`. Default:
  223. `CROSS_CORRELATION`
  224. :param compute_mode: When set to `DEFAULT`, no special requirements will be
  225. placed on the precision of intermediate results. When set to `FLOAT32`,
  226. float32 would be used for accumulator and intermediate result, but only
  227. effective when input and output are of float16 dtype.
  228. """
  229. _conv_mode_type = P.Convolution.Mode
  230. _compute_mode_type = P.Convolution.ComputeMode
  231. def __init__(
  232. self,
  233. in_channels: int,
  234. out_channels: int,
  235. kernel_size: Union[int, Tuple[int, int]],
  236. stride: Union[int, Tuple[int, int]] = 1,
  237. padding: Union[int, Tuple[int, int]] = 0,
  238. dilation: Union[int, Tuple[int, int]] = 1,
  239. groups: int = 1,
  240. bias: bool = True,
  241. conv_mode: str = "CROSS_CORRELATION",
  242. compute_mode: str = "DEFAULT",
  243. ):
  244. kernel_size = _pair_nonzero(kernel_size)
  245. stride = _pair_nonzero(stride)
  246. padding = _pair(padding)
  247. dilation = _pair_nonzero(dilation)
  248. self.conv_mode = self._conv_mode_type.convert(conv_mode)
  249. self.compute_mode = self._compute_mode_type.convert(compute_mode)
  250. super().__init__(
  251. in_channels,
  252. out_channels,
  253. kernel_size,
  254. stride,
  255. padding,
  256. dilation,
  257. groups,
  258. bias,
  259. )
  260. def _get_fanin(self):
  261. kh, kw = self.kernel_size
  262. oc = self.out_channels
  263. return kh * kw * oc
  264. def _infer_weight_shape(self):
  265. group = self.groups
  266. ichl = self.in_channels
  267. ochl = self.out_channels
  268. kh, kw = self.kernel_size
  269. if group == 1:
  270. # Assume format is NCHW
  271. return (ichl, ochl, kh, kw)
  272. assert (
  273. ichl % group == 0 and ochl % group == 0
  274. ), "invalid config: input_channels={} output_channels={} group={}".format(
  275. ichl, ochl, group
  276. )
  277. # Assume format is NCHW
  278. return (group, ichl // group, ochl // group, kh, kw)
  279. def _infer_bias_shape(self):
  280. # Assume format is NCHW
  281. return (1, self.out_channels, 1, 1)
  282. def forward(self, inp):
  283. return conv_transpose2d(
  284. inp,
  285. self.weight,
  286. self.bias,
  287. self.stride,
  288. self.padding,
  289. self.dilation,
  290. self.groups,
  291. self.conv_mode,
  292. self.compute_mode,
  293. )
  294. class LocalConv2d(Conv2d):
  295. r"""Applies a spatial convolution with untied kernels over an input 4D tensor.
  296. It is also known as the locally connected layer.
  297. :param in_channels: number of input channels.
  298. :param out_channels: number of output channels.
  299. :param input_height: the height of the input images.
  300. :param input_width: the width of the input images.
  301. :param kernel_size: size of weight on spatial dimensions. If kernel_size is
  302. an :class:`int`, the actual kernel size would be
  303. `(kernel_size, kernel_size)`. Default: 1
  304. :param stride: stride of the 2D convolution operation. Default: 1
  305. :param padding: size of the paddings added to the input on both sides of its
  306. spatial dimensions. Only zero-padding is supported. Default: 0
  307. :param groups: number of groups to divide input and output channels into,
  308. so as to perform a "grouped convolution". When groups is not 1,
  309. in_channels and out_channels must be divisible by groups.
  310. The shape of weight is `(groups, output_height, output_width,
  311. in_channels // groups, *kernel_size, out_channels // groups)`.
  312. """
  313. _conv_mode_type = P.Convolution.Mode
  314. def __init__(
  315. self,
  316. in_channels: int,
  317. out_channels: int,
  318. input_height: int,
  319. input_width: int,
  320. kernel_size: Union[int, Tuple[int, int]],
  321. stride: Union[int, Tuple[int, int]] = 1,
  322. padding: Union[int, Tuple[int, int]] = 0,
  323. dilation: Union[int, Tuple[int, int]] = 1,
  324. groups: int = 1,
  325. conv_mode: str = "CROSS_CORRELATION",
  326. ):
  327. self.input_height = input_height
  328. self.input_width = input_width
  329. super().__init__(
  330. in_channels,
  331. out_channels,
  332. kernel_size,
  333. stride,
  334. padding,
  335. dilation,
  336. groups,
  337. bias=False,
  338. )
  339. def _infer_weight_shape(self):
  340. group = self.groups
  341. output_height = (
  342. self.input_height + self.padding[0] * 2 - self.kernel_size[0]
  343. ) // self.stride[0] + 1
  344. output_width = (
  345. self.input_width + self.padding[1] * 2 - self.kernel_size[1]
  346. ) // self.stride[1] + 1
  347. # Assume format is NCHW
  348. return (
  349. group,
  350. output_height,
  351. output_width,
  352. self.in_channels // group,
  353. self.kernel_size[0],
  354. self.kernel_size[1],
  355. self.out_channels // group,
  356. )
  357. def forward(self, inp):
  358. return local_conv2d(
  359. inp,
  360. self.weight,
  361. None,
  362. self.stride,
  363. self.padding,
  364. self.dilation,
  365. self.conv_mode,
  366. )
  367. class ConvRelu2d(Conv2d):
  368. r"""
  369. A fused :class:`~.Module` including Conv2d and relu. Could be replaced
  370. with :class:`~.QATModule` version :class:`~.qat.conv.ConvRelu2d` using
  371. :func:`~.quantize.quantize_qat`.
  372. """
  373. def forward(self, inp):
  374. return relu(self.calc_conv(inp, self.weight, self.bias))

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台