You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

math.py 25 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980
  1. # -*- coding: utf-8 -*-
  2. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  3. #
  4. # Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  5. #
  6. # Unless required by applicable law or agreed to in writing,
  7. # software distributed under the License is distributed on an
  8. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  9. import collections
  10. import math
  11. from typing import Optional, Sequence, Tuple, Union
  12. from ..core._imperative_rt.core2 import apply
  13. from ..core._trace_option import use_symbolic_shape
  14. from ..core.ops import builtin
  15. from ..core.ops.special import Const
  16. from ..core.tensor import amp
  17. from ..core.tensor.utils import _normalize_axis, cast_tensors, convert_inputs, setscalar
  18. from ..tensor import Tensor
  19. from .debug_param import get_execution_strategy
  20. from .elemwise import clip, exp, log, log1p
  21. from .tensor import broadcast_to, concat, expand_dims, reshape, squeeze
  22. __all__ = [
  23. "argmax",
  24. "argmin",
  25. "argsort",
  26. "dot",
  27. "isinf",
  28. "isnan",
  29. "matinv",
  30. "matmul",
  31. "max",
  32. "mean",
  33. "min",
  34. "norm",
  35. "normalize",
  36. "prod",
  37. "sign",
  38. "sort",
  39. "std",
  40. "sum",
  41. "svd",
  42. "topk",
  43. "var",
  44. ]
  45. def isnan(inp: Tensor) -> Tensor:
  46. r"""
  47. Returns a new tensor representing if each element is ``NaN`` or not.
  48. :param inp: input tensor.
  49. :return: result tensor.
  50. Examples:
  51. .. testcode::
  52. from megengine import tensor
  53. import megengine.functional as F
  54. x = tensor([1, float("nan"), 0])
  55. print(F.isnan(x).numpy())
  56. Outputs:
  57. .. testoutput::
  58. [False True False]
  59. """
  60. return inp != inp
  61. def isinf(inp: Tensor) -> Tensor:
  62. r"""
  63. Returns a new tensor representing if each element is ``Inf`` or not.
  64. :param inp: input tensor.
  65. :return: result tensor.
  66. Examples:
  67. .. testcode::
  68. from megengine import tensor
  69. import megengine.functional as F
  70. x = tensor([1, float("inf"), 0])
  71. print(F.isinf(x).numpy())
  72. Outputs:
  73. .. testoutput::
  74. [False True False]
  75. """
  76. return abs(inp).astype("float32") == float("inf")
  77. def sign(inp: Tensor):
  78. r"""
  79. Returns a new tensor representing the sign of each element in input tensor.
  80. :param: input tensor.
  81. :return: the sign of input tensor.
  82. Examples:
  83. .. testcode::
  84. from megengine import tensor
  85. import megengine.functional as F
  86. x = tensor([1, -1, 0])
  87. print(F.sign(x).numpy())
  88. Outputs:
  89. .. testoutput::
  90. [ 1 -1 0]
  91. """
  92. return (inp > 0).astype(inp.dtype) - (inp < 0).astype(inp.dtype)
  93. def sum(
  94. inp: Tensor,
  95. axis: Optional[Union[int, Sequence[int]]] = None,
  96. keepdims: bool = False,
  97. ) -> Tensor:
  98. r"""
  99. Returns the sum of input tensor along given axis. If axis is a list of dimensions,
  100. reduce over all of them.
  101. :param inp: input tensor.
  102. :param axis: dimension to reduce. If None, all dimensions will be reduced.
  103. Default: None
  104. :param keepdims: whether the output tensor has axis retained or not.
  105. Default: False
  106. :return: output tensor.
  107. Examples:
  108. .. testcode::
  109. import numpy as np
  110. from megengine import tensor
  111. import megengine.functional as F
  112. x = tensor(np.arange(1, 7, dtype=np.int32).reshape(2, 3))
  113. out = F.sum(x)
  114. print(out.numpy())
  115. Outputs:
  116. .. testoutput::
  117. 21
  118. """
  119. return inp.sum(axis=axis, keepdims=keepdims)
  120. def prod(
  121. inp: Tensor, axis: Optional[Union[int, Sequence[int]]] = None, keepdims=False
  122. ) -> Tensor:
  123. r"""
  124. Returns the product of input tensor along given axis. If axis is a list of dimensions,
  125. reduce over all of them.
  126. :param inp: input tensor.
  127. :param axis: dimension to reduce. If None, all dimensions will be reduced. Default: None
  128. :param keepdims: whether the output tensor has axis retained or not. Default: False
  129. :return: output tensor.
  130. Examples:
  131. .. testcode::
  132. import numpy as np
  133. from megengine import tensor
  134. import megengine.functional as F
  135. x = tensor(np.arange(1, 7, dtype=np.int32).reshape(2, 3))
  136. out = F.prod(x)
  137. print(out.numpy())
  138. Outputs:
  139. .. testoutput::
  140. 720
  141. """
  142. return inp.prod(axis=axis, keepdims=keepdims)
  143. def mean(
  144. inp: Tensor,
  145. axis: Optional[Union[int, Sequence[int]]] = None,
  146. keepdims: bool = False,
  147. ) -> Tensor:
  148. """
  149. Returns the mean value of input tensor along
  150. given axis. If axis is a list of dimensions,
  151. reduce over all of them.
  152. :param inp: input tensor.
  153. :param axis: dimension to reduce. If None, all dimensions will be reduced. Default: None
  154. :param keepdims: whether the output tensor has axis retained or not. Default: False
  155. :return: output tensor.
  156. Examples:
  157. .. testcode::
  158. import numpy as np
  159. from megengine import tensor
  160. import megengine.functional as F
  161. x = tensor(np.arange(1, 7, dtype=np.int32).reshape(2, 3))
  162. out = F.mean(x)
  163. print(out.numpy())
  164. Outputs:
  165. .. testoutput::
  166. 3.5
  167. """
  168. return inp.mean(axis=axis, keepdims=keepdims)
  169. def var(
  170. inp: Tensor,
  171. axis: Optional[Union[int, Sequence[int]]] = None,
  172. keepdims: bool = False,
  173. ) -> Tensor:
  174. """
  175. Returns the variance value of input tensor along
  176. given axis. If axis is a list of dimensions,
  177. reduce over all of them.
  178. :param inp: input tensor.
  179. :param axis: dimension to reduce. If None, all dimensions will be reduced. Default: None
  180. :param keepdims: whether the output tensor has axis retained or not. Default: False
  181. :return: output tensor.
  182. Examples:
  183. .. testcode::
  184. import numpy as np
  185. from megengine import tensor
  186. import megengine.functional as F
  187. data = tensor(np.arange(1, 7, dtype=np.float32).reshape(2, 3))
  188. out = F.var(data)
  189. print(out.numpy().round(decimals=4))
  190. Outputs:
  191. .. testoutput::
  192. 2.9167
  193. """
  194. if axis is None:
  195. m = mean(inp, axis=axis, keepdims=False)
  196. else:
  197. m = mean(inp, axis=axis, keepdims=True)
  198. v = inp - m
  199. return mean(v ** 2, axis=axis, keepdims=keepdims)
  200. def std(
  201. inp: Tensor,
  202. axis: Optional[Union[int, Sequence[int]]] = None,
  203. keepdims: bool = False,
  204. ) -> Tensor:
  205. """
  206. Returns the standard deviation of input tensor along
  207. given axis. If axis is a list of dimensions,
  208. reduce over all of them.
  209. :param inp: input tensor.
  210. :param axis: dimension to reduce. If None, all dimensions will be reduced. Default: None
  211. :param keepdims: whether the output tensor has axis retained or not. Default: False
  212. :return: output tensor.
  213. Examples:
  214. .. testcode::
  215. import numpy as np
  216. from megengine import tensor
  217. import megengine.functional as F
  218. data = tensor(np.arange(1, 7, dtype=np.float32).reshape(2, 3))
  219. out = F.std(data, axis=1)
  220. print(out.numpy().round(decimals=4))
  221. Outputs:
  222. .. testoutput::
  223. [0.8165 0.8165]
  224. """
  225. return var(inp, axis=axis, keepdims=keepdims) ** 0.5
  226. def min(
  227. inp: Tensor,
  228. axis: Optional[Union[int, Sequence[int]]] = None,
  229. keepdims: bool = False,
  230. ) -> Tensor:
  231. r"""
  232. Returns the min value of input tensor along
  233. given axis. If axis is a list of dimensions,
  234. reduce over all of them.
  235. :param inp: input tensor.
  236. :param axis: dimension to reduce. If None, all dimensions will be reduced. Default: None
  237. :param keepdims: whether the output tensor has axis retained or not. Default: False
  238. :return: output tensor.
  239. Examples:
  240. .. testcode::
  241. import numpy as np
  242. from megengine import tensor
  243. import megengine.functional as F
  244. x = tensor(np.arange(1, 7, dtype=np.int32).reshape(2,3))
  245. out = F.min(x)
  246. print(out.numpy())
  247. Outputs:
  248. .. testoutput::
  249. 1
  250. """
  251. return inp.min(axis=axis, keepdims=keepdims)
  252. def max(
  253. inp: Tensor,
  254. axis: Optional[Union[int, Sequence[int]]] = None,
  255. keepdims: bool = False,
  256. ) -> Tensor:
  257. r"""
  258. Returns the max value of the input tensor along
  259. given axis. If axis is a list of dimensions,
  260. reduce over all of them.
  261. :param inp: input tensor.
  262. :param axis: dimension to reduce. If None, all dimensions will be reduced. Default: None
  263. :param keepdims: whether the output tensor has axis retained or not. Default: False
  264. :return: output tensor.
  265. Examples:
  266. .. testcode::
  267. import numpy as np
  268. from megengine import tensor
  269. import megengine.functional as F
  270. x = tensor(np.arange(1, 7, dtype=np.int32).reshape(2,3))
  271. out = F.max(x)
  272. print(out.numpy())
  273. Outputs:
  274. .. testoutput::
  275. 6
  276. """
  277. return inp.max(axis=axis, keepdims=keepdims)
  278. def norm(
  279. inp: Tensor, ord: float = None, axis: int = None, keepdims=False,
  280. ):
  281. """
  282. Calculates ``p``-norm of input tensor along
  283. given axis.
  284. :param inp: input tensor.
  285. :param ord: power of value applied to inp. Default: 2
  286. :param axis: dimension to reduce. If None, input must be a vector. Default: None
  287. :param keepdims: whether the output tensor has axis retained or not. Default: False
  288. :return: output tensor.
  289. Examples:
  290. .. testcode::
  291. import numpy as np
  292. from megengine import tensor
  293. import megengine.functional as F
  294. x = tensor(np.arange(-3, 3, dtype=np.float32))
  295. out = F.norm(x)
  296. print(out.numpy().round(decimals=4))
  297. Outputs:
  298. .. testoutput::
  299. 4.3589
  300. """
  301. if axis is None:
  302. if inp.ndim != 1:
  303. raise TypeError("axis is required unless input is a vector")
  304. if ord is None:
  305. ord = 2
  306. if ord == 0:
  307. return sum(inp != 0, axis=axis, keepdims=keepdims)
  308. if ord == math.inf:
  309. return max(abs(inp))
  310. if ord == -math.inf:
  311. return min(abs(inp))
  312. return sum(abs(inp) ** ord, axis=axis, keepdims=keepdims) ** (1.0 / ord)
  313. def argmin(
  314. inp: Tensor,
  315. axis: Optional[Union[int, Sequence[int]]] = None,
  316. keepdims: bool = False,
  317. ) -> Tensor:
  318. r"""
  319. Returns the indices of the minimum values along
  320. given axis. If axis is a list of dimensions,
  321. reduce over all of them.
  322. :param inp: input tensor.
  323. :param axis: dimension to reduce. If None, all dimensions will be reduced. Default: None
  324. :param keepdims: whether the output tensor has axis retained or not. Default: False
  325. :return: output tensor.
  326. Examples:
  327. .. testcode::
  328. import numpy as np
  329. from megengine import tensor
  330. import megengine.functional as F
  331. x = tensor(np.arange(1, 7, dtype=np.int32).reshape(2,3))
  332. out = F.argmin(x)
  333. print(out.numpy())
  334. Outputs:
  335. .. testoutput::
  336. 0
  337. """
  338. if axis is None:
  339. assert not keepdims, "can not set axis=None and keepdims=True"
  340. inp = inp.flatten()
  341. axis = 0
  342. axis = _normalize_axis(inp.ndim, axis, reverse=True)
  343. if isinstance(axis, collections.abc.Iterable):
  344. for ai in axis:
  345. op = builtin.Argmin(axis=ai)
  346. (inp,) = apply(op, inp)
  347. if not keepdims:
  348. inp = squeeze(inp, ai)
  349. return inp
  350. op = builtin.Argmin(axis=axis)
  351. (result,) = apply(op, inp)
  352. if not keepdims:
  353. result = squeeze(result, axis)
  354. return result
  355. def argmax(
  356. inp: Tensor,
  357. axis: Optional[Union[int, Sequence[int]]] = None,
  358. keepdims: bool = False,
  359. ) -> Tensor:
  360. r"""
  361. Returns the indices of the maximum values along
  362. given axis. If axis is a list of dimensions,
  363. reduce over all of them.
  364. :param inp: input tensor.
  365. :param axis: dimension to reduce. If None, all dimensions will be reduced. Default: None
  366. :param keepdims: whether the output tensor has axis retained or not. Default: False
  367. :return: output tensor.
  368. Examples:
  369. .. testcode::
  370. import numpy as np
  371. from megengine import tensor
  372. import megengine.functional as F
  373. x = tensor(np.arange(1, 7, dtype=np.int32).reshape(2,3))
  374. out = F.argmax(x)
  375. print(out.numpy())
  376. Outputs:
  377. .. testoutput::
  378. 5
  379. """
  380. if axis is None:
  381. assert not keepdims, "can not set axis=None and keepdims=True"
  382. inp = inp.flatten()
  383. axis = 0
  384. axis = _normalize_axis(inp.ndim, axis, reverse=True)
  385. if isinstance(axis, collections.abc.Iterable):
  386. for ai in axis:
  387. op = builtin.Argmax(axis=ai)
  388. (inp,) = apply(op, inp)
  389. if not keepdims:
  390. inp = squeeze(inp, ai)
  391. return inp
  392. op = builtin.Argmax(axis=axis)
  393. (result,) = apply(op, inp)
  394. if not keepdims:
  395. result = squeeze(result, axis)
  396. return result
  397. def normalize(
  398. inp: Tensor, ord: float = None, axis: int = None, eps: float = 1e-12,
  399. ) -> Tensor:
  400. r"""
  401. Performs :math:`L_p` normalization of input tensor along
  402. given axis.
  403. For a tensor of shape :math:`(n_0, ..., n_{dim}, ..., n_k)`, each
  404. :math:`n_{dim}` -element vector :math:`v` along dimension :attr:`axis` is transformed as:
  405. .. math::
  406. v = \frac{v}{\max(\lVert v \rVert_p, \epsilon)}.
  407. :param inp: input tensor.
  408. :param ord: power of value applied to input tensor. Default: 2
  409. :param axis: dimension to reduce.If None, input must be a vector. Default: None
  410. :param eps: a small value to avoid division by zero. Default: 1e-12
  411. :return: normalized output tensor.
  412. """
  413. if axis is None:
  414. return inp / clip(norm(inp, ord, axis), lower=eps)
  415. else:
  416. return inp / clip(norm(inp, ord, axis, keepdims=True), lower=eps)
  417. def argsort(inp: Tensor, descending: bool = False) -> Tensor:
  418. r"""
  419. Returns the indices that would sort the input tensor.
  420. :param inp: input tensor. If it's 2d, the result would be array of indices show how to sort each row in the input tensor.
  421. :param descending: sort in descending order, where the largest comes first. Default: False
  422. :return: indices of int32 indicates how to sort the input.
  423. Examples:
  424. .. testcode::
  425. import numpy as np
  426. from megengine import tensor
  427. import megengine.functional as F
  428. x = tensor(np.array([1,2], dtype=np.float32))
  429. indices = F.argsort(x)
  430. print(indices.numpy())
  431. Outputs:
  432. .. testoutput::
  433. [0 1]
  434. """
  435. assert len(inp.shape) <= 2, "Input should be 1d or 2d"
  436. if descending:
  437. order = "descending"
  438. else:
  439. order = "ascending"
  440. op = builtin.Argsort(order=order)
  441. if len(inp.shape) == 1:
  442. inp = inp.reshape(1, -1)
  443. _, result = apply(op, inp)
  444. return result[0]
  445. _, result = apply(op, inp)
  446. return result
  447. def sort(inp: Tensor, descending: bool = False) -> Tuple[Tensor, Tensor]:
  448. r"""
  449. Returns sorted tensor and the indices would sort the input tensor.
  450. :param inp: input tensor. If it's 2d, the result would be sorted by row.
  451. :param descending: sort in descending order, where the largest comes first. Default: False
  452. :return: tuple of two tensors `(sorted_tensor, indices_of_int32)`.
  453. Examples:
  454. .. testcode::
  455. import numpy as np
  456. from megengine import tensor
  457. import megengine.functional as F
  458. x = tensor(np.array([1,2], dtype=np.float32))
  459. out, indices = F.sort(x)
  460. print(out.numpy())
  461. Outputs:
  462. .. testoutput::
  463. [1. 2.]
  464. """
  465. assert len(inp.shape) <= 2, "Input should be 1d or 2d"
  466. if descending:
  467. order = "descending"
  468. else:
  469. order = "ascending"
  470. op = builtin.Argsort(order=order)
  471. if len(inp.shape) == 1:
  472. inp = inp.reshape(1, -1)
  473. tns, ind = apply(op, inp)
  474. return tns[0], ind[0]
  475. tns, ind = apply(op, inp)
  476. return tns, ind
  477. def topk(
  478. inp: Tensor,
  479. k: int,
  480. descending: bool = False,
  481. kth_only: bool = False,
  482. no_sort: bool = False,
  483. ) -> Tuple[Tensor, Tensor]:
  484. r"""
  485. Selects the ``Top-K`` (by default) smallest elements of 2d matrix by row.
  486. :param inp: input tensor. If input tensor is 2d, each row will be sorted.
  487. :param k: number of elements needed.
  488. :param descending: if True, return the largest elements instead. Default: False
  489. :param kth_only: if True, only the k-th element will be returned. Default: False
  490. :param no_sort: if True, the returned elements can be unordered. Default: False
  491. :return: tuple of two tensors `(topk_tensor, indices_of_int32)`.
  492. Examples:
  493. .. testcode::
  494. import numpy as np
  495. from megengine import tensor
  496. import megengine.functional as F
  497. x = tensor(np.array([2, 4, 6, 8, 7, 5, 3, 1], dtype=np.float32))
  498. top, indices = F.topk(x, 5)
  499. print(top.numpy(), indices.numpy())
  500. Outputs:
  501. .. testoutput::
  502. [1. 2. 3. 4. 5.] [7 0 6 1 5]
  503. """
  504. if descending:
  505. inp = -inp
  506. if kth_only:
  507. mode = "kth_only"
  508. elif no_sort:
  509. mode = "value_idx_nosort"
  510. else:
  511. mode = "value_idx_sorted"
  512. op = builtin.TopK(mode=mode)
  513. if not isinstance(k, Tensor):
  514. (k,) = Const(k, dtype="int32", device=inp.device)()
  515. if len(inp.shape) == 1:
  516. inp = inp.reshape(1, -1)
  517. res = apply(op, inp, k)
  518. if kth_only:
  519. tns = res[0]
  520. else:
  521. tns, ind = res[0][0], res[1][0]
  522. else:
  523. res = apply(op, inp, k)
  524. if kth_only:
  525. tns = res
  526. else:
  527. tns, ind = res[0], res[1]
  528. if descending:
  529. tns = -tns
  530. return tns, ind
  531. def matinv(inp: Tensor) -> Tensor:
  532. """
  533. Computes the inverse of a batch of matrices; input must has shape [..., n, n].
  534. :param inp: input tensor.
  535. :return: output tensor.
  536. Examples:
  537. .. testcode::
  538. import numpy as np
  539. from megengine import tensor
  540. import megengine.functional as F
  541. data = tensor([[1.0, 0.0], [1.0, 1.0]])
  542. out = F.matinv(data)
  543. print(out.numpy())
  544. Outputs:
  545. .. testoutput::
  546. [[ 1. 0.]
  547. [-1. 1.]]
  548. """
  549. (result,) = apply(builtin.MatrixInverse(), inp)
  550. return result
  551. def matmul(
  552. inp1: Tensor,
  553. inp2: Tensor,
  554. transpose_a=False,
  555. transpose_b=False,
  556. compute_mode="default",
  557. format="default",
  558. ) -> Tensor:
  559. """
  560. Performs a matrix multiplication of the matrices ``inp1`` and ``inp2``.
  561. With different inputs dim, this function behaves differently:
  562. - Both 1-D tensor, simply forward to ``dot``.
  563. - Both 2-D tensor, normal matrix multiplication.
  564. - If one input tensor is 1-D, matrix vector multiplication.
  565. - If at least one tensor are 3-dimensional or >3-dimensional, the other tensor should have dim >= 2,
  566. the batched matrix-matrix is returned, and the tensor with smaller dimension will be broadcasted.
  567. For example:
  568. - inp1: `(n, k, m)`, inp2: `(n, m, p)`, return: `(n, k, p)`
  569. - inp1: `(n, k, m)`, inp2: `(m, p)`, return: `(n, k, p)`
  570. - inp1: `(n, j, k, m)`, inp2: `(n, j, m, p)`, return: `(n, j, k, p)`
  571. :param inp1: first matrix to be multiplied.
  572. :param inp2: second matrix to be multiplied.
  573. :return: output tensor.
  574. Examples:
  575. .. testcode::
  576. import numpy as np
  577. from megengine import tensor
  578. import megengine.functional as F
  579. data1 = tensor(np.arange(0, 6, dtype=np.float32).reshape(2, 3))
  580. data2 = tensor(np.arange(0, 6, dtype=np.float32).reshape(3, 2))
  581. out = F.matmul(data1, data2)
  582. print(out.numpy())
  583. Outputs:
  584. .. testoutput::
  585. [[10. 13.]
  586. [28. 40.]]
  587. """
  588. if amp._enabled:
  589. compute_mode = "float32"
  590. inp1, inp2 = cast_tensors(inp1, inp2)
  591. else:
  592. inp1, inp2 = convert_inputs(inp1, inp2)
  593. remove_row, remove_col = False, False
  594. dim1, dim2 = inp1.ndim, inp2.ndim
  595. # handle dim=1 cases, dot and matrix-vector multiplication
  596. if dim1 == 1 and dim2 == 1:
  597. return dot(inp1, inp2)
  598. # the underlying matmul op requires input dims to be at least 2
  599. if dim1 == 1:
  600. inp1 = expand_dims(inp1, 0)
  601. dim1 = 2
  602. remove_row = True
  603. if dim2 == 1:
  604. inp2 = expand_dims(inp2, 1)
  605. dim2 = 2
  606. remove_col = True
  607. batch_shape = None
  608. shape1 = inp1.shape
  609. shape2 = inp2.shape
  610. maxdim = dim1 if dim1 > dim2 else dim2
  611. if dim1 >= 3 or dim2 >= 3:
  612. if use_symbolic_shape():
  613. if dim1 > dim2:
  614. shape2 = concat([shape1[:-2], shape2[-2:]])
  615. inp2 = broadcast_to(inp2, shape2)
  616. if dim1 < dim2:
  617. shape1 = concat([shape2[:-2], shape1[-2:]])
  618. inp1 = broadcast_to(inp1, shape1)
  619. if maxdim > 3:
  620. batch_shape = shape1[:-2]
  621. # compress inputs to 3d
  622. (inp1,) = apply(
  623. builtin.Reshape(), inp1, concat([prod(shape1[:-2]), shape1[-2:]])
  624. )
  625. (inp2,) = apply(
  626. builtin.Reshape(), inp2, concat([prod(shape2[:-2]), shape2[-2:]])
  627. )
  628. else:
  629. if dim1 > dim2:
  630. shape2 = shape1[:-2] + shape2[-2:]
  631. inp2 = broadcast_to(inp2, shape2)
  632. if dim1 < dim2:
  633. shape1 = shape2[:-2] + shape1[-2:]
  634. inp1 = broadcast_to(inp1, shape1)
  635. if maxdim > 3:
  636. batch_shape = shape1[:-2]
  637. # compress inputs to 3d
  638. inp1 = inp1.reshape((-1, shape1[-2], shape1[-1]))
  639. inp2 = inp2.reshape((-1, shape2[-2], shape2[-1]))
  640. op = builtin.BatchedMatrixMul(
  641. transposeA=transpose_a,
  642. transposeB=transpose_b,
  643. compute_mode=compute_mode,
  644. format=format,
  645. strategy=get_execution_strategy(),
  646. )
  647. else:
  648. op = builtin.MatrixMul(
  649. transposeA=transpose_a,
  650. transposeB=transpose_b,
  651. compute_mode=compute_mode,
  652. format=format,
  653. strategy=get_execution_strategy(),
  654. )
  655. (result,) = apply(op, inp1, inp2)
  656. if maxdim > 3:
  657. if use_symbolic_shape():
  658. (result,) = apply(
  659. builtin.Reshape(), result, concat([batch_shape, result.shape[-2:]])
  660. )
  661. else:
  662. result = result.reshape(batch_shape + result.shape[-2:])
  663. if remove_row:
  664. result = squeeze(result, axis=-2)
  665. if remove_col:
  666. result = squeeze(result, axis=-1)
  667. return result
  668. def dot(inp1: Tensor, inp2: Tensor) -> Tensor:
  669. """
  670. Computes dot-product of two vectors ``inp1`` and ``inp2``.
  671. inputs must be 1-dimensional or scalar. A scalar input is automatically broadcasted.
  672. Refer to :func:`~.matmul` for more general usage.
  673. :param inp1: first vector.
  674. :param inp2: second vector.
  675. :return: output value.
  676. Examples:
  677. .. testcode::
  678. import numpy as np
  679. from megengine import tensor
  680. import megengine.functional as F
  681. data1 = tensor(np.arange(0, 6, dtype=np.float32))
  682. data2 = tensor(np.arange(0, 6, dtype=np.float32))
  683. out = F.dot(data1, data2)
  684. print(out.numpy())
  685. Outputs:
  686. .. testoutput::
  687. 55.
  688. """
  689. op = builtin.Dot()
  690. inp1, inp2 = convert_inputs(inp1, inp2)
  691. assert (
  692. inp1.ndim <= 1 and inp2.ndim <= 1
  693. ), "Input tensors for dot must be 1-dimensional or scalar"
  694. (result,) = apply(op, inp1, inp2)
  695. setscalar(result)
  696. return result
  697. def svd(inp: Tensor, full_matrices=False, compute_uv=True) -> Tensor:
  698. """
  699. Computes the singular value decompositions of input matrix.
  700. :param inp: input matrix, must has shape `[..., M, N]`.
  701. :return: output matrices, `(U, sigma, V)`.
  702. Examples:
  703. .. testcode::
  704. import numpy as np
  705. from megengine import tensor
  706. import megengine.functional as F
  707. x = tensor(np.arange(0, 6, dtype=np.float32).reshape(2,3))
  708. _, y, _ = F.svd(x)
  709. print(y.numpy().round(decimals=3))
  710. Outputs:
  711. .. testoutput::
  712. [7.348 1. ]
  713. """
  714. op = builtin.SVD(full_matrices=full_matrices, compute_uv=compute_uv)
  715. U, sigma, V = apply(op, inp)
  716. return U, sigma, V
  717. def _has_inf(inp: Tensor) -> Tensor:
  718. """
  719. Check whether input contains infinite value.
  720. :param inp: a tensor to be checked.
  721. :return: a int32 scalar tensor, 0 for False and 1 for True.
  722. """
  723. op = builtin.CheckHasInf()
  724. (oup,) = apply(op, inp.reshape(-1).astype("float32"))
  725. oup._setscalar()
  726. return oup

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台