You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

nn.py 55 kB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534153515361537153815391540154115421543154415451546154715481549155015511552155315541555155615571558155915601561156215631564156515661567156815691570157115721573157415751576157715781579158015811582158315841585158615871588158915901591159215931594159515961597159815991600160116021603160416051606160716081609161016111612161316141615161616171618161916201621162216231624162516261627162816291630163116321633163416351636163716381639164016411642164316441645164616471648164916501651165216531654165516561657165816591660166116621663166416651666166716681669167016711672167316741675167616771678167916801681168216831684168516861687168816891690169116921693169416951696169716981699170017011702170317041705170617071708170917101711171217131714171517161717171817191720172117221723172417251726172717281729173017311732173317341735
  1. # -*- coding: utf-8 -*-
  2. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  3. #
  4. # Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  5. #
  6. # Unless required by applicable law or agreed to in writing,
  7. # software distributed under the License is distributed on an
  8. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  9. # pylint: disable=too-many-lines
  10. from typing import Iterable, Optional, Sequence, Tuple, Union
  11. from ..core._imperative_rt import CompNode
  12. from ..core._imperative_rt.core2 import apply
  13. from ..core._trace_option import use_symbolic_shape
  14. from ..core.ops import builtin
  15. from ..core.ops.builtin import BatchNorm
  16. from ..core.ops.special import Const
  17. from ..core.tensor import utils
  18. from ..core.tensor.utils import astensor1d, setscalar
  19. from ..distributed import WORLD, is_distributed
  20. from ..jit.tracing import is_tracing
  21. from ..random import uniform
  22. from ..tensor import Tensor
  23. from ..utils.tuple_function import _pair, _pair_nonzero
  24. from .debug_param import get_execution_strategy
  25. from .distributed import all_reduce_sum
  26. from .elemwise import exp, floor, log, log1p, maximum, minimum, relu
  27. from .math import argsort, matmul, max, prod, sum
  28. from .tensor import (
  29. broadcast_to,
  30. concat,
  31. expand_dims,
  32. full,
  33. ones,
  34. reshape,
  35. squeeze,
  36. zeros,
  37. )
  38. __all__ = [
  39. "adaptive_avg_pool2d",
  40. "adaptive_max_pool2d",
  41. "avg_pool2d",
  42. "batch_norm",
  43. "conv2d",
  44. "conv_transpose2d",
  45. "deformable_conv2d",
  46. "deformable_psroi_pooling",
  47. "dropout",
  48. "indexing_one_hot",
  49. "leaky_relu",
  50. "local_conv2d",
  51. "logsigmoid",
  52. "logsumexp",
  53. "logsoftmax",
  54. "max_pool2d",
  55. "one_hot",
  56. "prelu",
  57. "remap",
  58. "resize",
  59. "softmax",
  60. "softplus",
  61. "warp_affine",
  62. "warp_perspective",
  63. "conv1d",
  64. ]
  65. def expand_hw(x):
  66. # NOTE: >1d array is accepted, as long as 1 <= size <= 2
  67. try:
  68. x = int(x)
  69. return [x, x]
  70. except (TypeError, ValueError):
  71. pass
  72. h, w = x
  73. return int(h), int(w)
  74. def linear(inp: Tensor, weight: Tensor, bias: Optional[Tensor] = None) -> Tensor:
  75. """
  76. Applies a linear transformation to the input tensor.
  77. Refer to :class:`~.module.linear.Linear` for more information.
  78. :param inp: input tensor with shape `(N, in_features)`.
  79. :param weight: weight with shape `(out_features, in_features)`.
  80. :param bias: bias with shape `(out_features,)`.
  81. Default: None
  82. """
  83. ret = matmul(inp, weight, transpose_b=True)
  84. if bias is not None:
  85. ret += bias
  86. return ret
  87. def conv2d(
  88. inp: Tensor,
  89. weight: Tensor,
  90. bias: Optional[Tensor] = None,
  91. stride: Union[int, Tuple[int, int]] = 1,
  92. padding: Union[int, Tuple[int, int]] = 0,
  93. dilation: Union[int, Tuple[int, int]] = 1,
  94. groups: int = 1,
  95. conv_mode="CROSS_CORRELATION",
  96. compute_mode="DEFAULT",
  97. ) -> Tensor:
  98. """
  99. 2D convolution operation.
  100. Refer to :class:`~.module.Conv2d` for more information.
  101. :param inp: feature map of the convolution operation.
  102. :param weight: convolution kernel.
  103. :param bias: bias added to the result of convolution (if given).
  104. :param stride: stride of the 2D convolution operation. Default: 1
  105. :param padding: size of the paddings added to the input on both sides of its
  106. spatial dimensions. Only zero-padding is supported. Default: 0
  107. :param dilation: dilation of the 2D convolution operation. Default: 1
  108. :param groups: number of groups into which the input and output channels are divided,
  109. so as to perform a ``grouped convolution``. When ``groups`` is not 1,
  110. ``in_channels`` and ``out_channels`` must be divisible by ``groups``,
  111. and the shape of weight should be `(groups, out_channel // groups,
  112. in_channels // groups, height, width)`.
  113. :type conv_mode: string or :class:`Convolution.Mode`
  114. :param conv_mode: supports "CROSS_CORRELATION". Default:
  115. "CROSS_CORRELATION"
  116. :type compute_mode: string or
  117. :class:`Convolution.ComputeMode`
  118. :param compute_mode: when set to "DEFAULT", no special requirements will be
  119. placed on the precision of intermediate results. When set to "FLOAT32",
  120. "Float32" would be used for accumulator and intermediate result, but only
  121. effective when input and output are of Float16 dtype.
  122. :return: output tensor.
  123. """
  124. assert conv_mode == "CROSS_CORRELATION" or conv_mode.name == "CROSS_CORRELATION"
  125. assert compute_mode == "DEFAULT" or compute_mode.name == "DEFAULT"
  126. stride_h, stride_w = expand_hw(stride)
  127. pad_h, pad_w = expand_hw(padding)
  128. dilate_h, dilate_w = expand_hw(dilation)
  129. sparse_type = "DENSE" if groups == 1 else "GROUP"
  130. op = builtin.Convolution(
  131. stride_h=stride_h,
  132. stride_w=stride_w,
  133. pad_h=pad_h,
  134. pad_w=pad_w,
  135. dilate_h=dilate_h,
  136. dilate_w=dilate_w,
  137. strategy=get_execution_strategy(),
  138. mode=conv_mode,
  139. compute_mode=compute_mode,
  140. sparse=sparse_type,
  141. )
  142. inp, weight = utils.convert_inputs(inp, weight)
  143. (output,) = apply(op, inp, weight)
  144. if bias is not None:
  145. output += bias
  146. return output
  147. def conv_transpose2d(
  148. inp: Tensor,
  149. weight: Tensor,
  150. bias: Optional[Tensor] = None,
  151. stride: Union[int, Tuple[int, int]] = 1,
  152. padding: Union[int, Tuple[int, int]] = 0,
  153. dilation: Union[int, Tuple[int, int]] = 1,
  154. groups: int = 1,
  155. conv_mode="CROSS_CORRELATION",
  156. compute_mode="DEFAULT",
  157. ) -> Tensor:
  158. """
  159. 2D transposed convolution operation.
  160. Refer to :class:`~.ConvTranspose2d` for more information.
  161. :param inp: feature map of the convolution operation.
  162. :param weight: convolution kernel.
  163. :param bias: bias added to the result of convolution (if given).
  164. :param stride: stride of the 2D convolution operation. Default: 1
  165. :param padding: size of the paddings added to the input on both sides of its
  166. spatial dimensions. Only zero-padding is supported. Default: 0
  167. :param dilation: dilation of the 2D convolution operation. Default: 1
  168. :param groups: number of groups into which the input and output channels are divided,
  169. so as to perform a ``grouped convolution``. When ``groups`` is not 1,
  170. ``in_channels`` and ``out_channels`` must be divisible by groups,
  171. and the shape of weight should be `(groups, out_channel // groups,
  172. in_channels // groups, height, width)`. Default: 1
  173. :type conv_mode: string or :class:`Convolution.Mode`
  174. :param conv_mode: supports "CROSS_CORRELATION". Default:
  175. "CROSS_CORRELATION"
  176. :type compute_mode: string or
  177. :class:`Convolution.ComputeMode`
  178. :param compute_mode: when set to "DEFAULT", no special requirements will be
  179. placed on the precision of intermediate results. When set to "FLOAT32",
  180. "Float32" would be used for accumulator and intermediate result, but only
  181. effective when input and output are of Float16 dtype.
  182. :return: output tensor.
  183. """
  184. assert conv_mode == "CROSS_CORRELATION" or conv_mode.name == "CROSS_CORRELATION"
  185. assert compute_mode == "DEFAULT" or compute_mode.name == "DEFAULT"
  186. if groups != 1:
  187. raise NotImplementedError("TODO")
  188. stride_h, stride_w = expand_hw(stride)
  189. pad_h, pad_w = expand_hw(padding)
  190. dilate_h, dilate_w = expand_hw(dilation)
  191. op = builtin.ConvolutionBackwardData(
  192. stride_h=stride_h,
  193. stride_w=stride_w,
  194. pad_h=pad_h,
  195. pad_w=pad_w,
  196. dilate_h=dilate_h,
  197. dilate_w=dilate_w,
  198. strategy=get_execution_strategy(),
  199. )
  200. weight, inp = utils.convert_inputs(weight, inp)
  201. (output,) = apply(op, weight, inp)
  202. if bias is not None:
  203. output += bias
  204. return output
  205. def deformable_conv2d(
  206. inp: Tensor,
  207. weight: Tensor,
  208. offset: Tensor,
  209. mask: Tensor,
  210. bias: Optional[Tensor] = None,
  211. stride: Union[int, Tuple[int, int]] = 1,
  212. padding: Union[int, Tuple[int, int]] = 0,
  213. dilation: Union[int, Tuple[int, int]] = 1,
  214. groups: int = 1,
  215. conv_mode="CROSS_CORRELATION",
  216. compute_mode="DEFAULT",
  217. ) -> Tensor:
  218. """
  219. Deformable Convolution.
  220. :param inp: input feature map.
  221. :param weight: convolution kernel.
  222. :param offset: input offset to kernel, channel of this tensor should match the deformable settings.
  223. :param mask: input mask to kernel, channel of this tensor should match the deformable settings.
  224. :param bias: bias added to the result of convolution (if given).
  225. :param stride: stride of the 2D convolution operation. Default: 1
  226. :param padding: size of the paddings added to the input on both sides of its
  227. spatial dimensions. Only zero-padding is supported. Default: 0
  228. :param dilation: dilation of the 2D convolution operation. Default: 1
  229. :param groups: number of groups into which the input and output channels are divided,
  230. so as to perform a ``grouped convolution``. When ``groups`` is not 1,
  231. ``in_channels`` and ``out_channels`` must be divisible by groups,
  232. and the shape of weight should be `(groups, out_channel // groups,
  233. in_channels // groups, height, width)`. Default: 1
  234. :type conv_mode: string or :class:`Convolution.Mode`
  235. :param conv_mode: supports "CROSS_CORRELATION". Default:
  236. "CROSS_CORRELATION"
  237. :type compute_mode: string or
  238. :class:`Convolution.ComputeMode`
  239. :param compute_mode: when set to "DEFAULT", no special requirements will be
  240. placed on the precision of intermediate results. When set to "FLOAT32",
  241. "Float32" would be used for accumulator and intermediate result, but only
  242. effective when input and output are of Float16 dtype.
  243. :return: output tensor.
  244. """
  245. assert conv_mode == "CROSS_CORRELATION" or conv_mode.name == "CROSS_CORRELATION"
  246. assert compute_mode == "DEFAULT" or compute_mode.name == "DEFAULT"
  247. stride_h, stride_w = expand_hw(stride)
  248. pad_h, pad_w = expand_hw(padding)
  249. dilate_h, dilate_w = expand_hw(dilation)
  250. sparse_type = "DENSE" if groups == 1 else "GROUP"
  251. op = builtin.DeformableConv(
  252. stride_h=stride_h,
  253. stride_w=stride_w,
  254. pad_h=pad_h,
  255. pad_w=pad_w,
  256. dilate_h=dilate_h,
  257. dilate_w=dilate_w,
  258. strategy=get_execution_strategy(),
  259. mode=conv_mode,
  260. compute_mode=compute_mode,
  261. sparse=sparse_type,
  262. )
  263. inp, weight, offset, mask = utils.convert_inputs(inp, weight, offset, mask)
  264. (output,) = apply(op, inp, weight, offset, mask)
  265. if bias is not None:
  266. output += bias
  267. return output
  268. def local_conv2d(
  269. inp: Tensor,
  270. weight: Tensor,
  271. bias: Optional[Tensor] = None,
  272. stride: Union[int, Tuple[int, int]] = 1,
  273. padding: Union[int, Tuple[int, int]] = 0,
  274. dilation: Union[int, Tuple[int, int]] = 1,
  275. conv_mode="CROSS_CORRELATION",
  276. ):
  277. """Applies spatial 2D convolution over an groupped channeled image with untied kernels."""
  278. assert conv_mode == "CROSS_CORRELATION" or conv_mode.name == "CROSS_CORRELATION"
  279. stride_h, stride_w = expand_hw(stride)
  280. pad_h, pad_w = expand_hw(padding)
  281. dilate_h, dilate_w = expand_hw(dilation)
  282. op = builtin.GroupLocal(
  283. stride_h=stride_h,
  284. stride_w=stride_w,
  285. pad_h=pad_h,
  286. pad_w=pad_w,
  287. dilate_h=dilate_h,
  288. dilate_w=dilate_w,
  289. mode=conv_mode,
  290. compute_mode="DEFAULT",
  291. sparse="DENSE",
  292. )
  293. inp, weight = utils.convert_inputs(inp, weight)
  294. (output,) = apply(op, inp, weight)
  295. if bias is not None:
  296. output += bias
  297. return output
  298. def max_pool2d(
  299. inp: Tensor,
  300. kernel_size: Union[int, Tuple[int, int]],
  301. stride: Optional[Union[int, Tuple[int, int]]] = None,
  302. padding: Union[int, Tuple[int, int]] = 0,
  303. ) -> Tensor:
  304. """
  305. Applies a 2D max pooling over an input tensor.
  306. Refer to :class:`~.MaxPool2d` for more information.
  307. :param inp: input tensor.
  308. :param kernel_size: size of the window.
  309. :param stride: stride of the window. If not provided, its value is set to kernel_size.
  310. Default: None
  311. :param padding: implicit zero padding added on both sides. Default: 0
  312. :return: output tensor.
  313. """
  314. if stride is None:
  315. stride = kernel_size
  316. window_h, window_w = _pair_nonzero(kernel_size)
  317. stride_h, stride_w = _pair_nonzero(stride)
  318. padding_h, padding_w = _pair(padding)
  319. op = builtin.Pooling(
  320. window_h=window_h,
  321. window_w=window_w,
  322. stride_h=stride_h,
  323. stride_w=stride_w,
  324. pad_h=padding_h,
  325. pad_w=padding_w,
  326. mode="MAX",
  327. )
  328. (output,) = apply(op, inp)
  329. return output
  330. def avg_pool2d(
  331. inp: Tensor,
  332. kernel_size: Union[int, Tuple[int, int]],
  333. stride: Optional[Union[int, Tuple[int, int]]] = None,
  334. padding: Union[int, Tuple[int, int]] = 0,
  335. mode: str = "AVERAGE_COUNT_EXCLUDE_PADDING",
  336. ) -> Tensor:
  337. """
  338. Applies 2D average pooling over an input tensor.
  339. Refer to :class:`~.AvgPool2d` for more information.
  340. :param inp: input tensor.
  341. :param kernel_size: size of the window.
  342. :param stride: stride of the window. If not provided, its value is set to ``kernel_size``.
  343. Default: None
  344. :param padding: implicit zero padding added on both sides. Default: 0
  345. :param mode: whether to count padding values. Default: "AVERAGE_COUNT_EXCLUDE_PADDING"
  346. :return: output tensor.
  347. """
  348. if stride is None:
  349. stride = kernel_size
  350. window_h, window_w = _pair_nonzero(kernel_size)
  351. stride_h, stride_w = _pair_nonzero(stride)
  352. padding_h, padding_w = _pair(padding)
  353. op = builtin.Pooling(
  354. window_h=window_h,
  355. window_w=window_w,
  356. stride_h=stride_h,
  357. stride_w=stride_w,
  358. pad_h=padding_h,
  359. pad_w=padding_w,
  360. mode=mode,
  361. )
  362. (output,) = apply(op, inp)
  363. return output
  364. def adaptive_max_pool2d(
  365. inp: Tensor, oshp: Union[Tuple[int, int], int, Tensor],
  366. ) -> Tensor:
  367. """
  368. Applies a 2D max adaptive pooling over an input.
  369. Refer to :class:`~.MaxAdaptivePool2d` for more information.
  370. :param inp: input tensor.
  371. :param oshp: `(OH, OW)` size of the output shape.
  372. :return: output tensor.
  373. """
  374. if isinstance(oshp, int):
  375. oshp = (oshp, oshp)
  376. op = builtin.AdaptivePooling(mode="MAX", format="NCHW",)
  377. oshp = astensor1d(oshp, inp, dtype="int32", device=inp.device)
  378. (output,) = apply(op, inp, oshp)
  379. return output
  380. def adaptive_avg_pool2d(
  381. inp: Tensor, oshp: Union[Tuple[int, int], int, Tensor],
  382. ) -> Tensor:
  383. """
  384. Applies a 2D average adaptive pooling over an input.
  385. Refer to :class:`~.AvgAdaptivePool2d` for more information.
  386. :param inp: input tensor.
  387. :param oshp: `(OH, OW)` size of the output shape.
  388. :return: output tensor.
  389. """
  390. if isinstance(oshp, int):
  391. oshp = (oshp, oshp)
  392. op = builtin.AdaptivePooling(mode="AVERAGE", format="NCHW",)
  393. oshp = astensor1d(oshp, inp, dtype="int32", device=inp.device)
  394. (output,) = apply(op, inp, oshp)
  395. return output
  396. def deformable_psroi_pooling(
  397. inp: Tensor,
  398. rois: Tensor,
  399. trans: Tensor,
  400. no_trans: bool,
  401. part_size: int,
  402. pooled_h: int,
  403. pooled_w: int,
  404. sample_per_part: int,
  405. spatial_scale: float,
  406. trans_std: float = 0.1,
  407. ):
  408. """
  409. Deformable PSROI(Position Sensitive Region of Interest) Pooling.
  410. :param inp: input feature map.
  411. :param rois: the rois for feature pooling.
  412. :param trans: input offset to psroi_pooling.
  413. :param no_trans: check the phase of DeformablePSROIPooling. False to the
  414. 1st phase, True to the 2nd phase.
  415. :param part_size: part size.
  416. :param sample_per_part: sample points of each part.
  417. :param pooled_shape: kernel shape of convolution.
  418. :param spatial_scale: the spatial_scale w.r.t input image.
  419. :param trans_std: multiplier used in 2nd phase.
  420. """
  421. op = builtin.DeformablePSROIPooling(
  422. no_trans=no_trans,
  423. part_size=part_size,
  424. pooled_h=pooled_h,
  425. pooled_w=pooled_w,
  426. sample_per_part=sample_per_part,
  427. spatial_scale=spatial_scale,
  428. trans_std=trans_std,
  429. )
  430. output, _ = apply(op, inp, rois, trans)
  431. return output
  432. def prelu(inp: Tensor, weight: Tensor) -> Tensor:
  433. r"""
  434. Applies the element-wise PReLU function.
  435. Refer to :class:`~.PReLU` for more information.
  436. """
  437. return maximum(inp, 0) + weight * minimum(inp, 0)
  438. def leaky_relu(inp: Tensor, negative_slope: float = 0.01) -> Tensor:
  439. r"""
  440. Applies the element-wise leaky_relu function
  441. Refer to :class:`~.LeakyReLU` for more information.
  442. """
  443. return maximum(inp, 0) + negative_slope * minimum(inp, 0)
  444. def softplus(inp: Tensor) -> Tensor:
  445. r"""
  446. Applies the element-wise function:
  447. .. math::
  448. \text{softplus}(x) = \log(1 + \exp(x))
  449. softplus is a smooth approximation to the ReLU function and can be used
  450. to constrain the output to be always positive.
  451. For numerical stability the implementation follows this transformation:
  452. .. math::
  453. \text{softplus}(x) = \log(1 + \exp(x))
  454. = \log(1 + \exp(-\text{abs}(x))) + \max(x, 0)
  455. = \log1p(\exp(-\text{abs}(x))) + \text{relu}(x)
  456. :param inp: input tensor.
  457. Examples:
  458. .. testcode::
  459. import numpy as np
  460. from megengine import tensor
  461. import megengine.functional as F
  462. x = tensor(np.arange(-3, 3, dtype=np.float32))
  463. y = F.softplus(x)
  464. print(y.numpy().round(decimals=4))
  465. Outputs:
  466. .. testoutput::
  467. [0.0486 0.1269 0.3133 0.6931 1.3133 2.1269]
  468. """
  469. return log1p(exp(-abs(inp))) + relu(inp)
  470. def logsoftmax(inp: Tensor, axis: Union[int, Sequence[int]]) -> Tensor:
  471. r"""
  472. Applies the :math:`\log(\text{softmax}(x))` function to an n-dimensional
  473. input tensor. The :math:`\text{logsoftmax}(x)` formulation can be simplified as:
  474. .. math::
  475. \text{logsoftmax}(x_{i}) = \log(\frac{\exp(x_i) }{ \sum_j \exp(x_j)} )
  476. For numerical stability the implementation follows this transformation:
  477. .. math::
  478. \text{logsoftmax}(x)
  479. = \log (\frac{\exp (x)}{\sum_{i}(\exp (x_{i}))})
  480. = x - \log (\sum_{i}(\exp (x_{i})))
  481. = x - \text{logsumexp}(x)
  482. :param inp: input tensor.
  483. :param axis: axis along which :math:`\text{logsoftmax}(x)` will be applied.
  484. Examples:
  485. .. testcode::
  486. import numpy as np
  487. from megengine import tensor
  488. import megengine.functional as F
  489. x = tensor(np.arange(-5, 5, dtype=np.float32)).reshape(2,5)
  490. y = F.logsoftmax(x, axis=1)
  491. print(y.numpy().round(decimals=4))
  492. Outputs:
  493. .. testoutput::
  494. [[-4.4519 -3.4519 -2.4519 -1.4519 -0.4519]
  495. [-4.4519 -3.4519 -2.4519 -1.4519 -0.4519]]
  496. """
  497. return inp - logsumexp(inp, axis, keepdims=True)
  498. def logsigmoid(inp: Tensor) -> Tensor:
  499. r"""
  500. Applies the element-wise function:
  501. .. math::
  502. \text{logsigmoid}(x) = \log(\frac{ 1 }{ 1 + \exp(-x)})
  503. = \log(1/(1 + \exp(-x)))
  504. = - \log(1 + \exp(-x))
  505. = - \text{softplus}(-x)
  506. :param inp: input tensor.
  507. Examples:
  508. .. testcode::
  509. import numpy as np
  510. from megengine import tensor
  511. import megengine.functional as F
  512. x = tensor(np.arange(-5, 5, dtype=np.float32))
  513. y = F.logsigmoid(x)
  514. print(y.numpy().round(decimals=4))
  515. Outputs:
  516. .. testoutput::
  517. [-5.0067 -4.0182 -3.0486 -2.1269 -1.3133 -0.6931 -0.3133 -0.1269 -0.0486
  518. -0.0181]
  519. """
  520. return -softplus(-inp)
  521. def logsumexp(
  522. inp: Tensor, axis: Union[int, Sequence[int]], keepdims: bool = False
  523. ) -> Tensor:
  524. r"""
  525. Calculates the logarithm of the inputs' exponential sum along the given :attr:`axis`.
  526. .. math::
  527. \text{logsumexp}(x)= \log \sum_{j=1}^{n} \exp \left(x_{j}\right)
  528. For numerical stability, the implementation follows this transformation:
  529. .. math::
  530. \text{logsumexp}(x)= \log \sum_{j=1}^{n} \exp \left(x_{j}\right)
  531. = \text{logsumexp}(x)=b+\log \sum_{j=1}^{n} \exp \left(x_{j}-b\right)
  532. where
  533. .. math::
  534. b = \max(x_j)
  535. :param inp: input tensor.
  536. :param axis: axis over which the sum is taken. It could be single axis or list of axes.
  537. :param keepdims: whether to retain :attr:`axis` or not for the output tensor.
  538. Examples:
  539. .. testcode::
  540. import numpy as np
  541. from megengine import tensor
  542. import megengine.functional as F
  543. x = tensor(np.arange(-5, 5, dtype=np.float32)).reshape(2,5)
  544. y = F.logsumexp(x, axis=1, keepdims=False)
  545. print(y.numpy().round(decimals=4))
  546. Outputs:
  547. .. testoutput::
  548. [-0.5481 4.4519]
  549. """
  550. max_value = max(inp.detach(), axis, keepdims=True)
  551. if keepdims:
  552. return max_value + log(sum(exp(inp - max_value), axis, keepdims))
  553. else:
  554. return squeeze(max_value, axis=None) + log(
  555. sum(exp(inp - max_value), axis, keepdims)
  556. )
  557. def _get_softmax_axis(ndim: int) -> int:
  558. if ndim in (0, 1, 3):
  559. return 0
  560. return 1
  561. def softmax(inp: Tensor, axis: Optional[int] = None) -> Tensor:
  562. r"""
  563. Applies a :math:`\text{softmax}(x)` function. :math:`\text{softmax}(x)` is defined as:
  564. .. math::
  565. \text{softmax}(x_{i}) = \frac{\exp(x_i)}{\sum_j \exp(x_j)}
  566. It is applied to all elements along axis, and rescales elements so that
  567. they stay in the range `[0, 1]` and sum to 1.
  568. See :class:`~megengine.module.activation.Softmax` for more details.
  569. :param inp: input tensor.
  570. :param axis: an axis along which :math:`\text{softmax}(x)` will be applied. By default,
  571. :math:`\text{softmax}(x)` will apply along the highest ranked axis.
  572. Examples:
  573. .. testcode::
  574. import numpy as np
  575. from megengine import tensor
  576. import megengine.functional as F
  577. x = tensor(np.arange(-5, 5, dtype=np.float32)).reshape(2,5)
  578. out = F.softmax(x)
  579. print(out.numpy().round(decimals=4))
  580. Outputs:
  581. .. testoutput::
  582. [[0.0117 0.0317 0.0861 0.2341 0.6364]
  583. [0.0117 0.0317 0.0861 0.2341 0.6364]]
  584. """
  585. if axis is None:
  586. axis = _get_softmax_axis(len(inp.shape))
  587. offset = inp.max(axis=axis, keepdims=True).detach()
  588. cached = exp(inp - offset)
  589. down = sum(cached, axis=axis, keepdims=True)
  590. return cached / down
  591. def batch_norm(
  592. inp: Tensor,
  593. running_mean: Tensor = None,
  594. running_var: Tensor = None,
  595. weight: Optional[Tensor] = None,
  596. bias: Optional[Tensor] = None,
  597. *,
  598. training: bool = False,
  599. momentum: float = 0.9,
  600. eps: float = 1e-5,
  601. inplace: bool = True
  602. ):
  603. r"""
  604. Applies batch normalization to the input.
  605. Refer to :class:`~.BatchNorm2d` and :class:`~.BatchNorm1d` for more information.
  606. :param inp: input tensor.
  607. :param running_mean: tensor to store running mean.
  608. :param running_var: tensor to store running variance.
  609. :param weight: scaling tensor in the learnable affine parameters.
  610. See :math:`\gamma` in :class:`~.BatchNorm2d`.
  611. :param bias: bias tensor in the learnable affine parameters.
  612. See :math:`\beta` in :class:`~.BatchNorm2d`.
  613. :param training: a boolean value to indicate whether batch norm is performed
  614. in training mode. Default: False
  615. :param momentum: value used for the ``running_mean`` and ``running_var``
  616. computation.
  617. Default: 0.9
  618. :param eps: a value added to the denominator for numerical stability.
  619. Default: 1e-5
  620. :param inplace: whether to update ``running_mean`` and ``running_var`` inplace or return new tensors
  621. Default: True
  622. :return: output tensor.
  623. """
  624. if inp.ndim != 4:
  625. raise NotImplementedError("batch_norm for ndim != 4")
  626. C = inp.shape[1]
  627. def make_full_if_none(x, value):
  628. if x is None:
  629. (x,) = Const(value, dtype=inp.dtype, device=inp.device)()
  630. shape = utils.astensor1d(
  631. (1, C, 1, 1), inp, dtype="int32", device=inp.device
  632. )
  633. (result,) = apply(builtin.Broadcast(), x, shape)
  634. return result
  635. elif x.ndim == 1:
  636. shape = utils.astensor1d(
  637. (1, C, 1, 1), inp, dtype="int32", device=inp.device
  638. )
  639. (result,) = apply(builtin.Reshape(), x, shape)
  640. return result
  641. return x
  642. has_mean = running_mean is not None
  643. has_var = running_var is not None
  644. if not training:
  645. assert has_mean, "running_mean must be provided in inference mode"
  646. assert has_var, "running_var must be provided in inference mode"
  647. if has_mean and running_mean.ndim != 4:
  648. raise ValueError
  649. if has_var and running_var.ndim != 4:
  650. raise ValueError
  651. inp, weight, bias, running_mean, running_var = utils.convert_inputs(
  652. inp, weight, bias, running_mean, running_var
  653. )
  654. weight = make_full_if_none(weight, 1)
  655. bias = make_full_if_none(bias, 0)
  656. if not training:
  657. op = builtin.BatchNorm(
  658. fwd_mode=BatchNorm.FwdMode.INFERENCE, epsilon=eps, param_dim="DIM_1C11"
  659. )
  660. ret = apply(op, inp, weight, bias, running_mean, running_var)[-1]
  661. return ret
  662. else:
  663. op = builtin.BatchNorm(
  664. avg_factor=1 - momentum, epsilon=eps, param_dim="DIM_1C11"
  665. )
  666. if has_mean or has_var:
  667. running_mean = make_full_if_none(running_mean, 0)
  668. running_var = make_full_if_none(running_var, 1)
  669. new_mean, new_var, _, _, inp = apply(
  670. op, inp, weight, bias, running_mean, running_var
  671. )
  672. if not has_mean:
  673. new_mean = None
  674. if not has_var:
  675. new_var = None
  676. if inplace:
  677. if has_mean:
  678. running_mean[...] = new_mean
  679. if has_var:
  680. running_var[...] = new_var
  681. return inp
  682. else:
  683. return inp, new_mean, new_var
  684. else:
  685. (_, _, inp,) = apply(op, inp, weight, bias)
  686. return inp
  687. def sync_batch_norm(
  688. inp: Tensor,
  689. running_mean: Tensor,
  690. running_var: Tensor,
  691. weight: Optional[Tensor] = None,
  692. bias: Optional[Tensor] = None,
  693. training: bool = False,
  694. momentum: Union[float, Tensor] = 0.9,
  695. eps: float = 1e-5,
  696. eps_mode="ADDITIVE",
  697. group=WORLD,
  698. ) -> Tensor:
  699. r"""
  700. Applies synchronized batch normalization to the input.
  701. Refer to :class:`~.BatchNorm2d` and :class:`~.BatchNorm1d` for more information.
  702. :param inp: input tensor.
  703. :param running_mean: tensor to store running mean.
  704. :param running_var: tensor to store running variance.
  705. :param weight: scaling tensor in the learnable affine parameters.
  706. See :math:`\gamma` in :class:`~.BatchNorm2d`.
  707. :param bias: bias tensor in the learnable affine parameters.
  708. See :math:`\beta` in :class:`~.BatchNorm2d`.
  709. :param training: a boolean value to indicate whether batch norm is performed
  710. in traning mode. Default: False
  711. :param momentum: value used for the ``running_mean`` and ``running_var``
  712. computation.
  713. Default: 0.9
  714. :param eps: a value added to the denominator for numerical stability.
  715. Default: 1e-5
  716. :return: output tensor.
  717. """
  718. assert eps_mode in {"MAX", "ADDITIVE"}, "unknown eps_mode: {}".format(eps_mode)
  719. _channels = inp.shape[1]
  720. _ndim = inp.ndim
  721. _device = inp.device
  722. _dtype = inp.dtype
  723. _param_shape = (1, _channels) + (1,) * (_ndim - 2)
  724. _reduce_axis = [0] + [i for i in range(2, _ndim)]
  725. if training:
  726. def _sum_on_channel(inp):
  727. return inp.sum(axis=_reduce_axis, keepdims=True)
  728. reduce_size = inp.shape[0]
  729. for i in range(2, _ndim):
  730. reduce_size = reduce_size * inp.shape[i]
  731. channel_x1s = _sum_on_channel(inp)
  732. channel_x2s = _sum_on_channel(inp ** 2)
  733. if is_distributed():
  734. # reduce all nodes' data to calculate mean and variance
  735. reduce_size = broadcast_to(
  736. Tensor(reduce_size).astype(dtype=_dtype), [1] * _ndim
  737. )
  738. stat = concat([reduce_size, channel_x1s, channel_x2s], axis=1)
  739. stat = all_reduce_sum(stat, group)
  740. reduce_size = stat[:, :1].reshape(1)
  741. channel_x1s = stat[:, 1 : 1 + _channels]
  742. channel_x2s = stat[:, 1 + _channels :]
  743. channel_mean = channel_x1s / reduce_size
  744. channel_variance = (
  745. channel_x1s ** 2 / (-reduce_size * reduce_size) + channel_x2s / reduce_size
  746. )
  747. else:
  748. assert running_var is not None and running_mean is not None
  749. channel_variance = running_var.reshape(*_param_shape)
  750. channel_mean = running_mean.reshape(*_param_shape)
  751. invsqrt_channel_variance = (
  752. maximum(channel_variance, eps) if eps_mode == "MAX" else channel_variance + eps
  753. ) ** -0.5
  754. if weight is not None:
  755. weight = weight.reshape(*_param_shape)
  756. if bias is not None:
  757. bias = bias.reshape(*_param_shape)
  758. # outvar = output * weight + bias
  759. # where output = inp * invsqrt_channel_variance + (
  760. # -channel_mean * invsqrt_channel_variance
  761. # )
  762. # Manually expand output for gopt
  763. if weight is not None:
  764. inv_var_wt = invsqrt_channel_variance * weight
  765. neg_channel_mean = -channel_mean
  766. if bias is not None:
  767. outvar = inp * inv_var_wt + (neg_channel_mean * inv_var_wt + bias)
  768. else:
  769. outvar = inp * inv_var_wt + neg_channel_mean * inv_var_wt
  770. else:
  771. outvar = inp * invsqrt_channel_variance + (
  772. -channel_mean * invsqrt_channel_variance
  773. )
  774. if bias is not None:
  775. outvar = outvar + bias
  776. if training and running_var is not None and running_mean is not None:
  777. running_mean *= momentum
  778. running_mean += (1 - momentum) * channel_mean
  779. channel_variance_unbiased = channel_x1s ** 2 / (
  780. -reduce_size * (reduce_size - 1)
  781. ) + channel_x2s / (reduce_size - 1)
  782. running_var *= momentum
  783. running_var += (1 - momentum) * channel_variance_unbiased
  784. return outvar
  785. def one_hot(inp: Tensor, num_classes: int) -> Tensor:
  786. r"""
  787. Performs one-hot encoding for the input tensor.
  788. :param inp: input tensor.
  789. :param num_classes: number of classes denotes the last dimension of the output tensor.
  790. :return: output tensor.
  791. Examples:
  792. .. testcode::
  793. import numpy as np
  794. from megengine import tensor
  795. import megengine.functional as F
  796. x = tensor(np.arange(1, 4, dtype=np.int32))
  797. out = F.one_hot(x, num_classes=4)
  798. print(out.numpy())
  799. Outputs:
  800. .. testoutput::
  801. [[0 1 0 0]
  802. [0 0 1 0]
  803. [0 0 0 1]]
  804. """
  805. zeros_tensor = zeros(list(inp.shape) + [num_classes], inp.dtype, inp.device)
  806. ones_tensor = ones(list(inp.shape) + [1], inp.dtype, inp.device)
  807. op = builtin.IndexingSetOneHot(axis=inp.ndim)
  808. (result,) = apply(op, zeros_tensor, inp, ones_tensor)
  809. return result
  810. def resize(
  811. inp: Tensor, target_shape: Iterable[int], interp_mode: str = "LINEAR"
  812. ) -> Tensor:
  813. r"""
  814. Applies resize transformation to batched 2D images.
  815. :param inp: `(N, C, H, W)` input tensor. Currently only support "NCHW" format.
  816. :param target_shape: `(H, W)` target images shape.
  817. :param interp_mode: interpolation methods. Defaule mode is "LINEAR", Currently only support "LINEAR".
  818. Examples:
  819. .. testcode::
  820. import numpy as np
  821. from megengine import tensor
  822. import megengine.functional as F
  823. x = tensor(np.random.randn(10, 3, 32, 32))
  824. out = F.resize(x, (16, 16))
  825. print(out.numpy().shape)
  826. Outputs:
  827. .. testoutput::
  828. (10, 3, 16, 16)
  829. """
  830. op = builtin.Resize(imode=interp_mode, format="NCHW")
  831. shape = astensor1d(target_shape, inp, dtype="int32", device=inp.device)
  832. (result,) = apply(op, inp, shape)
  833. return result
  834. def warp_affine(
  835. inp: Tensor,
  836. weight: Tensor,
  837. out_shape,
  838. border_mode="REPLICATE",
  839. border_val=0,
  840. format="NHWC",
  841. imode="LINEAR",
  842. ):
  843. """
  844. Batched affine transform on 2D images.
  845. :param inp: input image.
  846. :param weight: weight tensor.
  847. :param out_shape: output tensor shape.
  848. :param border_mode: pixel extrapolation method.
  849. Default: "WRAP". Currently "CONSTANT", "REFLECT",
  850. "REFLECT_101", "ISOLATED", "WRAP", "REPLICATE", "TRANSPARENT" are supported.
  851. :param border_val: value used in case of a constant border. Default: 0
  852. :param format: "NHWC" as default based on historical concerns,
  853. "NCHW" is also supported. Default: "NCHW".
  854. :param imode: interpolation methods. Could be "LINEAR", "NEAREST", "CUBIC", "AREA".
  855. Default: "LINEAR".
  856. :return: output tensor.
  857. .. note::
  858. Here all available options for params are listed,
  859. however it does not mean that you can use all the combinations.
  860. On different platforms, different combinations are supported.
  861. """
  862. op = builtin.WarpAffine(
  863. border_mode=border_mode, border_val=border_val, format=format, imode=imode
  864. )
  865. out_shape = utils.astensor1d(out_shape, inp, dtype="int32", device=inp.device)
  866. (result,) = apply(op, inp, weight, out_shape)
  867. return result
  868. def warp_perspective(
  869. inp: Tensor,
  870. M: Tensor,
  871. dsize: Union[Tuple[int, int], int, Tensor],
  872. border_mode: str = "REPLICATE",
  873. border_val: float = 0.0,
  874. interp_mode: str = "LINEAR",
  875. ) -> Tensor:
  876. r"""
  877. Applies perspective transformation to batched 2D images.
  878. The input images are transformed to the output images by the transformation matrix:
  879. .. math::
  880. \text{output}(n, c, h, w) = \text{input} \left( n, c,
  881. \frac{M_{00}h + M_{01}w + M_{02}}{M_{20}h + M_{21}w + M_{22}},
  882. \frac{M_{10}h + M_{11}w + M_{12}}{M_{20}h + M_{21}w + M_{22}}
  883. \right)
  884. :param inp: input image.
  885. :param M: `(batch, 3, 3)` transformation matrix.
  886. :param dsize: `(h, w)` size of the output image.
  887. :param border_mode: pixel extrapolation method.
  888. Default: "REPLICATE". Currently also support "CONSTANT", "REFLECT",
  889. "REFLECT_101", "WRAP".
  890. :param border_val: value used in case of a constant border. Default: 0
  891. :param interp_mode: interpolation methods.
  892. Default: "LINEAR". Currently only support "LINEAR" mode.
  893. :return: output tensor.
  894. .. note::
  895. The transformation matrix is the inverse of that used by `cv2.warpPerspective`.
  896. Examples:
  897. .. testcode::
  898. import numpy as np
  899. from megengine import tensor
  900. import megengine.functional as F
  901. inp_shape = (1, 1, 4, 4)
  902. x = tensor(np.arange(16, dtype=np.float32).reshape(inp_shape))
  903. M_shape = (1, 3, 3)
  904. # M defines a translation: dst(1, 1, h, w) = rst(1, 1, h+1, w+1)
  905. M = tensor(np.array([[1., 0., 1.],
  906. [0., 1., 1.],
  907. [0., 0., 1.]], dtype=np.float32).reshape(M_shape))
  908. out = F.warp_perspective(x, M, (2, 2))
  909. print(out.numpy())
  910. Outputs:
  911. .. testoutput::
  912. [[[[ 5. 6.]
  913. [ 9. 10.]]]]
  914. """
  915. op = builtin.WarpPerspective(
  916. imode=interp_mode, bmode=border_mode, format="NCHW", border_val=border_val
  917. )
  918. inp, M = utils.convert_inputs(inp, M)
  919. dsize = astensor1d(dsize, inp, dtype="int32", device=inp.device)
  920. (result,) = apply(op, inp, M, dsize)
  921. return result
  922. def remap(
  923. inp: Tensor,
  924. map_xy: Tensor,
  925. border_mode: str = "REPLICATE",
  926. scalar: float = 0.0,
  927. interp_mode: str = "LINEAR",
  928. ) -> Tensor:
  929. r"""
  930. Applies remap transformation to batched 2D images.
  931. The input images are transformed to the output images by the tensor map_xy.
  932. The output's H and W are same as map_xy's H and W.
  933. :param inp: input image
  934. :param map_xy: (batch, oh, ow, 2) transformation matrix
  935. :param border_mode: pixel extrapolation method.
  936. Default: "REPLICATE". Currently also support "CONSTANT", "REFLECT",
  937. "REFLECT_101", "WRAP".
  938. :param scalar: value used in case of a constant border. Default: 0
  939. :param interp_mode: interpolation methods.
  940. Default: "LINEAR". Currently only support "LINEAR" mode.
  941. :return: output tensor.
  942. Examples:
  943. .. testcode::
  944. import numpy as np
  945. from megengine import tensor
  946. import megengine.functional as F
  947. inp_shape = (1, 1, 4, 4)
  948. inp = tensor(np.arange(16, dtype=np.float32).reshape(inp_shape))
  949. map_xy_shape = (1, 2, 2, 2)
  950. map_xy = tensor(np.array([[[1., 0.],[0., 1.]],
  951. [[0., 1.],[0., 1.]]],
  952. dtype=np.float32).reshape(map_xy_shape))
  953. out = F.remap(inp, map_xy)
  954. print(out.numpy())
  955. Outputs:
  956. .. testoutput::
  957. [[[[1. 4.]
  958. [4. 4.]]]]
  959. """
  960. op = builtin.Remap(
  961. imode=interp_mode, border_type=border_mode, format="NCHW", scalar=scalar
  962. )
  963. (result,) = apply(op, inp, map_xy)
  964. return result
  965. def interpolate(
  966. inp: Tensor,
  967. size: Optional[Union[int, Tuple[int, int]]] = None,
  968. scale_factor: Optional[Union[float, Tuple[float, float]]] = None,
  969. mode: str = "BILINEAR",
  970. align_corners: bool = None,
  971. ) -> Tensor:
  972. r"""
  973. Down/up samples the input tensor to either the given size or with the given scale_factor. ``size`` can not coexist with ``scale_factor``.
  974. :param inp: input tensor.
  975. :param size: size of the output tensor. Default: None
  976. :param scale_factor: scaling factor of the output tensor. Default: None
  977. :param mode: interpolation methods, acceptable values are:
  978. "BILINEAR", "LINEAR". Default: "BILINEAR"
  979. :return: output tensor.
  980. Examples:
  981. .. testcode::
  982. import numpy as np
  983. from megengine import tensor
  984. import megengine.functional as F
  985. x = tensor(np.arange(1, 5, dtype=np.float32).reshape(1, 1, 2, 2))
  986. out = F.nn.interpolate(x, [4, 4], align_corners=False)
  987. print(out.numpy())
  988. out2 = F.nn.interpolate(x, scale_factor=2.)
  989. np.testing.assert_allclose(out.numpy(), out2.numpy())
  990. Outputs:
  991. .. testoutput::
  992. [[[[1. 1.25 1.75 2. ]
  993. [1.5 1.75 2.25 2.5 ]
  994. [2.5 2.75 3.25 3.5 ]
  995. [3. 3.25 3.75 4. ]]]]
  996. """
  997. mode = mode.upper()
  998. if mode not in ["BILINEAR", "LINEAR"]:
  999. raise ValueError("interpolate only support linear or bilinear mode")
  1000. if mode not in ["BILINEAR", "LINEAR"]:
  1001. if align_corners is not None:
  1002. raise ValueError(
  1003. "align_corners option can only be set in the bilinear/linear interpolating mode"
  1004. )
  1005. else:
  1006. if align_corners is None:
  1007. align_corners = False
  1008. if mode == "LINEAR":
  1009. inp = expand_dims(inp, 3)
  1010. if inp.ndim != 4:
  1011. raise ValueError("shape of input tensor must correspond to the operartion mode")
  1012. if size is None:
  1013. if scale_factor is None:
  1014. raise ValueError("scale_factor must not be None when size is None")
  1015. if isinstance(scale_factor, (float, int)):
  1016. scale_factor = float(scale_factor)
  1017. if mode == "LINEAR":
  1018. scale_factor = (scale_factor, float(1))
  1019. else:
  1020. scale_factor = (scale_factor, scale_factor)
  1021. else:
  1022. if mode == "LINEAR":
  1023. raise ValueError(
  1024. "under LINEAR mode, scale_factor can only be single value"
  1025. )
  1026. assert len(scale_factor) == 2, "shape of scale_factor must be equal to (2, )"
  1027. assert isinstance(scale_factor[0], float) and isinstance(
  1028. scale_factor[1], float
  1029. ), "scale_factor must be float type"
  1030. dsize = tuple(
  1031. floor(
  1032. Tensor(
  1033. inp.shape[i + 2] * scale_factor[i],
  1034. dtype="float32",
  1035. device=inp.device,
  1036. )
  1037. )
  1038. for i in range(2)
  1039. )
  1040. dsize = concat([dsize[0], dsize[1]], axis=0)
  1041. else:
  1042. if scale_factor is not None:
  1043. raise ValueError("scale_factor must be None when size is provided")
  1044. if isinstance(size, int):
  1045. size = (size, 1)
  1046. else:
  1047. if mode == "LINEAR":
  1048. raise ValueError("under LINEAR mode, size can only be single value")
  1049. dsize = size
  1050. oh, ow = dsize[0], dsize[1]
  1051. ih, iw = inp.shape[2], inp.shape[3]
  1052. if align_corners:
  1053. hscale = (ih - 1.0) / (oh - 1.0)
  1054. wscale = 1.0 * iw / ow
  1055. if mode != "LINEAR":
  1056. wscale = (iw - 1.0) / (ow - 1.0)
  1057. row0 = concat(
  1058. [wscale, Tensor([0, 0], dtype="float32", device=inp.device)], axis=0
  1059. ).reshape(1, 3)
  1060. row1 = concat(
  1061. [
  1062. Tensor(0, dtype="float32", device=inp.device),
  1063. hscale,
  1064. Tensor(0, dtype="float32", device=inp.device),
  1065. ],
  1066. axis=0,
  1067. ).reshape(1, 3)
  1068. weight = concat(
  1069. [row0, row1, Tensor([[0, 0, 1]], dtype="float32", device=inp.device)],
  1070. axis=0,
  1071. ).reshape(1, 3, 3)
  1072. weight = broadcast_to(weight, (inp.shape[0], 3, 3))
  1073. else:
  1074. hscale = 1.0 * ih / oh
  1075. wscale = 1.0 * iw / ow
  1076. row0 = concat(
  1077. [wscale, Tensor(0, dtype="float32", device=inp.device), 0.5 * wscale - 0.5],
  1078. axis=0,
  1079. ).reshape(1, 3)
  1080. row1 = concat(
  1081. [Tensor(0, dtype="float32", device=inp.device), hscale, 0.5 * hscale - 0.5],
  1082. axis=0,
  1083. ).reshape(1, 3)
  1084. weight = concat(
  1085. [row0, row1, Tensor([[0, 0, 1]], dtype="float32", device=inp.device)],
  1086. axis=0,
  1087. ).reshape(1, 3, 3)
  1088. weight = broadcast_to(weight, (inp.shape[0], 3, 3))
  1089. weight = weight.astype("float32")
  1090. ret = warp_perspective(inp, weight, dsize, interp_mode="LINEAR")
  1091. if mode == "LINEAR":
  1092. ret = reshape(ret, ret.shape[0:3])
  1093. return ret
  1094. def dropout(inp: Tensor, drop_prob: float, training: bool = True) -> Tensor:
  1095. """
  1096. Returns a new tensor where each of the elements are randomly set to zero
  1097. with probability P = ``drop_prob``. Optionally rescale the output tensor if ``training`` is True.
  1098. :param inp: input tensor.
  1099. :param drop_prob: probability to drop (set to zero) a single element.
  1100. :param training: the default behavior of ``dropout`` during training is to rescale the output,
  1101. then it can be replaced by an :class:`~.Identity` during inference. Default: True
  1102. :return: the output tensor
  1103. Examples:
  1104. .. testcode::
  1105. import numpy as np
  1106. from megengine import tensor
  1107. import megengine.functional as F
  1108. x = tensor(np.ones(10, dtype=np.float32))
  1109. out = F.dropout(x, 1./3.)
  1110. print(out.numpy())
  1111. Outputs:
  1112. .. testoutput::
  1113. :options: +SKIP
  1114. [1.5 1.5 0. 1.5 1.5 1.5 1.5 1.5 1.5 1.5]
  1115. """
  1116. assert 0 <= drop_prob < 1
  1117. rv = uniform(size=inp.shape)
  1118. mask = rv > drop_prob
  1119. inp *= mask.astype(inp.dtype)
  1120. if training:
  1121. inp *= 1 / (1 - drop_prob)
  1122. return inp
  1123. def embedding(
  1124. inp: Tensor,
  1125. weight: Tensor,
  1126. padding_idx: Optional[int] = None,
  1127. max_norm: Optional[float] = None,
  1128. norm_type: Optional[float] = None,
  1129. ):
  1130. """
  1131. Applies lookup table for embedding.
  1132. :param inp: tensor with indices.
  1133. :param weight: learnable weights which embeds from.
  1134. :param padding_idx: should be set to None, not supported now.
  1135. :param max_norm: should be set to None, not supported now.
  1136. :param norm_type: should be set to None, not supported now.
  1137. :return: output tensor.
  1138. Refer to :class:`~.Embedding` for more information.
  1139. """
  1140. if padding_idx is not None:
  1141. raise ValueError("Not support padding_idx Now!")
  1142. if max_norm is not None or norm_type is not None:
  1143. raise ValueError("Not support weight normlization Now!")
  1144. dest_shp = list(inp.shape) + [weight.shape[-1]]
  1145. return weight[inp.reshape(-1)].reshape(dest_shp)
  1146. def roi_pooling(
  1147. inp: Tensor,
  1148. rois: Tensor,
  1149. output_shape: Union[int, tuple, list],
  1150. mode: str = "max",
  1151. scale: float = 1.0,
  1152. ) -> Tensor:
  1153. """
  1154. Applies roi pooling on input feature.
  1155. :param inp: tensor that represents the input feature, `(N, C, H, W)` images.
  1156. :param rois: `(K, 5)` boxes. First column is the index into N. The other 4 columns are xyxy.
  1157. :param output_shape: `(height, width)` of output rois feature.
  1158. :param mode: "max" or "average", use max/average align just like max/average pooling. Default: "max"
  1159. :param scale: scale the input boxes by this number. Default: 1.0
  1160. :return: `(K, C, output_shape[0], output_shape[1])` feature of rois.
  1161. Examples:
  1162. .. testcode::
  1163. import numpy as np
  1164. from megengine import tensor
  1165. import megengine.functional as F
  1166. np.random.seed(42)
  1167. inp = tensor(np.random.randn(1, 1, 128, 128))
  1168. rois = tensor(np.random.random((4, 5)))
  1169. y = F.nn.roi_pooling(inp, rois, (2, 2))
  1170. print(y.numpy()[0].round(decimals=4))
  1171. Outputs:
  1172. .. testoutput::
  1173. [[[-0.1383 -0.1383]
  1174. [-0.5035 -0.5035]]]
  1175. """
  1176. assert mode in ["max", "average"], "only max/average mode is supported"
  1177. if isinstance(output_shape, int):
  1178. output_shape = (output_shape, output_shape)
  1179. op = builtin.ROIPooling(mode=mode, scale=scale)
  1180. inp, rois = utils.convert_inputs(inp, rois)
  1181. result, _ = apply(
  1182. op, inp, rois, Tensor(output_shape, dtype="int32", device=inp.device)
  1183. )
  1184. return result
  1185. def roi_align(
  1186. inp: Tensor,
  1187. rois: Tensor,
  1188. output_shape: Union[int, tuple, list],
  1189. mode: str = "average",
  1190. spatial_scale: float = 1.0,
  1191. sample_points: Union[int, tuple, list] = 2,
  1192. aligned: bool = True,
  1193. ) -> Tensor:
  1194. """
  1195. Applies roi align on input feature.
  1196. :param inp: tensor that represents the input feature, shape is `(N, C, H, W)`.
  1197. :param rois: `(N, 5)` boxes. First column is the box index. The other 4 columns are ``xyxy``.
  1198. :param output_shape: `(height, width)` shape of output rois feature.
  1199. :param mode: "max" or "average", use max/average align just like max/average pooling. Default: "average"
  1200. :param spatial_scale: scale the input boxes by this number. Default: 1.0
  1201. :param sample_points: number of inputs samples to take for each output sample.
  1202. 0 to take samples densely. Default: 2
  1203. :param aligned: wheather to align the input feature, with `aligned=True`,
  1204. we first appropriately scale the ROI and then shift it by -0.5. Default: True
  1205. :return: output tensor.
  1206. Examples:
  1207. .. testcode::
  1208. import numpy as np
  1209. from megengine import tensor
  1210. import megengine.functional as F
  1211. np.random.seed(42)
  1212. inp = tensor(np.random.randn(1, 1, 128, 128))
  1213. rois = tensor(np.random.random((4, 5)))
  1214. y = F.nn.roi_align(inp, rois, (2, 2))
  1215. print(y.numpy()[0].round(decimals=4))
  1216. Outputs:
  1217. .. testoutput::
  1218. [[[0.175 0.175 ]
  1219. [0.1359 0.1359]]]
  1220. """
  1221. assert mode in ["max", "average"], "only max/average mode is supported"
  1222. if isinstance(output_shape, int):
  1223. output_shape = (output_shape, output_shape)
  1224. pooled_height, pooled_width = output_shape
  1225. if isinstance(sample_points, int):
  1226. sample_points = (sample_points, sample_points)
  1227. sample_height, sample_width = sample_points
  1228. offset = 0.5 if aligned else 0.0
  1229. op = builtin.ROIAlign(
  1230. mode=mode,
  1231. format="NCHW",
  1232. spatial_scale=spatial_scale,
  1233. offset=offset,
  1234. pooled_height=pooled_height,
  1235. pooled_width=pooled_width,
  1236. sample_height=sample_height,
  1237. sample_width=sample_width,
  1238. )
  1239. inp, rois = utils.convert_inputs(inp, rois)
  1240. result, *_ = apply(op, inp, rois)
  1241. return result
  1242. def indexing_one_hot(
  1243. src: Tensor, index: Tensor, axis: int = 1, keepdims=False
  1244. ) -> Tensor:
  1245. r"""
  1246. One-hot indexing for some axes.
  1247. :param src: input tensor.
  1248. :param index: index tensor.
  1249. :param axis: axis on src for which values in index index. Default: 1
  1250. :param keepdims: whether not to remove the axis in result. Default: False
  1251. :return: output tensor.
  1252. Examples:
  1253. .. testcode::
  1254. import megengine.functional as F
  1255. from megengine import tensor
  1256. src = tensor([[1.0, 2.0]])
  1257. index = tensor([0])
  1258. val = F.indexing_one_hot(src, index)
  1259. print(val.numpy())
  1260. Outputs:
  1261. .. testoutput::
  1262. [1.]
  1263. """
  1264. assert isinstance(src, Tensor), "src must be of Tensor type"
  1265. op = builtin.IndexingOneHot(axis=axis)
  1266. index = utils.convert_single_value(index, dtype="int32", device=src.device)
  1267. (result,) = apply(op, src, index)
  1268. if not keepdims:
  1269. result = squeeze(result, axis)
  1270. return result
  1271. def conv1d(
  1272. inp: Tensor,
  1273. weight: Tensor,
  1274. bias: Optional[Tensor] = None,
  1275. stride: int = 1,
  1276. padding: int = 0,
  1277. dilation: int = 1,
  1278. groups: int = 1,
  1279. conv_mode="CROSS_CORRELATION",
  1280. compute_mode="DEFAULT",
  1281. ) -> Tensor:
  1282. """1D convolution operation.
  1283. Refer to :class:`~.Conv1d` for more information.
  1284. :param inp: The feature map of the convolution operation
  1285. :param weight: The convolution kernel
  1286. :param bias: The bias added to the result of convolution (if given)
  1287. :param stride: Stride of the 1D convolution operation. Default: 1
  1288. :param padding: Size of the paddings added to the input on both sides of its
  1289. spatial dimensions. Only zero-padding is supported. Default: 0
  1290. :param dilation: Dilation of the 1D convolution operation. Default: 1
  1291. :param groups: number of groups to divide input and output channels into,
  1292. so as to perform a "grouped convolution". When ``groups`` is not 1,
  1293. ``in_channels`` and ``out_channels`` must be divisible by ``groups``,
  1294. and the shape of weight should be ``(groups, out_channel // groups,
  1295. in_channels // groups, height, width)``.
  1296. :type conv_mode: string or :class:`mgb.opr_param_defs.Convolution.Mode`
  1297. :param conv_mode: Supports 'CROSS_CORRELATION'. Default:
  1298. 'CROSS_CORRELATION'.
  1299. :type compute_mode: string or
  1300. :class:`mgb.opr_param_defs.Convolution.ComputeMode`
  1301. :param compute_mode: When set to 'DEFAULT', no special requirements will be
  1302. placed on the precision of intermediate results. When set to 'FLOAT32',
  1303. Float32 would be used for accumulator and intermediate result, but only
  1304. effective when input and output are of Float16 dtype.
  1305. """
  1306. assert conv_mode == "CROSS_CORRELATION" or conv_mode.name == "CROSS_CORRELATION"
  1307. assert compute_mode == "DEFAULT" or compute_mode.name == "DEFAULT"
  1308. assert inp.ndim == 3, "the input dimension of conv1d should be 3"
  1309. assert weight.ndim == 3, "the weight dimension of conv1d should be 3"
  1310. inp = expand_dims(inp, 3)
  1311. weight = expand_dims(weight, 3)
  1312. if bias is not None:
  1313. assert bias.ndim == 3, "the bias dimension of conv1d should be 3"
  1314. bias = expand_dims(bias, 3)
  1315. stride_h = stride
  1316. pad_h = padding
  1317. dilate_h = dilation
  1318. sparse_type = "DENSE" if groups == 1 else "GROUP"
  1319. op = builtin.Convolution(
  1320. stride_h=stride_h,
  1321. stride_w=1,
  1322. pad_h=pad_h,
  1323. pad_w=0,
  1324. dilate_h=dilate_h,
  1325. dilate_w=1,
  1326. strategy=get_execution_strategy(),
  1327. mode=conv_mode,
  1328. compute_mode=compute_mode,
  1329. sparse=sparse_type,
  1330. )
  1331. inp, weight = utils.convert_inputs(inp, weight)
  1332. (output,) = apply(op, inp, weight)
  1333. if bias is not None:
  1334. output += bias
  1335. output = squeeze(output, 3)
  1336. return output
  1337. def nms(
  1338. boxes: Tensor, scores: Tensor, iou_thresh: float, max_output: Optional[int] = None
  1339. ) -> Tensor:
  1340. r"""
  1341. Performs non-maximum suppression (NMS) on the boxes according to their intersection-over-union(IoU).
  1342. :param boxes: tensor of shape `(N, 4)`; the boxes to perform nms on; each box is expected to be in `(x1, y1, x2, y2)` format.
  1343. :param iou_thresh: IoU threshold for overlapping.
  1344. :param scores: tensor of shape `(N,)`, the score of boxes.
  1345. :param max_output: the maximum number of boxes to keep; it is optional if this operator is not traced
  1346. otherwise it required to be specified; if it is not specified, all boxes are kept.
  1347. :return: indices of the elements that have been kept by NMS.
  1348. Examples:
  1349. .. testcode::
  1350. import numpy as np
  1351. from megengine import tensor
  1352. import megengine.functional as F
  1353. x = np.zeros((100,4))
  1354. np.random.seed(42)
  1355. x[:,:2] = np.random.rand(100,2)*20
  1356. x[:,2:] = np.random.rand(100,2)*20 + 100
  1357. scores = tensor(np.random.rand(100))
  1358. inp = tensor(x)
  1359. result = F.nn.nms(inp, scores, iou_thresh=0.7)
  1360. print(result.numpy())
  1361. Outputs:
  1362. .. testoutput::
  1363. [75 69]
  1364. """
  1365. assert (
  1366. boxes.ndim == 2 and boxes.shape[1] == 4
  1367. ), "the expected shape of boxes is (N, 4)"
  1368. assert scores.ndim == 1, "the expected shape of scores is (N,)"
  1369. assert (
  1370. boxes.shape[0] == scores.shape[0]
  1371. ), "number of boxes and scores are not matched"
  1372. boxes = boxes.detach()
  1373. scores = scores.detach()
  1374. sorted_idx = argsort(scores, descending=True)
  1375. boxes = boxes[sorted_idx]
  1376. if is_tracing():
  1377. assert (
  1378. max_output is not None and max_output > 0
  1379. ), "max_output should be specified under tracing"
  1380. if max_output is None:
  1381. max_output = boxes.shape[0]
  1382. op = builtin.NMSKeep(iou_thresh, max_output)
  1383. inp = utils.convert_inputs(boxes.reshape(1, -1, 4))
  1384. indices, count = apply(op, *inp)
  1385. indices = indices[0][: count[0]]
  1386. keep_inds = sorted_idx[indices]
  1387. return keep_inds
  1388. def nvof(src: Tensor, precision: int = 1) -> Tensor:
  1389. r"""
  1390. Implements NVIDIA Optical Flow SDK.
  1391. :src shape: input tensor with shape (n, t, h, w, c4).
  1392. :src dtype: uint8.
  1393. :param precision: 0:NV_OF_PERF_LEVEL_SLOW 1:NV_OF_PERF_LEVEL_MEDIUM 2:NV_OF_PERF_LEVEL_FAST.
  1394. :output shape: (n, t-1, h//4, w//4, c2).
  1395. :output dtype: int16.
  1396. .. code-block:: python
  1397. import numpy as np
  1398. from megengine import tensor
  1399. import megengine.functional as F
  1400. x = np.random.random_integers(0, 255, (1,2,224,244,4)).astype("uint8")
  1401. src = tensor(x)
  1402. result = F.nn.nvof(src, precision=1)
  1403. print(result.numpy())
  1404. """
  1405. assert src.ndim == 5 and src.shape[4] == 4
  1406. src = src.detach()
  1407. op = builtin.NvOf(precision=precision)
  1408. return apply(op, src)[0]
  1409. from .loss import * # isort:skip
  1410. from .quantized import conv_bias_activation # isort:skip

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台