You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

nn.py 47 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488
  1. # -*- coding: utf-8 -*-
  2. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  3. #
  4. # Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  5. #
  6. # Unless required by applicable law or agreed to in writing,
  7. # software distributed under the License is distributed on an
  8. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  9. # pylint: disable=too-many-lines
  10. from typing import Optional, Sequence, Tuple, Union
  11. from ..core._imperative_rt.core2 import apply
  12. from ..core.ops import builtin
  13. from ..core.ops.builtin import BatchNorm, Elemwise
  14. from ..core.ops.special import Const
  15. from ..core.tensor import amp, megbrain_graph
  16. from ..core.tensor.array_method import _elwise_apply
  17. from ..core.tensor.utils import (
  18. astensor1d,
  19. astype,
  20. cast_tensors,
  21. convert_single_value,
  22. setscalar,
  23. )
  24. from ..device import get_default_device
  25. from ..distributed import WORLD, is_distributed
  26. from ..random import uniform
  27. from ..tensor import Tensor
  28. from ..utils.deprecation import deprecated_func
  29. from ..utils.tuple_function import _pair, _pair_nonzero, _triple, _triple_nonzero
  30. from .debug_param import get_execution_strategy
  31. from .distributed import all_reduce_sum
  32. from .elemwise import _elwise, exp, log, log1p, maximum, minimum
  33. from .math import matmul, max, sum
  34. from .tensor import broadcast_to, concat, expand_dims, ones, squeeze, zeros
  35. __all__ = [
  36. "adaptive_avg_pool2d",
  37. "adaptive_max_pool2d",
  38. "avg_pool2d",
  39. "batch_norm",
  40. "conv1d",
  41. "conv2d",
  42. "conv3d",
  43. "conv_transpose2d",
  44. "conv_transpose3d",
  45. "deformable_conv2d",
  46. "deformable_psroi_pooling",
  47. "dropout",
  48. "embedding",
  49. "gelu",
  50. "hsigmoid",
  51. "hswish",
  52. "indexing_one_hot",
  53. "leaky_relu",
  54. "linear",
  55. "local_conv2d",
  56. "logsigmoid",
  57. "logsumexp",
  58. "logsoftmax",
  59. "max_pool2d",
  60. "one_hot",
  61. "prelu",
  62. "relu",
  63. "relu6",
  64. "remap",
  65. "resize",
  66. "sigmoid",
  67. "sliding_window",
  68. "sliding_window_transpose",
  69. "silu",
  70. "softmax",
  71. "softplus",
  72. "sync_batch_norm",
  73. "warp_affine",
  74. "warp_perspective",
  75. ]
  76. def expand_hw(x):
  77. # NOTE: >1d array is accepted, as long as 1 <= size <= 2
  78. try:
  79. x = int(x)
  80. return [x, x]
  81. except (TypeError, ValueError):
  82. pass
  83. h, w = x
  84. return int(h), int(w)
  85. def linear(
  86. inp: Tensor, weight: Tensor, bias: Optional[Tensor] = None, compute_mode="default",
  87. ) -> Tensor:
  88. """
  89. Applies a linear transformation to the input tensor.
  90. Refer to :class:`~.module.linear.Linear` for more information.
  91. :param inp: input tensor with shape `(N, in_features)`.
  92. :param weight: weight with shape `(out_features, in_features)`.
  93. :param bias: bias with shape `(out_features,)`.
  94. Default: None
  95. """
  96. ret = matmul(inp, weight, transpose_b=True, compute_mode=compute_mode)
  97. if bias is not None:
  98. if amp._enabled:
  99. bias = bias.astype("float16")
  100. ret += bias
  101. return ret
  102. def conv1d(
  103. inp: Tensor,
  104. weight: Tensor,
  105. bias: Optional[Tensor] = None,
  106. stride: int = 1,
  107. padding: int = 0,
  108. dilation: int = 1,
  109. groups: int = 1,
  110. conv_mode="cross_correlation",
  111. compute_mode="default",
  112. ) -> Tensor:
  113. """1D convolution operation.
  114. Refer to :class:`~.Conv1d` for more information.
  115. :param inp: The feature map of the convolution operation
  116. :param weight: The convolution kernel.
  117. :param bias: The bias added to the result of convolution (if given)
  118. :param stride: Stride of the 1D convolution operation. Default: 1
  119. :param padding: Size of the paddings added to the input on both sides of its
  120. spatial dimensions. Only zero-padding is supported. Default: 0
  121. :param dilation: Dilation of the 1D convolution operation. Default: 1
  122. :param groups: number of groups to divide input and output channels into,
  123. so as to perform a "grouped convolution". When ``groups`` is not 1,
  124. ``in_channels`` and ``out_channels`` must be divisible by ``groups``,
  125. and the shape of weight should be ``(groups, out_channel // groups,
  126. in_channels // groups, kernel_size)``. Default: 1
  127. :type conv_mode: string or :class:`mgb.opr_param_defs.Convolution.Mode`
  128. :param conv_mode: Supports 'cross_correlation'. Default:
  129. 'cross_correlation'.
  130. :type compute_mode: string or
  131. :class:`mgb.opr_param_defs.Convolution.ComputeMode`
  132. :param compute_mode: When set to 'default', no special requirements will be
  133. placed on the precision of intermediate results. When set to 'float32',
  134. float32 would be used for accumulator and intermediate result, but only
  135. effective when input and output are of float16 dtype.
  136. """
  137. assert (
  138. conv_mode.lower() == "cross_correlation"
  139. or conv_mode.name == "CROSS_CORRELATION"
  140. )
  141. assert compute_mode.lower() == "default" or compute_mode.name == "DEFAULT"
  142. assert inp.ndim == 3, "the input dimension of conv1d should be 3"
  143. assert weight.ndim == 3, "the weight dimension of conv1d should be 3"
  144. if amp._enabled:
  145. compute_mode = "float32"
  146. inp, weight, bias = cast_tensors(inp, weight, bias)
  147. inp = expand_dims(inp, 3)
  148. weight = expand_dims(weight, 3)
  149. if bias is not None:
  150. assert bias.ndim == 3, "the bias dimension of conv1d should be 3"
  151. bias = expand_dims(bias, 3)
  152. stride_h = stride
  153. pad_h = padding
  154. dilate_h = dilation
  155. sparse_type = "dense" if groups == 1 else "group"
  156. op = builtin.Convolution(
  157. stride_h=stride_h,
  158. stride_w=1,
  159. pad_h=pad_h,
  160. pad_w=0,
  161. dilate_h=dilate_h,
  162. dilate_w=1,
  163. strategy=get_execution_strategy(),
  164. mode=conv_mode,
  165. compute_mode=compute_mode,
  166. sparse=sparse_type,
  167. )
  168. (output,) = apply(op, inp, weight)
  169. if bias is not None:
  170. output += bias
  171. output = squeeze(output, 3)
  172. return output
  173. def conv2d(
  174. inp: Tensor,
  175. weight: Tensor,
  176. bias: Optional[Tensor] = None,
  177. stride: Union[int, Tuple[int, int]] = 1,
  178. padding: Union[int, Tuple[int, int]] = 0,
  179. dilation: Union[int, Tuple[int, int]] = 1,
  180. groups: int = 1,
  181. conv_mode="cross_correlation",
  182. compute_mode="default",
  183. ) -> Tensor:
  184. """
  185. 2D convolution operation.
  186. Refer to :class:`~.module.Conv2d` for more information.
  187. :param inp: feature map of the convolution operation.
  188. :param weight: convolution kernel.
  189. :param bias: bias added to the result of convolution (if given).
  190. :param stride: stride of the 2D convolution operation. Default: 1
  191. :param padding: size of the paddings added to the input on both sides of its
  192. spatial dimensions. Only zero-padding is supported. Default: 0
  193. :param dilation: dilation of the 2D convolution operation. Default: 1
  194. :param groups: number of groups into which the input and output channels are divided,
  195. so as to perform a ``grouped convolution``. When ``groups`` is not 1,
  196. ``in_channels`` and ``out_channels`` must be divisible by ``groups``,
  197. and the shape of weight should be ``(groups, out_channel // groups,
  198. in_channels // groups, height, width)``. Default: 1
  199. :type conv_mode: string or :class:`Convolution.Mode`
  200. :param conv_mode: supports "cross_correlation". Default:
  201. "cross_correlation"
  202. :type compute_mode: string or
  203. :class:`Convolution.ComputeMode`
  204. :param compute_mode: when set to "default", no special requirements will be
  205. placed on the precision of intermediate results. When set to "float32",
  206. "float32" would be used for accumulator and intermediate result, but only
  207. effective when input and output are of float16 dtype.
  208. :return: output tensor.
  209. """
  210. assert (
  211. conv_mode.lower() == "cross_correlation"
  212. or conv_mode.name == "CROSS_CORRELATION"
  213. )
  214. if amp._enabled:
  215. compute_mode = "float32"
  216. inp, weight, bias = cast_tensors(inp, weight, bias)
  217. stride_h, stride_w = expand_hw(stride)
  218. pad_h, pad_w = expand_hw(padding)
  219. dilate_h, dilate_w = expand_hw(dilation)
  220. sparse_type = "dense" if groups == 1 else "group"
  221. op = builtin.Convolution(
  222. stride_h=stride_h,
  223. stride_w=stride_w,
  224. pad_h=pad_h,
  225. pad_w=pad_w,
  226. dilate_h=dilate_h,
  227. dilate_w=dilate_w,
  228. strategy=get_execution_strategy(),
  229. mode=conv_mode,
  230. compute_mode=compute_mode,
  231. sparse=sparse_type,
  232. )
  233. (output,) = apply(op, inp, weight)
  234. if bias is not None:
  235. output += bias
  236. return output
  237. def conv3d(
  238. inp: Tensor,
  239. weight: Tensor,
  240. bias: Optional[Tensor] = None,
  241. stride: Union[int, Tuple[int, int, int]] = 1,
  242. padding: Union[int, Tuple[int, int, int]] = 0,
  243. dilation: Union[int, Tuple[int, int, int]] = 1,
  244. groups: int = 1,
  245. conv_mode: str = "cross_correlation",
  246. ) -> Tensor:
  247. """
  248. 3D convolution operation.
  249. Refer to :class:`~.Conv3d` for more information.
  250. :param inp: feature map of the convolution operation.
  251. :param weight: convolution kernel.
  252. :param bias: bias added to the result of convolution (if given).
  253. :param stride: stride of the 3D convolution operation. Default: 1
  254. :param padding: size of the paddings added to the input on both sides of its
  255. spatial dimensions. Only zero-padding is supported. Default: 0
  256. :param dilation: dilation of the 3D convolution operation. Default: 1
  257. :param groups: number of groups into which the input and output channels are divided,
  258. so as to perform a ``grouped convolution``. When ``groups`` is not 1,
  259. ``in_channels`` and ``out_channels`` must be divisible by ``groups``,
  260. and the shape of weight should be ``(groups, out_channel // groups,
  261. in_channels // groups, depth, height, width)``. Default: 1
  262. :param conv_mode: supports "cross_correlation". Default:
  263. "cross_correlation"
  264. :return: output tensor.
  265. """
  266. assert conv_mode.lower() == "cross_correlation"
  267. D, H, W = 0, 1, 2
  268. pad = _triple(padding)
  269. stride = _triple_nonzero(stride)
  270. dilate = _triple_nonzero(dilation)
  271. sparse_type = "dense" if groups == 1 else "group"
  272. op = builtin.Convolution3D(
  273. pad_d=pad[D],
  274. pad_h=pad[H],
  275. pad_w=pad[W],
  276. stride_d=stride[D],
  277. stride_h=stride[H],
  278. stride_w=stride[W],
  279. dilate_d=dilate[D],
  280. dilate_h=dilate[H],
  281. dilate_w=dilate[W],
  282. strategy=get_execution_strategy(),
  283. mode=conv_mode,
  284. sparse=sparse_type,
  285. )
  286. (output,) = apply(op, inp, weight)
  287. if bias is not None:
  288. output += bias
  289. return output
  290. def conv_transpose2d(
  291. inp: Tensor,
  292. weight: Tensor,
  293. bias: Optional[Tensor] = None,
  294. stride: Union[int, Tuple[int, int]] = 1,
  295. padding: Union[int, Tuple[int, int]] = 0,
  296. dilation: Union[int, Tuple[int, int]] = 1,
  297. groups: int = 1,
  298. conv_mode="cross_correlation",
  299. compute_mode="default",
  300. ) -> Tensor:
  301. """
  302. 2D transposed convolution operation.
  303. Refer to :class:`~.ConvTranspose2d` for more information.
  304. :param inp: feature map of the convolution operation.
  305. :param weight: convolution kernel.
  306. :param bias: bias added to the result of convolution (if given).
  307. :param stride: stride of the 2D convolution operation. Default: 1
  308. :param padding: size of the paddings added to the input on both sides of its
  309. spatial dimensions. Only zero-padding is supported. Default: 0
  310. :param dilation: dilation of the 2D convolution operation. Default: 1
  311. :param groups: number of groups into which the input and output channels are divided,
  312. so as to perform a ``grouped convolution``. When ``groups`` is not 1,
  313. ``in_channels`` and ``out_channels`` must be divisible by groups,
  314. and the shape of weight should be ``(groups, in_channels // groups,
  315. out_channels // groups, height, width)``. Default: 1
  316. :type conv_mode: string or :class:`Convolution.Mode`
  317. :param conv_mode: supports "cross_correlation". Default:
  318. "cross_correlation"
  319. :type compute_mode: string or
  320. :class:`Convolution.ComputeMode`
  321. :param compute_mode: when set to "default", no special requirements will be
  322. placed on the precision of intermediate results. When set to "float32",
  323. "float32" would be used for accumulator and intermediate result, but only
  324. effective when input and output are of float16 dtype.
  325. :return: output tensor.
  326. """
  327. assert (
  328. conv_mode.lower() == "cross_correlation"
  329. or conv_mode.name == "CROSS_CORRELATION"
  330. )
  331. if amp._enabled:
  332. compute_mode = "float32"
  333. inp, weight, bias = cast_tensors(inp, weight, bias)
  334. if groups != 1:
  335. raise NotImplementedError("group transposed conv2d is not supported yet.")
  336. stride_h, stride_w = expand_hw(stride)
  337. pad_h, pad_w = expand_hw(padding)
  338. dilate_h, dilate_w = expand_hw(dilation)
  339. op = builtin.ConvolutionBackwardData(
  340. stride_h=stride_h,
  341. stride_w=stride_w,
  342. pad_h=pad_h,
  343. pad_w=pad_w,
  344. dilate_h=dilate_h,
  345. dilate_w=dilate_w,
  346. strategy=get_execution_strategy(),
  347. compute_mode=compute_mode,
  348. )
  349. (output,) = apply(op, weight, inp)
  350. if bias is not None:
  351. output += bias
  352. return output
  353. def deformable_conv2d(
  354. inp: Tensor,
  355. weight: Tensor,
  356. offset: Tensor,
  357. mask: Tensor,
  358. bias: Optional[Tensor] = None,
  359. stride: Union[int, Tuple[int, int]] = 1,
  360. padding: Union[int, Tuple[int, int]] = 0,
  361. dilation: Union[int, Tuple[int, int]] = 1,
  362. groups: int = 1,
  363. conv_mode="cross_correlation",
  364. compute_mode="default",
  365. ) -> Tensor:
  366. """
  367. Deformable Convolution.
  368. :param inp: input feature map.
  369. :param weight: convolution kernel.
  370. :param offset: input offset to kernel, channel of this tensor should match the deformable settings.
  371. :param mask: input mask to kernel, channel of this tensor should match the deformable settings.
  372. :param bias: bias added to the result of convolution (if given).
  373. :param stride: stride of the 2D convolution operation. Default: 1
  374. :param padding: size of the paddings added to the input on both sides of its
  375. spatial dimensions. Only zero-padding is supported. Default: 0
  376. :param dilation: dilation of the 2D convolution operation. Default: 1
  377. :param groups: number of groups into which the input and output channels are divided,
  378. so as to perform a ``grouped convolution``. When ``groups`` is not 1,
  379. ``in_channels`` and ``out_channels`` must be divisible by groups,
  380. and the shape of weight should be ``(groups, out_channel // groups,
  381. in_channels // groups, height, width)``. Default: 1
  382. :type conv_mode: string or :class:`Convolution.Mode`
  383. :param conv_mode: supports "cross_correlation". Default:
  384. "cross_correlation"
  385. :type compute_mode: string or
  386. :class:`Convolution.ComputeMode`
  387. :param compute_mode: when set to "default", no special requirements will be
  388. placed on the precision of intermediate results. When set to "float32",
  389. "float32" would be used for accumulator and intermediate result, but only
  390. effective when input and output are of float16 dtype.
  391. :return: output tensor.
  392. """
  393. assert (
  394. conv_mode.lower() == "cross_correlation"
  395. or conv_mode.name == "CROSS_CORRELATION"
  396. )
  397. if amp._enabled:
  398. compute_mode = "float32"
  399. inp, weight, offset, mask, bias = cast_tensors(inp, weight, offset, mask, bias)
  400. else:
  401. offset = offset.astype("float32")
  402. mask = mask.astype("float32")
  403. stride_h, stride_w = expand_hw(stride)
  404. pad_h, pad_w = expand_hw(padding)
  405. dilate_h, dilate_w = expand_hw(dilation)
  406. sparse_type = "dense" if groups == 1 else "group"
  407. op = builtin.DeformableConv(
  408. stride_h=stride_h,
  409. stride_w=stride_w,
  410. pad_h=pad_h,
  411. pad_w=pad_w,
  412. dilate_h=dilate_h,
  413. dilate_w=dilate_w,
  414. strategy=get_execution_strategy(),
  415. mode=conv_mode,
  416. compute_mode=compute_mode,
  417. sparse=sparse_type,
  418. )
  419. (output,) = apply(op, inp, weight, offset, mask)
  420. if bias is not None:
  421. output += bias
  422. return output
  423. def local_conv2d(
  424. inp: Tensor,
  425. weight: Tensor,
  426. bias: Optional[Tensor] = None,
  427. stride: Union[int, Tuple[int, int]] = 1,
  428. padding: Union[int, Tuple[int, int]] = 0,
  429. dilation: Union[int, Tuple[int, int]] = 1,
  430. conv_mode="cross_correlation",
  431. ):
  432. """Applies spatial 2D convolution over an groupped channeled image with untied kernels."""
  433. assert (
  434. conv_mode.lower() == "cross_correlation"
  435. or conv_mode.name == "CROSS_CORRELATION"
  436. )
  437. stride_h, stride_w = expand_hw(stride)
  438. pad_h, pad_w = expand_hw(padding)
  439. dilate_h, dilate_w = expand_hw(dilation)
  440. op = builtin.GroupLocal(
  441. stride_h=stride_h,
  442. stride_w=stride_w,
  443. pad_h=pad_h,
  444. pad_w=pad_w,
  445. dilate_h=dilate_h,
  446. dilate_w=dilate_w,
  447. mode=conv_mode,
  448. sparse="dense",
  449. )
  450. (output,) = apply(op, inp, weight)
  451. if bias is not None:
  452. output += bias
  453. return output
  454. def conv_transpose3d(
  455. inp: Tensor,
  456. weight: Tensor,
  457. bias: Optional[Tensor] = None,
  458. stride: Union[int, Tuple[int, int, int]] = 1,
  459. padding: Union[int, Tuple[int, int, int]] = 0,
  460. dilation: Union[int, Tuple[int, int, int]] = 1,
  461. ) -> Tensor:
  462. """
  463. 3D transposed convolution operation. Only support the case that groups = 1
  464. and conv_mode = "cross_correlation".
  465. Refer to :class:`~.ConvTranspose3d` for more information.
  466. :param inp: feature map of the convolution operation.
  467. :param weight: convolution kernel.
  468. weight usually has shape ``(in_channels, out_channels, depth, height, width)``.
  469. :param bias: bias added to the result of convolution (if given).
  470. :param stride: stride of the 3D convolution operation. Default: 1
  471. :param padding: size of the paddings added to the input on all sides of its
  472. spatial dimensions. Only zero-padding is supported. Default: 0
  473. :param dilation: dilation of the 3D convolution operation. Default: 1
  474. :return: output tensor.
  475. """
  476. D, H, W = 0, 1, 2
  477. pad = _triple(padding)
  478. stride = _triple_nonzero(stride)
  479. dilate = _triple_nonzero(dilation)
  480. op = builtin.Convolution3DBackwardData(
  481. pad_d=pad[D],
  482. pad_h=pad[H],
  483. pad_w=pad[W],
  484. stride_d=stride[D],
  485. stride_h=stride[H],
  486. stride_w=stride[W],
  487. dilate_d=dilate[D],
  488. dilate_h=dilate[H],
  489. dilate_w=dilate[W],
  490. strategy=get_execution_strategy(),
  491. )
  492. (output,) = apply(op, weight, inp)
  493. if bias is not None:
  494. output += bias
  495. return output
  496. def max_pool2d(
  497. inp: Tensor,
  498. kernel_size: Union[int, Tuple[int, int]],
  499. stride: Optional[Union[int, Tuple[int, int]]] = None,
  500. padding: Union[int, Tuple[int, int]] = 0,
  501. ) -> Tensor:
  502. """
  503. Applies a 2D max pooling over an input tensor.
  504. Refer to :class:`~.MaxPool2d` for more information.
  505. :param inp: input tensor.
  506. :param kernel_size: size of the window.
  507. :param stride: stride of the window. If not provided, its value is set to kernel_size.
  508. Default: None
  509. :param padding: implicit zero padding added on both sides. Default: 0
  510. :return: output tensor.
  511. """
  512. if stride is None:
  513. stride = kernel_size
  514. window_h, window_w = _pair_nonzero(kernel_size)
  515. stride_h, stride_w = _pair_nonzero(stride)
  516. padding_h, padding_w = _pair(padding)
  517. op = builtin.Pooling(
  518. window_h=window_h,
  519. window_w=window_w,
  520. stride_h=stride_h,
  521. stride_w=stride_w,
  522. pad_h=padding_h,
  523. pad_w=padding_w,
  524. mode="max",
  525. )
  526. (output,) = apply(op, inp)
  527. return output
  528. def avg_pool2d(
  529. inp: Tensor,
  530. kernel_size: Union[int, Tuple[int, int]],
  531. stride: Optional[Union[int, Tuple[int, int]]] = None,
  532. padding: Union[int, Tuple[int, int]] = 0,
  533. mode: str = "average_count_exclude_padding",
  534. ) -> Tensor:
  535. """
  536. Applies 2D average pooling over an input tensor.
  537. Refer to :class:`~.AvgPool2d` for more information.
  538. :param inp: input tensor.
  539. :param kernel_size: size of the window.
  540. :param stride: stride of the window. If not provided, its value is set to ``kernel_size``.
  541. Default: None
  542. :param padding: implicit zero padding added on both sides. Default: 0
  543. :param mode: whether to count padding values, set to "average" will do counting.
  544. Default: "average_count_exclude_padding"
  545. :return: output tensor.
  546. """
  547. if stride is None:
  548. stride = kernel_size
  549. window_h, window_w = _pair_nonzero(kernel_size)
  550. stride_h, stride_w = _pair_nonzero(stride)
  551. padding_h, padding_w = _pair(padding)
  552. op = builtin.Pooling(
  553. window_h=window_h,
  554. window_w=window_w,
  555. stride_h=stride_h,
  556. stride_w=stride_w,
  557. pad_h=padding_h,
  558. pad_w=padding_w,
  559. mode=mode,
  560. )
  561. (output,) = apply(op, inp)
  562. return output
  563. def adaptive_max_pool2d(
  564. inp: Tensor, oshp: Union[Tuple[int, int], int, Tensor],
  565. ) -> Tensor:
  566. """
  567. Applies a 2D max adaptive pooling over an input.
  568. Refer to :class:`~.MaxAdaptivePool2d` for more information.
  569. :param inp: input tensor.
  570. :param oshp: `(OH, OW)` size of the output shape.
  571. :return: output tensor.
  572. """
  573. if isinstance(oshp, int):
  574. oshp = (oshp, oshp)
  575. op = builtin.AdaptivePooling(mode="max", format="NCHW",)
  576. oshp = astensor1d(oshp, inp, dtype="int32", device=inp.device)
  577. (output,) = apply(op, inp, oshp)
  578. return output
  579. def adaptive_avg_pool2d(
  580. inp: Tensor, oshp: Union[Tuple[int, int], int, Tensor],
  581. ) -> Tensor:
  582. """
  583. Applies a 2D average adaptive pooling over an input.
  584. Refer to :class:`~.AvgAdaptivePool2d` for more information.
  585. :param inp: input tensor.
  586. :param oshp: `(OH, OW)` size of the output shape.
  587. :return: output tensor.
  588. """
  589. if isinstance(oshp, int):
  590. oshp = (oshp, oshp)
  591. op = builtin.AdaptivePooling(mode="average", format="NCHW",)
  592. oshp = astensor1d(oshp, inp, dtype="int32", device=inp.device)
  593. (output,) = apply(op, inp, oshp)
  594. return output
  595. def deformable_psroi_pooling(
  596. inp: Tensor,
  597. rois: Tensor,
  598. trans: Tensor,
  599. no_trans: bool,
  600. part_size: int,
  601. pooled_h: int,
  602. pooled_w: int,
  603. sample_per_part: int,
  604. spatial_scale: float,
  605. trans_std: float = 0.1,
  606. ):
  607. """
  608. Deformable PSROI(Position Sensitive Region of Interest) Pooling.
  609. :param inp: input feature map.
  610. :param rois: the rois for feature pooling.
  611. :param trans: input offset to psroi_pooling.
  612. :param no_trans: check the phase of DeformablePSROIPooling. False to the
  613. 1st phase, True to the 2nd phase.
  614. :param part_size: part size.
  615. :param sample_per_part: sample points of each part.
  616. :param pooled_shape: kernel shape of convolution.
  617. :param spatial_scale: the spatial_scale w.r.t input image.
  618. :param trans_std: multiplier used in 2nd phase.
  619. """
  620. op = builtin.DeformablePSROIPooling(
  621. no_trans=no_trans,
  622. part_size=part_size,
  623. pooled_h=pooled_h,
  624. pooled_w=pooled_w,
  625. sample_per_part=sample_per_part,
  626. spatial_scale=spatial_scale,
  627. trans_std=trans_std,
  628. )
  629. output, _ = apply(op, inp, rois, trans)
  630. return output
  631. def hswish(x):
  632. """
  633. Element-wise `x * relu6(x + 3) / 6`.
  634. :param x: input tensor.
  635. :return: computed tensor.
  636. Example:
  637. .. testcode::
  638. import numpy as np
  639. from megengine import tensor
  640. import megengine.functional as F
  641. x = tensor(np.arange(5).astype(np.float32))
  642. out = F.hswish(x)
  643. print(out.numpy().round(decimals=4))
  644. .. testoutput::
  645. [0. 0.6667 1.6667 3. 4. ]
  646. """
  647. return _elwise(x, mode=Elemwise.Mode.H_SWISH)
  648. def sigmoid(x):
  649. """Element-wise `1 / ( 1 + exp( -x ) )`."""
  650. return _elwise(x, mode=Elemwise.Mode.SIGMOID)
  651. def hsigmoid(x):
  652. """Element-wise `relu6(x + 3) / 6`."""
  653. return relu6(x + 3) / 6
  654. def relu(x):
  655. """Element-wise `max(x, 0)`."""
  656. return _elwise(x, mode=Elemwise.Mode.RELU)
  657. def relu6(x):
  658. """Element-wise `min(max(x, 0), 6)`."""
  659. return minimum(maximum(x, 0), 6)
  660. def prelu(inp: Tensor, weight: Tensor) -> Tensor:
  661. r"""
  662. Applies the element-wise PReLU function.
  663. Refer to :class:`~.PReLU` for more information.
  664. """
  665. return maximum(inp, 0) + weight * minimum(inp, 0)
  666. def leaky_relu(inp: Tensor, negative_slope: float = 0.01) -> Tensor:
  667. r"""
  668. Applies the element-wise leaky_relu function
  669. Refer to :class:`~.LeakyReLU` for more information.
  670. """
  671. return maximum(inp, 0) + negative_slope * minimum(inp, 0)
  672. def silu(x):
  673. r"""
  674. Applies the element-wise Sigmoid Linear Unit function, i.e. `x * sigmoid(x)`.
  675. """
  676. return _elwise(x, mode=Elemwise.Mode.SILU)
  677. def gelu(x):
  678. r"""
  679. Applies the element-wise function:
  680. .. math::
  681. \text{gelu}(x) = x\Phi(x)
  682. where :math:`\Phi(x)` is the Cumulative Distribution Function for Gaussian Distribution.
  683. """
  684. return _elwise(x, mode=Elemwise.Mode.GELU)
  685. def softplus(inp: Tensor) -> Tensor:
  686. r"""
  687. Applies the element-wise function:
  688. .. math::
  689. \text{softplus}(x) = \log(1 + \exp(x))
  690. softplus is a smooth approximation to the ReLU function and can be used
  691. to constrain the output to be always positive.
  692. For numerical stability the implementation follows this transformation:
  693. .. math::
  694. \text{softplus}(x) = \log(1 + \exp(x))
  695. = \log(1 + \exp(-\text{abs}(x))) + \max(x, 0)
  696. = \log1p(\exp(-\text{abs}(x))) + \text{relu}(x)
  697. :param inp: input tensor.
  698. Examples:
  699. .. testcode::
  700. import numpy as np
  701. from megengine import tensor
  702. import megengine.functional as F
  703. x = tensor(np.arange(-3, 3, dtype=np.float32))
  704. y = F.softplus(x)
  705. print(y.numpy().round(decimals=4))
  706. Outputs:
  707. .. testoutput::
  708. [0.0486 0.1269 0.3133 0.6931 1.3133 2.1269]
  709. """
  710. return log1p(exp(-abs(inp))) + relu(inp)
  711. def logsoftmax(inp: Tensor, axis: Union[int, Sequence[int]]) -> Tensor:
  712. r"""
  713. Applies the :math:`\log(\text{softmax}(x))` function to an n-dimensional
  714. input tensor. The :math:`\text{logsoftmax}(x)` formulation can be simplified as:
  715. .. math::
  716. \text{logsoftmax}(x_{i}) = \log(\frac{\exp(x_i) }{ \sum_j \exp(x_j)} )
  717. For numerical stability the implementation follows this transformation:
  718. .. math::
  719. \text{logsoftmax}(x)
  720. = \log (\frac{\exp (x)}{\sum_{i}(\exp (x_{i}))})
  721. = x - \log (\sum_{i}(\exp (x_{i})))
  722. = x - \text{logsumexp}(x)
  723. :param inp: input tensor.
  724. :param axis: axis along which :math:`\text{logsoftmax}(x)` will be applied.
  725. Examples:
  726. .. testcode::
  727. import numpy as np
  728. from megengine import tensor
  729. import megengine.functional as F
  730. x = tensor(np.arange(-5, 5, dtype=np.float32)).reshape(2,5)
  731. y = F.logsoftmax(x, axis=1)
  732. print(y.numpy().round(decimals=4))
  733. Outputs:
  734. .. testoutput::
  735. [[-4.4519 -3.4519 -2.4519 -1.4519 -0.4519]
  736. [-4.4519 -3.4519 -2.4519 -1.4519 -0.4519]]
  737. """
  738. return inp - logsumexp(inp, axis, keepdims=True)
  739. def logsigmoid(inp: Tensor) -> Tensor:
  740. r"""
  741. Applies the element-wise function:
  742. .. math::
  743. \text{logsigmoid}(x) = \log(\frac{ 1 }{ 1 + \exp(-x)})
  744. = \log(1/(1 + \exp(-x)))
  745. = - \log(1 + \exp(-x))
  746. = - \text{softplus}(-x)
  747. :param inp: input tensor.
  748. Examples:
  749. .. testcode::
  750. import numpy as np
  751. from megengine import tensor
  752. import megengine.functional as F
  753. x = tensor(np.arange(-5, 5, dtype=np.float32))
  754. y = F.logsigmoid(x)
  755. print(y.numpy().round(decimals=4))
  756. Outputs:
  757. .. testoutput::
  758. [-5.0067 -4.0182 -3.0486 -2.1269 -1.3133 -0.6931 -0.3133 -0.1269 -0.0486
  759. -0.0181]
  760. """
  761. return -softplus(-inp)
  762. def logsumexp(
  763. inp: Tensor, axis: Union[int, Sequence[int]], keepdims: bool = False
  764. ) -> Tensor:
  765. r"""
  766. Calculates the logarithm of the inputs' exponential sum along the given :attr:`axis`.
  767. .. math::
  768. \text{logsumexp}(x)= \log \sum_{j=1}^{n} \exp \left(x_{j}\right)
  769. For numerical stability, the implementation follows this transformation:
  770. .. math::
  771. \text{logsumexp}(x)= \log \sum_{j=1}^{n} \exp \left(x_{j}\right)
  772. = \text{logsumexp}(x)=b+\log \sum_{j=1}^{n} \exp \left(x_{j}-b\right)
  773. where
  774. .. math::
  775. b = \max(x_j)
  776. :param inp: input tensor.
  777. :param axis: axis over which the sum is taken. It could be single axis or list of axes.
  778. :param keepdims: whether to retain :attr:`axis` or not for the output tensor.
  779. Examples:
  780. .. testcode::
  781. import numpy as np
  782. from megengine import tensor
  783. import megengine.functional as F
  784. x = tensor(np.arange(-5, 5, dtype=np.float32)).reshape(2,5)
  785. y = F.logsumexp(x, axis=1, keepdims=False)
  786. print(y.numpy().round(decimals=4))
  787. Outputs:
  788. .. testoutput::
  789. [-0.5481 4.4519]
  790. """
  791. max_value = max(inp.detach(), axis, keepdims=True)
  792. if keepdims:
  793. return max_value + log(sum(exp(inp - max_value), axis, keepdims))
  794. else:
  795. return squeeze(max_value, axis=None) + log(
  796. sum(exp(inp - max_value), axis, keepdims)
  797. )
  798. def _get_softmax_axis(ndim: int) -> int:
  799. if ndim in (0, 1, 3):
  800. return 0
  801. return 1
  802. def softmax(inp: Tensor, axis: Optional[int] = None) -> Tensor:
  803. r"""
  804. Applies a :math:`\text{softmax}(x)` function. :math:`\text{softmax}(x)` is defined as:
  805. .. math::
  806. \text{softmax}(x_{i}) = \frac{\exp(x_i)}{\sum_j \exp(x_j)}
  807. It is applied to all elements along axis, and rescales elements so that
  808. they stay in the range `[0, 1]` and sum to 1.
  809. See :class:`~megengine.module.activation.Softmax` for more details.
  810. :param inp: input tensor.
  811. :param axis: an axis along which :math:`\text{softmax}(x)` will be applied. By default,
  812. :math:`\text{softmax}(x)` will apply along the highest ranked axis.
  813. Examples:
  814. .. testcode::
  815. import numpy as np
  816. from megengine import tensor
  817. import megengine.functional as F
  818. x = tensor(np.arange(-5, 5, dtype=np.float32)).reshape(2,5)
  819. out = F.softmax(x)
  820. print(out.numpy().round(decimals=4))
  821. Outputs:
  822. .. testoutput::
  823. [[0.0117 0.0317 0.0861 0.2341 0.6364]
  824. [0.0117 0.0317 0.0861 0.2341 0.6364]]
  825. """
  826. if axis is None:
  827. axis = _get_softmax_axis(len(inp.shape))
  828. offset = inp.max(axis=axis, keepdims=True).detach()
  829. cached = exp(inp - offset)
  830. down = sum(cached, axis=axis, keepdims=True)
  831. return cached / down
  832. def batch_norm(
  833. inp: Tensor,
  834. running_mean: Tensor = None,
  835. running_var: Tensor = None,
  836. weight: Optional[Tensor] = None,
  837. bias: Optional[Tensor] = None,
  838. *,
  839. training: bool = False,
  840. momentum: float = 0.9,
  841. eps: float = 1e-5,
  842. inplace: bool = True,
  843. compute_mode="default"
  844. ):
  845. r"""
  846. Applies batch normalization to the input.
  847. Refer to :class:`~.BatchNorm2d` and :class:`~.BatchNorm1d` for more information.
  848. :param inp: input tensor.
  849. :param running_mean: tensor to store running mean.
  850. :param running_var: tensor to store running variance.
  851. :param weight: scaling tensor in the learnable affine parameters.
  852. See :math:`\gamma` in :class:`~.BatchNorm2d`.
  853. :param bias: bias tensor in the learnable affine parameters.
  854. See :math:`\beta` in :class:`~.BatchNorm2d`.
  855. :param training: a boolean value to indicate whether batch norm is performed
  856. in training mode. Default: False
  857. :param momentum: value used for the ``running_mean`` and ``running_var``
  858. computation.
  859. Default: 0.9
  860. :param eps: a value added to the denominator for numerical stability.
  861. Default: 1e-5
  862. :param inplace: whether to update ``running_mean`` and ``running_var`` inplace or return new tensors
  863. Default: True
  864. :return: output tensor.
  865. """
  866. if inp.ndim != 4:
  867. raise NotImplementedError("batch_norm for ndim != 4")
  868. C = inp.shape[1]
  869. def make_full_if_none(x, value):
  870. if x is None:
  871. (x,) = Const(value, dtype=inp.dtype, device=inp.device)()
  872. shape = astensor1d((1, C, 1, 1), inp, dtype="int32", device=inp.device)
  873. (result,) = apply(builtin.Broadcast(), x, shape)
  874. return result
  875. elif x.ndim == 1:
  876. shape = astensor1d((1, C, 1, 1), inp, dtype="int32", device=inp.device)
  877. (result,) = apply(builtin.Reshape(), x, shape)
  878. return result
  879. return x
  880. has_mean = running_mean is not None
  881. has_var = running_var is not None
  882. if not training:
  883. assert has_mean, "running_mean must be provided in inference mode"
  884. assert has_var, "running_var must be provided in inference mode"
  885. if has_mean and running_mean.ndim != 4:
  886. raise ValueError
  887. if has_var and running_var.ndim != 4:
  888. raise ValueError
  889. if amp._enabled:
  890. inp = inp.astype("float16")
  891. weight, bias, running_mean, running_var = cast_tensors(
  892. weight, bias, running_mean, running_var, promote=True
  893. )
  894. weight = make_full_if_none(weight, 1)
  895. bias = make_full_if_none(bias, 0)
  896. if not training:
  897. op = builtin.BatchNorm(
  898. fwd_mode=BatchNorm.FwdMode.INFERENCE, epsilon=eps, param_dim="dim_1c11"
  899. )
  900. ret = apply(op, inp, weight, bias, running_mean, running_var)[-1]
  901. return ret
  902. else:
  903. op = builtin.BatchNorm(
  904. avg_factor=1 - momentum, epsilon=eps, param_dim="dim_1c11"
  905. )
  906. if has_mean or has_var:
  907. running_mean = make_full_if_none(running_mean, 0)
  908. running_var = make_full_if_none(running_var, 1)
  909. new_mean, new_var, _, _, inp = apply(
  910. op, inp, weight, bias, running_mean, running_var
  911. )
  912. if not has_mean:
  913. new_mean = None
  914. if not has_var:
  915. new_var = None
  916. if inplace:
  917. if has_mean:
  918. running_mean[...] = new_mean
  919. if has_var:
  920. running_var[...] = new_var
  921. return inp
  922. else:
  923. return inp, new_mean, new_var
  924. else:
  925. (_, _, inp,) = apply(op, inp, weight, bias)
  926. return inp
  927. def sync_batch_norm(
  928. inp: Tensor,
  929. running_mean: Tensor,
  930. running_var: Tensor,
  931. weight: Optional[Tensor] = None,
  932. bias: Optional[Tensor] = None,
  933. training: bool = False,
  934. momentum: Union[float, Tensor] = 0.9,
  935. eps: float = 1e-5,
  936. eps_mode="additive",
  937. group=WORLD,
  938. ) -> Tensor:
  939. r"""
  940. Applies synchronized batch normalization to the input.
  941. Refer to :class:`~.BatchNorm2d` and :class:`~.BatchNorm1d` for more information.
  942. :param inp: input tensor.
  943. :param running_mean: tensor to store running mean.
  944. :param running_var: tensor to store running variance.
  945. :param weight: scaling tensor in the learnable affine parameters.
  946. See :math:`\gamma` in :class:`~.BatchNorm2d`.
  947. :param bias: bias tensor in the learnable affine parameters.
  948. See :math:`\beta` in :class:`~.BatchNorm2d`.
  949. :param training: a boolean value to indicate whether batch norm is performed
  950. in traning mode. Default: False
  951. :param momentum: value used for the ``running_mean`` and ``running_var``
  952. computation.
  953. Default: 0.9
  954. :param eps: a value added to the denominator for numerical stability.
  955. Default: 1e-5
  956. :param eps_mode: mode of calculation for eps, "max" or "additive".
  957. Default: "additive"
  958. :param group: communication group, caculate mean and variance between this group.
  959. Default: :obj:`~megengine.distributed.WORLD`
  960. :return: output tensor.
  961. """
  962. assert eps_mode.lower() in {"max", "additive"}, "unknown eps_mode: {}".format(
  963. eps_mode
  964. )
  965. _channels = inp.shape[1]
  966. _ndim = inp.ndim
  967. _device = inp.device
  968. _dtype = inp.dtype
  969. _param_shape = (1, _channels) + (1,) * (_ndim - 2)
  970. _reduce_axis = [0] + [i for i in range(2, _ndim)]
  971. if training:
  972. def _sum_on_channel(inp):
  973. return inp.sum(axis=_reduce_axis, keepdims=True)
  974. reduce_size = inp.shape[0]
  975. for i in range(2, _ndim):
  976. reduce_size = reduce_size * inp.shape[i]
  977. channel_x1s = _sum_on_channel(inp)
  978. channel_x2s = _sum_on_channel(inp ** 2)
  979. if is_distributed():
  980. # reduce all nodes' data to calculate mean and variance
  981. reduce_size = broadcast_to(
  982. Tensor(reduce_size).astype(dtype=_dtype), [1] * _ndim
  983. )
  984. stat = concat([reduce_size, channel_x1s, channel_x2s], axis=1)
  985. stat = all_reduce_sum(stat, group)
  986. reduce_size = stat[:, :1].reshape(1)
  987. channel_x1s = stat[:, 1 : 1 + _channels]
  988. channel_x2s = stat[:, 1 + _channels :]
  989. channel_mean = channel_x1s / reduce_size
  990. channel_variance = (
  991. channel_x1s ** 2 / (-reduce_size * reduce_size) + channel_x2s / reduce_size
  992. )
  993. else:
  994. assert running_var is not None and running_mean is not None
  995. channel_variance = running_var.reshape(*_param_shape)
  996. channel_mean = running_mean.reshape(*_param_shape)
  997. invsqrt_channel_variance = (
  998. maximum(channel_variance, eps) if eps_mode == "max" else channel_variance + eps
  999. ) ** -0.5
  1000. if weight is not None:
  1001. weight = weight.reshape(*_param_shape)
  1002. if bias is not None:
  1003. bias = bias.reshape(*_param_shape)
  1004. # outvar = output * weight + bias
  1005. # where output = inp * invsqrt_channel_variance + (
  1006. # -channel_mean * invsqrt_channel_variance
  1007. # )
  1008. # Manually expand output for gopt
  1009. if weight is not None:
  1010. inv_var_wt = invsqrt_channel_variance * weight
  1011. neg_channel_mean = -channel_mean
  1012. if bias is not None:
  1013. outvar = inp * inv_var_wt + (neg_channel_mean * inv_var_wt + bias)
  1014. else:
  1015. outvar = inp * inv_var_wt + neg_channel_mean * inv_var_wt
  1016. else:
  1017. outvar = inp * invsqrt_channel_variance + (
  1018. -channel_mean * invsqrt_channel_variance
  1019. )
  1020. if bias is not None:
  1021. outvar = outvar + bias
  1022. if training and running_var is not None and running_mean is not None:
  1023. running_mean *= momentum
  1024. running_mean += (1 - momentum) * channel_mean
  1025. channel_variance_unbiased = channel_x1s ** 2 / (
  1026. -reduce_size * (reduce_size - 1)
  1027. ) + channel_x2s / (reduce_size - 1)
  1028. running_var *= momentum
  1029. running_var += (1 - momentum) * channel_variance_unbiased
  1030. return outvar
  1031. def dropout(inp: Tensor, drop_prob: float, training: bool = True) -> Tensor:
  1032. """
  1033. Returns a new tensor where each of the elements are randomly set to zero
  1034. with probability P = ``drop_prob``. Optionally rescale the output tensor if ``training`` is True.
  1035. :param inp: input tensor.
  1036. :param drop_prob: probability to drop (set to zero) a single element.
  1037. :param training: the default behavior of ``dropout`` during training is to rescale the output,
  1038. then it can be replaced by an :class:`~.Identity` during inference. Default: True
  1039. :return: the output tensor
  1040. Examples:
  1041. .. testcode::
  1042. import numpy as np
  1043. from megengine import tensor
  1044. import megengine.functional as F
  1045. x = tensor(np.ones(10, dtype=np.float32))
  1046. out = F.dropout(x, 1./3.)
  1047. print(out.numpy())
  1048. Outputs:
  1049. .. testoutput::
  1050. :options: +SKIP
  1051. [1.5 1.5 0. 1.5 1.5 1.5 1.5 1.5 1.5 1.5]
  1052. """
  1053. assert 0 <= drop_prob < 1
  1054. rv = uniform(size=inp.shape)
  1055. mask = rv > drop_prob
  1056. inp *= mask.astype(inp.dtype)
  1057. if training:
  1058. inp *= 1 / (1 - drop_prob)
  1059. return inp
  1060. def one_hot(inp: Tensor, num_classes: int) -> Tensor:
  1061. r"""
  1062. Performs one-hot encoding for the input tensor.
  1063. :param inp: input tensor.
  1064. :param num_classes: number of classes denotes the last dimension of the output tensor.
  1065. :return: output tensor.
  1066. Examples:
  1067. .. testcode::
  1068. import numpy as np
  1069. from megengine import tensor
  1070. import megengine.functional as F
  1071. x = tensor(np.arange(1, 4, dtype=np.int32))
  1072. out = F.one_hot(x, num_classes=4)
  1073. print(out.numpy())
  1074. Outputs:
  1075. .. testoutput::
  1076. [[0 1 0 0]
  1077. [0 0 1 0]
  1078. [0 0 0 1]]
  1079. """
  1080. zeros_tensor = zeros(list(inp.shape) + [num_classes], inp.dtype, inp.device)
  1081. ones_tensor = ones(list(inp.shape) + [1], inp.dtype, inp.device)
  1082. op = builtin.IndexingSetOneHot(axis=inp.ndim)
  1083. (result,) = apply(op, zeros_tensor, inp, ones_tensor)
  1084. return result
  1085. def embedding(
  1086. inp: Tensor,
  1087. weight: Tensor,
  1088. padding_idx: Optional[int] = None,
  1089. max_norm: Optional[float] = None,
  1090. norm_type: Optional[float] = None,
  1091. ):
  1092. """
  1093. Applies lookup table for embedding.
  1094. :param inp: tensor with indices.
  1095. :param weight: learnable weights which embeds from.
  1096. :param padding_idx: should be set to None, not supported now.
  1097. :param max_norm: should be set to None, not supported now.
  1098. :param norm_type: should be set to None, not supported now.
  1099. :return: output tensor.
  1100. Refer to :class:`~.Embedding` for more information.
  1101. """
  1102. if padding_idx is not None:
  1103. raise ValueError("Not support padding_idx Now!")
  1104. if max_norm is not None or norm_type is not None:
  1105. raise ValueError("Not support weight normlization Now!")
  1106. dest_shp = list(inp.shape) + [weight.shape[-1]]
  1107. return weight[inp.reshape(-1)].reshape(dest_shp)
  1108. def indexing_one_hot(
  1109. src: Tensor, index: Tensor, axis: int = 1, keepdims=False
  1110. ) -> Tensor:
  1111. r"""
  1112. One-hot indexing for some axes.
  1113. :param src: input tensor.
  1114. :param index: index tensor.
  1115. :param axis: axis on src for which values in index index. Default: 1
  1116. :param keepdims: whether not to remove the axis in result. Default: False
  1117. :return: output tensor.
  1118. Examples:
  1119. .. testcode::
  1120. import megengine.functional as F
  1121. from megengine import tensor
  1122. src = tensor([[1.0, 2.0]])
  1123. index = tensor([0])
  1124. val = F.indexing_one_hot(src, index)
  1125. print(val.numpy())
  1126. Outputs:
  1127. .. testoutput::
  1128. [1.]
  1129. """
  1130. assert isinstance(src, Tensor), "src must be of Tensor type"
  1131. op = builtin.IndexingOneHot(axis=axis)
  1132. index = convert_single_value(index, dtype="int32", device=src.device)
  1133. (result,) = apply(op, src, index)
  1134. if not keepdims:
  1135. result = squeeze(result, axis)
  1136. return result
  1137. def sliding_window(
  1138. inp: Tensor,
  1139. kernel_size: Union[int, Tuple[int, int]],
  1140. padding: Union[int, Tuple[int, int]] = 0,
  1141. stride: Union[int, Tuple[int, int]] = 1,
  1142. dilation: Union[int, Tuple[int, int]] = 1,
  1143. ) -> Tensor:
  1144. """
  1145. Extracts sliding local blocks from a batched input tensor.
  1146. Refer to :class:`~.SlidingWindow` for more information.
  1147. :param inp: input tensor.
  1148. :param kernel_size: size of the window.
  1149. :param padding: implicit zero padding added on both sides of input. Default: 0
  1150. :param stride: stride of the window. Default: 1
  1151. :param dilation: dilation of the window. Default: 1
  1152. :return: output tensor.
  1153. """
  1154. padding_h, padding_w = _pair(padding)
  1155. stride_h, stride_w = _pair_nonzero(stride)
  1156. dilation_h, dilation_w = _pair_nonzero(dilation)
  1157. window_h, window_w = _pair_nonzero(kernel_size)
  1158. op = builtin.Images2Neibs(
  1159. pad_h=padding_h,
  1160. pad_w=padding_w,
  1161. stride_h=stride_h,
  1162. stride_w=stride_w,
  1163. dilate_h=dilation_h,
  1164. dilate_w=dilation_w,
  1165. window_h=window_h,
  1166. window_w=window_w,
  1167. )
  1168. (output,) = apply(op, inp)
  1169. return output
  1170. def sliding_window_transpose(
  1171. inp: Tensor,
  1172. output_size: Union[int, Tuple[int, int]],
  1173. kernel_size: Union[int, Tuple[int, int]],
  1174. padding: Union[int, Tuple[int, int]] = 0,
  1175. stride: Union[int, Tuple[int, int]] = 1,
  1176. dilation: Union[int, Tuple[int, int]] = 1,
  1177. ) -> Tensor:
  1178. """
  1179. Sum over the sliding windows on the corresponding input location.
  1180. Refer to :class:`~.SlidingWindowTranspose` for more information.
  1181. :param inp: input tensor.
  1182. :param output_size: shape of output tensor.
  1183. :param kernel_size: size of the window.
  1184. :param padding: implicit zero padding added on both sides of input. Default: 0
  1185. :param stride: stride of the window. Default: 1
  1186. :param dilation: dilation of the window. Default: 1
  1187. :return: output tensor.
  1188. """
  1189. output_h, output_w = _pair_nonzero(output_size)
  1190. padding_h, padding_w = _pair(padding)
  1191. stride_h, stride_w = _pair_nonzero(stride)
  1192. dilation_h, dilation_w = _pair_nonzero(dilation)
  1193. window_h, window_w = _pair_nonzero(kernel_size)
  1194. expected_h = (
  1195. output_h + 2 * padding_h - dilation_h * (window_h - 1) - 1
  1196. ) // stride_h + 1
  1197. expected_w = (
  1198. output_w + 2 * padding_w - dilation_w * (window_w - 1) - 1
  1199. ) // stride_w + 1
  1200. assert inp.ndim == 6, "the input dimension of sliding_window_transpose should be 6"
  1201. assert (
  1202. inp.shape[2] == expected_h and inp.shape[3] == expected_w
  1203. ), "the input shape and output size do not match"
  1204. op = builtin.SlidingWindowTranspose(
  1205. out_h=output_h,
  1206. out_w=output_w,
  1207. pad_h=padding_h,
  1208. pad_w=padding_w,
  1209. stride_h=stride_h,
  1210. stride_w=stride_w,
  1211. dilate_h=dilation_h,
  1212. dilate_w=dilation_w,
  1213. window_h=window_h,
  1214. window_w=window_w,
  1215. )
  1216. (output,) = apply(op, inp)
  1217. return output
  1218. interpolate = deprecated_func("1.3", "megengine.functional.vision", "interpolate", True)
  1219. roi_pooling = deprecated_func("1.3", "megengine.functional.vision", "roi_pooling", True)
  1220. roi_align = deprecated_func("1.3", "megengine.functional.vision", "roi_align", True)
  1221. nms = deprecated_func("1.3", "megengine.functional.vision", "nms", True)
  1222. resize = deprecated_func("1.3", "megengine.functional.vision", "resize", True)
  1223. remap = deprecated_func("1.3", "megengine.functional.vision", "remap", True)
  1224. nvof = deprecated_func("1.3", "megengine.functional.vision", "nvof", True)
  1225. warp_affine = deprecated_func("1.3", "megengine.functional.vision", "warp_affine", True)
  1226. warp_perspective = deprecated_func(
  1227. "1.3", "megengine.functional.vision", "warp_perspective", True
  1228. )
  1229. from .quantized import conv_bias_activation # isort:skip
  1230. from .loss import * # isort:skip

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台