You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

nn.py 69 kB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534153515361537153815391540154115421543154415451546154715481549155015511552155315541555155615571558155915601561156215631564156515661567156815691570157115721573157415751576157715781579158015811582158315841585158615871588158915901591159215931594159515961597159815991600160116021603160416051606160716081609161016111612161316141615161616171618161916201621162216231624162516261627162816291630163116321633163416351636163716381639164016411642164316441645164616471648164916501651165216531654165516561657165816591660166116621663166416651666166716681669167016711672167316741675167616771678167916801681168216831684168516861687168816891690169116921693169416951696169716981699170017011702170317041705170617071708170917101711171217131714171517161717171817191720172117221723172417251726172717281729173017311732173317341735173617371738173917401741174217431744174517461747174817491750175117521753175417551756175717581759176017611762176317641765176617671768176917701771177217731774177517761777177817791780178117821783178417851786178717881789179017911792179317941795179617971798179918001801180218031804180518061807180818091810181118121813181418151816181718181819182018211822182318241825182618271828182918301831183218331834183518361837183818391840184118421843184418451846184718481849185018511852185318541855185618571858185918601861186218631864186518661867186818691870187118721873187418751876187718781879188018811882188318841885188618871888188918901891189218931894189518961897189818991900190119021903190419051906190719081909191019111912191319141915191619171918191919201921192219231924192519261927192819291930193119321933193419351936193719381939194019411942194319441945194619471948194919501951195219531954195519561957195819591960196119621963196419651966196719681969197019711972197319741975197619771978197919801981198219831984198519861987198819891990199119921993199419951996199719981999200020012002200320042005200620072008200920102011201220132014201520162017201820192020202120222023202420252026202720282029203020312032203320342035203620372038203920402041
  1. # -*- coding: utf-8 -*-
  2. # pylint: disable=too-many-lines
  3. from functools import lru_cache
  4. from typing import NamedTuple, Optional, Sequence, Tuple, Union
  5. from ..core import _config
  6. from ..core._imperative_rt.core2 import (
  7. Const,
  8. adaptive_pool2d_cpp,
  9. apply,
  10. dtype_promotion,
  11. pixel_shuffle_cpp,
  12. )
  13. from ..core._imperative_rt.ops import get_global_rng_seed as _get_global_rng_seed
  14. from ..core.ops import builtin
  15. from ..core.ops.builtin import (
  16. BatchNorm,
  17. Dimshuffle,
  18. Dropout,
  19. Elemwise,
  20. GetVarShape,
  21. Identity,
  22. Reduce,
  23. Reshape,
  24. TypeCvt,
  25. )
  26. from ..core.tensor import amp, megbrain_graph
  27. from ..core.tensor.array_method import _matmul
  28. from ..core.tensor.utils import (
  29. astensor1d,
  30. cast_tensors,
  31. convert_single_value,
  32. make_shape_tuple,
  33. subgraph,
  34. subgraph_fn,
  35. )
  36. from ..device import get_default_device
  37. from ..distributed import WORLD, is_distributed
  38. from ..jit import exclude_from_trace
  39. from ..tensor import Tensor
  40. from ..utils.deprecation import deprecated_func
  41. from .debug_param import get_execution_strategy
  42. from .distributed import all_reduce_sum
  43. from .elemwise import _elwise, exp, log, log1p, maximum, minimum
  44. from .math import max, normalize, sum
  45. from .metric import topk_accuracy
  46. from .tensor import broadcast_to, concat, expand_dims, ones, squeeze, zeros
  47. __all__ = [
  48. "adaptive_avg_pool2d",
  49. "adaptive_max_pool2d",
  50. "avg_pool2d",
  51. "batch_norm",
  52. "conv1d",
  53. "conv2d",
  54. "conv3d",
  55. "conv_transpose2d",
  56. "conv_transpose3d",
  57. "deformable_conv2d",
  58. "deformable_psroi_pooling",
  59. "dropout",
  60. "embedding",
  61. "gelu",
  62. "hsigmoid",
  63. "hswish",
  64. "indexing_one_hot",
  65. "layer_norm",
  66. "leaky_relu",
  67. "linear",
  68. "local_conv2d",
  69. "local_response_norm",
  70. "logsigmoid",
  71. "logsumexp",
  72. "logsoftmax",
  73. "max_pool2d",
  74. "normalize",
  75. "one_hot",
  76. "prelu",
  77. "pad",
  78. "relu",
  79. "relu6",
  80. "remap",
  81. "sigmoid",
  82. "sliding_window",
  83. "sliding_window_transpose",
  84. "silu",
  85. "softmax",
  86. "softplus",
  87. "sync_batch_norm",
  88. "topk_accuracy",
  89. "warp_affine",
  90. "warp_perspective",
  91. "pixel_shuffle",
  92. "region_restricted_conv",
  93. ]
  94. def _expand_hw(x):
  95. # judge int is 5 times faster than judge Sequence
  96. if isinstance(x, int):
  97. return x, x
  98. if isinstance(x, Sequence):
  99. return int(x[0]), int(x[1])
  100. return int(x), int(x)
  101. def _expand_dhw(x):
  102. if isinstance(x, int):
  103. return x, x, x
  104. if isinstance(x, Sequence):
  105. return int(x[0]), int(x[1]), int(x[2])
  106. return int(x), int(x), int(x)
  107. def linear(
  108. inp: Tensor, weight: Tensor, bias: Optional[Tensor] = None, compute_mode="default",
  109. ) -> Tensor:
  110. r"""Applies a linear transformation to the input tensor.
  111. Refer to :class:`~.module.linear.Linear` for more information.
  112. Args:
  113. inp: input tensor with shape `(N, in_features)`.
  114. weight: weight with shape `(out_features, in_features)`.
  115. bias: bias with shape `(out_features,)`. Default: None
  116. """
  117. compute_mode = _config._get_actual_op_param(compute_mode, _config.__compute_mode)
  118. ret = _matmul(inp, weight, transpose_b=True, compute_mode=compute_mode)
  119. if bias is not None:
  120. if amp._enabled:
  121. bias = bias.astype("float16")
  122. ret += bias
  123. return ret
  124. def conv1d(
  125. inp: Tensor,
  126. weight: Tensor,
  127. bias: Optional[Tensor] = None,
  128. stride: int = 1,
  129. padding: int = 0,
  130. dilation: int = 1,
  131. groups: int = 1,
  132. conv_mode="cross_correlation",
  133. compute_mode="default",
  134. ) -> Tensor:
  135. r"""1D convolution operation.
  136. Refer to :class:`~.Conv1d` for more information.
  137. Args:
  138. inp: The feature map of the convolution operation
  139. weight: The convolution kernel.
  140. bias: The bias added to the result of convolution (if given)
  141. stride: Stride of the 1D convolution operation. Default: 1
  142. padding: Size of the paddings added to the input on both sides of its
  143. spatial dimensions. Only zero-padding is supported. Default: 0
  144. dilation: Dilation of the 1D convolution operation. Default: 1
  145. groups: number of groups to divide input and output channels into,
  146. so as to perform a "grouped convolution". When ``groups`` is not 1,
  147. ``in_channels`` and ``out_channels`` must be divisible by ``groups``,
  148. and the shape of weight should be ``(groups, out_channel // groups,
  149. in_channels // groups, kernel_size)``. Default: 1
  150. conv_mode: Supports 'cross_correlation'. Default:
  151. 'cross_correlation'.
  152. compute_mode: When set to 'default', no special requirements will be
  153. placed on the precision of intermediate results. When set to 'float32',
  154. float32 would be used for accumulator and intermediate result, but only
  155. effective when input and output are of float16 dtype.
  156. """
  157. assert (
  158. conv_mode.lower() == "cross_correlation"
  159. or conv_mode.name == "CROSS_CORRELATION"
  160. )
  161. assert compute_mode.lower() == "default" or compute_mode.name == "DEFAULT"
  162. assert inp.ndim == 3, "the input dimension of conv1d should be 3"
  163. assert weight.ndim == 3, "the weight dimension of conv1d should be 3"
  164. if bias is not None:
  165. assert bias.ndim == 3, "the bias dimension of conv1d should be 3"
  166. stride_h = stride
  167. pad_h = padding
  168. dilate_h = dilation
  169. compute_mode = _config._get_actual_op_param(compute_mode, _config.__compute_mode)
  170. sparse_type = "dense" if groups == 1 else "group"
  171. op = builtin.Convolution(
  172. stride_h=stride_h,
  173. stride_w=1,
  174. pad_h=pad_h,
  175. pad_w=0,
  176. dilate_h=dilate_h,
  177. dilate_w=1,
  178. strategy=get_execution_strategy(),
  179. mode=conv_mode,
  180. compute_mode=compute_mode,
  181. sparse=sparse_type,
  182. )
  183. (output,) = apply(op, inp, weight)
  184. if bias is not None:
  185. if amp._enabled:
  186. (bias,) = cast_tensors(bias)
  187. output += bias
  188. return output
  189. def conv2d(
  190. inp: Tensor,
  191. weight: Tensor,
  192. bias: Optional[Tensor] = None,
  193. stride: Union[int, Tuple[int, int]] = 1,
  194. padding: Union[int, Tuple[int, int]] = 0,
  195. dilation: Union[int, Tuple[int, int]] = 1,
  196. groups: int = 1,
  197. conv_mode="cross_correlation",
  198. compute_mode="default",
  199. ) -> Tensor:
  200. r"""2D convolution operation.
  201. Refer to :class:`~.module.Conv2d` for more information.
  202. Args:
  203. inp: feature map of the convolution operation.
  204. weight: convolution kernel.
  205. bias: bias added to the result of convolution (if given).
  206. stride: stride of the 2D convolution operation. Default: 1
  207. padding: size of the paddings added to the input on both sides of its
  208. spatial dimensions. Only zero-padding is supported. Default: 0
  209. dilation: dilation of the 2D convolution operation. Default: 1
  210. groups: number of groups into which the input and output channels are divided,
  211. so as to perform a ``grouped convolution``. When ``groups`` is not 1,
  212. ``in_channels`` and ``out_channels`` must be divisible by ``groups``,
  213. and the shape of weight should be ``(groups, out_channel // groups,
  214. in_channels // groups, height, width)``. Default: 1
  215. conv_mode: supports "cross_correlation". Default: "cross_correlation"
  216. compute_mode: when set to "default", no special requirements will be
  217. placed on the precision of intermediate results. When set to "float32",
  218. "float32" would be used for accumulator and intermediate result, but only
  219. effective when input and output are of float16 dtype.
  220. Returns:
  221. output tensor.
  222. """
  223. assert (
  224. conv_mode.lower() == "cross_correlation"
  225. or conv_mode.name == "CROSS_CORRELATION"
  226. )
  227. stride_h, stride_w = _expand_hw(stride)
  228. pad_h, pad_w = _expand_hw(padding)
  229. dilate_h, dilate_w = _expand_hw(dilation)
  230. sparse_type = "dense" if groups == 1 else "group"
  231. compute_mode = _config._get_actual_op_param(compute_mode, _config.__compute_mode)
  232. op = builtin.Convolution(
  233. stride_h=stride_h,
  234. stride_w=stride_w,
  235. pad_h=pad_h,
  236. pad_w=pad_w,
  237. dilate_h=dilate_h,
  238. dilate_w=dilate_w,
  239. strategy=get_execution_strategy(),
  240. mode=conv_mode,
  241. compute_mode=compute_mode,
  242. sparse=sparse_type,
  243. )
  244. (output,) = apply(op, inp, weight)
  245. if bias is not None:
  246. if amp._enabled:
  247. (bias,) = cast_tensors(bias)
  248. output += bias
  249. return output
  250. def conv3d(
  251. inp: Tensor,
  252. weight: Tensor,
  253. bias: Optional[Tensor] = None,
  254. stride: Union[int, Tuple[int, int, int]] = 1,
  255. padding: Union[int, Tuple[int, int, int]] = 0,
  256. dilation: Union[int, Tuple[int, int, int]] = 1,
  257. groups: int = 1,
  258. conv_mode: str = "cross_correlation",
  259. ) -> Tensor:
  260. r"""3D convolution operation.
  261. Refer to :class:`~.Conv3d` for more information.
  262. Args:
  263. inp: feature map of the convolution operation.
  264. weight: convolution kernel.
  265. bias: bias added to the result of convolution (if given).
  266. stride: stride of the 3D convolution operation. Default: 1
  267. padding: size of the paddings added to the input on both sides of its
  268. spatial dimensions. Only zero-padding is supported. Default: 0
  269. dilation: dilation of the 3D convolution operation. Default: 1
  270. groups: number of groups into which the input and output channels are divided,
  271. so as to perform a ``grouped convolution``. When ``groups`` is not 1,
  272. ``in_channels`` and ``out_channels`` must be divisible by ``groups``,
  273. and the shape of weight should be ``(groups, out_channel // groups,
  274. in_channels // groups, depth, height, width)``. Default: 1
  275. conv_mode: supports "cross_correlation". Default: "cross_correlation"
  276. Returns:
  277. output tensor.
  278. """
  279. assert conv_mode.lower() == "cross_correlation"
  280. D, H, W = 0, 1, 2
  281. pad = _expand_dhw(padding)
  282. stride = _expand_dhw(stride)
  283. dilate = _expand_dhw(dilation)
  284. sparse_type = "dense" if groups == 1 else "group"
  285. op = builtin.Convolution3D(
  286. pad_d=pad[D],
  287. pad_h=pad[H],
  288. pad_w=pad[W],
  289. stride_d=stride[D],
  290. stride_h=stride[H],
  291. stride_w=stride[W],
  292. dilate_d=dilate[D],
  293. dilate_h=dilate[H],
  294. dilate_w=dilate[W],
  295. strategy=get_execution_strategy(),
  296. mode=conv_mode,
  297. sparse=sparse_type,
  298. )
  299. (output,) = apply(op, inp, weight)
  300. if bias is not None:
  301. output += bias
  302. return output
  303. def conv_transpose2d(
  304. inp: Tensor,
  305. weight: Tensor,
  306. bias: Optional[Tensor] = None,
  307. stride: Union[int, Tuple[int, int]] = 1,
  308. padding: Union[int, Tuple[int, int]] = 0,
  309. output_padding: Union[int, Tuple[int, int]] = 0,
  310. dilation: Union[int, Tuple[int, int]] = 1,
  311. groups: int = 1,
  312. conv_mode="cross_correlation",
  313. compute_mode="default",
  314. ) -> Tensor:
  315. r"""2D transposed convolution operation.
  316. Refer to :class:`~.module.conv.ConvTranspose2d` for more information.
  317. Args:
  318. inp: feature map of the convolution operation.
  319. weight: convolution kernel.
  320. weight usually has shape ``(in_channels, out_channels, height, width)``.
  321. bias: bias added to the result of convolution (if given).
  322. stride: stride of the 2D convolution operation. Default: 1
  323. padding: size of the paddings added to the input on both sides of its
  324. spatial dimensions. Only zero-padding is supported. Default: 0
  325. output_padding: size of paddings appended to output. Default: 0
  326. dilation: dilation of the 2D convolution operation. Default: 1
  327. groups: number of groups into which the input and output channels are divided,
  328. so as to perform a ``grouped convolution``. When ``groups`` is not 1,
  329. ``in_channels`` and ``out_channels`` must be divisible by groups,
  330. and the shape of weight should be ``(groups, in_channels // groups,
  331. out_channels // groups, height, width)``. Default: 1
  332. conv_mode: supports "cross_correlation". Default: "cross_correlation"
  333. compute_mode: when set to "default", no special requirements will be
  334. placed on the precision of intermediate results. When set to "float32",
  335. "float32" would be used for accumulator and intermediate result, but only
  336. effective when input and output are of float16 dtype.
  337. Returns:
  338. output tensor.
  339. """
  340. assert (
  341. conv_mode.lower() == "cross_correlation"
  342. or conv_mode.name == "CROSS_CORRELATION"
  343. )
  344. stride_h, stride_w = _expand_hw(stride)
  345. pad_h, pad_w = _expand_hw(padding)
  346. output_pad_h, output_pad_w = _expand_hw(output_padding)
  347. dilate_h, dilate_w = _expand_hw(dilation)
  348. compute_mode = _config._get_actual_op_param(compute_mode, _config.__compute_mode)
  349. sparse_type = "dense" if groups == 1 else "group"
  350. op = builtin.ConvolutionBackwardData(
  351. stride_h=stride_h,
  352. stride_w=stride_w,
  353. pad_h=pad_h,
  354. pad_w=pad_w,
  355. dilate_h=dilate_h,
  356. dilate_w=dilate_w,
  357. strategy=get_execution_strategy(),
  358. compute_mode=compute_mode,
  359. sparse=sparse_type,
  360. )
  361. if output_pad_h != 0 or output_pad_h != 0:
  362. assert (
  363. output_pad_h < stride[0]
  364. ), "output_padding[0] shoule be less than stride[0]"
  365. assert (
  366. output_pad_w < stride[1]
  367. ), "output_padding[1] shoule be less than stride[1]"
  368. Hout = (
  369. (inp.shape[2] - 1) * stride[0]
  370. - 2 * padding[0]
  371. + dilation[0] * (weight.shape[2] - 1)
  372. + output_pad_h
  373. + 1
  374. )
  375. Wout = (
  376. (inp.shape[3] - 1) * stride[1]
  377. - 2 * padding[1]
  378. + dilation[1] * (weight.shape[3] - 1)
  379. + output_pad_w
  380. + 1
  381. )
  382. output_shape = [inp.shape[0], weight.shape[1], Hout, Wout]
  383. output_shape = astensor1d(output_shape)
  384. (output,) = apply(op, weight, inp, output_shape)
  385. else:
  386. (output,) = apply(op, weight, inp)
  387. if bias is not None:
  388. if amp._enabled:
  389. bias = cast_tensors(bias)
  390. output += bias
  391. return output
  392. def deformable_conv2d(
  393. inp: Tensor,
  394. weight: Tensor,
  395. offset: Tensor,
  396. mask: Tensor,
  397. bias: Optional[Tensor] = None,
  398. stride: Union[int, Tuple[int, int]] = 1,
  399. padding: Union[int, Tuple[int, int]] = 0,
  400. dilation: Union[int, Tuple[int, int]] = 1,
  401. groups: int = 1,
  402. conv_mode="cross_correlation",
  403. compute_mode="default",
  404. ) -> Tensor:
  405. r"""Deformable Convolution.
  406. Args:
  407. inp: input feature map.
  408. weight: convolution kernel.
  409. weight usually has shape ``(out_channels, in_channels, height, width)``.
  410. offset: input offset to kernel, channel of this tensor should match the deformable settings.
  411. mask: input mask to kernel, channel of this tensor should match the deformable settings.
  412. bias: bias added to the result of convolution (if given).
  413. stride: stride of the 2D convolution operation. Default: 1
  414. padding: size of the paddings added to the input on both sides of its
  415. spatial dimensions. Only zero-padding is supported. Default: 0
  416. dilation: dilation of the 2D convolution operation. Default: 1
  417. groups: number of groups into which the input and output channels are divided,
  418. so as to perform a ``grouped convolution``. When ``groups`` is not 1,
  419. ``in_channels`` and ``out_channels`` must be divisible by groups,
  420. and the shape of weight should be ``(groups, out_channel // groups,
  421. in_channels // groups, height, width)``. Default: 1
  422. conv_mode: supports "cross_correlation". Default: "cross_correlation"
  423. compute_mode: when set to "default", no special requirements will be
  424. placed on the precision of intermediate results. When set to "float32",
  425. "float32" would be used for accumulator and intermediate result, but only
  426. effective when input and output are of float16 dtype.
  427. Returns:
  428. output tensor.
  429. """
  430. assert (
  431. conv_mode.lower() == "cross_correlation"
  432. or conv_mode.name == "CROSS_CORRELATION"
  433. )
  434. if amp._enabled:
  435. inp, weight, offset, mask, bias = cast_tensors(inp, weight, offset, mask, bias)
  436. else:
  437. offset = offset.astype("float32")
  438. mask = mask.astype("float32")
  439. stride_h, stride_w = _expand_hw(stride)
  440. pad_h, pad_w = _expand_hw(padding)
  441. dilate_h, dilate_w = _expand_hw(dilation)
  442. compute_mode = _config._get_actual_op_param(compute_mode, _config.__compute_mode)
  443. sparse_type = "dense" if groups == 1 else "group"
  444. op = builtin.DeformableConv(
  445. stride_h=stride_h,
  446. stride_w=stride_w,
  447. pad_h=pad_h,
  448. pad_w=pad_w,
  449. dilate_h=dilate_h,
  450. dilate_w=dilate_w,
  451. strategy=get_execution_strategy(),
  452. mode=conv_mode,
  453. compute_mode=compute_mode,
  454. sparse=sparse_type,
  455. )
  456. (output,) = apply(op, inp, weight, offset, mask)
  457. if bias is not None:
  458. output += bias
  459. return output
  460. def local_conv2d(
  461. inp: Tensor,
  462. weight: Tensor,
  463. bias: Optional[Tensor] = None,
  464. stride: Union[int, Tuple[int, int]] = 1,
  465. padding: Union[int, Tuple[int, int]] = 0,
  466. dilation: Union[int, Tuple[int, int]] = 1,
  467. conv_mode="cross_correlation",
  468. ):
  469. r"""Applies a spatial convolution with untied kernels over an groupped channeled input 4D tensor.
  470. It is also known as the locally connected layer.
  471. Args:
  472. inp: input feature map.
  473. weight: convolution kernel.
  474. weight usually has shape ``(out_channels, in_channels, height, width)``.
  475. bias: bias added to the result of convolution (if given).
  476. stride: stride of the 2D convolution operation. Default: 1
  477. padding: size of the paddings added to the input on both sides of its
  478. spatial dimensions. Only zero-padding is supported. Default: 0
  479. dilation: dilation of the 2D convolution operation. Default: 1
  480. Returns:
  481. output tensor.
  482. """
  483. assert (
  484. conv_mode.lower() == "cross_correlation"
  485. or conv_mode.name == "CROSS_CORRELATION"
  486. )
  487. stride_h, stride_w = _expand_hw(stride)
  488. pad_h, pad_w = _expand_hw(padding)
  489. dilate_h, dilate_w = _expand_hw(dilation)
  490. # local conv only support "dense" mode, but weight could contain group dimension.
  491. op = builtin.GroupLocal(
  492. stride_h=stride_h,
  493. stride_w=stride_w,
  494. pad_h=pad_h,
  495. pad_w=pad_w,
  496. dilate_h=dilate_h,
  497. dilate_w=dilate_w,
  498. mode=conv_mode,
  499. sparse="dense",
  500. )
  501. (output,) = apply(op, inp, weight)
  502. if bias is not None:
  503. output += bias
  504. return output
  505. def conv_transpose3d(
  506. inp: Tensor,
  507. weight: Tensor,
  508. bias: Optional[Tensor] = None,
  509. stride: Union[int, Tuple[int, int, int]] = 1,
  510. padding: Union[int, Tuple[int, int, int]] = 0,
  511. output_padding: Union[int, Tuple[int, int, int]] = 0,
  512. dilation: Union[int, Tuple[int, int, int]] = 1,
  513. groups: int = 1,
  514. ) -> Tensor:
  515. r"""3D transposed convolution operation. Only support the case that groups = 1
  516. and conv_mode = "cross_correlation".
  517. Refer to :class:`~.ConvTranspose3d` for more information.
  518. Args:
  519. inp: feature map of the convolution operation.
  520. weight: convolution kernel.
  521. weight usually has shape ``(in_channels, out_channels, depth, height, width)``.
  522. bias: bias added to the result of convolution (if given).
  523. stride: stride of the 3D convolution operation. Default: 1
  524. padding: size of the paddings added to the input on all sides of its
  525. spatial dimensions. Only zero-padding is supported. Default: 0
  526. output_padding: size of paddings appended to output. Default: 0
  527. dilation: dilation of the 3D convolution operation. Default: 1
  528. groups: number of groups into which the input and output channels are divided,
  529. so as to perform a ``grouped convolution``. When ``groups`` is not 1,
  530. ``in_channels`` and ``out_channels`` must be divisible by groups,
  531. and the shape of weight should be ``(groups, in_channels // groups,
  532. out_channels // groups, depth, height, width)``. Default: 1
  533. Returns:
  534. output tensor.
  535. """
  536. D, H, W = 0, 1, 2
  537. pad = _expand_dhw(padding)
  538. stride = _expand_dhw(stride)
  539. dilate = _expand_dhw(dilation)
  540. output_padding = _expand_dhw(output_padding)
  541. sparse_type = "dense" if groups == 1 else "group"
  542. op = builtin.Convolution3DBackwardData(
  543. pad_d=pad[D],
  544. pad_h=pad[H],
  545. pad_w=pad[W],
  546. stride_d=stride[D],
  547. stride_h=stride[H],
  548. stride_w=stride[W],
  549. dilate_d=dilate[D],
  550. dilate_h=dilate[H],
  551. dilate_w=dilate[W],
  552. strategy=get_execution_strategy(),
  553. sparse=sparse_type,
  554. )
  555. if output_padding[0] != 0 or output_padding[1] != 0 or output_padding[2] != 0:
  556. assert (
  557. output_padding[0] < stride[0]
  558. ), "output_padding[0] shoule be less than stride[0]"
  559. assert (
  560. output_padding[1] < stride[1]
  561. ), "output_padding[1] shoule be less than stride[1]"
  562. assert (
  563. output_padding[2] < stride[2]
  564. ), "output_padding[2] shoule be less than stride[2]"
  565. Dout = (
  566. (inp.shape[2] - 1) * stride[0]
  567. - 2 * padding[0]
  568. + dilation[0] * (weight.shape[2] - 1)
  569. + output_padding[0]
  570. + 1
  571. )
  572. Hout = (
  573. (inp.shape[3] - 1) * stride[1]
  574. - 2 * padding[1]
  575. + dilation[1] * (weight.shape[3] - 1)
  576. + output_padding[1]
  577. + 1
  578. )
  579. Wout = (
  580. (inp.shape[4] - 1) * stride[2]
  581. - 2 * padding[2]
  582. + dilation[2] * (weight.shape[4] - 1)
  583. + output_padding[2]
  584. + 1
  585. )
  586. output_shape = [inp.shape[0], weight.shape[1], Dout, Hout, Wout]
  587. output_shape = astensor1d(output_shape)
  588. (output,) = apply(op, weight, inp, output_shape)
  589. else:
  590. (output,) = apply(op, weight, inp)
  591. if bias is not None:
  592. output += bias
  593. return output
  594. def max_pool2d(
  595. inp: Tensor,
  596. kernel_size: Union[int, Tuple[int, int]],
  597. stride: Optional[Union[int, Tuple[int, int]]] = None,
  598. padding: Union[int, Tuple[int, int]] = 0,
  599. ) -> Tensor:
  600. r"""Applies a 2D max pooling over an input tensor.
  601. Refer to :class:`~.MaxPool2d` for more information.
  602. Args:
  603. inp: input tensor of shape :math:`(N, C, H_{\text{in}}, W_{\text{in}})`.
  604. kernel_size: size of the window used to calculate the max value.
  605. stride: stride of the window. Default value is ``kernel_size``.
  606. padding: implicit zero padding added on both sides. Default: 0.
  607. Returns:
  608. output tensor of shape `(N, C, H_{\text{out}}, W_{\text{out}})`.
  609. Examples:
  610. >>> import numpy as np
  611. >>> input = Tensor(np.arange(1 * 1 * 3 * 4).astype(np.float32).reshape(1, 1, 3, 4))
  612. >>> F.nn.max_pool2d(input, 2, 1, 0)
  613. Tensor([[[[ 5. 6. 7.]
  614. [ 9. 10. 11.]]]], device=xpux:0)
  615. """
  616. if stride is None:
  617. stride = kernel_size
  618. window_h, window_w = _expand_hw(kernel_size)
  619. stride_h, stride_w = _expand_hw(stride)
  620. padding_h, padding_w = _expand_hw(padding)
  621. op = builtin.Pooling(
  622. window_h=window_h,
  623. window_w=window_w,
  624. stride_h=stride_h,
  625. stride_w=stride_w,
  626. pad_h=padding_h,
  627. pad_w=padding_w,
  628. mode="max",
  629. strategy=get_execution_strategy(),
  630. )
  631. (output,) = apply(op, inp)
  632. return output
  633. def avg_pool2d(
  634. inp: Tensor,
  635. kernel_size: Union[int, Tuple[int, int]],
  636. stride: Optional[Union[int, Tuple[int, int]]] = None,
  637. padding: Union[int, Tuple[int, int]] = 0,
  638. mode: str = "average_count_exclude_padding",
  639. ) -> Tensor:
  640. r"""Applies 2D average pooling over an input tensor.
  641. Refer to :class:`~.AvgPool2d` for more information.
  642. Args:
  643. inp: input tensor of shape :math:`(N, C, H_{\text{in}}, W_{\text{in}})` .
  644. kernel_size: size of the window used to calculate the average value.
  645. stride: stride of the window. Default value is ``kernel_size``.
  646. padding: implicit zero padding added on both sides. Default: 0.
  647. mode: whether to include the padding values while calculating the average, set
  648. to "average" will do counting.
  649. Default: "average_count_exclude_padding"
  650. Returns:
  651. output tensor of shape :math:`(N, C, H_{\text{out}}, W_{\text{out}})`.
  652. Examples:
  653. >>> import numpy as np
  654. >>> inp = Tensor(np.arange(1 * 1 * 3 * 4).astype(np.float32).reshape(1, 1, 3, 4))
  655. >>> F.avg_pool2d(inp, kernel_size=2, stride=2, padding=[1,0], mode="average")
  656. Tensor([[[[0.25 1.25]
  657. [6.5 8.5 ]]]], device=xpux:0)
  658. """
  659. if stride is None:
  660. stride = kernel_size
  661. window_h, window_w = _expand_hw(kernel_size)
  662. stride_h, stride_w = _expand_hw(stride)
  663. padding_h, padding_w = _expand_hw(padding)
  664. op = builtin.Pooling(
  665. window_h=window_h,
  666. window_w=window_w,
  667. stride_h=stride_h,
  668. stride_w=stride_w,
  669. pad_h=padding_h,
  670. pad_w=padding_w,
  671. mode=mode,
  672. strategy=get_execution_strategy(),
  673. )
  674. (output,) = apply(op, inp)
  675. return output
  676. def adaptive_max_pool2d(
  677. inp: Tensor, oshp: Union[Tuple[int, int], int, Tensor],
  678. ) -> Tensor:
  679. r"""Applies a 2D max adaptive pooling over an input.
  680. Refer to :class:`~.MaxAdaptivePool2d` for more information.
  681. Args:
  682. inp: input tensor.
  683. oshp: `(OH, OW)` size of the output shape.
  684. Returns:
  685. output tensor.
  686. """
  687. return adaptive_pool2d_cpp(inp, oshp, "MAX")
  688. def adaptive_avg_pool2d(
  689. inp: Tensor, oshp: Union[Tuple[int, int], int, Tensor],
  690. ) -> Tensor:
  691. r"""Applies a 2D average adaptive pooling over an input.
  692. Refer to :class:`~.AvgAdaptivePool2d` for more information.
  693. Args:
  694. inp: input tensor.
  695. oshp: `(OH, OW)` size of the output shape.
  696. Returns:
  697. output tensor.
  698. """
  699. return adaptive_pool2d_cpp(inp, oshp, "AVERAGE")
  700. def deformable_psroi_pooling(
  701. inp: Tensor,
  702. rois: Tensor,
  703. trans: Tensor,
  704. no_trans: bool,
  705. part_size: int,
  706. pooled_h: int,
  707. pooled_w: int,
  708. sample_per_part: int,
  709. spatial_scale: float,
  710. trans_std: float = 0.1,
  711. ):
  712. r"""Deformable PSROI(Position Sensitive Region of Interest) Pooling.
  713. Args:
  714. inp: input feature map.
  715. rois: the rois for feature pooling.
  716. trans: input offset to psroi_pooling.
  717. no_trans: check the phase of DeformablePSROIPooling. False to the
  718. 1st phase, True to the 2nd phase.
  719. part_size: part size.
  720. sample_per_part: sample points of each part.
  721. pooled_shape: kernel shape of convolution.
  722. spatial_scale: the spatial_scale w.r.t input image.
  723. trans_std: multiplier used in 2nd phase.
  724. """
  725. op = builtin.DeformablePSROIPooling(
  726. no_trans=no_trans,
  727. part_size=part_size,
  728. pooled_h=pooled_h,
  729. pooled_w=pooled_w,
  730. sample_per_part=sample_per_part,
  731. spatial_scale=spatial_scale,
  732. trans_std=trans_std,
  733. )
  734. output, _ = apply(op, inp, rois, trans)
  735. return output
  736. def hswish(x):
  737. r"""Element-wise `x * relu6(x + 3) / 6`.
  738. Example:
  739. >>> import numpy as np
  740. >>> x = Tensor(np.arange(5).astype(np.float32))
  741. >>> out = F.hswish(x)
  742. >>> out.numpy().round(decimals=4)
  743. array([0. , 0.6667, 1.6667, 3. , 4. ], dtype=float32)
  744. """
  745. return _elwise(x, mode=Elemwise.Mode.H_SWISH)
  746. def sigmoid(x):
  747. r"""Element-wise `1 / ( 1 + exp( -x ) )`."""
  748. return _elwise(x, mode=Elemwise.Mode.SIGMOID)
  749. @lru_cache(maxsize=None)
  750. def _get_hsigmoid_op(dtype=None, device=None):
  751. @subgraph_fn(
  752. "Hsigmoid",
  753. dtype=dtype,
  754. device=device,
  755. nr_inputs=1,
  756. jit_fusion=True,
  757. custom_grad=True,
  758. )
  759. def hsigmoid(inputs, f, c):
  760. (inp,) = inputs[0:1]
  761. inp = f("+", inp, c(3))
  762. max_0 = f("max", inp, c(0))
  763. min_6 = f("min", max_0, c(6))
  764. oup = f("/", min_6, c(6))
  765. (oup_grad,) = yield (oup,)
  766. inp_grad = f("/", oup_grad, c(6))
  767. inp_grad = f("cond_leq_mov", max_0, c(6), inp_grad)
  768. inp_grad = f("cond_leq_mov", c(0), inp, inp_grad)
  769. yield (inp_grad,)
  770. return hsigmoid
  771. def hsigmoid(x):
  772. r"""Element-wise `relu6(x + 3) / 6`."""
  773. hsigmoid = _get_hsigmoid_op(x.dtype, x.device)
  774. (x,) = hsigmoid(x)
  775. return x
  776. # return relu6(x + 3) / 6
  777. def relu(x):
  778. r"""Element-wise `max(x, 0)`."""
  779. return _elwise(x, mode=Elemwise.Mode.RELU)
  780. @lru_cache(maxsize=None)
  781. def _get_relu6_op(dtype=None, device=None):
  782. @subgraph_fn(
  783. "ReLU6",
  784. dtype=dtype,
  785. device=device,
  786. nr_inputs=1,
  787. jit_fusion=True,
  788. custom_grad=True,
  789. )
  790. def relu6(inputs, f, c):
  791. (inp,) = inputs[0:1]
  792. max_0 = f("max", inp, c(0))
  793. min_6 = f("min", max_0, c(6))
  794. oup = min_6
  795. (oup_grad,) = yield (oup,)
  796. inp_grad = f("cond_leq_mov", max_0, c(6), oup_grad)
  797. inp_grad = f("cond_leq_mov", c(0), inp, inp_grad)
  798. yield (inp_grad,)
  799. return relu6
  800. def relu6(x):
  801. r"""Element-wise `min(max(x, 0), 6)`."""
  802. relu6 = _get_relu6_op(x.dtype, x.device)
  803. (x,) = relu6(x)
  804. return x
  805. @lru_cache(maxsize=None)
  806. def _get_prelu_op(dtype=None, device=None):
  807. @subgraph_fn(
  808. "PReLU",
  809. dtype=dtype,
  810. device=device,
  811. nr_inputs=2,
  812. jit_fusion=True,
  813. custom_grad=True,
  814. )
  815. def prelu(inputs, f, c):
  816. (inp, weight) = inputs[0:2]
  817. max_0 = f("max", inp, c(0))
  818. min_0 = f("min", inp, c(0))
  819. oup = f("fma3", min_0, weight, max_0)
  820. (oup_grad,) = yield (oup,)
  821. inp_grad_0 = f("cond_leq_mov", c(0), inp, oup_grad)
  822. inp_grad_1 = f("*", oup_grad, weight)
  823. inp_grad_1 = f("cond_leq_mov", inp, c(0), inp_grad_1)
  824. inp_grad = f("+", inp_grad_0, inp_grad_1)
  825. weight_grad = f("*", oup_grad, min_0)
  826. yield (inp_grad, weight_grad)
  827. return prelu
  828. def prelu(inp: Tensor, weight: Tensor) -> Tensor:
  829. r"""Element-wise PReLU function.
  830. Refer to :class:`~.PReLU` for more information.
  831. """
  832. prelu = _get_prelu_op(dtype=inp.dtype, device=inp.device)
  833. (oup,) = prelu(inp, broadcast_to(weight, inp.shape))
  834. return oup
  835. @lru_cache(maxsize=None)
  836. def _get_leaky_relu_op(negative_slope, *, dtype=None, device=None):
  837. @subgraph_fn(
  838. "LeakyReLU",
  839. dtype=dtype,
  840. device=device,
  841. nr_inputs=1,
  842. jit_fusion=True,
  843. custom_grad=True,
  844. )
  845. def leakyReLU(inputs, f, c):
  846. (inp,) = inputs[0:1]
  847. max_0 = f("max", inp, c(0))
  848. min_0 = f("min", inp, c(0))
  849. oup = f("+", max_0, f("*", min_0, c(negative_slope)))
  850. (oup_grad,) = yield (oup,)
  851. inp_grad_0 = f("cond_leq_mov", c(0), inp, oup_grad)
  852. inp_grad_1 = f("*", oup_grad, c(negative_slope))
  853. inp_grad_1 = f("cond_leq_mov", inp, c(0), inp_grad_1)
  854. inp_grad = f("+", inp_grad_0, inp_grad_1)
  855. yield (inp_grad,)
  856. return leakyReLU
  857. def leaky_relu(inp: Tensor, negative_slope: float = 0.01) -> Tensor:
  858. r"""Element-wise LeakyReLU function
  859. Refer to :class:`~.LeakyReLU` for more information.
  860. """
  861. leakyReLU = _get_leaky_relu_op(negative_slope, dtype=inp.dtype, device=inp.device)
  862. (oup,) = leakyReLU(inp)
  863. return oup
  864. def silu(x):
  865. r"""Applies the element-wise Sigmoid Linear Unit function, i.e. `x * sigmoid(x)`."""
  866. return _elwise(x, mode=Elemwise.Mode.SILU)
  867. def gelu(x):
  868. r"""Applies the element-wise function:
  869. .. math::
  870. \text{gelu}(x) = x\Phi(x)
  871. where :math:`\Phi(x)` is the Cumulative Distribution Function for Gaussian Distribution.
  872. """
  873. return _elwise(x, mode=Elemwise.Mode.GELU)
  874. @lru_cache(maxsize=None)
  875. def _get_softplus_op(dtype=None, device=None):
  876. @subgraph_fn(
  877. "Softplus",
  878. dtype=dtype,
  879. device=device,
  880. nr_inputs=1,
  881. jit_fusion=True,
  882. custom_grad=True,
  883. )
  884. def softplus(inputs, f, c):
  885. (inp,) = inputs[0:1]
  886. neg_abs = f("-", f("abs", inp))
  887. exp = f("exp", neg_abs)
  888. oup0 = f("log1p", exp)
  889. oup1 = f("relu", inp)
  890. oup = f("+", oup0, oup1)
  891. (oup_grad,) = yield (oup,)
  892. inp_grad_0 = f("switch_gt0", oup1, oup_grad)
  893. inp_grad_1 = oup_grad
  894. inp_grad_1 = f("/", oup_grad, f("+", exp, c(1)))
  895. inp_grad_1 = f("*", inp_grad_1, exp)
  896. inp_grad_1 = f("-", inp_grad_1)
  897. inp_grad_1 = f("abs_grad", inp, inp_grad_1)
  898. inp_grad = f("+", inp_grad_0, inp_grad_1)
  899. yield (inp_grad,)
  900. return softplus
  901. def softplus(inp: Tensor) -> Tensor:
  902. r"""Applies the element-wise function:
  903. .. math::
  904. \text{softplus}(x) = \log(1 + \exp(x))
  905. softplus is a smooth approximation to the ReLU function and can be used
  906. to constrain the output to be always positive.
  907. For numerical stability the implementation follows this transformation:
  908. .. math::
  909. \text{softplus}(x) = \log(1 + \exp(x))
  910. = \log(1 + \exp(-\text{abs}(x))) + \max(x, 0)
  911. = \log1p(\exp(-\text{abs}(x))) + \text{relu}(x)
  912. Examples:
  913. >>> import numpy as np
  914. >>> x = Tensor(np.arange(-3, 3, dtype=np.float32))
  915. >>> y = F.softplus(x)
  916. >>> y.numpy().round(decimals=4)
  917. array([0.0486, 0.1269, 0.3133, 0.6931, 1.3133, 2.1269], dtype=float32)
  918. """
  919. softplus = _get_softplus_op(inp.dtype, inp.device)
  920. (oup,) = softplus(inp)
  921. return oup
  922. def logsoftmax(inp: Tensor, axis: Union[int, Sequence[int]]) -> Tensor:
  923. r"""Applies the :math:`\log(\text{softmax}(x))` function to an n-dimensional
  924. input tensor. The :math:`\text{logsoftmax}(x)` formulation can be simplified as:
  925. .. math::
  926. \text{logsoftmax}(x_{i}) = \log(\frac{\exp(x_i) }{ \sum_j \exp(x_j)} )
  927. For numerical stability the implementation follows this transformation:
  928. .. math::
  929. \text{logsoftmax}(x)
  930. = \log (\frac{\exp (x)}{\sum_{i}(\exp (x_{i}))})
  931. = x - \log (\sum_{i}(\exp (x_{i})))
  932. = x - \text{logsumexp}(x)
  933. Examples:
  934. >>> import numpy as np
  935. >>> x = Tensor(np.arange(-5, 5, dtype=np.float32)).reshape(2,5)
  936. >>> y = F.logsoftmax(x, axis=1)
  937. >>> y.numpy().round(decimals=4)
  938. array([[-4.4519, -3.4519, -2.4519, -1.4519, -0.4519],
  939. [-4.4519, -3.4519, -2.4519, -1.4519, -0.4519]], dtype=float32)
  940. """
  941. return inp - logsumexp(inp, axis, keepdims=True)
  942. @lru_cache(maxsize=None)
  943. def _get_logsigmoid_op(dtype=None, device=None):
  944. @subgraph_fn(
  945. "LogSigmoid",
  946. dtype=dtype,
  947. device=device,
  948. nr_inputs=1,
  949. jit_fusion=True,
  950. custom_grad=True,
  951. )
  952. def logsigmoid(inputs, f, c):
  953. (inp,) = inputs[0:1]
  954. neg_abs = f("-", f("abs", inp))
  955. exp = f("exp", neg_abs)
  956. oup0 = f("log1p", exp)
  957. oup1 = f("relu", f("-", inp))
  958. oup = f("+", oup0, oup1)
  959. oup = f("-", oup)
  960. (oup_grad,) = yield (oup,)
  961. oup_grad = f("-", oup_grad)
  962. inp_grad_0 = f("switch_gt0", oup1, oup_grad)
  963. inp_grad_0 = f("-", inp_grad_0)
  964. inp_grad_1 = oup_grad
  965. inp_grad_1 = f("/", inp_grad_1, f("+", exp, c(1)))
  966. inp_grad_1 = f("*", inp_grad_1, exp)
  967. inp_grad_1 = f("-", inp_grad_1)
  968. inp_grad_1 = f("abs_grad", inp, inp_grad_1)
  969. inp_grad = f("+", inp_grad_0, inp_grad_1)
  970. yield (inp_grad,)
  971. return logsigmoid
  972. def logsigmoid(inp: Tensor) -> Tensor:
  973. r"""Applies the element-wise function:
  974. .. math::
  975. \text{logsigmoid}(x) = \log(\frac{ 1 }{ 1 + \exp(-x)})
  976. = \log(1/(1 + \exp(-x)))
  977. = - \log(1 + \exp(-x))
  978. = - \text{softplus}(-x)
  979. Examples:
  980. >>> import numpy as np
  981. >>> x = Tensor(np.arange(-5, 5, dtype=np.float32))
  982. >>> y = F.logsigmoid(x)
  983. >>> y.numpy().round(decimals=4)
  984. array([-5.0067, -4.0182, -3.0486, -2.1269, -1.3133, -0.6931, -0.3133,
  985. -0.1269, -0.0486, -0.0181], dtype=float32)
  986. """
  987. logsigmoid = _get_logsigmoid_op(inp.dtype, inp.device)
  988. (oup,) = logsigmoid(inp)
  989. return oup
  990. def logsumexp(
  991. inp: Tensor, axis: Union[int, Sequence[int]], keepdims: bool = False
  992. ) -> Tensor:
  993. r"""Calculates the logarithm of the inputs' exponential sum along the given :attr:`axis`.
  994. .. math::
  995. \text{logsumexp}(x)= \log \sum_{j=1}^{n} \exp \left(x_{j}\right)
  996. For numerical stability, the implementation follows this transformation:
  997. .. math::
  998. \text{logsumexp}(x)= \log \sum_{j=1}^{n} \exp \left(x_{j}\right)
  999. = \text{logsumexp}(x)=b+\log \sum_{j=1}^{n} \exp \left(x_{j}-b\right)
  1000. where
  1001. .. math::
  1002. b = \max(x_j)
  1003. Examples:
  1004. >>> import numpy as np
  1005. >>> x = Tensor(np.arange(-5, 5, dtype=np.float32)).reshape(2,5)
  1006. >>> y = F.logsumexp(x, axis=1, keepdims=False)
  1007. >>> y.numpy().round(decimals=4)
  1008. array([-0.5481, 4.4519], dtype=float32)
  1009. """
  1010. max_value = max(inp.detach(), axis, keepdims=True)
  1011. if keepdims:
  1012. return max_value + log(sum(exp(inp - max_value), axis, keepdims))
  1013. else:
  1014. return squeeze(max_value, axis=None) + log(
  1015. sum(exp(inp - max_value), axis, keepdims)
  1016. )
  1017. def _get_softmax_axis(ndim: int) -> int:
  1018. if ndim in (0, 1, 3):
  1019. return 0
  1020. return 1
  1021. def softmax(inp: Tensor, axis: Optional[int] = None) -> Tensor:
  1022. r"""Applies a :math:`\text{softmax}(x)` function. :math:`\text{softmax}(x)` is defined as:
  1023. .. math::
  1024. \text{softmax}(x_{i}) = \frac{\exp(x_i)}{\sum_j \exp(x_j)}
  1025. It is applied to all elements along axis, and rescales elements so that
  1026. they stay in the range `[0, 1]` and sum to 1.
  1027. See :class:`~.module.Softmax` for more details.
  1028. Examples:
  1029. >>> import numpy as np
  1030. >>> x = Tensor(np.arange(-5, 5, dtype=np.float32)).reshape(2,5)
  1031. >>> out = F.softmax(x)
  1032. >>> out.numpy().round(decimals=4)
  1033. array([[0.0117, 0.0317, 0.0861, 0.2341, 0.6364],
  1034. [0.0117, 0.0317, 0.0861, 0.2341, 0.6364]], dtype=float32)
  1035. """
  1036. if axis is None:
  1037. axis = _get_softmax_axis(len(inp.shape))
  1038. if isinstance(axis, list):
  1039. offset = inp.max(axis=axis, keepdims=True).detach()
  1040. cached = exp(inp - offset)
  1041. down = sum(cached, axis=axis, keepdims=True)
  1042. return cached / down
  1043. else:
  1044. op = builtin.Softmax(axis=axis,)
  1045. (output,) = apply(op, inp)
  1046. return output
  1047. def layer_norm(
  1048. inp: Tensor,
  1049. normalized_shape: tuple,
  1050. affine: bool,
  1051. weight: Optional[Tensor] = None,
  1052. bias: Optional[Tensor] = None,
  1053. eps: float = 1e-5,
  1054. ):
  1055. r"""Applies layer normalization to the input. Support tensor of any shape as input.
  1056. Reference: https://arxiv.org/pdf/1803.08494.pdf.
  1057. Args:
  1058. inp: input tensor.
  1059. normalized_shape: the shape that you want to be normalizated
  1060. affine: whether to use weight and bias
  1061. weight: must not be None when the affine is true
  1062. bias: must not be None when the affine is true
  1063. eps: a value added to the denominator for numerical stability. Default: 1e-5
  1064. """
  1065. if isinstance(normalized_shape, int):
  1066. normalized_shape = [normalized_shape]
  1067. normalized_dim = len(normalized_shape)
  1068. assert normalized_dim > 0
  1069. normalized_size = 1
  1070. for i in range(normalized_dim):
  1071. normalized_size = normalized_size * normalized_shape[i]
  1072. op = builtin.LayerNorm(
  1073. affine=affine,
  1074. eps=eps,
  1075. normalized_dim=normalized_dim,
  1076. normalized_size=normalized_size,
  1077. )
  1078. if affine:
  1079. assert weight is not None and bias is not None
  1080. return apply(op, inp, weight, bias)[0]
  1081. else:
  1082. # assert weight is None and bias is None
  1083. return apply(op, inp)[0]
  1084. def batch_norm(
  1085. inp: Tensor,
  1086. running_mean: Tensor = None,
  1087. running_var: Tensor = None,
  1088. weight: Optional[Tensor] = None,
  1089. bias: Optional[Tensor] = None,
  1090. *,
  1091. training: bool = False,
  1092. momentum: float = 0.9,
  1093. eps: float = 1e-5,
  1094. inplace: bool = True,
  1095. ):
  1096. r"""Applies batch normalization to the input.
  1097. Refer to :class:`~.BatchNorm2d` and :class:`~.BatchNorm1d` for more information.
  1098. Args:
  1099. inp: input tensor.
  1100. running_mean: tensor to store running mean.
  1101. running_var: tensor to store running variance.
  1102. weight: scaling tensor in the learnable affine parameters.
  1103. See :math:`\gamma` in :class:`~.BatchNorm2d`.
  1104. bias: bias tensor in the learnable affine parameters.
  1105. See :math:`\beta` in :class:`~.BatchNorm2d`.
  1106. training: a boolean value to indicate whether batch norm is performed
  1107. in training mode. Default: False
  1108. momentum: value used for the ``running_mean`` and ``running_var``
  1109. computation. Default: 0.9
  1110. eps: a value added to the denominator for numerical stability. Default: 1e-5
  1111. inplace: whether to update ``running_mean`` and ``running_var``
  1112. inplace or return new tensors. Default: True
  1113. compute_mode: When set to 'default', no special requirements will be
  1114. placed on the precision of intermediate results. When set to 'float32',
  1115. float32 would be used for accumulator and intermediate result, but only
  1116. effective when input and output are of float16 dtype.
  1117. param_dim: a value indicating in which format the parameters are.
  1118. Default: 'dim_1c11', which means NCHW format.
  1119. And 'dim_111c' means NHWC format.
  1120. """
  1121. def make_full_if_none(x, value):
  1122. x_ndim = None if x is None else x.ndim
  1123. # in general case, x will be returned here directly
  1124. if x_ndim is not None and x_ndim != 1:
  1125. return x
  1126. C = inp.shape[1]
  1127. pshape = (1, C, 1, 1)
  1128. if x is None:
  1129. x = Const(value, inp.dtype, inp.device)
  1130. shape = astensor1d(pshape, inp, dtype="int32", device=inp.device)
  1131. (result,) = apply(builtin.Broadcast(), x, shape)
  1132. result.format = inp.format
  1133. return result
  1134. else:
  1135. assert x_ndim == 1
  1136. shape = astensor1d(pshape, inp, dtype="int32", device=inp.device)
  1137. (result,) = apply(builtin.Reshape(), x, shape)
  1138. return result
  1139. has_mean = running_mean is not None
  1140. has_var = running_var is not None
  1141. if not training:
  1142. assert has_mean, "running_mean must be provided in inference mode"
  1143. assert has_var, "running_var must be provided in inference mode"
  1144. weight = make_full_if_none(weight, 1)
  1145. bias = make_full_if_none(bias, 0)
  1146. if not training:
  1147. op = builtin.BatchNorm(
  1148. fwd_mode=BatchNorm.FwdMode.INFERENCE, epsilon=eps, param_dim="dim_1c11"
  1149. )
  1150. ret = apply(op, inp, weight, bias, running_mean, running_var)[-1]
  1151. return ret
  1152. else:
  1153. op = builtin.BatchNorm(
  1154. avg_factor=1 - momentum, epsilon=eps, param_dim="dim_1c11"
  1155. )
  1156. if has_mean or has_var:
  1157. running_mean = make_full_if_none(running_mean, 0)
  1158. running_var = make_full_if_none(running_var, 1)
  1159. new_mean, new_var, *_, inp = apply(
  1160. op, inp, weight, bias, running_mean, running_var
  1161. )
  1162. if not has_mean:
  1163. new_mean = None
  1164. if not has_var:
  1165. new_var = None
  1166. if inplace:
  1167. if has_mean:
  1168. running_mean[...] = new_mean
  1169. if has_var:
  1170. running_var[...] = new_var
  1171. return inp
  1172. else:
  1173. return inp, new_mean, new_var
  1174. else:
  1175. inp = apply(op, inp, weight, bias)[-1]
  1176. return inp
  1177. @lru_cache(maxsize=None)
  1178. def _get_sync_bn_ops(device, dtype, eps_mode, ndim, channels):
  1179. # fmt: off
  1180. @subgraph("SyncBnStage0", dtype, device, 1)
  1181. def syncbn_stage0(inputs, f, c):
  1182. input = inputs[0]
  1183. reduce_shape = c((1, channels) + (1,) * (ndim - 2), dtype="int32", device=device)
  1184. input_shape = f(GetVarShape(), input)
  1185. input_elems = f(Reduce(mode="product", axis=0), input_shape)
  1186. reduce_elems = f(Reduce(mode="product", axis=0), reduce_shape)
  1187. reduce_size = f("//", input_elems, reduce_elems)
  1188. channel_x1s = f(Reduce(mode="sum"), input, reduce_shape)
  1189. channel_x2s = f(Reduce(mode="sum_sqr"), input, reduce_shape)
  1190. reduce_size_f = f(TypeCvt(dtype=dtype), reduce_size)
  1191. return (reduce_shape, reduce_size_f, channel_x1s, channel_x2s), (False, False, True, True)
  1192. @subgraph("SyncBnStage1", dtype, device, 7)
  1193. def syncbn_stage1(inputs, f, c):
  1194. input, reduce_size, channel_x1s, channel_x2s, eps = inputs[0:5]
  1195. weight, bias = inputs[5:7]
  1196. channel_mean = f("/", channel_x1s, reduce_size)
  1197. channel_var =\
  1198. f("+", f("/", f("**", channel_x1s, c(2)),
  1199. f("-", f("*", reduce_size, reduce_size))),
  1200. f("/", channel_x2s, reduce_size))
  1201. invsqrt_channel_var = f("**", f(eps_mode, channel_var, eps), c(-0.5))
  1202. inv_var_wt = f("*", invsqrt_channel_var, weight)
  1203. neg_channel_mean = f("-", channel_mean)
  1204. outvar =\
  1205. f("fma3", input, inv_var_wt,
  1206. f("+", f("*", neg_channel_mean, inv_var_wt),
  1207. bias))
  1208. return (outvar, channel_mean, channel_var), (True, True, True)
  1209. @subgraph("SyncBnStage1Inference", dtype, device, 6)
  1210. def syncbn_stage1_inference(inputs, f, c):
  1211. input, channel_mean, channel_var, eps = inputs[0:4]
  1212. weight, bias = inputs[4:6]
  1213. invsqrt_channel_var = f("**", f(eps_mode, channel_var, eps), c(-0.5))
  1214. inv_var_wt = f("*", invsqrt_channel_var, weight)
  1215. neg_channel_mean = f("-", channel_mean)
  1216. outvar =\
  1217. f("+", f("*", input, inv_var_wt),
  1218. f("+", f("*", neg_channel_mean, inv_var_wt),
  1219. bias))
  1220. return (outvar,), (True,)
  1221. @subgraph("SyncBnStage2", dtype, device, 7)
  1222. def syncbn_stage2(inputs, f, c):
  1223. running_mean, running_var, momentum = inputs[0:3]
  1224. reduce_size, channel_x1s, channel_x2s, channel_mean = inputs[3:7]
  1225. c1_minus_momentum = f("-", c(1), momentum)
  1226. reduce_size_minus_c1 = f("-", reduce_size, c(1))
  1227. running_mean = f("fma4",
  1228. running_mean, momentum,
  1229. c1_minus_momentum, channel_mean,
  1230. )
  1231. channel_variance_unbiased =\
  1232. f("+", f("/", f("**", channel_x1s, c(2)),
  1233. f("*", f("-", reduce_size),
  1234. reduce_size_minus_c1)),
  1235. f("/", channel_x2s,
  1236. reduce_size_minus_c1))
  1237. running_var = f("fma4",
  1238. running_var, momentum,
  1239. c1_minus_momentum, channel_variance_unbiased
  1240. )
  1241. return (running_mean, running_var), (True, True)
  1242. @subgraph("SyncBnConcatStats", dtype, device, 3)
  1243. def syncbn_concat_stats(inputs, f, c):
  1244. reduce_size, channel_x1s, channel_x2s = inputs[0:3]
  1245. reduce_size = f(builtin.Broadcast(), reduce_size, c([1]*ndim, dtype="int32"))
  1246. stats = f(builtin.Concat(axis=1, comp_node=device), reduce_size, channel_x1s, channel_x2s)
  1247. return (stats,), (True,)
  1248. @subgraph("SyncBnSplitStats", dtype, device, 1)
  1249. def syncbn_split_stats(inputs, f, c):
  1250. stats = inputs[0]
  1251. c_1 = c(1, dtype="int32")
  1252. channel_x1s_end = c(channels+1, dtype="int32")
  1253. def _subtensor(src, axis, begin, end):
  1254. items = (axis, (begin is not None), (end is not None), False, False),
  1255. args = ()
  1256. if begin is not None:
  1257. args += begin,
  1258. if end is not None:
  1259. args += end,
  1260. return f(builtin.Subtensor(items=items), src, *args)
  1261. reduce_size = _subtensor(stats, 1, None, c_1)
  1262. channel_x1s = _subtensor(stats, 1, c_1, channel_x1s_end)
  1263. channel_x2s = _subtensor(stats, 1, channel_x1s_end, None)
  1264. reduce_size = f(builtin.Reshape(), reduce_size, c_1)
  1265. return (reduce_size, channel_x1s, channel_x2s), (False, True, True)
  1266. # fmt: on
  1267. return (
  1268. syncbn_stage0,
  1269. syncbn_stage1,
  1270. syncbn_stage1_inference,
  1271. syncbn_stage2,
  1272. syncbn_concat_stats,
  1273. syncbn_split_stats,
  1274. )
  1275. def sync_batch_norm(
  1276. inp: Tensor,
  1277. running_mean: Tensor,
  1278. running_var: Tensor,
  1279. weight: Optional[Tensor] = None,
  1280. bias: Optional[Tensor] = None,
  1281. training: bool = False,
  1282. momentum: Union[float, Tensor] = 0.9,
  1283. eps: float = 1e-5,
  1284. eps_mode="additive",
  1285. group=WORLD,
  1286. ) -> Tensor:
  1287. r"""Applies synchronized batch normalization to the input.
  1288. Refer to :class:`~.BatchNorm2d` and :class:`~.BatchNorm1d` for more information.
  1289. Args:
  1290. inp: input tensor.
  1291. running_mean: tensor to store running mean.
  1292. running_var: tensor to store running variance.
  1293. weight: scaling tensor in the learnable affine parameters.
  1294. See :math:`\gamma` in :class:`~.BatchNorm2d`.
  1295. bias: bias tensor in the learnable affine parameters.
  1296. See :math:`\beta` in :class:`~.BatchNorm2d`.
  1297. training: a boolean value to indicate whether batch norm is performed
  1298. in traning mode. Default: False
  1299. momentum: value used for the ``running_mean`` and ``running_var``
  1300. computation. Default: 0.9
  1301. eps: a value added to the denominator for numerical stability.
  1302. Default: 1e-5
  1303. eps_mode: mode of calculation for eps, "max" or "additive".
  1304. Default: "additive"
  1305. group: communication group, caculate mean and variance between this group.
  1306. Default: :obj:`~megengine.distributed.WORLD`
  1307. """
  1308. _eps_mode = eps_mode.lower()
  1309. assert _eps_mode in {"max", "additive"}, "unknown eps_mode: {}".format(eps_mode)
  1310. if _eps_mode == "additive" and not (is_distributed() and training):
  1311. return batch_norm(
  1312. inp,
  1313. running_mean,
  1314. running_var,
  1315. weight,
  1316. bias,
  1317. training=training,
  1318. momentum=momentum,
  1319. eps=eps,
  1320. )
  1321. if amp._enabled:
  1322. inp, weight, bias, running_mean, running_var = cast_tensors(
  1323. inp, weight, bias, running_mean, running_var, promote=True
  1324. )
  1325. _channels = make_shape_tuple(inp.shape)[1]
  1326. _ndim = inp.ndim
  1327. _device = inp.device
  1328. _dtype = inp.dtype
  1329. if _ndim != 4:
  1330. raise NotImplementedError("sync_batch_norm for ndim != 4")
  1331. def _make_full_if_none(x, value):
  1332. if x is None:
  1333. x = Const(value, inp.dtype, _device)
  1334. (result,) = apply(builtin.Broadcast(), x, reduce_shape)
  1335. return result
  1336. elif x.ndim == 1:
  1337. (result,) = apply(builtin.Reshape(), x, reduce_shape)
  1338. return result
  1339. return x
  1340. (
  1341. syncbn_stage0,
  1342. syncbn_stage1,
  1343. syncbn_stage1_inference,
  1344. syncbn_stage2,
  1345. syncbn_concat_stats,
  1346. syncbn_split_stats,
  1347. ) = _get_sync_bn_ops(_device, _dtype, eps_mode, _ndim, _channels)
  1348. reduce_shape, reduce_size, channel_x1s, channel_x2s = apply(syncbn_stage0(), inp)
  1349. eps = convert_single_value(eps, dtype=inp.dtype, device=inp.device)
  1350. weight = _make_full_if_none(weight, 1)
  1351. bias = _make_full_if_none(bias, 0)
  1352. if training:
  1353. if is_distributed():
  1354. # reduce all nodes' data to calculate mean and variance
  1355. (stat,) = apply(
  1356. syncbn_concat_stats(), reduce_size, channel_x1s, channel_x2s
  1357. )
  1358. stat = all_reduce_sum(stat, group)
  1359. reduce_size, channel_x1s, channel_x2s = apply(syncbn_split_stats(), stat)
  1360. outvar, channel_mean, *_ = apply(
  1361. syncbn_stage1(),
  1362. inp,
  1363. reduce_size,
  1364. channel_x1s,
  1365. channel_x2s,
  1366. eps,
  1367. weight,
  1368. bias,
  1369. )
  1370. else:
  1371. assert running_var is not None and running_mean is not None
  1372. channel_mean = running_mean
  1373. channel_var = running_var
  1374. outvar, *_ = apply(
  1375. syncbn_stage1_inference(), inp, channel_mean, channel_var, eps, weight, bias
  1376. )
  1377. # outvar = output * weight + bias
  1378. # where output = inp * invsqrt_channel_variance + (
  1379. # -channel_mean * invsqrt_channel_variance
  1380. # )
  1381. # Manually expand output for gopt
  1382. if training and running_var is not None and running_mean is not None:
  1383. momentum = convert_single_value(momentum, dtype=inp.dtype, device=inp.device)
  1384. running_mean[...], running_var[...] = apply(
  1385. syncbn_stage2(),
  1386. running_mean,
  1387. running_var,
  1388. momentum,
  1389. reduce_size,
  1390. channel_x1s,
  1391. channel_x2s,
  1392. channel_mean,
  1393. )
  1394. if amp._enabled:
  1395. outvar = outvar.astype("float16")
  1396. return outvar
  1397. def dropout(inp: Tensor, drop_prob: float, training: bool = True) -> Tensor:
  1398. r"""Returns a new tensor where each of the elements are randomly set to zero
  1399. with probability P = ``drop_prob``. Optionally rescale the output tensor if ``training`` is True.
  1400. Args:
  1401. inp: input tensor.
  1402. drop_prob: probability to drop (set to zero) a single element.
  1403. training: the default behavior of ``dropout`` during training is to rescale the output,
  1404. then it can be replaced by an :class:`~.module.identify.Identity` during inference. Default: True
  1405. Returns:
  1406. the ouput tensor
  1407. Examples:
  1408. >>> import numpy as np
  1409. >>> data = Tensor(np.ones(10000000, dtype=np.float32))
  1410. >>> out = F.nn.dropout(data, 1.0 / 3.0, training=True)
  1411. >>> assert not out.numpy().all()
  1412. >>> out = F.nn.dropout(data, 1.0 / 3.0, training=False)
  1413. >>> assert out.numpy().all()
  1414. >>> out.numpy()
  1415. array([1., 1., 1., ..., 1., 1., 1.], dtype=float32)
  1416. """
  1417. assert 0 <= drop_prob < 1
  1418. if not training or drop_prob == 0:
  1419. return inp
  1420. # model in training mode, e.g. model.train()
  1421. op = Dropout(drop_prob=drop_prob, seed=_get_global_rng_seed(), handle=0)
  1422. outputs = apply(op, inp)
  1423. return outputs[0]
  1424. def one_hot(inp: Tensor, num_classes: int) -> Tensor:
  1425. r"""Performs one-hot encoding for the input tensor.
  1426. Args:
  1427. inp: input tensor.
  1428. num_classes: number of classes denotes the last dimension of the output tensor.
  1429. Examples:
  1430. >>> import numpy as np
  1431. >>> x = Tensor(np.arange(1, 4, dtype=np.int32))
  1432. >>> F.one_hot(x, num_classes=4)
  1433. Tensor([[0 1 0 0]
  1434. [0 0 1 0]
  1435. [0 0 0 1]], dtype=int32, device=xpux:0)
  1436. """
  1437. zeros_tensor = zeros(
  1438. list(inp.shape) + [num_classes], dtype=inp.dtype, device=inp.device
  1439. )
  1440. ones_tensor = ones(list(inp.shape) + [1], dtype=inp.dtype, device=inp.device)
  1441. op = builtin.IndexingSetOneHot(axis=inp.ndim, ndim=inp.ndim)
  1442. (result,) = apply(op, zeros_tensor, inp, ones_tensor)
  1443. return result
  1444. def embedding(
  1445. inp: Tensor,
  1446. weight: Tensor,
  1447. padding_idx: Optional[int] = None,
  1448. max_norm: Optional[float] = None,
  1449. norm_type: Optional[float] = None,
  1450. ):
  1451. r"""Applies lookup table for embedding.
  1452. Args:
  1453. inp: tensor with indices.
  1454. weight: learnable weights which embeds from.
  1455. padding_idx: should be set to None, not supported now.
  1456. max_norm: should be set to None, not supported now.
  1457. norm_type: should be set to None, not supported now.
  1458. Refer to :class:`~.module.Embedding` for more information.
  1459. """
  1460. if padding_idx is not None:
  1461. raise ValueError("Not support padding_idx Now!")
  1462. if max_norm is not None or norm_type is not None:
  1463. raise ValueError("Not support weight normlization Now!")
  1464. dest_shp = list(inp.shape) + [weight.shape[-1]]
  1465. return weight[inp.reshape(-1)].reshape(dest_shp)
  1466. def indexing_one_hot(
  1467. src: Tensor, index: Tensor, axis: int = 1, keepdims=False
  1468. ) -> Tensor:
  1469. r"""One-hot indexing for some axes.
  1470. Args:
  1471. src: input tensor.
  1472. index: index tensor.
  1473. axis: axis on src for which values in index index. Default: 1
  1474. keepdims: whether not to remove the axis in result. Default: False
  1475. Examples:
  1476. >>> src = Tensor([[1.0, 2.0]])
  1477. >>> index = Tensor([0])
  1478. >>> val = F.indexing_one_hot(src, index)
  1479. >>> val.numpy()
  1480. array([1.], dtype=float32)
  1481. """
  1482. assert isinstance(src, Tensor), "src must be of Tensor type"
  1483. op = builtin.IndexingOneHot(axis=axis, ndim=src.ndim)
  1484. index = convert_single_value(index, dtype="int32", device=src.device)
  1485. (result,) = apply(op, src, index)
  1486. if not keepdims:
  1487. result = squeeze(result, axis)
  1488. return result
  1489. def sliding_window(
  1490. inp: Tensor,
  1491. kernel_size: Union[int, Tuple[int, int]],
  1492. padding: Union[int, Tuple[int, int]] = 0,
  1493. stride: Union[int, Tuple[int, int]] = 1,
  1494. dilation: Union[int, Tuple[int, int]] = 1,
  1495. ) -> Tensor:
  1496. r"""Extracts sliding local blocks from a batched input tensor.
  1497. Refer to :class:`~.module.sliding_window.SlidingWindow` for more information.
  1498. Args:
  1499. inp: input tensor.
  1500. kernel_size: size of the window.
  1501. padding: implicit zero padding added on both sides of input. Default: 0
  1502. stride: stride of the window. Default: 1
  1503. dilation: dilation of the window. Default: 1
  1504. """
  1505. padding_h, padding_w = _expand_hw(padding)
  1506. stride_h, stride_w = _expand_hw(stride)
  1507. dilation_h, dilation_w = _expand_hw(dilation)
  1508. window_h, window_w = _expand_hw(kernel_size)
  1509. op = builtin.Images2Neibs(
  1510. pad_h=padding_h,
  1511. pad_w=padding_w,
  1512. stride_h=stride_h,
  1513. stride_w=stride_w,
  1514. dilate_h=dilation_h,
  1515. dilate_w=dilation_w,
  1516. window_h=window_h,
  1517. window_w=window_w,
  1518. )
  1519. (output,) = apply(op, inp)
  1520. return output
  1521. def sliding_window_transpose(
  1522. inp: Tensor,
  1523. output_size: Union[int, Tuple[int, int]],
  1524. kernel_size: Union[int, Tuple[int, int]],
  1525. padding: Union[int, Tuple[int, int]] = 0,
  1526. stride: Union[int, Tuple[int, int]] = 1,
  1527. dilation: Union[int, Tuple[int, int]] = 1,
  1528. ) -> Tensor:
  1529. r"""Sum over the sliding windows on the corresponding input location.
  1530. Refer to :class:`~.module.sliding_window.SlidingWindowTranspose` for more information.
  1531. Args:
  1532. inp: input tensor.
  1533. output_size: shape of output tensor.
  1534. kernel_size: size of the window.
  1535. padding: implicit zero padding added on both sides of input. Default: 0
  1536. stride: stride of the window. Default: 1
  1537. dilation: dilation of the window. Default: 1
  1538. """
  1539. output_h, output_w = _expand_hw(output_size)
  1540. padding_h, padding_w = _expand_hw(padding)
  1541. stride_h, stride_w = _expand_hw(stride)
  1542. dilation_h, dilation_w = _expand_hw(dilation)
  1543. window_h, window_w = _expand_hw(kernel_size)
  1544. expected_h = (
  1545. output_h + 2 * padding_h - dilation_h * (window_h - 1) - 1
  1546. ) // stride_h + 1
  1547. expected_w = (
  1548. output_w + 2 * padding_w - dilation_w * (window_w - 1) - 1
  1549. ) // stride_w + 1
  1550. assert inp.ndim == 6, "the input dimension of sliding_window_transpose should be 6"
  1551. assert (
  1552. inp.shape[2] == expected_h and inp.shape[3] == expected_w
  1553. ), "the input shape and output size do not match"
  1554. op = builtin.SlidingWindowTranspose(
  1555. out_h=output_h,
  1556. out_w=output_w,
  1557. pad_h=padding_h,
  1558. pad_w=padding_w,
  1559. stride_h=stride_h,
  1560. stride_w=stride_w,
  1561. dilate_h=dilation_h,
  1562. dilate_w=dilation_w,
  1563. window_h=window_h,
  1564. window_w=window_w,
  1565. )
  1566. (output,) = apply(op, inp)
  1567. return output
  1568. def pad(
  1569. src: Tensor,
  1570. pad_width: Tuple[Tuple[int, int], ...],
  1571. mode: str = "constant",
  1572. constant_value: float = 0.0,
  1573. ) -> Tensor:
  1574. r"""Pads the input tensor.
  1575. Args:
  1576. pad_width: A tuple. Each element in the tuple is the tuple of 2-elements,
  1577. the 2 elements represent the padding size on both sides of the current dimension, ``(front_offset, back_offset)``
  1578. mode: One of the following string values. Default: ``'constant'``
  1579. * ``'constant'``: Pads with a constant value.
  1580. * ``'reflect'``: Pads with the reflection of the tensor mirrored on the first and last values of the tensor along each axis.
  1581. * ``'replicate'``: Pads with the edge values of tensor.
  1582. constant_val: Fill value for ``'constant'`` padding. Default: 0
  1583. Examples:
  1584. >>> import numpy as np
  1585. >>> inp = Tensor([[1., 2., 3.],[4., 5., 6.]])
  1586. >>> inp
  1587. Tensor([[1. 2. 3.]
  1588. [4. 5. 6.]], device=xpux:0)
  1589. >>> F.nn.pad(inp, pad_width=((1, 1),), mode="constant")
  1590. Tensor([[0. 0. 0.]
  1591. [1. 2. 3.]
  1592. [4. 5. 6.]
  1593. [0. 0. 0.]], device=xpux:0)
  1594. >>> F.nn.pad(inp, pad_width=((1, 1),), mode="constant", constant_value=9)
  1595. Tensor([[9. 9. 9.]
  1596. [1. 2. 3.]
  1597. [4. 5. 6.]
  1598. [9. 9. 9.]], device=xpux:0)
  1599. >>> F.nn.pad(inp, pad_width=((1, 1), (1, 2)), mode="reflect")
  1600. Tensor([[5. 4. 5. 6. 5. 4.]
  1601. [2. 1. 2. 3. 2. 1.]
  1602. [5. 4. 5. 6. 5. 4.]
  1603. [2. 1. 2. 3. 2. 1.]], device=xpux:0)
  1604. >>> F.nn.pad(inp, pad_width=((1, 1), (1, 2)), mode="replicate")
  1605. Tensor([[1. 1. 2. 3. 3. 3.]
  1606. [1. 1. 2. 3. 3. 3.]
  1607. [4. 4. 5. 6. 6. 6.]
  1608. [4. 4. 5. 6. 6. 6.]], device=xpux:0)
  1609. """
  1610. p_offsets = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
  1611. assert mode.lower() in ["constant", "edge", "replicate", "reflect"]
  1612. if mode.lower() == "edge":
  1613. mode = "replicate"
  1614. for i in range(0, len(pad_width)):
  1615. p_offsets[i * 2] = pad_width[i][0]
  1616. p_offsets[i * 2 + 1] = pad_width[i][1]
  1617. op = builtin.Padding(
  1618. front_offset_dim0=p_offsets[0],
  1619. front_offset_dim1=p_offsets[2],
  1620. front_offset_dim2=p_offsets[4],
  1621. front_offset_dim3=p_offsets[6],
  1622. front_offset_dim4=p_offsets[8],
  1623. front_offset_dim5=p_offsets[10],
  1624. front_offset_dim6=p_offsets[12],
  1625. back_offset_dim0=p_offsets[1],
  1626. back_offset_dim1=p_offsets[3],
  1627. back_offset_dim2=p_offsets[5],
  1628. back_offset_dim3=p_offsets[7],
  1629. back_offset_dim4=p_offsets[9],
  1630. back_offset_dim5=p_offsets[11],
  1631. back_offset_dim6=p_offsets[13],
  1632. padding_val=constant_value,
  1633. padding_mode=mode.upper(),
  1634. )
  1635. (output,) = apply(op, src)
  1636. return output
  1637. def local_response_norm(
  1638. inp: Tensor,
  1639. kernel_size: int = 5,
  1640. k: float = 2.0,
  1641. alpha: float = 1e-4,
  1642. beta: float = 0.75,
  1643. ) -> Tensor:
  1644. r"""
  1645. Apply local response normalization to the input tensor.
  1646. Args:
  1647. kernel_size: the size of the kernel to apply LRN on.
  1648. k: hyperparameter k. The default vaule is 2.0.
  1649. alpha: hyperparameter alpha. The default value is 1e-4.
  1650. beta: hyperparameter beta. The default value is 0.75.
  1651. Example:
  1652. >>> import numpy as np
  1653. >>> inp = Tensor(np.arange(25, dtype=np.float32).reshape(1,1,5,5))
  1654. >>> GT = np.array([[[[ 0., 0.999925, 1.9994003, 2.9979765, 3.9952066],
  1655. ... [ 4.9906454, 5.983851, 6.974385, 7.961814, 8.945709 ],
  1656. ... [ 9.925651, 10.90122, 11.872011, 12.837625, 13.7976675],
  1657. ... [14.751757, 15.699524, 16.640602, 17.574642, 18.501305 ],
  1658. ... [19.420258, 20.331186, 21.233786, 22.127764, 23.012836 ]]]])
  1659. >>> out = F.local_response_norm(inp, kernel_size=3, k=1.0, alpha=1e-4, beta=0.75)
  1660. >>> np.testing.assert_allclose(GT, out.numpy(), rtol=1e-6, atol=1e-6)
  1661. """
  1662. op = builtin.LRN(n=kernel_size, k=k, alpha=alpha, beta=beta,)
  1663. (output,) = apply(op, inp)
  1664. return output
  1665. @lru_cache(maxsize=None)
  1666. def _get_layerPixelShuffle(device, dtype, dim_order):
  1667. @subgraph("LayerPixelShuffle", dtype, device, 3)
  1668. def layerPixelShuffle(inputs, f, c):
  1669. inp, shape_0, shape_1 = inputs
  1670. inp = f(Reshape(), inp, shape_0)
  1671. inp = f(Dimshuffle(dim_order), inp)
  1672. oup = f(Reshape(), inp, shape_1)
  1673. return (oup,), (True,)
  1674. return layerPixelShuffle
  1675. def _layerPixelShuffle_traceable(inp, upscale_factor):
  1676. assert upscale_factor > 0, "upscale_factor should larger than 0"
  1677. assert inp.ndim >= 3, "the input dimension of pixel_shuffle should be larger than 3"
  1678. assert (
  1679. inp.shape[-3] % (upscale_factor ** 2) == 0
  1680. ), "the -3 dimension should be divided by (upscale_factor ** 2)"
  1681. _device = inp.device
  1682. _dtype = inp.dtype
  1683. shape_ori = inp.shape
  1684. high_dim = shape_ori[:-3]
  1685. square = upscale_factor ** 2
  1686. n = 1
  1687. for item in high_dim:
  1688. n *= item
  1689. shape_0 = (
  1690. n,
  1691. int(shape_ori[-3] / square),
  1692. upscale_factor,
  1693. upscale_factor,
  1694. shape_ori[-2],
  1695. shape_ori[-1],
  1696. )
  1697. shape_1 = (
  1698. *high_dim,
  1699. int(shape_ori[-3] / square),
  1700. shape_ori[-2] * upscale_factor,
  1701. shape_ori[-1] * upscale_factor,
  1702. )
  1703. dim_order = (0, 1, 4, 2, 5, 3)
  1704. layerPixelShuffle = _get_layerPixelShuffle(_device, _dtype, dim_order)
  1705. shape_0 = convert_single_value(shape_0, device=inp.device)
  1706. shape_1 = convert_single_value(shape_1, device=inp.device)
  1707. outvar, *_ = apply(layerPixelShuffle(), inp, shape_0, shape_1)
  1708. return outvar
  1709. def pixel_shuffle(inp: Tensor, upscale_factor: int) -> Tensor:
  1710. """
  1711. Rearranges elements in a tensor of shape `(..., C * r^2, H, W)` to a tensor of
  1712. shape `(..., C, H * r, W * r)`, where `r` is an upscale factor, where `...` is
  1713. zero or more batch dimensions.
  1714. :param inp: input tensor.
  1715. :param upscale_factor: upscale factor of pixel_shuffle.
  1716. :return: output tensor.
  1717. """
  1718. return pixel_shuffle_cpp(inp, upscale_factor, _layerPixelShuffle_traceable)
  1719. def region_restricted_conv(
  1720. inp: Tensor,
  1721. weight: Tensor,
  1722. rin: Tensor,
  1723. rout: Tensor,
  1724. bias: Optional[Tensor] = None,
  1725. stride: Union[int, Tuple[int, int, int]] = 1,
  1726. padding: Union[int, Tuple[int, int, int]] = 0,
  1727. dilation: Union[int, Tuple[int, int, int]] = 1,
  1728. groups: int = 1,
  1729. conv_mode: str = "cross_correlation",
  1730. compute_mode="default",
  1731. ) -> Tensor:
  1732. r"""Region Restricted convolution operation.
  1733. Refer to :class:`~.RegionRestrictedConv` for more information.
  1734. Args:
  1735. inp: feature map of the convolution operation.
  1736. weight: convolution kernel.
  1737. rin: input mask
  1738. rout: output mask
  1739. bias: bias added to the result of convolution (if given).
  1740. stride: stride of the 2D region restricted convolution operation. Default: 1
  1741. padding: size of the paddings added to the input on both sides of its
  1742. spatial dimensions. Only zero-padding is supported. Default: 0
  1743. dilation: dilation of the 2D convolution operation. Default: 1
  1744. groups: number of groups into which the input and output channels are divided,
  1745. so as to perform a ``grouped convolution``. When ``groups`` is not 1,
  1746. ``in_channels`` and ``out_channels`` must be divisible by ``groups``,
  1747. and the shape of weight should be ``(groups, out_channel // groups,
  1748. in_channels // groups, depth, height, width)``. Default: 1
  1749. conv_mode: supports "cross_correlation". Default: "cross_correlation"
  1750. Returns:
  1751. output tensor.
  1752. """
  1753. assert conv_mode.lower() == "cross_correlation"
  1754. pad_h, pad_w = _expand_hw(padding)
  1755. stride_h, stride_w = _expand_hw(stride)
  1756. dilate_h, dilate_w = _expand_hw(dilation)
  1757. sparse_type = "dense" if groups == 1 else "group"
  1758. op = builtin.RegionRestrictedConvolution(
  1759. stride_h=stride_h,
  1760. stride_w=stride_w,
  1761. pad_h=pad_h,
  1762. pad_w=pad_w,
  1763. dilate_h=dilate_h,
  1764. dilate_w=dilate_w,
  1765. mode=conv_mode,
  1766. compute_mode=compute_mode,
  1767. sparse=sparse_type,
  1768. )
  1769. (output,) = apply(op, inp, weight, rin, rout)
  1770. if bias is not None:
  1771. output += bias
  1772. return output
  1773. from .quantized import conv_bias_activation # isort:skip
  1774. from .loss import * # isort:skip
  1775. from .vision import * # isort:skip