You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

math.py 32 kB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204
  1. # -*- coding: utf-8 -*-
  2. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  3. #
  4. # Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  5. #
  6. # Unless required by applicable law or agreed to in writing,
  7. # software distributed under the License is distributed on an
  8. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  9. import collections
  10. import math
  11. from functools import lru_cache
  12. from typing import Iterable, Optional, Sequence, Tuple, Union
  13. from ..core import _config
  14. from ..core._imperative_rt.core2 import apply, dtype_promotion
  15. from ..core._imperative_rt.ops import SubgraphBuilder as _SubgraphBuilder
  16. from ..core._trace_option import use_symbolic_shape
  17. from ..core.ops import builtin
  18. from ..core.ops.builtin import BatchNorm, Elemwise, GetVarShape, Reduce, TypeCvt
  19. from ..core.ops.special import Const
  20. from ..core.tensor import amp
  21. from ..core.tensor.utils import _normalize_axis, cast_tensors, setscalar, subgraph
  22. from ..jit import exclude_from_trace
  23. from ..tensor import Tensor
  24. from ..utils.deprecation import deprecated_kwargs_default
  25. from .debug_param import get_execution_strategy
  26. from .elemwise import clip, minimum
  27. from .tensor import broadcast_to, concat, expand_dims, squeeze
  28. __all__ = [
  29. "argmax",
  30. "argmin",
  31. "argsort",
  32. "dot",
  33. "isinf",
  34. "isnan",
  35. "matinv",
  36. "matmul",
  37. "max",
  38. "mean",
  39. "min",
  40. "norm",
  41. "normalize",
  42. "prod",
  43. "sign",
  44. "sort",
  45. "std",
  46. "sum",
  47. "svd",
  48. "topk",
  49. "var",
  50. ]
  51. def isnan(inp: Tensor) -> Tensor:
  52. r"""Returns a new tensor representing if each element is ``NaN`` or not.
  53. Args:
  54. inp: input tensor.
  55. Returns:
  56. result tensor.
  57. Examples:
  58. .. testcode::
  59. from megengine import tensor
  60. import megengine.functional as F
  61. x = tensor([1, float("nan"), 0])
  62. print(F.isnan(x).numpy())
  63. Outputs:
  64. .. testoutput::
  65. [False True False]
  66. """
  67. return inp != inp
  68. def isinf(inp: Tensor) -> Tensor:
  69. r"""Returns a new tensor representing if each element is ``Inf`` or not.
  70. Args:
  71. inp: input tensor.
  72. Returns:
  73. result tensor.
  74. Examples:
  75. .. testcode::
  76. from megengine import tensor
  77. import megengine.functional as F
  78. x = tensor([1, float("inf"), 0])
  79. print(F.isinf(x).numpy())
  80. Outputs:
  81. .. testoutput::
  82. [False True False]
  83. """
  84. return abs(inp).astype("float32") == float("inf")
  85. def sign(inp: Tensor):
  86. r"""Returns a new tensor representing the sign of each element in input tensor.
  87. Args:
  88. inp: Tensor:
  89. Returns:
  90. the sign of input tensor.
  91. Examples:
  92. .. testcode::
  93. from megengine import tensor
  94. import megengine.functional as F
  95. x = tensor([1, -1, 0])
  96. print(F.sign(x).numpy())
  97. Outputs:
  98. .. testoutput::
  99. [ 1 -1 0]
  100. """
  101. return (inp > 0).astype(inp.dtype) - (inp < 0).astype(inp.dtype)
  102. def sum(
  103. inp: Tensor,
  104. axis: Optional[Union[int, Sequence[int]]] = None,
  105. keepdims: bool = False,
  106. ) -> Tensor:
  107. r"""Returns the sum of input tensor along given axis. If axis is a list of dimensions,
  108. reduce over all of them.
  109. Args:
  110. inp: input tensor.
  111. axis: dimension to reduce. If None, all dimensions will be reduced.
  112. Default: None
  113. keepdims: whether the output tensor has axis retained or not.
  114. Default: False
  115. Returns:
  116. output tensor.
  117. Examples:
  118. .. testcode::
  119. import numpy as np
  120. from megengine import tensor
  121. import megengine.functional as F
  122. x = tensor(np.arange(1, 7, dtype=np.int32).reshape(2, 3))
  123. out = F.sum(x)
  124. print(out.numpy())
  125. Outputs:
  126. .. testoutput::
  127. 21
  128. """
  129. return inp.sum(axis=axis, keepdims=keepdims)
  130. def prod(
  131. inp: Tensor, axis: Optional[Union[int, Sequence[int]]] = None, keepdims=False
  132. ) -> Tensor:
  133. r"""Returns the product of input tensor along given axis. If axis is a list of dimensions,
  134. reduce over all of them.
  135. Args:
  136. inp: input tensor.
  137. axis: dimension to reduce. If None, all dimensions will be reduced. Default: None
  138. keepdims: whether the output tensor has axis retained or not. Default: False
  139. Returns:
  140. output tensor.
  141. Examples:
  142. .. testcode::
  143. import numpy as np
  144. from megengine import tensor
  145. import megengine.functional as F
  146. x = tensor(np.arange(1, 7, dtype=np.int32).reshape(2, 3))
  147. out = F.prod(x)
  148. print(out.numpy())
  149. Outputs:
  150. .. testoutput::
  151. 720
  152. """
  153. return inp.prod(axis=axis, keepdims=keepdims)
  154. def mean(
  155. inp: Tensor,
  156. axis: Optional[Union[int, Sequence[int]]] = None,
  157. keepdims: bool = False,
  158. ) -> Tensor:
  159. r"""Returns the mean value of input tensor along
  160. given axis. If axis is a list of dimensions,
  161. reduce over all of them.
  162. Args:
  163. inp: input tensor.
  164. axis: dimension to reduce. If None, all dimensions will be reduced. Default: None
  165. keepdims: whether the output tensor has axis retained or not. Default: False
  166. Returns:
  167. output tensor.
  168. Examples:
  169. .. testcode::
  170. import numpy as np
  171. from megengine import tensor
  172. import megengine.functional as F
  173. x = tensor(np.arange(1, 7, dtype=np.int32).reshape(2, 3))
  174. out = F.mean(x)
  175. print(out.numpy())
  176. Outputs:
  177. .. testoutput::
  178. 3.5
  179. """
  180. return inp.mean(axis=axis, keepdims=keepdims)
  181. def var(
  182. inp: Tensor,
  183. axis: Optional[Union[int, Sequence[int]]] = None,
  184. keepdims: bool = False,
  185. ) -> Tensor:
  186. r"""Returns the variance value of input tensor along
  187. given axis. If axis is a list of dimensions,
  188. reduce over all of them.
  189. Args:
  190. inp: input tensor.
  191. axis: dimension to reduce. If None, all dimensions will be reduced. Default: None
  192. keepdims: whether the output tensor has axis retained or not. Default: False
  193. Returns:
  194. output tensor.
  195. Examples:
  196. .. testcode::
  197. import numpy as np
  198. from megengine import tensor
  199. import megengine.functional as F
  200. data = tensor(np.arange(1, 7, dtype=np.float32).reshape(2, 3))
  201. out = F.var(data)
  202. print(out.numpy().round(decimals=4))
  203. Outputs:
  204. .. testoutput::
  205. 2.9167
  206. """
  207. if axis is None:
  208. m = mean(inp, axis=axis, keepdims=False)
  209. else:
  210. m = mean(inp, axis=axis, keepdims=True)
  211. v = inp - m
  212. return mean(v ** 2, axis=axis, keepdims=keepdims)
  213. def std(
  214. inp: Tensor,
  215. axis: Optional[Union[int, Sequence[int]]] = None,
  216. keepdims: bool = False,
  217. ) -> Tensor:
  218. r"""Returns the standard deviation of input tensor along
  219. given axis. If axis is a list of dimensions,
  220. reduce over all of them.
  221. Args:
  222. inp: input tensor.
  223. axis: dimension to reduce. If None, all dimensions will be reduced. Default: None
  224. keepdims: whether the output tensor has axis retained or not. Default: False
  225. Returns:
  226. output tensor.
  227. Examples:
  228. .. testcode::
  229. import numpy as np
  230. from megengine import tensor
  231. import megengine.functional as F
  232. data = tensor(np.arange(1, 7, dtype=np.float32).reshape(2, 3))
  233. out = F.std(data, axis=1)
  234. print(out.numpy().round(decimals=4))
  235. Outputs:
  236. .. testoutput::
  237. [0.8165 0.8165]
  238. """
  239. return var(inp, axis=axis, keepdims=keepdims) ** 0.5
  240. def min(
  241. inp: Tensor,
  242. axis: Optional[Union[int, Sequence[int]]] = None,
  243. keepdims: bool = False,
  244. ) -> Tensor:
  245. r"""Returns the min value of input tensor along
  246. given axis. If axis is a list of dimensions,
  247. reduce over all of them.
  248. Args:
  249. inp: input tensor.
  250. axis: dimension to reduce. If None, all dimensions will be reduced. Default: None
  251. keepdims: whether the output tensor has axis retained or not. Default: False
  252. Returns:
  253. output tensor.
  254. Examples:
  255. .. testcode::
  256. import numpy as np
  257. from megengine import tensor
  258. import megengine.functional as F
  259. x = tensor(np.arange(1, 7, dtype=np.int32).reshape(2,3))
  260. out = F.min(x)
  261. print(out.numpy())
  262. Outputs:
  263. .. testoutput::
  264. 1
  265. """
  266. return inp.min(axis=axis, keepdims=keepdims)
  267. def max(
  268. inp: Tensor,
  269. axis: Optional[Union[int, Sequence[int]]] = None,
  270. keepdims: bool = False,
  271. ) -> Tensor:
  272. r"""Returns the max value of the input tensor along
  273. given axis. If axis is a list of dimensions,
  274. reduce over all of them.
  275. Args:
  276. inp: input tensor.
  277. axis: dimension to reduce. If None, all dimensions will be reduced. Default: None
  278. keepdims: whether the output tensor has axis retained or not. Default: False
  279. Returns:
  280. output tensor.
  281. Examples:
  282. .. testcode::
  283. import numpy as np
  284. from megengine import tensor
  285. import megengine.functional as F
  286. x = tensor(np.arange(1, 7, dtype=np.int32).reshape(2,3))
  287. out = F.max(x)
  288. print(out.numpy())
  289. Outputs:
  290. .. testoutput::
  291. 6
  292. """
  293. return inp.max(axis=axis, keepdims=keepdims)
  294. def norm(
  295. inp: Tensor, ord: float = None, axis: int = None, keepdims=False,
  296. ):
  297. r"""Calculates ``p``-norm of input tensor along
  298. given axis.
  299. Args:
  300. inp: input tensor.
  301. ord: power of value applied to inp. Default: 2
  302. axis: dimension to reduce. If None, input must be a vector. Default: None
  303. keepdims: whether the output tensor has axis retained or not. Default: False
  304. Returns:
  305. output tensor.
  306. Examples:
  307. .. testcode::
  308. import numpy as np
  309. from megengine import tensor
  310. import megengine.functional as F
  311. x = tensor(np.arange(-3, 3, dtype=np.float32))
  312. out = F.norm(x)
  313. print(out.numpy().round(decimals=4))
  314. Outputs:
  315. .. testoutput::
  316. 4.3589
  317. """
  318. if axis is None:
  319. if inp.ndim != 1:
  320. raise TypeError("axis is required unless input is a vector")
  321. if ord is None:
  322. ord = 2
  323. if ord == 0:
  324. return sum(inp != 0, axis=axis, keepdims=keepdims)
  325. if ord == math.inf:
  326. return max(abs(inp))
  327. if ord == -math.inf:
  328. return min(abs(inp))
  329. return sum(abs(inp) ** ord, axis=axis, keepdims=keepdims) ** (1.0 / ord)
  330. def argmin(
  331. inp: Tensor,
  332. axis: Optional[Union[int, Sequence[int]]] = None,
  333. keepdims: bool = False,
  334. ) -> Tensor:
  335. r"""Returns the indices of the minimum values along
  336. given axis. If axis is a list of dimensions,
  337. reduce over all of them.
  338. Args:
  339. inp: input tensor.
  340. axis: dimension to reduce. If None, all dimensions will be reduced. Default: None
  341. keepdims: whether the output tensor has axis retained or not. Default: False
  342. Returns:
  343. output tensor.
  344. Examples:
  345. .. testcode::
  346. import numpy as np
  347. from megengine import tensor
  348. import megengine.functional as F
  349. x = tensor(np.arange(1, 7, dtype=np.int32).reshape(2,3))
  350. out = F.argmin(x)
  351. print(out.numpy())
  352. Outputs:
  353. .. testoutput::
  354. 0
  355. """
  356. if axis is None:
  357. assert not keepdims, "can not set axis=None and keepdims=True"
  358. inp = inp.flatten()
  359. axis = 0
  360. axis = _normalize_axis(inp.ndim, axis, reverse=True)
  361. if isinstance(axis, collections.abc.Iterable):
  362. for ai in axis:
  363. op = builtin.Argmin(axis=ai)
  364. (inp,) = apply(op, inp)
  365. if not keepdims:
  366. inp = squeeze(inp, ai)
  367. return inp
  368. op = builtin.Argmin(axis=axis)
  369. (result,) = apply(op, inp)
  370. if not keepdims:
  371. result = squeeze(result, axis)
  372. return result
  373. def argmax(
  374. inp: Tensor,
  375. axis: Optional[Union[int, Sequence[int]]] = None,
  376. keepdims: bool = False,
  377. ) -> Tensor:
  378. r"""Returns the indices of the maximum values along
  379. given axis. If axis is a list of dimensions,
  380. reduce over all of them.
  381. Args:
  382. inp: input tensor.
  383. axis: dimension to reduce. If None, all dimensions will be reduced. Default: None
  384. keepdims: whether the output tensor has axis retained or not. Default: False
  385. Returns:
  386. output tensor.
  387. Examples:
  388. .. testcode::
  389. import numpy as np
  390. from megengine import tensor
  391. import megengine.functional as F
  392. x = tensor(np.arange(1, 7, dtype=np.int32).reshape(2,3))
  393. out = F.argmax(x)
  394. print(out.numpy())
  395. Outputs:
  396. .. testoutput::
  397. 5
  398. """
  399. if axis is None:
  400. assert not keepdims, "can not set axis=None and keepdims=True"
  401. inp = inp.flatten()
  402. axis = 0
  403. axis = _normalize_axis(inp.ndim, axis, reverse=True)
  404. if isinstance(axis, collections.abc.Iterable):
  405. for ai in axis:
  406. op = builtin.Argmax(axis=ai)
  407. (inp,) = apply(op, inp)
  408. if not keepdims:
  409. inp = squeeze(inp, ai)
  410. return inp
  411. op = builtin.Argmax(axis=axis)
  412. (result,) = apply(op, inp)
  413. if not keepdims:
  414. result = squeeze(result, axis)
  415. return result
  416. def normalize(
  417. inp: Tensor, ord: float = None, axis: int = None, eps: float = 1e-12,
  418. ) -> Tensor:
  419. r"""Performs :math:`L_p` normalization of input tensor along
  420. given axis.
  421. For a tensor of shape :math:`(n_0, ..., n_{dim}, ..., n_k)`, each
  422. :math:`n_{dim}` -element vector :math:`v` along dimension :attr:`axis` is transformed as:
  423. .. math::
  424. v = \frac{v}{\max(\lVert v \rVert_p, \epsilon)}.
  425. Args:
  426. inp: input tensor.
  427. ord: power of value applied to input tensor. Default: 2
  428. axis: dimension to reduce.If None, input must be a vector. Default: None
  429. eps: a small value to avoid division by zero. Default: 1e-12
  430. Returns:
  431. normalized output tensor.
  432. """
  433. if axis is None:
  434. return inp / clip(norm(inp, ord, axis), lower=eps)
  435. else:
  436. return inp / clip(norm(inp, ord, axis, keepdims=True), lower=eps)
  437. def argsort(inp: Tensor, descending: bool = False) -> Tensor:
  438. r"""Returns the indices that would sort the input tensor.
  439. Args:
  440. inp: input tensor. If it's 2d, the result would be array of indices show how to sort each row in the input tensor.
  441. descending: sort in descending order, where the largest comes first. Default: False
  442. inp: Tensor:
  443. descending: bool:
  444. Returns:
  445. indices of int32 indicates how to sort the input.
  446. Examples:
  447. .. testcode::
  448. import numpy as np
  449. from megengine import tensor
  450. import megengine.functional as F
  451. x = tensor(np.array([1,2], dtype=np.float32))
  452. indices = F.argsort(x)
  453. print(indices.numpy())
  454. Outputs:
  455. .. testoutput::
  456. [0 1]
  457. """
  458. assert len(inp.shape) <= 2, "Input should be 1d or 2d"
  459. if descending:
  460. order = "descending"
  461. else:
  462. order = "ascending"
  463. op = builtin.Argsort(order=order)
  464. if len(inp.shape) == 1:
  465. inp = inp.reshape(1, -1)
  466. _, result = apply(op, inp)
  467. return result[0]
  468. _, result = apply(op, inp)
  469. return result
  470. def sort(inp: Tensor, descending: bool = False) -> Tuple[Tensor, Tensor]:
  471. r"""Returns sorted tensor and the indices would sort the input tensor.
  472. Args:
  473. inp: input tensor. If it's 2d, the result would be sorted by row.
  474. descending: sort in descending order, where the largest comes first. Default: False
  475. Returns:
  476. tuple of two tensors `(sorted_tensor, indices_of_int32)`.
  477. Examples:
  478. .. testcode::
  479. import numpy as np
  480. from megengine import tensor
  481. import megengine.functional as F
  482. x = tensor(np.array([1,2], dtype=np.float32))
  483. out, indices = F.sort(x)
  484. print(out.numpy())
  485. Outputs:
  486. .. testoutput::
  487. [1. 2.]
  488. """
  489. assert len(inp.shape) <= 2, "Input should be 1d or 2d"
  490. if descending:
  491. order = "descending"
  492. else:
  493. order = "ascending"
  494. op = builtin.Argsort(order=order)
  495. if len(inp.shape) == 1:
  496. inp = inp.reshape(1, -1)
  497. tns, ind = apply(op, inp)
  498. return tns[0], ind[0]
  499. tns, ind = apply(op, inp)
  500. return tns, ind
  501. @deprecated_kwargs_default("1.12", "descending", 3)
  502. def topk(
  503. inp: Tensor,
  504. k: int,
  505. descending: bool = False,
  506. kth_only: bool = False,
  507. no_sort: bool = False,
  508. ) -> Tuple[Tensor, Tensor]:
  509. r"""Selects the ``Top-K`` (by default) smallest elements of 2d matrix by row.
  510. Args:
  511. inp: input tensor. If input tensor is 2d, each row will be sorted.
  512. k: number of elements needed.
  513. descending: if True, return the largest elements instead. Default: False
  514. kth_only: if True, only the k-th element will be returned. Default: False
  515. no_sort: if True, the returned elements can be unordered. Default: False
  516. Returns:
  517. tuple of two tensors ``(topk_tensor, indices_of_int32)``
  518. Examples:
  519. .. testcode::
  520. import numpy as np
  521. from megengine import tensor
  522. import megengine.functional as F
  523. x = tensor(np.array([2, 4, 6, 8, 7, 5, 3, 1], dtype=np.float32))
  524. top, indices = F.topk(x, 5, descending=False)
  525. print(top.numpy(), indices.numpy())
  526. Outputs:
  527. .. testoutput::
  528. [1. 2. 3. 4. 5.] [7 0 6 1 5]
  529. """
  530. if descending:
  531. k = -k
  532. if kth_only:
  533. mode = "kth_only"
  534. elif no_sort:
  535. mode = "value_idx_nosort"
  536. else:
  537. mode = "value_idx_sorted"
  538. op = builtin.TopK(mode=mode)
  539. if not isinstance(k, Tensor):
  540. (k,) = Const(k, dtype="int32", device=inp.device)()
  541. if len(inp.shape) == 1:
  542. if kth_only:
  543. (tns,) = apply(op, expand_dims(inp, 0), k)
  544. # FIXME:
  545. # could use a dedicated kernel
  546. # gradient may be routed to other indices if k-th value is not unique
  547. ind = argmax((tns == inp).astype("int8"))
  548. tns = squeeze(tns, 0)
  549. else:
  550. tns, ind = apply(op, expand_dims(inp, 0), k)
  551. tns = squeeze(tns, 0)
  552. ind = squeeze(ind, 0)
  553. else:
  554. if kth_only:
  555. (tns,) = apply(op, inp, k)
  556. # FIXME: same as above
  557. ind = argmax((expand_dims(tns, 1) == inp).astype("int8"), 1)
  558. else:
  559. tns, ind = apply(op, inp, k)
  560. return tns, ind
  561. def matinv(inp: Tensor) -> Tensor:
  562. r"""Computes the inverse of a batch of matrices; input must has shape [..., n, n].
  563. Args:
  564. inp: input tensor.
  565. Returns:
  566. output tensor.
  567. Examples:
  568. .. testcode::
  569. import numpy as np
  570. from megengine import tensor
  571. import megengine.functional as F
  572. data = tensor([[1.0, 0.0], [1.0, 1.0]])
  573. out = F.matinv(data)
  574. print(out.numpy())
  575. Outputs:
  576. .. testoutput::
  577. [[ 1. 0.]
  578. [-1. 1.]]
  579. """
  580. (result,) = apply(builtin.MatrixInverse(), inp)
  581. return result
  582. class _Hashable:
  583. def __init__(self, value) -> None:
  584. self.value = value
  585. def __hash__(self) -> int:
  586. return hash(str(self.value))
  587. def __eq__(self, o: object) -> bool:
  588. if not isinstance(o, _Hashable):
  589. return False
  590. return self.value == o.value
  591. @lru_cache(maxsize=None)
  592. def _get_extentedMatrixMulOp(
  593. device, dtype, dim1, dim2, transpose_a, transpose_b, compute_mode, format, strategy,
  594. ):
  595. @subgraph("extentedMatrixMulOp", dtype, device, 2, gopt_level=2)
  596. def extentedMatrixMulOp(inputs, f, c):
  597. assert len(inputs) == 2
  598. inp1, inp2 = inputs
  599. _dim1, _dim2 = dim1, dim2
  600. def build_shape_head(shape, idx=-1):
  601. # shape[:idx]
  602. return f(
  603. builtin.Subtensor(items=[[0, False, True, False, False]]),
  604. shape,
  605. c(idx, "int32"),
  606. )
  607. def build_shape_tail(shape, idx=-1):
  608. # shape[idx:]
  609. return f(
  610. builtin.Subtensor(items=[[0, True, False, False, False]]),
  611. shape,
  612. c(idx, "int32"),
  613. )
  614. remove_row, remove_col = False, False
  615. if _dim1 == 1:
  616. _dim1 = 2
  617. remove_row = True
  618. if _dim2 == 1:
  619. _dim2 = 2
  620. remove_col = True
  621. if remove_row:
  622. inp1 = f(builtin.AddAxis(axis=[0,]), inp1)
  623. if remove_col:
  624. inp2 = f(builtin.AddAxis(axis=[1,]), inp2)
  625. shape1 = f(GetVarShape(), inp1)
  626. shape2 = f(GetVarShape(), inp2)
  627. if _dim1 > 2:
  628. inp1 = f(
  629. builtin.Reshape(),
  630. inp1,
  631. f(
  632. builtin.Concat(axis=0, comp_node=device),
  633. f(builtin.Reduce(mode="product", axis=0), build_shape_head(shape1)),
  634. build_shape_tail(shape1),
  635. ),
  636. )
  637. if _dim2 > 2:
  638. inp2 = f(
  639. builtin.Reshape(),
  640. inp2,
  641. f(
  642. builtin.Concat(axis=0, comp_node=device),
  643. f(builtin.Reduce(mode="product", axis=0), build_shape_head(shape2)),
  644. build_shape_tail(shape2),
  645. ),
  646. )
  647. op = builtin.MatrixMul(
  648. transposeA=transpose_a,
  649. transposeB=transpose_b,
  650. compute_mode=compute_mode,
  651. format=format,
  652. strategy=strategy.value,
  653. )
  654. result = f(op, inp1, inp2)
  655. result_shape = f(GetVarShape(), result)
  656. if _dim1 > 2:
  657. result = f(
  658. builtin.Reshape(),
  659. result,
  660. f(
  661. builtin.Concat(axis=0, comp_node=device),
  662. build_shape_head(shape1),
  663. build_shape_tail(result_shape),
  664. ),
  665. )
  666. if _dim2 > 2:
  667. result = f(
  668. builtin.Reshape(),
  669. result,
  670. f(
  671. builtin.Concat(axis=0, comp_node=device),
  672. build_shape_head(shape2),
  673. build_shape_tail(result_shape),
  674. ),
  675. )
  676. maxdim = _dim1 if _dim1 > _dim2 else _dim2
  677. if remove_row:
  678. result = f(builtin.RemoveAxis(axis=[maxdim - 2]), result)
  679. if remove_col:
  680. result = f(builtin.RemoveAxis(axis=[maxdim - 1]), result)
  681. return (result,), (True,)
  682. return extentedMatrixMulOp
  683. @lru_cache(maxsize=None)
  684. def _get_extentedBatchedMatrixMulOp(
  685. device, dtype, dim1, dim2, transpose_a, transpose_b, compute_mode, format, strategy,
  686. ):
  687. @subgraph("extentedBatchedMatrixMulOp", dtype, device, 2, gopt_level=2)
  688. def extentedBatchedMatrixMulOp(inputs, f, c):
  689. assert len(inputs) == 2
  690. inp1, inp2 = inputs
  691. _dim1, _dim2 = dim1, dim2
  692. def build_shape_head(shape, idx=-2):
  693. # shape[:idx]
  694. return f(
  695. builtin.Subtensor(items=[[0, False, True, False, False]]),
  696. shape,
  697. c(idx, "int32"),
  698. )
  699. def build_shape_tail(shape, idx=-2):
  700. # shape[idx:]
  701. return f(
  702. builtin.Subtensor(items=[[0, True, False, False, False]]),
  703. shape,
  704. c(idx, "int32"),
  705. )
  706. remove_row, remove_col = False, False
  707. if _dim1 == 1:
  708. _dim1 = 2
  709. remove_row = True
  710. if _dim2 == 1:
  711. _dim2 = 2
  712. remove_col = True
  713. if remove_row:
  714. inp1 = f(builtin.AddAxis(axis=[0,]), inp1)
  715. if remove_col:
  716. inp2 = f(builtin.AddAxis(axis=[1,]), inp2)
  717. shape1 = f(GetVarShape(), inp1)
  718. shape2 = f(GetVarShape(), inp2)
  719. maxdim = _dim1 if _dim1 > _dim2 else _dim2
  720. if _dim1 > _dim2:
  721. # broadcast
  722. shape2 = f(
  723. builtin.Concat(axis=0, comp_node=device),
  724. build_shape_head(shape1, idx=-_dim2), # shape1[:-_dim2]
  725. shape2,
  726. )
  727. inp2 = f(builtin.Broadcast(), inp2, shape2)
  728. batch_shape = build_shape_head(shape1)
  729. if _dim2 > _dim1:
  730. # broadcast
  731. shape1 = f(
  732. builtin.Concat(axis=0, comp_node=device),
  733. build_shape_head(shape2, idx=-_dim1), # shape2[:-_dim1]
  734. shape1,
  735. )
  736. inp1 = f(builtin.Broadcast(), inp1, shape1)
  737. batch_shape = build_shape_head(shape2)
  738. if _dim1 == _dim2:
  739. batch_shape = build_shape_head(shape1)
  740. # compress inputs to 3d
  741. if maxdim > 3:
  742. inp1 = f(
  743. builtin.Reshape(),
  744. inp1,
  745. f(
  746. builtin.Concat(axis=0, comp_node=device),
  747. f(builtin.Reduce(mode="product", axis=0), batch_shape),
  748. build_shape_tail(shape1),
  749. ),
  750. )
  751. inp2 = f(
  752. builtin.Reshape(),
  753. inp2,
  754. f(
  755. builtin.Concat(axis=0, comp_node=device),
  756. f(builtin.Reduce(mode="product", axis=0), batch_shape),
  757. build_shape_tail(shape2),
  758. ),
  759. )
  760. op = builtin.BatchedMatrixMul(
  761. transposeA=transpose_a,
  762. transposeB=transpose_b,
  763. compute_mode=compute_mode,
  764. format=format,
  765. strategy=strategy.value,
  766. )
  767. result = f(op, inp1, inp2)
  768. if maxdim > 3:
  769. result = f(
  770. builtin.Reshape(),
  771. result,
  772. f(
  773. builtin.Concat(axis=0, comp_node=device),
  774. batch_shape,
  775. build_shape_tail(f(GetVarShape(), result)),
  776. ),
  777. )
  778. if remove_row:
  779. result = f(builtin.RemoveAxis(axis=[maxdim - 2]), result)
  780. if remove_col:
  781. result = f(builtin.RemoveAxis(axis=[maxdim - 1]), result)
  782. return (result,), (True,)
  783. return extentedBatchedMatrixMulOp
  784. def matmul(
  785. inp1: Tensor,
  786. inp2: Tensor,
  787. transpose_a=False,
  788. transpose_b=False,
  789. compute_mode="default",
  790. format="default",
  791. ) -> Tensor:
  792. r"""Performs a matrix multiplication of the matrices ``inp1`` and ``inp2``.
  793. With different inputs dim, this function behaves differently:
  794. * Both 1-D tensor, simply forward to ``dot``.
  795. * Both 2-D tensor, normal matrix multiplication.
  796. * If one input tensor is 1-D, matrix vector multiplication.
  797. * If at least one tensor are 3-dimensional or >3-dimensional, the other tensor should have dim >= 2,
  798. the batched matrix-matrix is returned, and the tensor with smaller dimension will be broadcasted.
  799. For example:
  800. * inp1: `(n, k, m)`, inp2: `(n, m, p)`, return: `(n, k, p)`
  801. * inp1: `(n, k, m)`, inp2: `(m, p)`, return: `(n, k, p)`
  802. * inp1: `(n, j, k, m)`, inp2: `(n, j, m, p)`, return: `(n, j, k, p)`
  803. Args:
  804. inp1: first matrix to be multiplied.
  805. inp2: second matrix to be multiplied.
  806. Returns:
  807. output tensor.
  808. Examples:
  809. .. testcode::
  810. import numpy as np
  811. from megengine import tensor
  812. import megengine.functional as F
  813. data1 = tensor(np.arange(0, 6, dtype=np.float32).reshape(2, 3))
  814. data2 = tensor(np.arange(0, 6, dtype=np.float32).reshape(3, 2))
  815. out = F.matmul(data1, data2)
  816. print(out.numpy())
  817. Outputs:
  818. .. testoutput::
  819. [[10. 13.]
  820. [28. 40.]]
  821. """
  822. if amp._enabled:
  823. compute_mode = "float32"
  824. inp1, inp2 = cast_tensors(inp1, inp2)
  825. else:
  826. dtype = dtype_promotion(inp1, inp2)
  827. if inp1.dtype != dtype:
  828. inp1 = inp1.astype(dtype)
  829. if inp2.dtype != dtype:
  830. inp2 = inp2.astype(dtype)
  831. dim1, dim2 = inp1.ndim, inp2.ndim
  832. assert dim1 > 0 and dim2 > 0
  833. maxdim = dim1 if dim1 > dim2 else dim2
  834. compute_mode = _config._get_actual_op_param(compute_mode, _config.__compute_mode)
  835. if dim1 == 1 and dim2 == 1: # dispatch to Dot
  836. return dot(inp1, inp2)
  837. elif maxdim <= 2 or dim2 <= 2: # dispath to MatrixMul
  838. extentedMatrixMulOp = _get_extentedMatrixMulOp(
  839. inp1.device,
  840. inp1.dtype,
  841. dim1,
  842. dim2,
  843. transpose_a,
  844. transpose_b,
  845. compute_mode,
  846. format,
  847. strategy=_Hashable(get_execution_strategy()),
  848. )
  849. (result,) = apply(extentedMatrixMulOp(), inp1, inp2)
  850. return result
  851. else: # dispath to BatchedMatrixMul
  852. extentedBatchedMatrixMulOp = _get_extentedBatchedMatrixMulOp(
  853. inp1.device,
  854. inp1.dtype,
  855. dim1,
  856. dim2,
  857. transpose_a,
  858. transpose_b,
  859. compute_mode,
  860. format,
  861. strategy=_Hashable(get_execution_strategy()),
  862. )
  863. (result,) = apply(extentedBatchedMatrixMulOp(), inp1, inp2)
  864. return result
  865. def dot(inp1: Tensor, inp2: Tensor) -> Tensor:
  866. r"""Computes dot-product of two vectors ``inp1`` and ``inp2``.
  867. inputs must be 1-dimensional or scalar. A scalar input is automatically broadcasted.
  868. Refer to :func:`~.matmul` for more general usage.
  869. Args:
  870. inp1: first vector.
  871. inp2: second vector.
  872. Returns:
  873. output value.
  874. Examples:
  875. .. testcode::
  876. import numpy as np
  877. from megengine import tensor
  878. import megengine.functional as F
  879. data1 = tensor(np.arange(0, 6, dtype=np.float32))
  880. data2 = tensor(np.arange(0, 6, dtype=np.float32))
  881. out = F.dot(data1, data2)
  882. print(out.numpy())
  883. Outputs:
  884. .. testoutput::
  885. 55.
  886. """
  887. op = builtin.Dot()
  888. assert (
  889. inp1.ndim <= 1 and inp2.ndim <= 1
  890. ), "Input tensors for dot must be 1-dimensional or scalar"
  891. (result,) = apply(op, inp1, inp2)
  892. setscalar(result)
  893. return result
  894. def svd(inp: Tensor, full_matrices=False, compute_uv=True) -> Tensor:
  895. r"""Computes the singular value decompositions of input matrix.
  896. Args:
  897. inp: input matrix, must has shape `[..., M, N]`.
  898. Returns:
  899. output matrices, `(U, sigma, V)`.
  900. Examples:
  901. .. testcode::
  902. import numpy as np
  903. from megengine import tensor
  904. import megengine.functional as F
  905. x = tensor(np.arange(0, 6, dtype=np.float32).reshape(2,3))
  906. _, y, _ = F.svd(x)
  907. print(y.numpy().round(decimals=3))
  908. Outputs:
  909. .. testoutput::
  910. [7.348 1. ]
  911. """
  912. op = builtin.SVD(full_matrices=full_matrices, compute_uv=compute_uv)
  913. U, sigma, V = apply(op, inp)
  914. return U, sigma, V
  915. def _check_non_finite(inps: Iterable[Tensor], scale=1.0) -> Tensor:
  916. r"""Check whether input contains infinite or nan value.
  917. Args:
  918. inp: a tensor to be checked.
  919. Returns:
  920. a int32 scalar tensor, 0 for False and 1 for True.
  921. """
  922. op = builtin.CheckNonFinite(scale=scale)
  923. oups = apply(op, *inps)
  924. out = oups[-1]
  925. for i in range(len(inps)):
  926. inps[i]._reset(oups[i])
  927. out._setscalar()
  928. return out