You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

nn.py 31 kB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010
  1. # -*- coding: utf-8 -*-
  2. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  3. #
  4. # Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
  5. #
  6. # Unless required by applicable law or agreed to in writing,
  7. # software distributed under the License is distributed on an
  8. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  9. # pylint: disable=too-many-lines
  10. from typing import Optional, Tuple, Union
  11. import megengine._internal as mgb
  12. from megengine._internal import CompGraph, CompNode
  13. from ..core import Tensor, wrap_io_tensor
  14. from ..core.graph import _use_default_if_none
  15. from ..jit import barrier, mark_impure
  16. from ..random import uniform
  17. from ..utils.types import _pair, _pair_nonzero
  18. from .debug_param import get_conv_execution_strategy
  19. from .tensor import concat
  20. from .utils import _decide_comp_node_and_comp_graph
  21. @wrap_io_tensor
  22. def linear(inp: Tensor, weight: Tensor, bias: Optional[Tensor] = None) -> Tensor:
  23. """Applies a linear transformation to the input.
  24. Refer to :class:`~.Linear` for more information.
  25. :param inp: the input tensor with shape `(N, in_features)`.
  26. :param weight: the weight with shape `(out_features, in_features)`.
  27. :param bias: the bias with shape `(out_features,)`.
  28. Default: ``None``
  29. """
  30. orig_shape = inp.shape
  31. inp = inp.reshape(-1, orig_shape[-1])
  32. ret = mgb.opr.matrix_mul(inp, weight, transposeB=True)
  33. ret = ret.reshape(orig_shape[:-1], weight.shape[0])
  34. if bias is not None:
  35. ret += bias
  36. return ret
  37. @wrap_io_tensor
  38. def conv2d(
  39. inp: Tensor,
  40. weight: Tensor,
  41. bias: Optional[Tensor] = None,
  42. stride: Union[int, Tuple[int, int]] = 1,
  43. padding: Union[int, Tuple[int, int]] = 0,
  44. dilation: Union[int, Tuple[int, int]] = 1,
  45. groups: int = 1,
  46. conv_mode="CROSS_CORRELATION",
  47. compute_mode="DEFAULT",
  48. ) -> Tensor:
  49. """2D convolution operation.
  50. Refer to :class:`~.Conv2d` for more information.
  51. :param inp: The feature map of the convolution operation
  52. :param weight: The convolution kernel
  53. :param bias: The bias added to the result of convolution (if given)
  54. :param stride: Stride of the 2D convolution operation. Default: 1
  55. :param padding: Size of the paddings added to the input on both sides of its
  56. spatial dimensions. Only zero-padding is supported. Default: 0
  57. :param dilation: Dilation of the 2D convolution operation. Default: 1
  58. :param groups: number of groups to divide input and output channels into,
  59. so as to perform a "grouped convolution". When ``groups`` is not 1,
  60. ``in_channels`` and ``out_channels`` must be divisible by ``groups``,
  61. and the shape of weight should be ``(groups, out_channel // groups,
  62. in_channels // groups, height, width)``.
  63. :type conv_mode: string or :class:`mgb.opr_param_defs.Convolution.Mode`
  64. :param conv_mode: Supports 'CROSS_CORRELATION' or 'CONVOLUTION'. Default:
  65. 'CROSS_CORRELATION'.
  66. :type compute_mode: string or
  67. :class:`mgb.opr_param_defs.Convolution.ComputeMode`
  68. :param compute_mode: When set to 'DEFAULT', no special requirements will be
  69. placed on the precision of intermediate results. When set to 'FLOAT32',
  70. Float32 would be used for accumulator and intermediate result, but only
  71. effective when input and output are of Float16 dtype.
  72. """
  73. ph, pw = _pair(padding)
  74. sh, sw = _pair_nonzero(stride)
  75. dh, dw = _pair_nonzero(dilation)
  76. Sparse = mgb.opr_param_defs.Convolution.Sparse
  77. sparse_type = Sparse.DENSE if groups == 1 else Sparse.GROUP
  78. res = mgb.opr.convolution(
  79. inp,
  80. weight,
  81. pad_h=ph,
  82. pad_w=pw,
  83. stride_h=sh,
  84. stride_w=sw,
  85. dilate_h=dh,
  86. dilate_w=dw,
  87. format="NCHW",
  88. strategy=get_conv_execution_strategy(),
  89. mode=conv_mode,
  90. compute_mode=compute_mode,
  91. sparse=sparse_type,
  92. )
  93. if bias is not None:
  94. res += bias
  95. return res
  96. @wrap_io_tensor
  97. def conv_transpose2d(
  98. inp: Tensor,
  99. weight: Tensor,
  100. bias: Optional[Tensor] = None,
  101. stride: Union[int, Tuple[int, int]] = 1,
  102. padding: Union[int, Tuple[int, int]] = 0,
  103. dilation: Union[int, Tuple[int, int]] = 1,
  104. groups: int = 1,
  105. conv_mode="CROSS_CORRELATION",
  106. compute_mode="DEFAULT",
  107. ) -> Tensor:
  108. """2D transposed convolution operation.
  109. Refer to :class:`~.ConvTranspose2d` for more information.
  110. :param inp: The feature map of the convolution operation
  111. :param weight: The convolution kernel
  112. :param bias: The bias added to the result of convolution (if given)
  113. :param stride: Stride of the 2D convolution operation. Default: 1
  114. :param padding: Size of the paddings added to the input on both sides of its
  115. spatial dimensions. Only zero-padding is supported. Default: 0
  116. :param dilation: Dilation of the 2D convolution operation. Default: 1
  117. :param groups: number of groups to divide input and output channels into,
  118. so as to perform a "grouped convolution". When ``groups`` is not 1,
  119. ``in_channels`` and ``out_channels`` must be divisible by ``groups``,
  120. and the shape of weight should be ``(groups, out_channel // groups,
  121. in_channels // groups, height, width)``. Default: 1
  122. :type conv_mode: string or :class:`mgb.opr_param_defs.Convolution.Mode`
  123. :param conv_mode: Supports 'CROSS_CORRELATION' or 'CONVOLUTION'. Default:
  124. 'CROSS_CORRELATION'.
  125. :type compute_mode: string or
  126. :class:`mgb.opr_param_defs.Convolution.ComputeMode`
  127. :param compute_mode: When set to 'DEFAULT', no special requirements will be
  128. placed on the precision of intermediate results. When set to 'FLOAT32',
  129. Float32 would be used for accumulator and intermediate result, but only
  130. effective when input and output are of Float16 dtype.
  131. """
  132. ph, pw = _pair(padding)
  133. sh, sw = _pair_nonzero(stride)
  134. dh, dw = _pair_nonzero(dilation)
  135. Sparse = mgb.opr_param_defs.Convolution.Sparse
  136. sparse_type = Sparse.DENSE if groups == 1 else Sparse.GROUP
  137. res = mgb.opr.deconvolution(
  138. inp,
  139. weight,
  140. pad_h=ph,
  141. pad_w=pw,
  142. stride_h=sh,
  143. stride_w=sw,
  144. dilate_h=dh,
  145. dilate_w=dw,
  146. format="NCHW",
  147. strategy=get_conv_execution_strategy(),
  148. mode=conv_mode,
  149. compute_mode=compute_mode,
  150. sparse=sparse_type,
  151. )
  152. if bias is not None:
  153. res += bias
  154. return res
  155. @wrap_io_tensor
  156. def max_pool2d(
  157. inp: Tensor,
  158. kernel_size: Union[int, Tuple[int, int]],
  159. stride: Optional[Union[int, Tuple[int, int]]] = None,
  160. padding: Union[int, Tuple[int, int]] = 0,
  161. ) -> Tensor:
  162. """Applies a 2D max pooling over an input.
  163. Refer to :class:`~.MaxPool2d` for more information.
  164. :param inp: The input tensor.
  165. :param kernel_size: The size of the window.
  166. :param stride: The stride of the window. If not provided, its value is set to ``kernel_size``.
  167. Default: None
  168. :param padding: Implicit zero padding to be added on both sides. Default: 0
  169. """
  170. kh, kw = _pair_nonzero(kernel_size)
  171. sh, sw = _pair_nonzero(stride or kernel_size)
  172. ph, pw = _pair(padding)
  173. mode = mgb.opr_param_defs.Pooling.Mode.MAX
  174. return mgb.opr.pooling(
  175. inp,
  176. mode=mode,
  177. format="NCHW",
  178. stride_h=sh,
  179. stride_w=sw,
  180. pad_h=ph,
  181. pad_w=pw,
  182. window_h=kh,
  183. window_w=kw,
  184. )
  185. @wrap_io_tensor
  186. def avg_pool2d(
  187. inp: Tensor,
  188. kernel_size: Union[int, Tuple[int, int]],
  189. stride: Optional[Union[int, Tuple[int, int]]] = None,
  190. padding: Union[int, Tuple[int, int]] = 0,
  191. ) -> Tensor:
  192. """ Applies a 2D average pooling over an input.
  193. Refer to :class:`~.AvgPool2d` for more information.
  194. :param inp: The input tensor.
  195. :param kernel_size: The size of the window.
  196. :param stride: The stride of the window. If not provided, its value is set to ``kernel_size``.
  197. Default: None
  198. :param padding: Implicit zero padding to be added on both sides. Default: 0
  199. """
  200. kh, kw = _pair_nonzero(kernel_size)
  201. sh, sw = _pair_nonzero(stride or kernel_size)
  202. ph, pw = _pair(padding)
  203. mode = mgb.opr_param_defs.Pooling.Mode.AVERAGE
  204. return mgb.opr.pooling(
  205. inp,
  206. mode=mode,
  207. format="NCHW",
  208. stride_h=sh,
  209. stride_w=sw,
  210. pad_h=ph,
  211. pad_w=pw,
  212. window_h=kh,
  213. window_w=kw,
  214. )
  215. @wrap_io_tensor
  216. def prelu(inp: Tensor, weight: Tensor) -> Tensor:
  217. r"""
  218. Applies the element-wise PReLU function.
  219. Refer to :class:`~.PReLU` for more information.
  220. """
  221. return mgb.opr.elemwise(inp, 0, mode="MAX") + weight * mgb.opr.elemwise(
  222. inp, 0, mode="MIN"
  223. )
  224. @wrap_io_tensor
  225. def leaky_relu(inp: Tensor, negative_slope: float = 0.01) -> Tensor:
  226. r"""
  227. Applies the element-wise leaky_relu function
  228. Refer to :class:`~.LeakyReLU` for more information.
  229. """
  230. return mgb.opr.elemwise(inp, 0, mode="MAX") + negative_slope * mgb.opr.elemwise(
  231. inp, 0, mode="MIN"
  232. )
  233. @wrap_io_tensor
  234. def flatten(inp: Tensor, start_axis: int = 0, end_axis: int = -1) -> Tensor:
  235. r"""
  236. Reshapes the tensor by flattening the sub-tensor from dimension ``start_axis`` to dimension ``end_axis``.
  237. :param inp: The input tensor.
  238. :param start_axis: The start dimension that the sub-tensor to be flattened. Default: 0
  239. :param end_axis: The end dimension that the sub-tensor to be flattened. Default: -1
  240. Examples:
  241. .. testcode::
  242. import numpy as np
  243. from megengine import tensor
  244. import megengine.functional as F
  245. inp_shape = (2, 2, 3, 3)
  246. inp = tensor(
  247. np.arange(36, dtype=np.int32).reshape(inp_shape),
  248. )
  249. oup = F.flatten(inp, 2)
  250. print(inp.numpy().shape)
  251. print(oup.numpy().shape)
  252. Outputs:
  253. .. testoutput::
  254. (2, 2, 3, 3)
  255. (2, 2, 9)
  256. """
  257. target_shape = tuple(inp.shape[i] for i in range(start_axis)) + (-1,)
  258. if end_axis != -1:
  259. target_shape += (inp.shape[end_axis + 1 :],)
  260. return inp.reshape(*target_shape)
  261. def _get_softmax_axis(ndim: int) -> int:
  262. if ndim in (0, 1, 3):
  263. return 0
  264. return 1
  265. @wrap_io_tensor
  266. def softmax(inp: Tensor, axis: Optional[int] = None) -> Tensor:
  267. r"""
  268. Applies a softmax function. Softmax is defined as:
  269. .. math::
  270. \text{Softmax}(x_{i}) = \frac{\exp(x_i)}{\sum_j \exp(x_j)}
  271. It is applied to all elements along axis, and will re-scale them so that
  272. the elements lie in the range `[0, 1]` and sum to 1.
  273. See :class:`~megengine.module.activation.Softmax` for more details.
  274. :param inp: The input tensor.
  275. :param axis: An axis along which softmax will be applied. By default,
  276. softmax will apply along the highest ranked axis.
  277. """
  278. if axis is None:
  279. axis = _get_softmax_axis(len(inp.imm_shape))
  280. offset = mgb.opr.zero_grad(inp.max(axis=axis, keepdims=True))
  281. inp = inp - offset
  282. down = mgb.opr.elem.exp(inp).sum(axis=axis, keepdims=True)
  283. return mgb.opr.elem.exp(inp) / down
  284. @wrap_io_tensor
  285. def batch_norm2d(
  286. inp: Tensor,
  287. running_mean: Tensor,
  288. running_var: Tensor,
  289. weight: Optional[Tensor] = None,
  290. bias: Optional[Tensor] = None,
  291. training: bool = False,
  292. momentum: float = 0.9,
  293. eps: float = 1e-5,
  294. ) -> Tensor:
  295. """Applies batch normalization to the input.
  296. Refer to :class:`~.BatchNorm2d` and :class:`~.BatchNorm1d` for more information.
  297. :param inp: input tensor.
  298. :param running_mean: tensor to store running mean.
  299. :param running_var: tensor to store running variance.
  300. :param weight: scaling tensor in the learnable affine parameters.
  301. See :math:`\gamma` in :class:`~.BatchNorm2d`
  302. :param bias: bias tensor in the learnable affine parameters.
  303. See :math:`\beta` in :class:`~.BatchNorm2d`
  304. :param training: a boolean value to indicate whether batch norm is performed
  305. in traning mode. Default: ``False``
  306. :param momentum: the value used for the ``running_mean`` and ``running_var``
  307. computation.
  308. Default: 0.9
  309. :param eps: a value added to the denominator for numerical stability.
  310. Default: 1e-5.
  311. """
  312. inp = mgb.opr.mark_no_broadcast_elemwise(inp)
  313. _channels = inp.imm_shape[1]
  314. _ndim = len(inp.imm_shape)
  315. _param_shape = (1, _channels) + (1,) * (_ndim - 2)
  316. assert _ndim == 4, "only 4D tensor supported"
  317. if weight is not None:
  318. weight = weight.reshape(*_param_shape)
  319. else:
  320. weight = mgb.make_immutable(*_use_default_if_none(None, None), 1.0).broadcast(
  321. *_param_shape
  322. )
  323. if bias is not None:
  324. bias = bias.reshape(*_param_shape)
  325. else:
  326. bias = mgb.make_immutable(*_use_default_if_none(None, None), 0.0).broadcast(
  327. *_param_shape
  328. )
  329. FwdMode = mgb.opr_param_defs.BN.FwdMode
  330. fwdmode = FwdMode.TRAINING if training else FwdMode.INFERENCE
  331. avg_factor = 1 - momentum
  332. if running_mean is not None and running_var is not None:
  333. if training:
  334. inp = barrier(inp)
  335. output = mgb.opr.batch_norm(
  336. inp,
  337. weight,
  338. bias,
  339. running_mean,
  340. running_var,
  341. param_dim="DIM_1C11",
  342. fwd_mode=fwdmode,
  343. epsilon=eps,
  344. avg_factor=avg_factor,
  345. )[-1]
  346. if training:
  347. mark_impure(output)
  348. else:
  349. output = mgb.opr.batch_norm_no_statistic(
  350. inp,
  351. weight,
  352. bias,
  353. param_dim="DIM_1C11",
  354. fwd_mode=fwdmode,
  355. epsilon=eps,
  356. avg_factor=avg_factor,
  357. )[-1]
  358. return output
  359. def one_hot(inp: Tensor, num_classes: int = -1) -> Tensor:
  360. r"""
  361. Perform one-hot encoding for the input tensor.
  362. :param inp: input tensor
  363. :param num_classes: number of classes denotes the last dimension of the output tensor
  364. Examples:
  365. .. testcode::
  366. import numpy as np
  367. from megengine import tensor
  368. import megengine.functional as F
  369. inp = tensor(np.arange(1, 4, dtype=np.int32))
  370. out = F.one_hot(inp)
  371. print(out.numpy())
  372. Outputs:
  373. .. testoutput::
  374. [[0 1 0 0]
  375. [0 0 1 0]
  376. [0 0 0 1]]
  377. """
  378. comp_node, comp_graph = _decide_comp_node_and_comp_graph(inp)
  379. if num_classes == -1:
  380. num_classes = inp.max() + 1
  381. zeros = mgb.make_immutable(value=0, comp_node=comp_node, comp_graph=comp_graph)
  382. zeros_symvar = zeros.broadcast(inp.shapeof(), num_classes)
  383. ones = mgb.make_immutable(value=1, comp_node=comp_node, comp_graph=comp_graph)
  384. ones_symvar = ones.broadcast(inp.shapeof(), 1)
  385. return Tensor(
  386. mgb.opr.indexing_set_one_hot(
  387. zeros_symvar, axis=len(inp.shapeof()), index=inp, value=ones_symvar
  388. )
  389. )
  390. @wrap_io_tensor
  391. def warp_perspective(
  392. inp: Tensor,
  393. M: Tensor,
  394. dsize: Union[Tuple[int, int], int, Tensor],
  395. border_mode: str = "REPLICATE",
  396. border_val: float = 0.0,
  397. interp_mode: str = "LINEAR",
  398. ):
  399. r"""
  400. Applies perspective transformation to batched 2D images.
  401. The input images are transformed to the output images by the transformation matrix:
  402. .. math::
  403. \text{output}(n, c, h, w) = \text{input} \left( n, c,
  404. \frac{M_{00}h + M_{01}w + M_{02}}{M_{20}h + M_{21}w + M_{22}},
  405. \frac{M_{10}h + M_{11}w + M_{12}}{M_{20}h + M_{21}w + M_{22}}
  406. \right)
  407. :param inp: input image
  408. :param M: (batch, 3, 3) transformation matrix
  409. :param dsize: (h, w) size of the output image
  410. :param border_mode: pixel extrapolation method. Default: ``"REPLICATE"``
  411. :param border_val: value used in case of a constant border. Default: ``0``
  412. :param interp_mode: interpolation methods. Default: ``"LINEAR"``
  413. Examples:
  414. .. testcode::
  415. import numpy as np
  416. from megengine import tensor
  417. import megengine.functional as F
  418. inp_shape = (1, 1, 4, 4)
  419. inp = tensor(np.arange(16, dtype=np.float32).reshape(inp_shape))
  420. M_shape = (1, 3, 3)
  421. # M defines a translation: dst(1, 1, h, w) = rst(1, 1, h+1, w+1)
  422. M = tensor(np.array([[1., 0., 1.],
  423. [0., 1., 1.],
  424. [0., 0., 1.]], dtype=np.float32).reshape(M_shape))
  425. out = F.warp_perspective(inp, M, (2, 2))
  426. print(out.numpy())
  427. Outputs:
  428. .. testoutput::
  429. [[[[ 5. 6.]
  430. [ 9. 10.]]]]
  431. """
  432. return mgb.opr.warp_perspective(
  433. inp,
  434. M,
  435. dsize,
  436. bmode=border_mode,
  437. border_val=border_val,
  438. imode=interp_mode,
  439. format="NCHW",
  440. )
  441. @wrap_io_tensor
  442. def eye(
  443. n: int,
  444. m: Optional[int] = None,
  445. *,
  446. dtype=None,
  447. device: Optional[CompNode] = None,
  448. comp_graph: Optional[CompGraph] = None
  449. ) -> Tensor:
  450. """
  451. Fills the 2-dimensional input :class:`SymbolVar` with the identity matrix.
  452. :param n: The number of rows
  453. :param m: The number of columns. Default: None
  454. :param dtype: The data type. Default: None
  455. :param device: Compute node of the matrix. Default: None
  456. :param comp_graph: Compute graph of the matrix. Default: None
  457. :return: The eye matrix
  458. Examples:
  459. .. testcode::
  460. import numpy as np
  461. import megengine.functional as F
  462. data_shape = (4, 6)
  463. n, m = data_shape
  464. out = F.eye(n, m, dtype=np.float32)
  465. print(out.numpy())
  466. Outputs:
  467. .. testoutput::
  468. [[1. 0. 0. 0. 0. 0.]
  469. [0. 1. 0. 0. 0. 0.]
  470. [0. 0. 1. 0. 0. 0.]
  471. [0. 0. 0. 1. 0. 0.]]
  472. """
  473. device, comp_graph = _use_default_if_none(device, comp_graph)
  474. if m is None:
  475. m = n
  476. return mgb.opr.eye((n, m), dtype=dtype, comp_node=device, comp_graph=comp_graph)
  477. @wrap_io_tensor
  478. def matrix_mul(inp1: Tensor, inp2: Tensor) -> Tensor:
  479. """
  480. Performs a matrix multiplication of the matrices ``inp1`` and ``inp2``
  481. :param inp1: The first matrix to be multiplied (a, b)
  482. :param inp2: The second matrix to be multiplied (b, c)
  483. :return: The output tensor (a, c)
  484. Examples:
  485. .. testcode::
  486. import numpy as np
  487. from megengine import tensor
  488. import megengine.functional as F
  489. shape_1 = (2, 3)
  490. shape_2 = (3, 4)
  491. data1 = tensor(np.arange(0, 6, dtype=np.float32).reshape(2, 3))
  492. data2 = tensor(np.arange(0, 6, dtype=np.float32).reshape(3, 2))
  493. out = F.matrix_mul(data1, data2)
  494. print(out.numpy())
  495. Outputs:
  496. .. testoutput::
  497. [[10. 13.]
  498. [28. 40.]]
  499. """
  500. return mgb.opr.matrix_mul(inp1, inp2)
  501. @wrap_io_tensor
  502. def batched_matrix_mul(inp1: Tensor, inp2: Tensor) -> Tensor:
  503. """
  504. Performs a batched multiplication of th batched matrices ``inp1`` and ``inp2``
  505. :param inp1: The first batch matrix to be multiplied (n, a, b)
  506. :param inp2: The second batch matrix to be multiplied (n, b, c)
  507. :return: The output batch (n, a, c)
  508. Examples:
  509. .. testcode::
  510. import numpy as np
  511. from megengine import tensor
  512. import megengine.functional as F
  513. batch_size = 3
  514. shape_1 = (batch_size, 2, 3)
  515. shape_2 = (batch_size, 3, 4)
  516. data1 = tensor(
  517. np.arange(0, batch_size * 6, dtype=np.float32).reshape(batch_size, 2, 3))
  518. data2 = tensor(
  519. np.arange(0, batch_size * 12, dtype=np.float32).reshape(batch_size, 3, 4))
  520. out = F.batched_matrix_mul(data1, data2)
  521. print(out.numpy())
  522. Outputs:
  523. .. testoutput::
  524. [[[ 20. 23. 26. 29.]
  525. [ 56. 68. 80. 92.]]
  526. [[ 344. 365. 386. 407.]
  527. [ 488. 518. 548. 578.]]
  528. [[1100. 1139. 1178. 1217.]
  529. [1352. 1400. 1448. 1496.]]]
  530. """
  531. return mgb.opr.batched_matrix_mul(inp1, inp2)
  532. @wrap_io_tensor
  533. def interpolate(
  534. inp: Tensor,
  535. size: Optional[Union[int, Tuple[int, int]]] = None,
  536. scale_factor: Optional[Union[float, Tuple[float, float]]] = None,
  537. mode: str = "BILINEAR",
  538. align_corners: bool = None,
  539. ) -> Tensor:
  540. r"""
  541. Down/up samples the input tensor to either the given :attr:`size` or the given
  542. :attr:`scale_factor`
  543. :param inp: input tensor
  544. :param size: size of the output tensor. Default: ``None``
  545. :param scale_factor: scaling factor of the output tensor. Default: ``None``
  546. :param mode: interpolation methods, acceptable values are:
  547. 'BILINEAR', 'LINEAR'. Default: ``BILINEAR``
  548. Examples:
  549. .. testcode::
  550. import numpy as np
  551. from megengine import tensor
  552. import megengine.functional as F
  553. from megengine.test import assertTensorClose
  554. inp = tensor(np.arange(1, 5, dtype=np.float32).reshape(1, 1, 2, 2))
  555. out = F.interpolate(inp, [4, 4], align_corners=False)
  556. print(out.numpy())
  557. out2 = F.interpolate(inp, scale_factor=2.)
  558. assertTensorClose(out.numpy(), out2.numpy())
  559. Outputs:
  560. .. testoutput::
  561. [[[[1. 1.25 1.75 2. ]
  562. [1.5 1.75 2.25 2.5 ]
  563. [2.5 2.75 3.25 3.5 ]
  564. [3. 3.25 3.75 4. ]]]]
  565. """
  566. mode = mode.upper()
  567. if mode not in ["BILINEAR", "LINEAR"]:
  568. raise ValueError("interpolate only support linear or bilinear mode")
  569. if mode not in ["BILINEAR", "LINEAR"]:
  570. if align_corners is not None:
  571. raise ValueError(
  572. "align_corners option can only be set in the bilinear/linear interpolating mode"
  573. )
  574. else:
  575. if align_corners is None:
  576. align_corners = False
  577. if mode == "LINEAR":
  578. inp = mgb.opr.add_axis(inp, 3)
  579. if len(inp.imm_shape) != 4:
  580. raise ValueError("shape of input tensor must correspond to the operartion mode")
  581. if size is None:
  582. if scale_factor is None:
  583. raise ValueError("scale_factor must not be None when size is None")
  584. if isinstance(scale_factor, (float, int)):
  585. scale_factor = float(scale_factor)
  586. if mode == "LINEAR":
  587. scale_factor = (scale_factor, float(1))
  588. else:
  589. scale_factor = (scale_factor, scale_factor)
  590. else:
  591. if mode == "LINEAR":
  592. raise ValueError(
  593. "under LINEAR mode, scale_factor can only be single value"
  594. )
  595. assert len(scale_factor) == 2, "shape of scale_factor must be equal to (2, )"
  596. assert isinstance(scale_factor[0], float) and isinstance(
  597. scale_factor[1], float
  598. ), "scale_factor must be float type"
  599. dsize = tuple(
  600. mgb.opr.elemwise(inp.shape[i + 2] * scale_factor[i], mode="FLOOR")
  601. for i in range(2)
  602. )
  603. dsize = mgb.opr.concat([dsize[0], dsize[1]], axis=0)
  604. else:
  605. if scale_factor is not None:
  606. raise ValueError("scale_factor must be None when size is provided")
  607. if isinstance(size, int):
  608. size = (size, 1)
  609. else:
  610. if mode == "LINEAR":
  611. raise ValueError("under LINEAR mode, size can only be single value")
  612. dsize = size
  613. oh, ow = dsize[0], dsize[1]
  614. ih, iw = inp.shape[2], inp.shape[3]
  615. if align_corners:
  616. hscale = (ih - 1.0) / (oh - 1.0)
  617. wscale = 1.0 * iw / ow
  618. if mode != "LINEAR":
  619. wscale = (iw - 1.0) / (ow - 1.0)
  620. row0 = mgb.opr.concat([wscale, [0, 0]], axis=0).reshape(1, 3)
  621. row1 = mgb.opr.concat([[0], hscale, [0]], axis=0).reshape(1, 3)
  622. weight = mgb.opr.concat([row0, row1, [[0, 0, 1]]], axis=0).reshape(1, 3, 3)
  623. weight = mgb.opr.broadcast(weight, (inp.shape[0], 3, 3))
  624. else:
  625. hscale = 1.0 * ih / oh
  626. wscale = 1.0 * iw / ow
  627. row0 = mgb.opr.concat([wscale, [0], 0.5 * wscale - 0.5], axis=0).reshape(1, 3)
  628. row1 = mgb.opr.concat([[0], hscale, 0.5 * hscale - 0.5], axis=0).reshape(1, 3)
  629. weight = mgb.opr.concat([row0, row1, [[0, 0, 1]]], axis=0).reshape(1, 3, 3)
  630. weight = mgb.opr.broadcast(weight, (inp.shape[0], 3, 3))
  631. ret = mgb.opr.warp_perspective(inp, weight, dsize, imode="LINEAR", format="NCHW")
  632. if mode == "LINEAR":
  633. ret = mgb.opr.reshape(ret, ret.shape[0:3])
  634. return ret
  635. @wrap_io_tensor
  636. def dropout(inp: Tensor, drop_prob: float, rescale: bool = True) -> Tensor:
  637. """
  638. Returns a new tensor where each of the elements are randomly set to zero
  639. with probability P = ``drop_prob``. Optionally rescale the output tensor.
  640. :param inp: The input tensor
  641. :param drop_prob: The probability to drop (set to zero) a single element
  642. :param rescale: The default behavior of ``dropout`` during training is to rescale the output,
  643. then it can be replaced by an :class:`~.Identity` during inference, default to True.
  644. :return: The output tensor
  645. Examples:
  646. .. testcode::
  647. import numpy as np
  648. import megengine as mge
  649. import megengine.functional as F
  650. from megengine import tensor
  651. data = tensor(np.ones(10, dtype=np.float32))
  652. out = F.dropout(data, 1./3.)
  653. print(out.numpy())
  654. Outputs:
  655. .. testoutput::
  656. :options: +SKIP
  657. [1.5 1.5 0. 1.5 1.5 1.5 1.5 1.5 1.5 1.5]
  658. """
  659. assert 0 <= drop_prob < 1
  660. rv = uniform(inp.shape)
  661. mask = rv > drop_prob
  662. inp *= mask.astype(inp.dtype)
  663. if rescale:
  664. inp *= 1 / (1 - drop_prob)
  665. return inp
  666. @wrap_io_tensor
  667. def identity(inp: Tensor) -> Tensor:
  668. """applies an identity transform to the input tensor.
  669. :param inp: The input tensor
  670. """
  671. return mgb.opr.identity(inp)
  672. @wrap_io_tensor
  673. def embedding(
  674. input: Tensor,
  675. weight: Tensor,
  676. padding_idx: Optional[int] = None,
  677. max_norm: Optional[float] = None,
  678. norm_type: Optional[float] = None,
  679. ):
  680. """
  681. Applies lookup table for embedding.
  682. :param input: the tensor with indices.
  683. :param weight: the learnable weights which embedding from.
  684. :param padding_idx: should be set to None, not support now.
  685. :param max_norm: should be set to None, not support now.
  686. :param norm_type: should be set to None, not support now.
  687. Refer to :class:`~.Embedding` for more information.
  688. """
  689. if padding_idx is not None:
  690. raise ValueError("Not support padding_idx Now!")
  691. if max_norm is not None or norm_type is not None:
  692. raise ValueError("Not support weight normlization Now!")
  693. return mgb.opr.advanced_indexing(weight)[input.reshape(-1), :].reshape(
  694. input.shape, weight.shape[-1]
  695. )
  696. @wrap_io_tensor
  697. def roi_pooling(
  698. input: Tensor,
  699. rois: Tensor,
  700. output_shape: Union[int, tuple, list],
  701. mode: str = "max",
  702. scale: float = 1.0,
  703. ) -> Tensor:
  704. """
  705. Apply roi pooling on input feature
  706. :param input: tensor that represents the input feature, (N, C, H, W) images
  707. :param rois: (K, 5) boxes. First column is the index into N. The other 4 columns are xyxy
  708. :param output_shape: (height, width) of output rois feature
  709. :param mode: "max" or "average", use max/average align just like max/average pooling. Default: ``"max"``
  710. :param scale: scale the input boxes by this number. Default: 1.0
  711. :return: (K, C, output_shape[0], output_shape[1]) feature of rois
  712. """
  713. assert mode in ["max", "average"], "only max/average mode is supported"
  714. if isinstance(output_shape, int):
  715. output_shape = (output_shape, output_shape)
  716. return mgb.opr.roi_pooling(
  717. input, rois, output_shape, mode=mode.upper(), scale=scale
  718. )
  719. @wrap_io_tensor
  720. def roi_align(
  721. input: Tensor,
  722. rois: Tensor,
  723. output_shape: Union[int, tuple, list],
  724. mode: str = "average",
  725. spatial_scale: float = 1.0,
  726. sample_points: Union[int, tuple, list] = 2,
  727. aligned: bool = True,
  728. ) -> Tensor:
  729. """
  730. Apply roi align on input feature
  731. :param input: tensor that represents the input feature, (N, C, H, W) images
  732. :param rois: (N, 5) boxes. First column is the index into N. The other 4 columns are xyxy
  733. :param output_shape: (height, width) shape of output rois feature.
  734. :param mode: "max" or "average", use max/average align just like max/average pooling. Default: ``"average"``
  735. :param spatial_scale: scale the input boxes by this number. Default: 1.0
  736. :param sample_points: number of inputs samples to take for each output sample.
  737. 0 to take samples densely. Default: 2
  738. :param aligned: wheather align the input feature, with `aligned=True`,
  739. we first appropriately scale the ROI and then shift it by -0.5. Default: True
  740. """
  741. assert mode in ["max", "average"], "only max/average mode is supported"
  742. if isinstance(output_shape, int):
  743. output_shape = (output_shape, output_shape)
  744. pooled_height, pooled_width = output_shape
  745. if isinstance(sample_points, int):
  746. sample_points = (sample_points, sample_points)
  747. sample_height, sample_width = sample_points
  748. offset = 0.5 if aligned else 0.0
  749. return mgb.opr.roi_align(
  750. input,
  751. rois,
  752. mode=mode.upper(),
  753. spatial_scale=spatial_scale,
  754. offset=offset,
  755. pooled_height=pooled_height,
  756. pooled_width=pooled_width,
  757. sample_height=sample_height,
  758. sample_width=sample_width,
  759. )
  760. @wrap_io_tensor
  761. def assert_equal(
  762. get: Tensor, expect: Tensor, max_err: float = 1e-4, verbose: bool = False
  763. ) -> Tensor:
  764. r"""
  765. Asserts that ``get`` equals to ``expect``, and returns value of ``expect``.
  766. :param get: tensor to be checked.
  767. :param expect: tensor with expected values.
  768. :param max_err: tolerance that two float values are asserted equal. Default: 1e-4
  769. :param verbose: whether to print details if two tensors are not equal. Default: False
  770. Examples:
  771. .. testcode::
  772. import megengine.functional as F
  773. from megengine import tensor
  774. get = tensor([1.0, 2.0])
  775. max_err = 0.1
  776. expect = get + max_err / 2.0
  777. val = F.assert_equal(expect, get, max_err=max_err)
  778. print(val.numpy())
  779. Outputs:
  780. .. testoutput::
  781. [1.05 2.05]
  782. """
  783. return mgb.opr.assert_equal(get, expect, maxerr=max_err, verbose=verbose)
  784. @wrap_io_tensor
  785. def indexing_one_hot(
  786. src: Tensor, index: Tensor, axis: int = 1, keepdims=False
  787. ) -> Tensor:
  788. r"""
  789. One-hot indexing for some axis.
  790. :param src: input data tensor.
  791. :param index: index tensor.
  792. :param axis: the axis on src for which values in index index. Default: 1
  793. :param keepdims: whether not to remove the axis in result. Default: ``False``
  794. Examples:
  795. .. testcode::
  796. import megengine.functional as F
  797. from megengine import tensor
  798. src = tensor([[1.0, 2.0]])
  799. index = tensor([0])
  800. val = F.indexing_one_hot(src, index)
  801. print(val.numpy())
  802. .. testoutput::
  803. [1.]
  804. """
  805. return mgb.opr.indexing_one_hot(src, axis, index, keepdims=keepdims)

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台