You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

nn.py 61 kB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553155415551556155715581559156015611562156315641565156615671568156915701571157215731574157515761577157815791580158115821583158415851586158715881589159015911592159315941595159615971598159916001601160216031604160516061607160816091610161116121613161416151616161716181619162016211622162316241625162616271628162916301631163216331634163516361637163816391640164116421643164416451646164716481649165016511652165316541655165616571658165916601661166216631664166516661667166816691670167116721673167416751676167716781679168016811682168316841685168616871688168916901691169216931694169516961697169816991700170117021703170417051706170717081709171017111712171317141715171617171718171917201721172217231724172517261727172817291730173117321733173417351736173717381739174017411742174317441745174617471748174917501751175217531754175517561757175817591760176117621763176417651766176717681769177017711772177317741775177617771778177917801781178217831784178517861787178817891790179117921793179417951796179717981799180018011802180318041805180618071808180918101811181218131814181518161817181818191820182118221823182418251826182718281829183018311832183318341835183618371838183918401841184218431844184518461847184818491850185118521853185418551856185718581859186018611862186318641865
  1. # -*- coding: utf-8 -*-
  2. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  3. #
  4. # Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  5. #
  6. # Unless required by applicable law or agreed to in writing,
  7. # software distributed under the License is distributed on an
  8. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  9. # pylint: disable=too-many-lines
  10. from functools import lru_cache
  11. from typing import NamedTuple, Optional, Sequence, Tuple, Union
  12. from ..core import _config
  13. from ..core._imperative_rt.core2 import apply, dtype_promotion
  14. from ..core._imperative_rt.ops import SubgraphBuilder as _SubgraphBuilder
  15. from ..core._imperative_rt.ops import get_global_rng_seed as _get_global_rng_seed
  16. from ..core.ops import builtin
  17. from ..core.ops.builtin import (
  18. BatchNorm,
  19. Dimshuffle,
  20. Dropout,
  21. Elemwise,
  22. GetVarShape,
  23. Identity,
  24. Reduce,
  25. Reshape,
  26. TypeCvt,
  27. )
  28. from ..core.ops.special import Const
  29. from ..core.tensor import amp, megbrain_graph
  30. from ..core.tensor.array_method import _elwise_apply
  31. from ..core.tensor.utils import (
  32. astensor1d,
  33. astype,
  34. cast_tensors,
  35. convert_single_value,
  36. make_shape_tuple,
  37. subgraph,
  38. )
  39. from ..device import get_default_device
  40. from ..distributed import WORLD, is_distributed
  41. from ..jit import exclude_from_trace
  42. from ..tensor import Tensor
  43. from ..utils.deprecation import deprecated_func
  44. from ..utils.tuple_function import _pair, _pair_nonzero, _triple, _triple_nonzero
  45. from .debug_param import get_execution_strategy
  46. from .distributed import all_reduce_sum
  47. from .elemwise import _elwise, exp, log, log1p, maximum, minimum
  48. from .math import matmul, max, sum
  49. from .tensor import broadcast_to, concat, expand_dims, ones, squeeze, zeros
  50. __all__ = [
  51. "adaptive_avg_pool2d",
  52. "adaptive_max_pool2d",
  53. "avg_pool2d",
  54. "batch_norm",
  55. "conv1d",
  56. "conv2d",
  57. "conv3d",
  58. "conv_transpose2d",
  59. "conv_transpose3d",
  60. "deformable_conv2d",
  61. "deformable_psroi_pooling",
  62. "dropout",
  63. "embedding",
  64. "gelu",
  65. "hsigmoid",
  66. "hswish",
  67. "indexing_one_hot",
  68. "leaky_relu",
  69. "linear",
  70. "local_conv2d",
  71. "local_response_norm",
  72. "logsigmoid",
  73. "logsumexp",
  74. "logsoftmax",
  75. "max_pool2d",
  76. "one_hot",
  77. "prelu",
  78. "pad",
  79. "relu",
  80. "relu6",
  81. "remap",
  82. "sigmoid",
  83. "sliding_window",
  84. "sliding_window_transpose",
  85. "silu",
  86. "softmax",
  87. "softplus",
  88. "sync_batch_norm",
  89. "warp_affine",
  90. "warp_perspective",
  91. "pixel_shuffle",
  92. ]
  93. def expand_hw(x):
  94. # NOTE: >1d array is accepted, as long as 1 <= size <= 2
  95. try:
  96. x = int(x)
  97. return [x, x]
  98. except (TypeError, ValueError):
  99. pass
  100. h, w = x
  101. return int(h), int(w)
  102. def linear(
  103. inp: Tensor, weight: Tensor, bias: Optional[Tensor] = None, compute_mode="default",
  104. ) -> Tensor:
  105. r"""Applies a linear transformation to the input tensor.
  106. Refer to :class:`~.module.linear.Linear` for more information.
  107. Args:
  108. inp: input tensor with shape `(N, in_features)`.
  109. weight: weight with shape `(out_features, in_features)`.
  110. bias: bias with shape `(out_features,)`. Default: None
  111. """
  112. compute_mode = _config._get_actual_op_param(compute_mode, _config.__compute_mode)
  113. ret = matmul(inp, weight, transpose_b=True, compute_mode=compute_mode)
  114. if bias is not None:
  115. if amp._enabled:
  116. bias = bias.astype("float16")
  117. ret += bias
  118. return ret
  119. def conv1d(
  120. inp: Tensor,
  121. weight: Tensor,
  122. bias: Optional[Tensor] = None,
  123. stride: int = 1,
  124. padding: int = 0,
  125. dilation: int = 1,
  126. groups: int = 1,
  127. conv_mode="cross_correlation",
  128. compute_mode="default",
  129. ) -> Tensor:
  130. r"""1D convolution operation.
  131. Refer to :class:`~.Conv1d` for more information.
  132. Args:
  133. inp: The feature map of the convolution operation
  134. weight: The convolution kernel.
  135. bias: The bias added to the result of convolution (if given)
  136. stride: Stride of the 1D convolution operation. Default: 1
  137. padding: Size of the paddings added to the input on both sides of its
  138. spatial dimensions. Only zero-padding is supported. Default: 0
  139. dilation: Dilation of the 1D convolution operation. Default: 1
  140. groups: number of groups to divide input and output channels into,
  141. so as to perform a "grouped convolution". When ``groups`` is not 1,
  142. ``in_channels`` and ``out_channels`` must be divisible by ``groups``,
  143. and the shape of weight should be ``(groups, out_channel // groups,
  144. in_channels // groups, kernel_size)``. Default: 1
  145. conv_mode: Supports 'cross_correlation'. Default:
  146. 'cross_correlation'.
  147. compute_mode: When set to 'default', no special requirements will be
  148. placed on the precision of intermediate results. When set to 'float32',
  149. float32 would be used for accumulator and intermediate result, but only
  150. effective when input and output are of float16 dtype.
  151. """
  152. assert (
  153. conv_mode.lower() == "cross_correlation"
  154. or conv_mode.name == "CROSS_CORRELATION"
  155. )
  156. assert compute_mode.lower() == "default" or compute_mode.name == "DEFAULT"
  157. assert inp.ndim == 3, "the input dimension of conv1d should be 3"
  158. assert weight.ndim == 3, "the weight dimension of conv1d should be 3"
  159. if amp._enabled:
  160. compute_mode = "float32"
  161. inp, weight, bias = cast_tensors(inp, weight, bias)
  162. else:
  163. dtype = dtype_promotion(inp, weight)
  164. if inp.dtype != dtype:
  165. inp = inp.astype(dtype)
  166. if weight.dtype != dtype:
  167. weight = weight.astype(dtype)
  168. inp = expand_dims(inp, 3)
  169. weight = expand_dims(weight, 3)
  170. if bias is not None:
  171. assert bias.ndim == 3, "the bias dimension of conv1d should be 3"
  172. bias = expand_dims(bias, 3)
  173. stride_h = stride
  174. pad_h = padding
  175. dilate_h = dilation
  176. compute_mode = _config._get_actual_op_param(compute_mode, _config.__compute_mode)
  177. conv_format = _config._get_actual_op_param("NCHW", _config.__conv_format)
  178. sparse_type = "dense" if groups == 1 else "group"
  179. op = builtin.Convolution(
  180. stride_h=stride_h,
  181. stride_w=1,
  182. pad_h=pad_h,
  183. pad_w=0,
  184. dilate_h=dilate_h,
  185. dilate_w=1,
  186. strategy=get_execution_strategy(),
  187. mode=conv_mode,
  188. compute_mode=compute_mode,
  189. sparse=sparse_type,
  190. format=conv_format,
  191. )
  192. (output,) = apply(op, inp, weight)
  193. if bias is not None:
  194. output += bias
  195. output = squeeze(output, 3)
  196. return output
  197. def conv2d(
  198. inp: Tensor,
  199. weight: Tensor,
  200. bias: Optional[Tensor] = None,
  201. stride: Union[int, Tuple[int, int]] = 1,
  202. padding: Union[int, Tuple[int, int]] = 0,
  203. dilation: Union[int, Tuple[int, int]] = 1,
  204. groups: int = 1,
  205. conv_mode="cross_correlation",
  206. compute_mode="default",
  207. ) -> Tensor:
  208. r"""2D convolution operation.
  209. Refer to :class:`~.module.Conv2d` for more information.
  210. Args:
  211. inp: feature map of the convolution operation.
  212. weight: convolution kernel.
  213. bias: bias added to the result of convolution (if given).
  214. stride: stride of the 2D convolution operation. Default: 1
  215. padding: size of the paddings added to the input on both sides of its
  216. spatial dimensions. Only zero-padding is supported. Default: 0
  217. dilation: dilation of the 2D convolution operation. Default: 1
  218. groups: number of groups into which the input and output channels are divided,
  219. so as to perform a ``grouped convolution``. When ``groups`` is not 1,
  220. ``in_channels`` and ``out_channels`` must be divisible by ``groups``,
  221. and the shape of weight should be ``(groups, out_channel // groups,
  222. in_channels // groups, height, width)``. Default: 1
  223. conv_mode: supports "cross_correlation". Default: "cross_correlation"
  224. compute_mode: when set to "default", no special requirements will be
  225. placed on the precision of intermediate results. When set to "float32",
  226. "float32" would be used for accumulator and intermediate result, but only
  227. effective when input and output are of float16 dtype.
  228. Returns:
  229. output tensor.
  230. """
  231. assert (
  232. conv_mode.lower() == "cross_correlation"
  233. or conv_mode.name == "CROSS_CORRELATION"
  234. )
  235. if amp._enabled:
  236. compute_mode = "float32"
  237. inp, weight, bias = cast_tensors(inp, weight, bias)
  238. else:
  239. dtype = dtype_promotion(inp, weight)
  240. if inp.dtype != dtype:
  241. inp = inp.astype(dtype)
  242. if weight.dtype != dtype:
  243. weight = weight.astype(dtype)
  244. stride_h, stride_w = expand_hw(stride)
  245. pad_h, pad_w = expand_hw(padding)
  246. dilate_h, dilate_w = expand_hw(dilation)
  247. sparse_type = "dense" if groups == 1 else "group"
  248. compute_mode = _config._get_actual_op_param(compute_mode, _config.__compute_mode)
  249. conv_format = _config._get_actual_op_param("NCHW", _config.__conv_format)
  250. op = builtin.Convolution(
  251. stride_h=stride_h,
  252. stride_w=stride_w,
  253. pad_h=pad_h,
  254. pad_w=pad_w,
  255. dilate_h=dilate_h,
  256. dilate_w=dilate_w,
  257. strategy=get_execution_strategy(),
  258. mode=conv_mode,
  259. compute_mode=compute_mode,
  260. sparse=sparse_type,
  261. format=conv_format,
  262. )
  263. (output,) = apply(op, inp, weight)
  264. if bias is not None:
  265. output += bias
  266. return output
  267. def conv3d(
  268. inp: Tensor,
  269. weight: Tensor,
  270. bias: Optional[Tensor] = None,
  271. stride: Union[int, Tuple[int, int, int]] = 1,
  272. padding: Union[int, Tuple[int, int, int]] = 0,
  273. dilation: Union[int, Tuple[int, int, int]] = 1,
  274. groups: int = 1,
  275. conv_mode: str = "cross_correlation",
  276. ) -> Tensor:
  277. r"""3D convolution operation.
  278. Refer to :class:`~.Conv3d` for more information.
  279. Args:
  280. inp: feature map of the convolution operation.
  281. weight: convolution kernel.
  282. bias: bias added to the result of convolution (if given).
  283. stride: stride of the 3D convolution operation. Default: 1
  284. padding: size of the paddings added to the input on both sides of its
  285. spatial dimensions. Only zero-padding is supported. Default: 0
  286. dilation: dilation of the 3D convolution operation. Default: 1
  287. groups: number of groups into which the input and output channels are divided,
  288. so as to perform a ``grouped convolution``. When ``groups`` is not 1,
  289. ``in_channels`` and ``out_channels`` must be divisible by ``groups``,
  290. and the shape of weight should be ``(groups, out_channel // groups,
  291. in_channels // groups, depth, height, width)``. Default: 1
  292. conv_mode: supports "cross_correlation". Default: "cross_correlation"
  293. Returns:
  294. output tensor.
  295. """
  296. assert conv_mode.lower() == "cross_correlation"
  297. D, H, W = 0, 1, 2
  298. pad = _triple(padding)
  299. stride = _triple_nonzero(stride)
  300. dilate = _triple_nonzero(dilation)
  301. dtype = dtype_promotion(inp, weight)
  302. if inp.dtype != dtype:
  303. inp = inp.astype(dtype)
  304. if weight.dtype != dtype:
  305. weight = weight.astype(dtype)
  306. sparse_type = "dense" if groups == 1 else "group"
  307. op = builtin.Convolution3D(
  308. pad_d=pad[D],
  309. pad_h=pad[H],
  310. pad_w=pad[W],
  311. stride_d=stride[D],
  312. stride_h=stride[H],
  313. stride_w=stride[W],
  314. dilate_d=dilate[D],
  315. dilate_h=dilate[H],
  316. dilate_w=dilate[W],
  317. strategy=get_execution_strategy(),
  318. mode=conv_mode,
  319. sparse=sparse_type,
  320. )
  321. (output,) = apply(op, inp, weight)
  322. if bias is not None:
  323. output += bias
  324. return output
  325. def conv_transpose2d(
  326. inp: Tensor,
  327. weight: Tensor,
  328. bias: Optional[Tensor] = None,
  329. stride: Union[int, Tuple[int, int]] = 1,
  330. padding: Union[int, Tuple[int, int]] = 0,
  331. dilation: Union[int, Tuple[int, int]] = 1,
  332. groups: int = 1,
  333. conv_mode="cross_correlation",
  334. compute_mode="default",
  335. ) -> Tensor:
  336. r"""2D transposed convolution operation.
  337. Refer to :class:`~.module.conv.ConvTranspose2d` for more information.
  338. Args:
  339. inp: feature map of the convolution operation.
  340. weight: convolution kernel.
  341. weight usually has shape ``(in_channels, out_channels, height, width)``.
  342. bias: bias added to the result of convolution (if given).
  343. stride: stride of the 2D convolution operation. Default: 1
  344. padding: size of the paddings added to the input on both sides of its
  345. spatial dimensions. Only zero-padding is supported. Default: 0
  346. dilation: dilation of the 2D convolution operation. Default: 1
  347. groups: number of groups into which the input and output channels are divided,
  348. so as to perform a ``grouped convolution``. When ``groups`` is not 1,
  349. ``in_channels`` and ``out_channels`` must be divisible by groups,
  350. and the shape of weight should be ``(groups, in_channels // groups,
  351. out_channels // groups, height, width)``. Default: 1
  352. conv_mode: supports "cross_correlation". Default: "cross_correlation"
  353. compute_mode: when set to "default", no special requirements will be
  354. placed on the precision of intermediate results. When set to "float32",
  355. "float32" would be used for accumulator and intermediate result, but only
  356. effective when input and output are of float16 dtype.
  357. Returns:
  358. output tensor.
  359. """
  360. assert (
  361. conv_mode.lower() == "cross_correlation"
  362. or conv_mode.name == "CROSS_CORRELATION"
  363. )
  364. if amp._enabled:
  365. compute_mode = "float32"
  366. inp, weight, bias = cast_tensors(inp, weight, bias)
  367. else:
  368. dtype = dtype_promotion(inp, weight)
  369. if inp.dtype != dtype:
  370. inp = inp.astype(dtype)
  371. if weight.dtype != dtype:
  372. weight = weight.astype(dtype)
  373. stride_h, stride_w = expand_hw(stride)
  374. pad_h, pad_w = expand_hw(padding)
  375. dilate_h, dilate_w = expand_hw(dilation)
  376. compute_mode = _config._get_actual_op_param(compute_mode, _config.__compute_mode)
  377. sparse_type = "dense" if groups == 1 else "group"
  378. op = builtin.ConvolutionBackwardData(
  379. stride_h=stride_h,
  380. stride_w=stride_w,
  381. pad_h=pad_h,
  382. pad_w=pad_w,
  383. dilate_h=dilate_h,
  384. dilate_w=dilate_w,
  385. strategy=get_execution_strategy(),
  386. compute_mode=compute_mode,
  387. sparse=sparse_type,
  388. )
  389. (output,) = apply(op, weight, inp)
  390. if bias is not None:
  391. output += bias
  392. return output
  393. def deformable_conv2d(
  394. inp: Tensor,
  395. weight: Tensor,
  396. offset: Tensor,
  397. mask: Tensor,
  398. bias: Optional[Tensor] = None,
  399. stride: Union[int, Tuple[int, int]] = 1,
  400. padding: Union[int, Tuple[int, int]] = 0,
  401. dilation: Union[int, Tuple[int, int]] = 1,
  402. groups: int = 1,
  403. conv_mode="cross_correlation",
  404. compute_mode="default",
  405. ) -> Tensor:
  406. r"""Deformable Convolution.
  407. Args:
  408. inp: input feature map.
  409. weight: convolution kernel.
  410. weight usually has shape ``(out_channels, in_channels, height, width)``.
  411. offset: input offset to kernel, channel of this tensor should match the deformable settings.
  412. mask: input mask to kernel, channel of this tensor should match the deformable settings.
  413. bias: bias added to the result of convolution (if given).
  414. stride: stride of the 2D convolution operation. Default: 1
  415. padding: size of the paddings added to the input on both sides of its
  416. spatial dimensions. Only zero-padding is supported. Default: 0
  417. dilation: dilation of the 2D convolution operation. Default: 1
  418. groups: number of groups into which the input and output channels are divided,
  419. so as to perform a ``grouped convolution``. When ``groups`` is not 1,
  420. ``in_channels`` and ``out_channels`` must be divisible by groups,
  421. and the shape of weight should be ``(groups, out_channel // groups,
  422. in_channels // groups, height, width)``. Default: 1
  423. conv_mode: supports "cross_correlation". Default: "cross_correlation"
  424. compute_mode: when set to "default", no special requirements will be
  425. placed on the precision of intermediate results. When set to "float32",
  426. "float32" would be used for accumulator and intermediate result, but only
  427. effective when input and output are of float16 dtype.
  428. Returns:
  429. output tensor.
  430. """
  431. assert (
  432. conv_mode.lower() == "cross_correlation"
  433. or conv_mode.name == "CROSS_CORRELATION"
  434. )
  435. if amp._enabled:
  436. compute_mode = "float32"
  437. inp, weight, offset, mask, bias = cast_tensors(inp, weight, offset, mask, bias)
  438. else:
  439. offset = offset.astype("float32")
  440. mask = mask.astype("float32")
  441. stride_h, stride_w = expand_hw(stride)
  442. pad_h, pad_w = expand_hw(padding)
  443. dilate_h, dilate_w = expand_hw(dilation)
  444. compute_mode = _config._get_actual_op_param(compute_mode, _config.__compute_mode)
  445. sparse_type = "dense" if groups == 1 else "group"
  446. op = builtin.DeformableConv(
  447. stride_h=stride_h,
  448. stride_w=stride_w,
  449. pad_h=pad_h,
  450. pad_w=pad_w,
  451. dilate_h=dilate_h,
  452. dilate_w=dilate_w,
  453. strategy=get_execution_strategy(),
  454. mode=conv_mode,
  455. compute_mode=compute_mode,
  456. sparse=sparse_type,
  457. )
  458. (output,) = apply(op, inp, weight, offset, mask)
  459. if bias is not None:
  460. output += bias
  461. return output
  462. def local_conv2d(
  463. inp: Tensor,
  464. weight: Tensor,
  465. bias: Optional[Tensor] = None,
  466. stride: Union[int, Tuple[int, int]] = 1,
  467. padding: Union[int, Tuple[int, int]] = 0,
  468. dilation: Union[int, Tuple[int, int]] = 1,
  469. conv_mode="cross_correlation",
  470. ):
  471. r"""Applies a spatial convolution with untied kernels over an groupped channeled input 4D tensor.
  472. It is also known as the locally connected layer.
  473. Args:
  474. inp: input feature map.
  475. weight: convolution kernel.
  476. weight usually has shape ``(out_channels, in_channels, height, width)``.
  477. bias: bias added to the result of convolution (if given).
  478. stride: stride of the 2D convolution operation. Default: 1
  479. padding: size of the paddings added to the input on both sides of its
  480. spatial dimensions. Only zero-padding is supported. Default: 0
  481. dilation: dilation of the 2D convolution operation. Default: 1
  482. Returns:
  483. output tensor.
  484. """
  485. assert (
  486. conv_mode.lower() == "cross_correlation"
  487. or conv_mode.name == "CROSS_CORRELATION"
  488. )
  489. stride_h, stride_w = expand_hw(stride)
  490. pad_h, pad_w = expand_hw(padding)
  491. dilate_h, dilate_w = expand_hw(dilation)
  492. dtype = dtype_promotion(inp, weight)
  493. if inp.dtype != dtype:
  494. inp = inp.astype(dtype)
  495. if weight.dtype != dtype:
  496. weight = weight.astype(dtype)
  497. # local conv only support "dense" mode, but weight could contain group dimension.
  498. op = builtin.GroupLocal(
  499. stride_h=stride_h,
  500. stride_w=stride_w,
  501. pad_h=pad_h,
  502. pad_w=pad_w,
  503. dilate_h=dilate_h,
  504. dilate_w=dilate_w,
  505. mode=conv_mode,
  506. sparse="dense",
  507. )
  508. (output,) = apply(op, inp, weight)
  509. if bias is not None:
  510. output += bias
  511. return output
  512. def conv_transpose3d(
  513. inp: Tensor,
  514. weight: Tensor,
  515. bias: Optional[Tensor] = None,
  516. stride: Union[int, Tuple[int, int, int]] = 1,
  517. padding: Union[int, Tuple[int, int, int]] = 0,
  518. dilation: Union[int, Tuple[int, int, int]] = 1,
  519. groups: int = 1,
  520. ) -> Tensor:
  521. r"""3D transposed convolution operation. Only support the case that groups = 1
  522. and conv_mode = "cross_correlation".
  523. Refer to :class:`~.ConvTranspose3d` for more information.
  524. Args:
  525. inp: feature map of the convolution operation.
  526. weight: convolution kernel.
  527. weight usually has shape ``(in_channels, out_channels, depth, height, width)``.
  528. bias: bias added to the result of convolution (if given).
  529. stride: stride of the 3D convolution operation. Default: 1
  530. padding: size of the paddings added to the input on all sides of its
  531. spatial dimensions. Only zero-padding is supported. Default: 0
  532. dilation: dilation of the 3D convolution operation. Default: 1
  533. groups: number of groups into which the input and output channels are divided,
  534. so as to perform a ``grouped convolution``. When ``groups`` is not 1,
  535. ``in_channels`` and ``out_channels`` must be divisible by groups,
  536. and the shape of weight should be ``(groups, in_channels // groups,
  537. out_channels // groups, depth, height, width)``. Default: 1
  538. Returns:
  539. output tensor.
  540. """
  541. D, H, W = 0, 1, 2
  542. pad = _triple(padding)
  543. stride = _triple_nonzero(stride)
  544. dilate = _triple_nonzero(dilation)
  545. dtype = dtype_promotion(inp, weight)
  546. if inp.dtype != dtype:
  547. inp = inp.astype(dtype)
  548. if weight.dtype != dtype:
  549. weight = weight.astype(dtype)
  550. sparse_type = "dense" if groups == 1 else "group"
  551. op = builtin.Convolution3DBackwardData(
  552. pad_d=pad[D],
  553. pad_h=pad[H],
  554. pad_w=pad[W],
  555. stride_d=stride[D],
  556. stride_h=stride[H],
  557. stride_w=stride[W],
  558. dilate_d=dilate[D],
  559. dilate_h=dilate[H],
  560. dilate_w=dilate[W],
  561. strategy=get_execution_strategy(),
  562. sparse=sparse_type,
  563. )
  564. (output,) = apply(op, weight, inp)
  565. if bias is not None:
  566. output += bias
  567. return output
  568. def max_pool2d(
  569. inp: Tensor,
  570. kernel_size: Union[int, Tuple[int, int]],
  571. stride: Optional[Union[int, Tuple[int, int]]] = None,
  572. padding: Union[int, Tuple[int, int]] = 0,
  573. ) -> Tensor:
  574. r"""Applies a 2D max pooling over an input tensor.
  575. Refer to :class:`~.MaxPool2d` for more information.
  576. Args:
  577. inp: input tensor.
  578. kernel_size: size of the window.
  579. stride: stride of the window. If not provided, its value is set to kernel_size.
  580. Default: None
  581. padding: implicit zero padding added on both sides. Default: 0
  582. Returns:
  583. output tensor.
  584. """
  585. if stride is None:
  586. stride = kernel_size
  587. window_h, window_w = _pair_nonzero(kernel_size)
  588. stride_h, stride_w = _pair_nonzero(stride)
  589. padding_h, padding_w = _pair(padding)
  590. conv_format = _config._get_actual_op_param("NCHW", _config.__conv_format)
  591. op = builtin.Pooling(
  592. window_h=window_h,
  593. window_w=window_w,
  594. stride_h=stride_h,
  595. stride_w=stride_w,
  596. pad_h=padding_h,
  597. pad_w=padding_w,
  598. mode="max",
  599. format=conv_format,
  600. )
  601. (output,) = apply(op, inp)
  602. return output
  603. def avg_pool2d(
  604. inp: Tensor,
  605. kernel_size: Union[int, Tuple[int, int]],
  606. stride: Optional[Union[int, Tuple[int, int]]] = None,
  607. padding: Union[int, Tuple[int, int]] = 0,
  608. mode: str = "average_count_exclude_padding",
  609. ) -> Tensor:
  610. r"""Applies 2D average pooling over an input tensor.
  611. Refer to :class:`~.AvgPool2d` for more information.
  612. Args:
  613. inp: input tensor.
  614. kernel_size: size of the window.
  615. stride: stride of the window. If not provided, its value is set to ``kernel_size``.
  616. Default: None
  617. padding: implicit zero padding added on both sides. Default: 0
  618. mode: whether to count padding values, set to "average" will do counting.
  619. Default: "average_count_exclude_padding"
  620. Returns:
  621. output tensor.
  622. """
  623. if stride is None:
  624. stride = kernel_size
  625. window_h, window_w = _pair_nonzero(kernel_size)
  626. stride_h, stride_w = _pair_nonzero(stride)
  627. padding_h, padding_w = _pair(padding)
  628. conv_format = _config._get_actual_op_param("NCHW", _config.__conv_format)
  629. op = builtin.Pooling(
  630. window_h=window_h,
  631. window_w=window_w,
  632. stride_h=stride_h,
  633. stride_w=stride_w,
  634. pad_h=padding_h,
  635. pad_w=padding_w,
  636. mode=mode,
  637. format=conv_format,
  638. )
  639. (output,) = apply(op, inp)
  640. return output
  641. def adaptive_max_pool2d(
  642. inp: Tensor, oshp: Union[Tuple[int, int], int, Tensor],
  643. ) -> Tensor:
  644. r"""Applies a 2D max adaptive pooling over an input.
  645. Refer to :class:`~.MaxAdaptivePool2d` for more information.
  646. Args:
  647. inp: input tensor.
  648. oshp: OH, OW)` size of the output shape.
  649. Returns:
  650. output tensor.
  651. """
  652. if isinstance(oshp, int):
  653. oshp = (oshp, oshp)
  654. conv_format = _config._get_actual_op_param("NCHW", _config.__conv_format)
  655. op = builtin.AdaptivePooling(mode="max", format=conv_format,)
  656. oshp = astensor1d(oshp, inp, dtype="int32", device=inp.device)
  657. (output,) = apply(op, inp, oshp)
  658. return output
  659. def adaptive_avg_pool2d(
  660. inp: Tensor, oshp: Union[Tuple[int, int], int, Tensor],
  661. ) -> Tensor:
  662. r"""Applies a 2D average adaptive pooling over an input.
  663. Refer to :class:`~.AvgAdaptivePool2d` for more information.
  664. Args:
  665. inp: input tensor.
  666. oshp: OH, OW)` size of the output shape.
  667. Returns:
  668. output tensor.
  669. """
  670. if isinstance(oshp, int):
  671. oshp = (oshp, oshp)
  672. op = builtin.AdaptivePooling(mode="average", format="NCHW",)
  673. oshp = astensor1d(oshp, inp, dtype="int32", device=inp.device)
  674. (output,) = apply(op, inp, oshp)
  675. return output
  676. def deformable_psroi_pooling(
  677. inp: Tensor,
  678. rois: Tensor,
  679. trans: Tensor,
  680. no_trans: bool,
  681. part_size: int,
  682. pooled_h: int,
  683. pooled_w: int,
  684. sample_per_part: int,
  685. spatial_scale: float,
  686. trans_std: float = 0.1,
  687. ):
  688. r"""Deformable PSROI(Position Sensitive Region of Interest) Pooling.
  689. Args:
  690. inp: input feature map.
  691. rois: the rois for feature pooling.
  692. trans: input offset to psroi_pooling.
  693. no_trans: check the phase of DeformablePSROIPooling. False to the
  694. 1st phase, True to the 2nd phase.
  695. part_size: part size.
  696. sample_per_part: sample points of each part.
  697. pooled_shape: kernel shape of convolution.
  698. spatial_scale: the spatial_scale w.r.t input image.
  699. trans_std: multiplier used in 2nd phase.
  700. """
  701. op = builtin.DeformablePSROIPooling(
  702. no_trans=no_trans,
  703. part_size=part_size,
  704. pooled_h=pooled_h,
  705. pooled_w=pooled_w,
  706. sample_per_part=sample_per_part,
  707. spatial_scale=spatial_scale,
  708. trans_std=trans_std,
  709. )
  710. output, _ = apply(op, inp, rois, trans)
  711. return output
  712. def hswish(x):
  713. r"""Element-wise `x * relu6(x + 3) / 6`.
  714. Example:
  715. .. testcode::
  716. import numpy as np
  717. from megengine import tensor
  718. import megengine.functional as F
  719. x = tensor(np.arange(5).astype(np.float32))
  720. out = F.hswish(x)
  721. print(out.numpy().round(decimals=4))
  722. .. testoutput::
  723. [0. 0.6667 1.6667 3. 4. ]
  724. """
  725. return _elwise(x, mode=Elemwise.Mode.H_SWISH)
  726. def sigmoid(x):
  727. r"""Element-wise `1 / ( 1 + exp( -x ) )`."""
  728. return _elwise(x, mode=Elemwise.Mode.SIGMOID)
  729. def hsigmoid(x):
  730. r"""Element-wise `relu6(x + 3) / 6`."""
  731. return relu6(x + 3) / 6
  732. def relu(x):
  733. r"""Element-wise `max(x, 0)`."""
  734. return _elwise(x, mode=Elemwise.Mode.RELU)
  735. def relu6(x):
  736. r"""Element-wise `min(max(x, 0), 6)`."""
  737. return minimum(maximum(x, 0), 6)
  738. def prelu(inp: Tensor, weight: Tensor) -> Tensor:
  739. r"""Elememt-wise PReLU function.
  740. Refer to :class:`~.PReLU` for more information.
  741. """
  742. return maximum(inp, 0) + weight * minimum(inp, 0)
  743. def leaky_relu(inp: Tensor, negative_slope: float = 0.01) -> Tensor:
  744. r"""Element-wose LeakyReLU function
  745. Refer to :class:`~.LeakyReLU` for more information.
  746. """
  747. return maximum(inp, 0) + negative_slope * minimum(inp, 0)
  748. def silu(x):
  749. r"""Applies the element-wise Sigmoid Linear Unit function, i.e. `x * sigmoid(x)`."""
  750. return _elwise(x, mode=Elemwise.Mode.SILU)
  751. def gelu(x):
  752. r"""Applies the element-wise function:
  753. .. math::
  754. \text{gelu}(x) = x\Phi(x)
  755. where :math:`\Phi(x)` is the Cumulative Distribution Function for Gaussian Distribution.
  756. """
  757. return _elwise(x, mode=Elemwise.Mode.GELU)
  758. def softplus(inp: Tensor) -> Tensor:
  759. r"""Applies the element-wise function:
  760. .. math::
  761. \text{softplus}(x) = \log(1 + \exp(x))
  762. softplus is a smooth approximation to the ReLU function and can be used
  763. to constrain the output to be always positive.
  764. For numerical stability the implementation follows this transformation:
  765. .. math::
  766. \text{softplus}(x) = \log(1 + \exp(x))
  767. = \log(1 + \exp(-\text{abs}(x))) + \max(x, 0)
  768. = \log1p(\exp(-\text{abs}(x))) + \text{relu}(x)
  769. Examples:
  770. .. testcode::
  771. import numpy as np
  772. from megengine import tensor
  773. import megengine.functional as F
  774. x = tensor(np.arange(-3, 3, dtype=np.float32))
  775. y = F.softplus(x)
  776. print(y.numpy().round(decimals=4))
  777. Outputs:
  778. .. testoutput::
  779. [0.0486 0.1269 0.3133 0.6931 1.3133 2.1269]
  780. """
  781. return log1p(exp(-abs(inp))) + relu(inp)
  782. def logsoftmax(inp: Tensor, axis: Union[int, Sequence[int]]) -> Tensor:
  783. r"""Applies the :math:`\log(\text{softmax}(x))` function to an n-dimensional
  784. input tensor. The :math:`\text{logsoftmax}(x)` formulation can be simplified as:
  785. .. math::
  786. \text{logsoftmax}(x_{i}) = \log(\frac{\exp(x_i) }{ \sum_j \exp(x_j)} )
  787. For numerical stability the implementation follows this transformation:
  788. .. math::
  789. \text{logsoftmax}(x)
  790. = \log (\frac{\exp (x)}{\sum_{i}(\exp (x_{i}))})
  791. = x - \log (\sum_{i}(\exp (x_{i})))
  792. = x - \text{logsumexp}(x)
  793. Examples:
  794. .. testcode::
  795. import numpy as np
  796. from megengine import tensor
  797. import megengine.functional as F
  798. x = tensor(np.arange(-5, 5, dtype=np.float32)).reshape(2,5)
  799. y = F.logsoftmax(x, axis=1)
  800. print(y.numpy().round(decimals=4))
  801. Outputs:
  802. .. testoutput::
  803. [[-4.4519 -3.4519 -2.4519 -1.4519 -0.4519]
  804. [-4.4519 -3.4519 -2.4519 -1.4519 -0.4519]]
  805. """
  806. return inp - logsumexp(inp, axis, keepdims=True)
  807. def logsigmoid(inp: Tensor) -> Tensor:
  808. r"""Applies the element-wise function:
  809. .. math::
  810. \text{logsigmoid}(x) = \log(\frac{ 1 }{ 1 + \exp(-x)})
  811. = \log(1/(1 + \exp(-x)))
  812. = - \log(1 + \exp(-x))
  813. = - \text{softplus}(-x)
  814. Examples:
  815. .. testcode::
  816. import numpy as np
  817. from megengine import tensor
  818. import megengine.functional as F
  819. x = tensor(np.arange(-5, 5, dtype=np.float32))
  820. y = F.logsigmoid(x)
  821. print(y.numpy().round(decimals=4))
  822. Outputs:
  823. .. testoutput::
  824. [-5.0067 -4.0182 -3.0486 -2.1269 -1.3133 -0.6931 -0.3133 -0.1269 -0.0486
  825. -0.0181]
  826. """
  827. return -softplus(-inp)
  828. def logsumexp(
  829. inp: Tensor, axis: Union[int, Sequence[int]], keepdims: bool = False
  830. ) -> Tensor:
  831. r"""Calculates the logarithm of the inputs' exponential sum along the given :attr:`axis`.
  832. .. math::
  833. \text{logsumexp}(x)= \log \sum_{j=1}^{n} \exp \left(x_{j}\right)
  834. For numerical stability, the implementation follows this transformation:
  835. .. math::
  836. \text{logsumexp}(x)= \log \sum_{j=1}^{n} \exp \left(x_{j}\right)
  837. = \text{logsumexp}(x)=b+\log \sum_{j=1}^{n} \exp \left(x_{j}-b\right)
  838. where
  839. .. math::
  840. b = \max(x_j)
  841. Examples:
  842. .. testcode::
  843. import numpy as np
  844. from megengine import tensor
  845. import megengine.functional as F
  846. x = tensor(np.arange(-5, 5, dtype=np.float32)).reshape(2,5)
  847. y = F.logsumexp(x, axis=1, keepdims=False)
  848. print(y.numpy().round(decimals=4))
  849. Outputs:
  850. .. testoutput::
  851. [-0.5481 4.4519]
  852. """
  853. max_value = max(inp.detach(), axis, keepdims=True)
  854. if keepdims:
  855. return max_value + log(sum(exp(inp - max_value), axis, keepdims))
  856. else:
  857. return squeeze(max_value, axis=None) + log(
  858. sum(exp(inp - max_value), axis, keepdims)
  859. )
  860. def _get_softmax_axis(ndim: int) -> int:
  861. if ndim in (0, 1, 3):
  862. return 0
  863. return 1
  864. def softmax(inp: Tensor, axis: Optional[int] = None) -> Tensor:
  865. r"""Applies a :math:`\text{softmax}(x)` function. :math:`\text{softmax}(x)` is defined as:
  866. .. math::
  867. \text{softmax}(x_{i}) = \frac{\exp(x_i)}{\sum_j \exp(x_j)}
  868. It is applied to all elements along axis, and rescales elements so that
  869. they stay in the range `[0, 1]` and sum to 1.
  870. See :class:`~.module.Softmax` for more details.
  871. Examples:
  872. .. testcode::
  873. import numpy as np
  874. from megengine import tensor
  875. import megengine.functional as F
  876. x = tensor(np.arange(-5, 5, dtype=np.float32)).reshape(2,5)
  877. out = F.softmax(x)
  878. print(out.numpy().round(decimals=4))
  879. Outputs:
  880. .. testoutput::
  881. [[0.0117 0.0317 0.0861 0.2341 0.6364]
  882. [0.0117 0.0317 0.0861 0.2341 0.6364]]
  883. """
  884. if axis is None:
  885. axis = _get_softmax_axis(len(inp.shape))
  886. if isinstance(axis, list):
  887. offset = inp.max(axis=axis, keepdims=True).detach()
  888. cached = exp(inp - offset)
  889. down = sum(cached, axis=axis, keepdims=True)
  890. return cached / down
  891. else:
  892. op = builtin.Softmax(axis=axis,)
  893. (output,) = apply(op, inp)
  894. return output
  895. def layer_norm(
  896. inp: Tensor,
  897. normalized_shape: tuple,
  898. affine: bool,
  899. weight: Optional[Tensor] = None,
  900. bias: Optional[Tensor] = None,
  901. eps: float = 1e-5,
  902. ):
  903. r"""Applies layer normalization to the input. Support tensor of any shape as input.
  904. Reference: https://arxiv.org/pdf/1803.08494.pdf.
  905. Args:
  906. inp: input tensor.
  907. normalized_shape: the shape that you want to be normalizated
  908. affine: whether to use weight and bias
  909. weight: must not be None when the affine is true
  910. bias: must not be None when the affine is true
  911. eps: a value added to the denominator for numerical stability. Default: 1e-5
  912. """
  913. if amp._enabled:
  914. inp, weight, bias = cast_tensors(inp, weight, bias, promote=True)
  915. if isinstance(normalized_shape, int):
  916. normalized_shape = [normalized_shape]
  917. normalized_dim = len(normalized_shape)
  918. assert normalized_dim > 0
  919. normalized_size = 1
  920. for i in range(normalized_dim):
  921. normalized_size = normalized_size * normalized_shape[i]
  922. op = builtin.LayerNorm(
  923. affine=affine,
  924. eps=eps,
  925. normalized_dim=normalized_dim,
  926. normalized_size=normalized_size,
  927. )
  928. if affine:
  929. assert weight is not None and bias is not None
  930. return apply(op, inp, weight, bias)[0]
  931. else:
  932. # assert weight is None and bias is None
  933. return apply(op, inp)[0]
  934. def batch_norm(
  935. inp: Tensor,
  936. running_mean: Tensor = None,
  937. running_var: Tensor = None,
  938. weight: Optional[Tensor] = None,
  939. bias: Optional[Tensor] = None,
  940. *,
  941. training: bool = False,
  942. momentum: float = 0.9,
  943. eps: float = 1e-5,
  944. inplace: bool = True,
  945. compute_mode="default",
  946. param_dim="dim_1c11"
  947. ):
  948. r"""Applies batch normalization to the input.
  949. Refer to :class:`~.BatchNorm2d` and :class:`~.BatchNorm1d` for more information.
  950. Args:
  951. inp: input tensor.
  952. running_mean: tensor to store running mean.
  953. running_var: tensor to store running variance.
  954. weight: scaling tensor in the learnable affine parameters.
  955. See :math:`\gamma` in :class:`~.BatchNorm2d`.
  956. bias: bias tensor in the learnable affine parameters.
  957. See :math:`\beta` in :class:`~.BatchNorm2d`.
  958. training: a boolean value to indicate whether batch norm is performed
  959. in training mode. Default: False
  960. momentum: value used for the ``running_mean`` and ``running_var``
  961. computation. Default: 0.9
  962. eps: a value added to the denominator for numerical stability. Default: 1e-5
  963. inplace: whether to update ``running_mean`` and ``running_var``
  964. inplace or return new tensors. Default: True
  965. """
  966. if inp.ndim != 4:
  967. raise NotImplementedError("batch_norm for ndim != 4")
  968. if param_dim == "dim_1c11":
  969. C = inp.shape[1]
  970. pshape = (1, C, 1, 1)
  971. elif param_dim == "dim_111c":
  972. C = inp.shape[3]
  973. pshape = (1, 1, 1, C)
  974. else:
  975. raise ValueError("Invalid param_dim {}".format(param_dim))
  976. def make_full_if_none(x, value):
  977. if x is None:
  978. (x,) = Const(value, dtype=inp.dtype, device=inp.device)()
  979. shape = astensor1d(pshape, inp, dtype="int32", device=inp.device)
  980. (result,) = apply(builtin.Broadcast(), x, shape)
  981. return result
  982. elif x.ndim == 1:
  983. shape = astensor1d(pshape, inp, dtype="int32", device=inp.device)
  984. (result,) = apply(builtin.Reshape(), x, shape)
  985. return result
  986. return x
  987. has_mean = running_mean is not None
  988. has_var = running_var is not None
  989. if not training:
  990. assert has_mean, "running_mean must be provided in inference mode"
  991. assert has_var, "running_var must be provided in inference mode"
  992. if has_mean and running_mean.ndim != 4:
  993. raise ValueError
  994. if has_var and running_var.ndim != 4:
  995. raise ValueError
  996. if amp._enabled:
  997. inp = inp.astype("float16")
  998. weight, bias, running_mean, running_var = cast_tensors(
  999. weight, bias, running_mean, running_var, promote=True
  1000. )
  1001. weight = make_full_if_none(weight, 1)
  1002. bias = make_full_if_none(bias, 0)
  1003. if not training:
  1004. op = builtin.BatchNorm(
  1005. fwd_mode=BatchNorm.FwdMode.INFERENCE, epsilon=eps, param_dim=param_dim
  1006. )
  1007. ret = apply(op, inp, weight, bias, running_mean, running_var)[-1]
  1008. return ret
  1009. else:
  1010. op = builtin.BatchNorm(
  1011. avg_factor=1 - momentum, epsilon=eps, param_dim=param_dim
  1012. )
  1013. if has_mean or has_var:
  1014. running_mean = make_full_if_none(running_mean, 0)
  1015. running_var = make_full_if_none(running_var, 1)
  1016. new_mean, new_var, *_, inp = apply(
  1017. op, inp, weight, bias, running_mean, running_var
  1018. )
  1019. if not has_mean:
  1020. new_mean = None
  1021. if not has_var:
  1022. new_var = None
  1023. if inplace:
  1024. if has_mean:
  1025. running_mean[...] = new_mean
  1026. if has_var:
  1027. running_var[...] = new_var
  1028. return inp
  1029. else:
  1030. return inp, new_mean, new_var
  1031. else:
  1032. inp = apply(op, inp, weight, bias)[-1]
  1033. return inp
  1034. @lru_cache(maxsize=None)
  1035. def _get_sync_bn_ops(device, dtype, eps_mode, ndim, channels):
  1036. # fmt: off
  1037. @subgraph("SyncBnStage0", dtype, device, 1)
  1038. def syncbn_stage0(inputs, f, c):
  1039. input = inputs[0]
  1040. reduce_shape = c((1, channels) + (1,) * (ndim - 2), dtype="int32", device=device)
  1041. input_shape = f(GetVarShape(), input)
  1042. input_elems = f(Reduce(mode="product", axis=0), input_shape)
  1043. reduce_elems = f(Reduce(mode="product", axis=0), reduce_shape)
  1044. reduce_size = f("//", input_elems, reduce_elems)
  1045. channel_x1s = f(Reduce(mode="sum"), input, reduce_shape)
  1046. channel_x2s = f(Reduce(mode="sum_sqr"), input, reduce_shape)
  1047. reduce_size_f = f(TypeCvt(dtype=dtype), reduce_size)
  1048. return (reduce_shape, reduce_size_f, channel_x1s, channel_x2s), (False, False, True, True)
  1049. @subgraph("SyncBnStage1", dtype, device, 7)
  1050. def syncbn_stage1(inputs, f, c):
  1051. input, reduce_size, channel_x1s, channel_x2s, eps = inputs[0:5]
  1052. weight, bias = inputs[5:7]
  1053. channel_mean = f("/", channel_x1s, reduce_size)
  1054. channel_var =\
  1055. f("+", f("/", f("**", channel_x1s, c(2)),
  1056. f("-", f("*", reduce_size, reduce_size))),
  1057. f("/", channel_x2s, reduce_size))
  1058. invsqrt_channel_var = f("**", f(eps_mode, channel_var, eps), c(-0.5))
  1059. inv_var_wt = f("*", invsqrt_channel_var, weight)
  1060. neg_channel_mean = f("-", channel_mean)
  1061. outvar =\
  1062. f("fma3", input, inv_var_wt,
  1063. f("+", f("*", neg_channel_mean, inv_var_wt),
  1064. bias))
  1065. return (outvar, channel_mean, channel_var, inv_var_wt), (True, False, False, False)
  1066. @subgraph("SyncBnStage1Inference", dtype, device, 6)
  1067. def syncbn_stage1_inference(inputs, f, c):
  1068. input, channel_mean, channel_var, eps = inputs[0:4]
  1069. weight, bias = inputs[4:6]
  1070. invsqrt_channel_var = f("**", f(eps_mode, channel_var, eps), c(-0.5))
  1071. inv_var_wt = f("*", invsqrt_channel_var, weight)
  1072. neg_channel_mean = f("-", channel_mean)
  1073. outvar =\
  1074. f("+", f("*", input, inv_var_wt),
  1075. f("+", f("*", neg_channel_mean, inv_var_wt),
  1076. bias))
  1077. return (outvar,), (True,)
  1078. @subgraph("SyncBnStage2", dtype, device, 7)
  1079. def syncbn_stage2(inputs, f, c):
  1080. running_mean, running_var, momentum = inputs[0:3]
  1081. reduce_size, channel_x1s, channel_x2s, channel_mean = inputs[3:7]
  1082. c1_minus_momentum = f("-", c(1), momentum)
  1083. reduce_size_minus_c1 = f("-", reduce_size, c(1))
  1084. running_mean = f("fma4",
  1085. running_mean, momentum,
  1086. c1_minus_momentum, channel_mean,
  1087. )
  1088. channel_variance_unbiased =\
  1089. f("+", f("/", f("**", channel_x1s, c(2)),
  1090. f("*", f("-", reduce_size),
  1091. reduce_size_minus_c1)),
  1092. f("/", channel_x2s,
  1093. reduce_size_minus_c1))
  1094. running_var = f("fma4",
  1095. running_var, momentum,
  1096. c1_minus_momentum, channel_variance_unbiased
  1097. )
  1098. return (running_mean, running_var), (True, True)
  1099. @subgraph("SyncBnConcatStats", dtype, device, 3)
  1100. def syncbn_concat_stats(inputs, f, c):
  1101. reduce_size, channel_x1s, channel_x2s = inputs[0:3]
  1102. reduce_size = f(builtin.Broadcast(), reduce_size, c([1]*ndim, dtype="int32"))
  1103. stats = f(builtin.Concat(axis=1, comp_node=device), reduce_size, channel_x1s, channel_x2s)
  1104. return (stats,), (True,)
  1105. @subgraph("SyncBnSplitStats", dtype, device, 1)
  1106. def syncbn_split_stats(inputs, f, c):
  1107. stats = inputs[0]
  1108. c_1 = c(1, dtype="int32")
  1109. channel_x1s_end = c(channels+1, dtype="int32")
  1110. def _subtensor(src, axis, begin, end):
  1111. items = (axis, (begin is not None), (end is not None), False, False),
  1112. args = ()
  1113. if begin is not None:
  1114. args += begin,
  1115. if end is not None:
  1116. args += end,
  1117. return f(builtin.Subtensor(items=items), src, *args)
  1118. reduce_size = _subtensor(stats, 1, None, c_1)
  1119. channel_x1s = _subtensor(stats, 1, c_1, channel_x1s_end)
  1120. channel_x2s = _subtensor(stats, 1, channel_x1s_end, None)
  1121. reduce_size = f(builtin.Reshape(), reduce_size, c_1)
  1122. return (reduce_size, channel_x1s, channel_x2s), (False, True, True)
  1123. # fmt: on
  1124. return (
  1125. syncbn_stage0,
  1126. syncbn_stage1,
  1127. syncbn_stage1_inference,
  1128. syncbn_stage2,
  1129. syncbn_concat_stats,
  1130. syncbn_split_stats,
  1131. )
  1132. def sync_batch_norm(
  1133. inp: Tensor,
  1134. running_mean: Tensor,
  1135. running_var: Tensor,
  1136. weight: Optional[Tensor] = None,
  1137. bias: Optional[Tensor] = None,
  1138. training: bool = False,
  1139. momentum: Union[float, Tensor] = 0.9,
  1140. eps: float = 1e-5,
  1141. eps_mode="additive",
  1142. group=WORLD,
  1143. ) -> Tensor:
  1144. r"""Applies synchronized batch normalization to the input.
  1145. Refer to :class:`~.BatchNorm2d` and :class:`~.BatchNorm1d` for more information.
  1146. Args:
  1147. inp: input tensor.
  1148. running_mean: tensor to store running mean.
  1149. running_var: tensor to store running variance.
  1150. weight: scaling tensor in the learnable affine parameters.
  1151. See :math:`\gamma` in :class:`~.BatchNorm2d`.
  1152. bias: bias tensor in the learnable affine parameters.
  1153. See :math:`\beta` in :class:`~.BatchNorm2d`.
  1154. training: a boolean value to indicate whether batch norm is performed
  1155. in traning mode. Default: False
  1156. momentum: value used for the ``running_mean`` and ``running_var``
  1157. computation. Default: 0.9
  1158. eps: a value added to the denominator for numerical stability.
  1159. Default: 1e-5
  1160. eps_mode: mode of calculation for eps, "max" or "additive".
  1161. Default: "additive"
  1162. group: communication group, caculate mean and variance between this group.
  1163. Default: :obj:`~megengine.distributed.WORLD`
  1164. """
  1165. _eps_mode = eps_mode.lower()
  1166. assert _eps_mode in {"max", "additive"}, "unknown eps_mode: {}".format(eps_mode)
  1167. if _eps_mode == "additive" and not (is_distributed() and training):
  1168. return batch_norm(
  1169. inp,
  1170. running_mean,
  1171. running_var,
  1172. weight,
  1173. bias,
  1174. training=training,
  1175. momentum=momentum,
  1176. eps=eps,
  1177. )
  1178. _channels = make_shape_tuple(inp.shape)[1]
  1179. _ndim = inp.ndim
  1180. _device = inp.device
  1181. _dtype = inp.dtype
  1182. if _ndim != 4:
  1183. raise NotImplementedError("sync_batch_norm for ndim != 4")
  1184. def _make_full_if_none(x, value):
  1185. if x is None:
  1186. (x,) = Const(value, dtype=inp.dtype, device=_device)()
  1187. (result,) = apply(builtin.Broadcast(), x, reduce_shape)
  1188. return result
  1189. elif x.ndim == 1:
  1190. (result,) = apply(builtin.Reshape(), x, reduce_shape)
  1191. return result
  1192. return x
  1193. (
  1194. syncbn_stage0,
  1195. syncbn_stage1,
  1196. syncbn_stage1_inference,
  1197. syncbn_stage2,
  1198. syncbn_concat_stats,
  1199. syncbn_split_stats,
  1200. ) = _get_sync_bn_ops(_device, _dtype, eps_mode, _ndim, _channels)
  1201. reduce_shape, reduce_size, channel_x1s, channel_x2s = apply(syncbn_stage0(), inp)
  1202. eps = convert_single_value(eps, dtype=inp.dtype, device=inp.device)
  1203. weight = _make_full_if_none(weight, 1)
  1204. bias = _make_full_if_none(bias, 0)
  1205. if training:
  1206. if is_distributed():
  1207. # reduce all nodes' data to calculate mean and variance
  1208. (stat,) = apply(
  1209. syncbn_concat_stats(), reduce_size, channel_x1s, channel_x2s
  1210. )
  1211. stat = all_reduce_sum(stat, group)
  1212. reduce_size, channel_x1s, channel_x2s = apply(syncbn_split_stats(), stat)
  1213. outvar, channel_mean, *_ = apply(
  1214. syncbn_stage1(),
  1215. inp,
  1216. reduce_size,
  1217. channel_x1s,
  1218. channel_x2s,
  1219. eps,
  1220. weight,
  1221. bias,
  1222. )
  1223. else:
  1224. assert running_var is not None and running_mean is not None
  1225. channel_mean = running_mean
  1226. channel_var = running_var
  1227. outvar, *_ = apply(
  1228. syncbn_stage1_inference(), inp, channel_mean, channel_var, eps, weight, bias
  1229. )
  1230. # outvar = output * weight + bias
  1231. # where output = inp * invsqrt_channel_variance + (
  1232. # -channel_mean * invsqrt_channel_variance
  1233. # )
  1234. # Manually expand output for gopt
  1235. if training and running_var is not None and running_mean is not None:
  1236. momentum = convert_single_value(momentum, dtype=inp.dtype, device=inp.device)
  1237. running_mean[...], running_var[...] = apply(
  1238. syncbn_stage2(),
  1239. running_mean,
  1240. running_var,
  1241. momentum,
  1242. reduce_size,
  1243. channel_x1s,
  1244. channel_x2s,
  1245. channel_mean,
  1246. )
  1247. return outvar
  1248. def dropout(inp: Tensor, drop_prob: float, training: bool = True) -> Tensor:
  1249. r"""Returns a new tensor where each of the elements are randomly set to zero
  1250. with probability P = ``drop_prob``. Optionally rescale the output tensor if ``training`` is True.
  1251. Args:
  1252. inp: input tensor.
  1253. drop_prob: probability to drop (set to zero) a single element.
  1254. training: the default behavior of ``dropout`` during training is to rescale the output,
  1255. then it can be replaced by an :class:`~.module.identify.Identity` during inference. Default: True
  1256. Returns:
  1257. the ouput tensor
  1258. Examples:
  1259. .. testcode::
  1260. import numpy as np
  1261. from megengine import tensor
  1262. import megengine.functional as F
  1263. # test training mode
  1264. data = tensor(np.ones(10000000, dtype=np.float32))
  1265. out = F.nn.dropout(data, 1.0 / 3.0, training=True)
  1266. assert not out.numpy().all()
  1267. # test eval mode
  1268. out = F.nn.dropout(data, 1.0 / 3.0, training=False)
  1269. assert out.numpy().all()
  1270. Outputs:
  1271. .. testoutput::
  1272. :options: +SKIP
  1273. [1.5 1.5 0. 1.5 1.5 1.5 1.5 1.5 1.5 1.5]
  1274. """
  1275. assert 0 <= drop_prob < 1
  1276. if not training or drop_prob == 0:
  1277. return inp
  1278. # model in training mode, e.g. model.train()
  1279. op = Dropout(drop_prob=drop_prob, seed=_get_global_rng_seed(), handle=0)
  1280. outputs = apply(op, inp)
  1281. return outputs[0]
  1282. def one_hot(inp: Tensor, num_classes: int) -> Tensor:
  1283. r"""Performs one-hot encoding for the input tensor.
  1284. Args:
  1285. inp: input tensor.
  1286. num_classes: number of classes denotes the last dimension of the output tensor.
  1287. Examples:
  1288. .. testcode::
  1289. import numpy as np
  1290. from megengine import tensor
  1291. import megengine.functional as F
  1292. x = tensor(np.arange(1, 4, dtype=np.int32))
  1293. out = F.one_hot(x, num_classes=4)
  1294. print(out.numpy())
  1295. Outputs:
  1296. .. testoutput::
  1297. [[0 1 0 0]
  1298. [0 0 1 0]
  1299. [0 0 0 1]]
  1300. """
  1301. zeros_tensor = zeros(
  1302. list(inp.shape) + [num_classes], dtype=inp.dtype, device=inp.device
  1303. )
  1304. ones_tensor = ones(list(inp.shape) + [1], dtype=inp.dtype, device=inp.device)
  1305. op = builtin.IndexingSetOneHot(axis=inp.ndim)
  1306. (result,) = apply(op, zeros_tensor, inp, ones_tensor)
  1307. return result
  1308. def embedding(
  1309. inp: Tensor,
  1310. weight: Tensor,
  1311. padding_idx: Optional[int] = None,
  1312. max_norm: Optional[float] = None,
  1313. norm_type: Optional[float] = None,
  1314. ):
  1315. r"""Applies lookup table for embedding.
  1316. Args:
  1317. inp: tensor with indices.
  1318. weight: learnable weights which embeds from.
  1319. padding_idx: should be set to None, not supported now.
  1320. max_norm: should be set to None, not supported now.
  1321. norm_type: should be set to None, not supported now.
  1322. Refer to :class:`~.module.Embedding` for more information.
  1323. """
  1324. if padding_idx is not None:
  1325. raise ValueError("Not support padding_idx Now!")
  1326. if max_norm is not None or norm_type is not None:
  1327. raise ValueError("Not support weight normlization Now!")
  1328. dest_shp = list(inp.shape) + [weight.shape[-1]]
  1329. return weight[inp.reshape(-1)].reshape(dest_shp)
  1330. def indexing_one_hot(
  1331. src: Tensor, index: Tensor, axis: int = 1, keepdims=False
  1332. ) -> Tensor:
  1333. r"""One-hot indexing for some axes.
  1334. Args:
  1335. src: input tensor.
  1336. index: index tensor.
  1337. axis: axis on src for which values in index index. Default: 1
  1338. keepdims: whether not to remove the axis in result. Default: False
  1339. Examples:
  1340. .. testcode::
  1341. import megengine.functional as F
  1342. from megengine import tensor
  1343. src = tensor([[1.0, 2.0]])
  1344. index = tensor([0])
  1345. val = F.indexing_one_hot(src, index)
  1346. print(val.numpy())
  1347. Outputs:
  1348. .. testoutput::
  1349. [1.]
  1350. """
  1351. assert isinstance(src, Tensor), "src must be of Tensor type"
  1352. op = builtin.IndexingOneHot(axis=axis)
  1353. index = convert_single_value(index, dtype="int32", device=src.device)
  1354. (result,) = apply(op, src, index)
  1355. if not keepdims:
  1356. result = squeeze(result, axis)
  1357. return result
  1358. def sliding_window(
  1359. inp: Tensor,
  1360. kernel_size: Union[int, Tuple[int, int]],
  1361. padding: Union[int, Tuple[int, int]] = 0,
  1362. stride: Union[int, Tuple[int, int]] = 1,
  1363. dilation: Union[int, Tuple[int, int]] = 1,
  1364. ) -> Tensor:
  1365. r"""Extracts sliding local blocks from a batched input tensor.
  1366. Refer to :class:`~.module.sliding_window.SlidingWindow` for more information.
  1367. Args:
  1368. inp: input tensor.
  1369. kernel_size: size of the window.
  1370. padding: implicit zero padding added on both sides of input. Default: 0
  1371. stride: stride of the window. Default: 1
  1372. dilation: dilation of the window. Default: 1
  1373. """
  1374. padding_h, padding_w = _pair(padding)
  1375. stride_h, stride_w = _pair_nonzero(stride)
  1376. dilation_h, dilation_w = _pair_nonzero(dilation)
  1377. window_h, window_w = _pair_nonzero(kernel_size)
  1378. op = builtin.Images2Neibs(
  1379. pad_h=padding_h,
  1380. pad_w=padding_w,
  1381. stride_h=stride_h,
  1382. stride_w=stride_w,
  1383. dilate_h=dilation_h,
  1384. dilate_w=dilation_w,
  1385. window_h=window_h,
  1386. window_w=window_w,
  1387. )
  1388. (output,) = apply(op, inp)
  1389. return output
  1390. def sliding_window_transpose(
  1391. inp: Tensor,
  1392. output_size: Union[int, Tuple[int, int]],
  1393. kernel_size: Union[int, Tuple[int, int]],
  1394. padding: Union[int, Tuple[int, int]] = 0,
  1395. stride: Union[int, Tuple[int, int]] = 1,
  1396. dilation: Union[int, Tuple[int, int]] = 1,
  1397. ) -> Tensor:
  1398. r"""Sum over the sliding windows on the corresponding input location.
  1399. Refer to :class:`~.module.sliding_window.SlidingWindowTranspose` for more information.
  1400. Args:
  1401. inp: input tensor.
  1402. output_size: shape of output tensor.
  1403. kernel_size: size of the window.
  1404. padding: implicit zero padding added on both sides of input. Default: 0
  1405. stride: stride of the window. Default: 1
  1406. dilation: dilation of the window. Default: 1
  1407. """
  1408. output_h, output_w = _pair_nonzero(output_size)
  1409. padding_h, padding_w = _pair(padding)
  1410. stride_h, stride_w = _pair_nonzero(stride)
  1411. dilation_h, dilation_w = _pair_nonzero(dilation)
  1412. window_h, window_w = _pair_nonzero(kernel_size)
  1413. expected_h = (
  1414. output_h + 2 * padding_h - dilation_h * (window_h - 1) - 1
  1415. ) // stride_h + 1
  1416. expected_w = (
  1417. output_w + 2 * padding_w - dilation_w * (window_w - 1) - 1
  1418. ) // stride_w + 1
  1419. assert inp.ndim == 6, "the input dimension of sliding_window_transpose should be 6"
  1420. assert (
  1421. inp.shape[2] == expected_h and inp.shape[3] == expected_w
  1422. ), "the input shape and output size do not match"
  1423. op = builtin.SlidingWindowTranspose(
  1424. out_h=output_h,
  1425. out_w=output_w,
  1426. pad_h=padding_h,
  1427. pad_w=padding_w,
  1428. stride_h=stride_h,
  1429. stride_w=stride_w,
  1430. dilate_h=dilation_h,
  1431. dilate_w=dilation_w,
  1432. window_h=window_h,
  1433. window_w=window_w,
  1434. )
  1435. (output,) = apply(op, inp)
  1436. return output
  1437. def pad(
  1438. src: Tensor,
  1439. pad_witdth: Tuple[Tuple[int, int], ...],
  1440. mode: str = "constant",
  1441. constant_value: float = 0.0,
  1442. ) -> Tensor:
  1443. """
  1444. Pad is python warpper for padding opr in megbrain, can padding in random one of the max 7 dimensions.
  1445. Supported constant, edge(replicate) and reflect mode, constatnt is the default mode.
  1446. """
  1447. p_offsets = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
  1448. assert mode.lower() in ["constant", "edge", "replicate", "reflect"]
  1449. if mode.lower() == "edge":
  1450. mode = "replicate"
  1451. for i in range(0, len(pad_witdth)):
  1452. p_offsets[i * 2] = pad_witdth[i][0]
  1453. p_offsets[i * 2 + 1] = pad_witdth[i][1]
  1454. op = builtin.Padding(
  1455. front_offset_dim0=p_offsets[0],
  1456. front_offset_dim1=p_offsets[2],
  1457. front_offset_dim2=p_offsets[4],
  1458. front_offset_dim3=p_offsets[6],
  1459. front_offset_dim4=p_offsets[8],
  1460. front_offset_dim5=p_offsets[10],
  1461. front_offset_dim6=p_offsets[12],
  1462. back_offset_dim0=p_offsets[1],
  1463. back_offset_dim1=p_offsets[3],
  1464. back_offset_dim2=p_offsets[5],
  1465. back_offset_dim3=p_offsets[7],
  1466. back_offset_dim4=p_offsets[9],
  1467. back_offset_dim5=p_offsets[11],
  1468. back_offset_dim6=p_offsets[13],
  1469. padding_val=constant_value,
  1470. padding_mode=mode.upper(),
  1471. )
  1472. (output,) = apply(op, src)
  1473. return output
  1474. def local_response_norm(
  1475. inp: Tensor,
  1476. kernel_size: int = 5,
  1477. k: float = 2.0,
  1478. alpha: float = 1e-4,
  1479. beta: float = 0.75,
  1480. ) -> Tensor:
  1481. r"""
  1482. Apply local response normalization to the input tensor.
  1483. Args:
  1484. kernel_size: the size of the kernel to apply LRN on.
  1485. k: hyperparameter k. The default vaule is 2.0.
  1486. alpha: hyperparameter alpha. The default value is 1e-4.
  1487. beta: hyperparameter beta. The default value is 0.75.
  1488. Example:
  1489. .. testcode::
  1490. from megengine import tensor
  1491. import megengine.functional as f
  1492. import numpy as np
  1493. inp = tensor(np.arange(25, dtype=np.float32).reshape(1,1,5,5))
  1494. GT = np.array([[[[ 0., 0.999925, 1.9994003, 2.9979765, 3.9952066],
  1495. [ 4.9906454, 5.983851, 6.974385, 7.961814, 8.945709 ],
  1496. [ 9.925651, 10.90122, 11.872011, 12.837625, 13.7976675],
  1497. [14.751757, 15.699524, 16.640602, 17.574642, 18.501305 ],
  1498. [19.420258, 20.331186, 21.233786, 22.127764, 23.012836 ]]]])
  1499. out = f.local_response_norm(inp, kernel_size=3, k=1.0, alpha=1e-4, beta=0.75)
  1500. np.testing.assert_allclose(GT, out.numpy(), rtol=1e-6, atol=1e-6)
  1501. print('pass')
  1502. Outputs:
  1503. .. testoutput::
  1504. pass
  1505. """
  1506. op = builtin.LRN(n=kernel_size, k=k, alpha=alpha, beta=beta,)
  1507. (output,) = apply(op, inp)
  1508. return output
  1509. @lru_cache(maxsize=None)
  1510. def _get_layerPixelShuffle(device, dtype, dim_order):
  1511. @subgraph("LayerPixelShuffle", dtype, device, 3)
  1512. def layerPixelShuffle(inputs, f, c):
  1513. inp, shape_0, shape_1 = inputs
  1514. inp = f(Reshape(), inp, shape_0)
  1515. inp = f(Dimshuffle(dim_order), inp)
  1516. oup = f(Reshape(), inp, shape_1)
  1517. return (oup,), (True,)
  1518. return layerPixelShuffle
  1519. def pixel_shuffle(inp: Tensor, upscale_factor: int) -> Tensor:
  1520. """
  1521. Rearranges elements in a tensor of shape (*, C x r^2, H, W) to a tensor of
  1522. shape (*, C, H x r, W x r), where r is an upscale factor, where * is zero
  1523. or more batch dimensions.
  1524. :param inp: input tensor.
  1525. :param upscale_factor: upscale factor of pixel_shuffle.
  1526. :return: output tensor.
  1527. """
  1528. assert upscale_factor > 0, "upscale_factor should larger than 0"
  1529. assert inp.ndim >= 3, "the input dimension of pixel_shuffle should be larger than 3"
  1530. assert (
  1531. inp.shape[-3] % (upscale_factor ** 2) == 0
  1532. ), "the -3 dimension should be divided by (upscale_factor ** 2)"
  1533. _device = inp.device
  1534. _dtype = inp.dtype
  1535. shape_ori = inp.shape
  1536. high_dim = shape_ori[:-3]
  1537. square = upscale_factor ** 2
  1538. n = 1
  1539. for item in high_dim:
  1540. n *= item
  1541. shape_0 = (
  1542. n,
  1543. int(shape_ori[-3] / square),
  1544. upscale_factor,
  1545. upscale_factor,
  1546. shape_ori[-2],
  1547. shape_ori[-1],
  1548. )
  1549. shape_1 = (
  1550. *high_dim,
  1551. shape_ori[-3] / square,
  1552. shape_ori[-2] * upscale_factor,
  1553. shape_ori[-1] * upscale_factor,
  1554. )
  1555. dim_order = (0, 1, 4, 2, 5, 3)
  1556. layerPixelShuffle = _get_layerPixelShuffle(_device, _dtype, dim_order)
  1557. shape_0 = convert_single_value(shape_0, dtype=inp.dtype, device=inp.device)
  1558. shape_1 = convert_single_value(shape_1, dtype=inp.dtype, device=inp.device)
  1559. outvar, *_ = apply(layerPixelShuffle(), inp, shape_0, shape_1)
  1560. return outvar
  1561. from .quantized import conv_bias_activation # isort:skip
  1562. from .loss import * # isort:skip
  1563. from .metric import * # isort:skip
  1564. from .vision import * # isort:skip