You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

math.py 25 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986
  1. # -*- coding: utf-8 -*-
  2. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  3. #
  4. # Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  5. #
  6. # Unless required by applicable law or agreed to in writing,
  7. # software distributed under the License is distributed on an
  8. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  9. import collections
  10. import math
  11. from typing import Optional, Sequence, Tuple, Union
  12. from ..core._imperative_rt.core2 import apply, dtype_promotion
  13. from ..core._trace_option import use_symbolic_shape
  14. from ..core.ops import builtin
  15. from ..core.ops.special import Const
  16. from ..core.tensor import amp
  17. from ..core.tensor.utils import _normalize_axis, cast_tensors, setscalar
  18. from ..tensor import Tensor
  19. from .debug_param import get_execution_strategy
  20. from .elemwise import clip
  21. from .tensor import broadcast_to, concat, expand_dims, squeeze
  22. __all__ = [
  23. "argmax",
  24. "argmin",
  25. "argsort",
  26. "dot",
  27. "isinf",
  28. "isnan",
  29. "matinv",
  30. "matmul",
  31. "max",
  32. "mean",
  33. "min",
  34. "norm",
  35. "normalize",
  36. "prod",
  37. "sign",
  38. "sort",
  39. "std",
  40. "sum",
  41. "svd",
  42. "topk",
  43. "var",
  44. ]
  45. def isnan(inp: Tensor) -> Tensor:
  46. r"""
  47. Returns a new tensor representing if each element is ``NaN`` or not.
  48. :param inp: input tensor.
  49. :return: result tensor.
  50. Examples:
  51. .. testcode::
  52. from megengine import tensor
  53. import megengine.functional as F
  54. x = tensor([1, float("nan"), 0])
  55. print(F.isnan(x).numpy())
  56. Outputs:
  57. .. testoutput::
  58. [False True False]
  59. """
  60. return inp != inp
  61. def isinf(inp: Tensor) -> Tensor:
  62. r"""
  63. Returns a new tensor representing if each element is ``Inf`` or not.
  64. :param inp: input tensor.
  65. :return: result tensor.
  66. Examples:
  67. .. testcode::
  68. from megengine import tensor
  69. import megengine.functional as F
  70. x = tensor([1, float("inf"), 0])
  71. print(F.isinf(x).numpy())
  72. Outputs:
  73. .. testoutput::
  74. [False True False]
  75. """
  76. return abs(inp).astype("float32") == float("inf")
  77. def sign(inp: Tensor):
  78. r"""
  79. Returns a new tensor representing the sign of each element in input tensor.
  80. :param: input tensor.
  81. :return: the sign of input tensor.
  82. Examples:
  83. .. testcode::
  84. from megengine import tensor
  85. import megengine.functional as F
  86. x = tensor([1, -1, 0])
  87. print(F.sign(x).numpy())
  88. Outputs:
  89. .. testoutput::
  90. [ 1 -1 0]
  91. """
  92. return (inp > 0).astype(inp.dtype) - (inp < 0).astype(inp.dtype)
  93. def sum(
  94. inp: Tensor,
  95. axis: Optional[Union[int, Sequence[int]]] = None,
  96. keepdims: bool = False,
  97. ) -> Tensor:
  98. r"""
  99. Returns the sum of input tensor along given axis. If axis is a list of dimensions,
  100. reduce over all of them.
  101. :param inp: input tensor.
  102. :param axis: dimension to reduce. If None, all dimensions will be reduced.
  103. Default: None
  104. :param keepdims: whether the output tensor has axis retained or not.
  105. Default: False
  106. :return: output tensor.
  107. Examples:
  108. .. testcode::
  109. import numpy as np
  110. from megengine import tensor
  111. import megengine.functional as F
  112. x = tensor(np.arange(1, 7, dtype=np.int32).reshape(2, 3))
  113. out = F.sum(x)
  114. print(out.numpy())
  115. Outputs:
  116. .. testoutput::
  117. 21
  118. """
  119. return inp.sum(axis=axis, keepdims=keepdims)
  120. def prod(
  121. inp: Tensor, axis: Optional[Union[int, Sequence[int]]] = None, keepdims=False
  122. ) -> Tensor:
  123. r"""
  124. Returns the product of input tensor along given axis. If axis is a list of dimensions,
  125. reduce over all of them.
  126. :param inp: input tensor.
  127. :param axis: dimension to reduce. If None, all dimensions will be reduced. Default: None
  128. :param keepdims: whether the output tensor has axis retained or not. Default: False
  129. :return: output tensor.
  130. Examples:
  131. .. testcode::
  132. import numpy as np
  133. from megengine import tensor
  134. import megengine.functional as F
  135. x = tensor(np.arange(1, 7, dtype=np.int32).reshape(2, 3))
  136. out = F.prod(x)
  137. print(out.numpy())
  138. Outputs:
  139. .. testoutput::
  140. 720
  141. """
  142. return inp.prod(axis=axis, keepdims=keepdims)
  143. def mean(
  144. inp: Tensor,
  145. axis: Optional[Union[int, Sequence[int]]] = None,
  146. keepdims: bool = False,
  147. ) -> Tensor:
  148. """
  149. Returns the mean value of input tensor along
  150. given axis. If axis is a list of dimensions,
  151. reduce over all of them.
  152. :param inp: input tensor.
  153. :param axis: dimension to reduce. If None, all dimensions will be reduced. Default: None
  154. :param keepdims: whether the output tensor has axis retained or not. Default: False
  155. :return: output tensor.
  156. Examples:
  157. .. testcode::
  158. import numpy as np
  159. from megengine import tensor
  160. import megengine.functional as F
  161. x = tensor(np.arange(1, 7, dtype=np.int32).reshape(2, 3))
  162. out = F.mean(x)
  163. print(out.numpy())
  164. Outputs:
  165. .. testoutput::
  166. 3.5
  167. """
  168. return inp.mean(axis=axis, keepdims=keepdims)
  169. def var(
  170. inp: Tensor,
  171. axis: Optional[Union[int, Sequence[int]]] = None,
  172. keepdims: bool = False,
  173. ) -> Tensor:
  174. """
  175. Returns the variance value of input tensor along
  176. given axis. If axis is a list of dimensions,
  177. reduce over all of them.
  178. :param inp: input tensor.
  179. :param axis: dimension to reduce. If None, all dimensions will be reduced. Default: None
  180. :param keepdims: whether the output tensor has axis retained or not. Default: False
  181. :return: output tensor.
  182. Examples:
  183. .. testcode::
  184. import numpy as np
  185. from megengine import tensor
  186. import megengine.functional as F
  187. data = tensor(np.arange(1, 7, dtype=np.float32).reshape(2, 3))
  188. out = F.var(data)
  189. print(out.numpy().round(decimals=4))
  190. Outputs:
  191. .. testoutput::
  192. 2.9167
  193. """
  194. if axis is None:
  195. m = mean(inp, axis=axis, keepdims=False)
  196. else:
  197. m = mean(inp, axis=axis, keepdims=True)
  198. v = inp - m
  199. return mean(v ** 2, axis=axis, keepdims=keepdims)
  200. def std(
  201. inp: Tensor,
  202. axis: Optional[Union[int, Sequence[int]]] = None,
  203. keepdims: bool = False,
  204. ) -> Tensor:
  205. """
  206. Returns the standard deviation of input tensor along
  207. given axis. If axis is a list of dimensions,
  208. reduce over all of them.
  209. :param inp: input tensor.
  210. :param axis: dimension to reduce. If None, all dimensions will be reduced. Default: None
  211. :param keepdims: whether the output tensor has axis retained or not. Default: False
  212. :return: output tensor.
  213. Examples:
  214. .. testcode::
  215. import numpy as np
  216. from megengine import tensor
  217. import megengine.functional as F
  218. data = tensor(np.arange(1, 7, dtype=np.float32).reshape(2, 3))
  219. out = F.std(data, axis=1)
  220. print(out.numpy().round(decimals=4))
  221. Outputs:
  222. .. testoutput::
  223. [0.8165 0.8165]
  224. """
  225. return var(inp, axis=axis, keepdims=keepdims) ** 0.5
  226. def min(
  227. inp: Tensor,
  228. axis: Optional[Union[int, Sequence[int]]] = None,
  229. keepdims: bool = False,
  230. ) -> Tensor:
  231. r"""
  232. Returns the min value of input tensor along
  233. given axis. If axis is a list of dimensions,
  234. reduce over all of them.
  235. :param inp: input tensor.
  236. :param axis: dimension to reduce. If None, all dimensions will be reduced. Default: None
  237. :param keepdims: whether the output tensor has axis retained or not. Default: False
  238. :return: output tensor.
  239. Examples:
  240. .. testcode::
  241. import numpy as np
  242. from megengine import tensor
  243. import megengine.functional as F
  244. x = tensor(np.arange(1, 7, dtype=np.int32).reshape(2,3))
  245. out = F.min(x)
  246. print(out.numpy())
  247. Outputs:
  248. .. testoutput::
  249. 1
  250. """
  251. return inp.min(axis=axis, keepdims=keepdims)
  252. def max(
  253. inp: Tensor,
  254. axis: Optional[Union[int, Sequence[int]]] = None,
  255. keepdims: bool = False,
  256. ) -> Tensor:
  257. r"""
  258. Returns the max value of the input tensor along
  259. given axis. If axis is a list of dimensions,
  260. reduce over all of them.
  261. :param inp: input tensor.
  262. :param axis: dimension to reduce. If None, all dimensions will be reduced. Default: None
  263. :param keepdims: whether the output tensor has axis retained or not. Default: False
  264. :return: output tensor.
  265. Examples:
  266. .. testcode::
  267. import numpy as np
  268. from megengine import tensor
  269. import megengine.functional as F
  270. x = tensor(np.arange(1, 7, dtype=np.int32).reshape(2,3))
  271. out = F.max(x)
  272. print(out.numpy())
  273. Outputs:
  274. .. testoutput::
  275. 6
  276. """
  277. return inp.max(axis=axis, keepdims=keepdims)
  278. def norm(
  279. inp: Tensor, ord: float = None, axis: int = None, keepdims=False,
  280. ):
  281. """
  282. Calculates ``p``-norm of input tensor along
  283. given axis.
  284. :param inp: input tensor.
  285. :param ord: power of value applied to inp. Default: 2
  286. :param axis: dimension to reduce. If None, input must be a vector. Default: None
  287. :param keepdims: whether the output tensor has axis retained or not. Default: False
  288. :return: output tensor.
  289. Examples:
  290. .. testcode::
  291. import numpy as np
  292. from megengine import tensor
  293. import megengine.functional as F
  294. x = tensor(np.arange(-3, 3, dtype=np.float32))
  295. out = F.norm(x)
  296. print(out.numpy().round(decimals=4))
  297. Outputs:
  298. .. testoutput::
  299. 4.3589
  300. """
  301. if axis is None:
  302. if inp.ndim != 1:
  303. raise TypeError("axis is required unless input is a vector")
  304. if ord is None:
  305. ord = 2
  306. if ord == 0:
  307. return sum(inp != 0, axis=axis, keepdims=keepdims)
  308. if ord == math.inf:
  309. return max(abs(inp))
  310. if ord == -math.inf:
  311. return min(abs(inp))
  312. return sum(abs(inp) ** ord, axis=axis, keepdims=keepdims) ** (1.0 / ord)
  313. def argmin(
  314. inp: Tensor,
  315. axis: Optional[Union[int, Sequence[int]]] = None,
  316. keepdims: bool = False,
  317. ) -> Tensor:
  318. r"""
  319. Returns the indices of the minimum values along
  320. given axis. If axis is a list of dimensions,
  321. reduce over all of them.
  322. :param inp: input tensor.
  323. :param axis: dimension to reduce. If None, all dimensions will be reduced. Default: None
  324. :param keepdims: whether the output tensor has axis retained or not. Default: False
  325. :return: output tensor.
  326. Examples:
  327. .. testcode::
  328. import numpy as np
  329. from megengine import tensor
  330. import megengine.functional as F
  331. x = tensor(np.arange(1, 7, dtype=np.int32).reshape(2,3))
  332. out = F.argmin(x)
  333. print(out.numpy())
  334. Outputs:
  335. .. testoutput::
  336. 0
  337. """
  338. if axis is None:
  339. assert not keepdims, "can not set axis=None and keepdims=True"
  340. inp = inp.flatten()
  341. axis = 0
  342. axis = _normalize_axis(inp.ndim, axis, reverse=True)
  343. if isinstance(axis, collections.abc.Iterable):
  344. for ai in axis:
  345. op = builtin.Argmin(axis=ai)
  346. (inp,) = apply(op, inp)
  347. if not keepdims:
  348. inp = squeeze(inp, ai)
  349. return inp
  350. op = builtin.Argmin(axis=axis)
  351. (result,) = apply(op, inp)
  352. if not keepdims:
  353. result = squeeze(result, axis)
  354. return result
  355. def argmax(
  356. inp: Tensor,
  357. axis: Optional[Union[int, Sequence[int]]] = None,
  358. keepdims: bool = False,
  359. ) -> Tensor:
  360. r"""
  361. Returns the indices of the maximum values along
  362. given axis. If axis is a list of dimensions,
  363. reduce over all of them.
  364. :param inp: input tensor.
  365. :param axis: dimension to reduce. If None, all dimensions will be reduced. Default: None
  366. :param keepdims: whether the output tensor has axis retained or not. Default: False
  367. :return: output tensor.
  368. Examples:
  369. .. testcode::
  370. import numpy as np
  371. from megengine import tensor
  372. import megengine.functional as F
  373. x = tensor(np.arange(1, 7, dtype=np.int32).reshape(2,3))
  374. out = F.argmax(x)
  375. print(out.numpy())
  376. Outputs:
  377. .. testoutput::
  378. 5
  379. """
  380. if axis is None:
  381. assert not keepdims, "can not set axis=None and keepdims=True"
  382. inp = inp.flatten()
  383. axis = 0
  384. axis = _normalize_axis(inp.ndim, axis, reverse=True)
  385. if isinstance(axis, collections.abc.Iterable):
  386. for ai in axis:
  387. op = builtin.Argmax(axis=ai)
  388. (inp,) = apply(op, inp)
  389. if not keepdims:
  390. inp = squeeze(inp, ai)
  391. return inp
  392. op = builtin.Argmax(axis=axis)
  393. (result,) = apply(op, inp)
  394. if not keepdims:
  395. result = squeeze(result, axis)
  396. return result
  397. def normalize(
  398. inp: Tensor, ord: float = None, axis: int = None, eps: float = 1e-12,
  399. ) -> Tensor:
  400. r"""
  401. Performs :math:`L_p` normalization of input tensor along
  402. given axis.
  403. For a tensor of shape :math:`(n_0, ..., n_{dim}, ..., n_k)`, each
  404. :math:`n_{dim}` -element vector :math:`v` along dimension :attr:`axis` is transformed as:
  405. .. math::
  406. v = \frac{v}{\max(\lVert v \rVert_p, \epsilon)}.
  407. :param inp: input tensor.
  408. :param ord: power of value applied to input tensor. Default: 2
  409. :param axis: dimension to reduce.If None, input must be a vector. Default: None
  410. :param eps: a small value to avoid division by zero. Default: 1e-12
  411. :return: normalized output tensor.
  412. """
  413. if axis is None:
  414. return inp / clip(norm(inp, ord, axis), lower=eps)
  415. else:
  416. return inp / clip(norm(inp, ord, axis, keepdims=True), lower=eps)
  417. def argsort(inp: Tensor, descending: bool = False) -> Tensor:
  418. r"""
  419. Returns the indices that would sort the input tensor.
  420. :param inp: input tensor. If it's 2d, the result would be array of indices show how to sort each row in the input tensor.
  421. :param descending: sort in descending order, where the largest comes first. Default: False
  422. :return: indices of int32 indicates how to sort the input.
  423. Examples:
  424. .. testcode::
  425. import numpy as np
  426. from megengine import tensor
  427. import megengine.functional as F
  428. x = tensor(np.array([1,2], dtype=np.float32))
  429. indices = F.argsort(x)
  430. print(indices.numpy())
  431. Outputs:
  432. .. testoutput::
  433. [0 1]
  434. """
  435. assert len(inp.shape) <= 2, "Input should be 1d or 2d"
  436. if descending:
  437. order = "descending"
  438. else:
  439. order = "ascending"
  440. op = builtin.Argsort(order=order)
  441. if len(inp.shape) == 1:
  442. inp = inp.reshape(1, -1)
  443. _, result = apply(op, inp)
  444. return result[0]
  445. _, result = apply(op, inp)
  446. return result
  447. def sort(inp: Tensor, descending: bool = False) -> Tuple[Tensor, Tensor]:
  448. r"""
  449. Returns sorted tensor and the indices would sort the input tensor.
  450. :param inp: input tensor. If it's 2d, the result would be sorted by row.
  451. :param descending: sort in descending order, where the largest comes first. Default: False
  452. :return: tuple of two tensors `(sorted_tensor, indices_of_int32)`.
  453. Examples:
  454. .. testcode::
  455. import numpy as np
  456. from megengine import tensor
  457. import megengine.functional as F
  458. x = tensor(np.array([1,2], dtype=np.float32))
  459. out, indices = F.sort(x)
  460. print(out.numpy())
  461. Outputs:
  462. .. testoutput::
  463. [1. 2.]
  464. """
  465. assert len(inp.shape) <= 2, "Input should be 1d or 2d"
  466. if descending:
  467. order = "descending"
  468. else:
  469. order = "ascending"
  470. op = builtin.Argsort(order=order)
  471. if len(inp.shape) == 1:
  472. inp = inp.reshape(1, -1)
  473. tns, ind = apply(op, inp)
  474. return tns[0], ind[0]
  475. tns, ind = apply(op, inp)
  476. return tns, ind
  477. def topk(
  478. inp: Tensor,
  479. k: int,
  480. descending: bool = False,
  481. kth_only: bool = False,
  482. no_sort: bool = False,
  483. ) -> Tuple[Tensor, Tensor]:
  484. r"""
  485. Selects the ``Top-K`` (by default) smallest elements of 2d matrix by row.
  486. :param inp: input tensor. If input tensor is 2d, each row will be sorted.
  487. :param k: number of elements needed.
  488. :param descending: if True, return the largest elements instead. Default: False
  489. :param kth_only: if True, only the k-th element will be returned. Default: False
  490. :param no_sort: if True, the returned elements can be unordered. Default: False
  491. :return: tuple of two tensors ``(topk_tensor, indices_of_int32)``
  492. Examples:
  493. .. testcode::
  494. import numpy as np
  495. from megengine import tensor
  496. import megengine.functional as F
  497. x = tensor(np.array([2, 4, 6, 8, 7, 5, 3, 1], dtype=np.float32))
  498. top, indices = F.topk(x, 5)
  499. print(top.numpy(), indices.numpy())
  500. Outputs:
  501. .. testoutput::
  502. [1. 2. 3. 4. 5.] [7 0 6 1 5]
  503. """
  504. if descending:
  505. k = -k
  506. if kth_only:
  507. mode = "kth_only"
  508. elif no_sort:
  509. mode = "value_idx_nosort"
  510. else:
  511. mode = "value_idx_sorted"
  512. op = builtin.TopK(mode=mode)
  513. if not isinstance(k, Tensor):
  514. (k,) = Const(k, dtype="int32", device=inp.device)()
  515. if len(inp.shape) == 1:
  516. if kth_only:
  517. (tns,) = apply(op, expand_dims(inp, 0), k)
  518. # FIXME:
  519. # could use a dedicated kernel
  520. # gradient may be routed to other indices if k-th value is not unique
  521. ind = argmax((tns == inp).astype("int8"))
  522. tns = squeeze(tns, 0)
  523. else:
  524. tns, ind = apply(op, expand_dims(inp, 0), k)
  525. tns = squeeze(tns, 0)
  526. ind = squeeze(ind, 0)
  527. else:
  528. if kth_only:
  529. (tns,) = apply(op, inp, k)
  530. # FIXME: same as above
  531. ind = argmax((expand_dims(tns, 1) == inp).astype("int8"), 1)
  532. else:
  533. tns, ind = apply(op, inp, k)
  534. return tns, ind
  535. def matinv(inp: Tensor) -> Tensor:
  536. """
  537. Computes the inverse of a batch of matrices; input must has shape [..., n, n].
  538. :param inp: input tensor.
  539. :return: output tensor.
  540. Examples:
  541. .. testcode::
  542. import numpy as np
  543. from megengine import tensor
  544. import megengine.functional as F
  545. data = tensor([[1.0, 0.0], [1.0, 1.0]])
  546. out = F.matinv(data)
  547. print(out.numpy())
  548. Outputs:
  549. .. testoutput::
  550. [[ 1. 0.]
  551. [-1. 1.]]
  552. """
  553. (result,) = apply(builtin.MatrixInverse(), inp)
  554. return result
  555. def matmul(
  556. inp1: Tensor,
  557. inp2: Tensor,
  558. transpose_a=False,
  559. transpose_b=False,
  560. compute_mode="default",
  561. format="default",
  562. ) -> Tensor:
  563. """
  564. Performs a matrix multiplication of the matrices ``inp1`` and ``inp2``.
  565. With different inputs dim, this function behaves differently:
  566. - Both 1-D tensor, simply forward to ``dot``.
  567. - Both 2-D tensor, normal matrix multiplication.
  568. - If one input tensor is 1-D, matrix vector multiplication.
  569. - If at least one tensor are 3-dimensional or >3-dimensional, the other tensor should have dim >= 2,
  570. the batched matrix-matrix is returned, and the tensor with smaller dimension will be broadcasted.
  571. For example:
  572. - inp1: `(n, k, m)`, inp2: `(n, m, p)`, return: `(n, k, p)`
  573. - inp1: `(n, k, m)`, inp2: `(m, p)`, return: `(n, k, p)`
  574. - inp1: `(n, j, k, m)`, inp2: `(n, j, m, p)`, return: `(n, j, k, p)`
  575. :param inp1: first matrix to be multiplied.
  576. :param inp2: second matrix to be multiplied.
  577. :return: output tensor.
  578. Examples:
  579. .. testcode::
  580. import numpy as np
  581. from megengine import tensor
  582. import megengine.functional as F
  583. data1 = tensor(np.arange(0, 6, dtype=np.float32).reshape(2, 3))
  584. data2 = tensor(np.arange(0, 6, dtype=np.float32).reshape(3, 2))
  585. out = F.matmul(data1, data2)
  586. print(out.numpy())
  587. Outputs:
  588. .. testoutput::
  589. [[10. 13.]
  590. [28. 40.]]
  591. """
  592. if amp._enabled:
  593. compute_mode = "float32"
  594. inp1, inp2 = cast_tensors(inp1, inp2)
  595. else:
  596. dtype = dtype_promotion(inp1, inp2)
  597. if inp1.dtype != dtype:
  598. inp1 = inp1.astype(dtype)
  599. if inp2.dtype != dtype:
  600. inp2 = inp2.astype(dtype)
  601. remove_row, remove_col = False, False
  602. dim1, dim2 = inp1.ndim, inp2.ndim
  603. # handle dim=1 cases, dot and matrix-vector multiplication
  604. if dim1 == 1 and dim2 == 1:
  605. return dot(inp1, inp2)
  606. # the underlying matmul op requires input dims to be at least 2
  607. if dim1 == 1:
  608. inp1 = expand_dims(inp1, 0)
  609. dim1 = 2
  610. remove_row = True
  611. if dim2 == 1:
  612. inp2 = expand_dims(inp2, 1)
  613. dim2 = 2
  614. remove_col = True
  615. batch_shape = None
  616. shape1 = inp1.shape
  617. shape2 = inp2.shape
  618. maxdim = dim1 if dim1 > dim2 else dim2
  619. if dim1 >= 3 or dim2 >= 3:
  620. if use_symbolic_shape():
  621. if dim1 > dim2:
  622. shape2 = concat([shape1[:-2], shape2[-2:]])
  623. inp2 = broadcast_to(inp2, shape2)
  624. if dim1 < dim2:
  625. shape1 = concat([shape2[:-2], shape1[-2:]])
  626. inp1 = broadcast_to(inp1, shape1)
  627. if maxdim > 3:
  628. batch_shape = shape1[:-2]
  629. # compress inputs to 3d
  630. (inp1,) = apply(
  631. builtin.Reshape(), inp1, concat([prod(shape1[:-2]), shape1[-2:]])
  632. )
  633. (inp2,) = apply(
  634. builtin.Reshape(), inp2, concat([prod(shape2[:-2]), shape2[-2:]])
  635. )
  636. else:
  637. if dim1 > dim2:
  638. shape2 = shape1[:-2] + shape2[-2:]
  639. inp2 = broadcast_to(inp2, shape2)
  640. if dim1 < dim2:
  641. shape1 = shape2[:-2] + shape1[-2:]
  642. inp1 = broadcast_to(inp1, shape1)
  643. if maxdim > 3:
  644. batch_shape = shape1[:-2]
  645. # compress inputs to 3d
  646. inp1 = inp1.reshape((-1, shape1[-2], shape1[-1]))
  647. inp2 = inp2.reshape((-1, shape2[-2], shape2[-1]))
  648. op = builtin.BatchedMatrixMul(
  649. transposeA=transpose_a,
  650. transposeB=transpose_b,
  651. compute_mode=compute_mode,
  652. format=format,
  653. strategy=get_execution_strategy(),
  654. )
  655. else:
  656. op = builtin.MatrixMul(
  657. transposeA=transpose_a,
  658. transposeB=transpose_b,
  659. compute_mode=compute_mode,
  660. format=format,
  661. strategy=get_execution_strategy(),
  662. )
  663. (result,) = apply(op, inp1, inp2)
  664. if maxdim > 3:
  665. if use_symbolic_shape():
  666. (result,) = apply(
  667. builtin.Reshape(), result, concat([batch_shape, result.shape[-2:]])
  668. )
  669. else:
  670. result = result.reshape(batch_shape + result.shape[-2:])
  671. if remove_row:
  672. result = squeeze(result, axis=-2)
  673. if remove_col:
  674. result = squeeze(result, axis=-1)
  675. return result
  676. def dot(inp1: Tensor, inp2: Tensor) -> Tensor:
  677. """
  678. Computes dot-product of two vectors ``inp1`` and ``inp2``.
  679. inputs must be 1-dimensional or scalar. A scalar input is automatically broadcasted.
  680. Refer to :func:`~.matmul` for more general usage.
  681. :param inp1: first vector.
  682. :param inp2: second vector.
  683. :return: output value.
  684. Examples:
  685. .. testcode::
  686. import numpy as np
  687. from megengine import tensor
  688. import megengine.functional as F
  689. data1 = tensor(np.arange(0, 6, dtype=np.float32))
  690. data2 = tensor(np.arange(0, 6, dtype=np.float32))
  691. out = F.dot(data1, data2)
  692. print(out.numpy())
  693. Outputs:
  694. .. testoutput::
  695. 55.
  696. """
  697. op = builtin.Dot()
  698. assert (
  699. inp1.ndim <= 1 and inp2.ndim <= 1
  700. ), "Input tensors for dot must be 1-dimensional or scalar"
  701. (result,) = apply(op, inp1, inp2)
  702. setscalar(result)
  703. return result
  704. def svd(inp: Tensor, full_matrices=False, compute_uv=True) -> Tensor:
  705. """
  706. Computes the singular value decompositions of input matrix.
  707. :param inp: input matrix, must has shape `[..., M, N]`.
  708. :return: output matrices, `(U, sigma, V)`.
  709. Examples:
  710. .. testcode::
  711. import numpy as np
  712. from megengine import tensor
  713. import megengine.functional as F
  714. x = tensor(np.arange(0, 6, dtype=np.float32).reshape(2,3))
  715. _, y, _ = F.svd(x)
  716. print(y.numpy().round(decimals=3))
  717. Outputs:
  718. .. testoutput::
  719. [7.348 1. ]
  720. """
  721. op = builtin.SVD(full_matrices=full_matrices, compute_uv=compute_uv)
  722. U, sigma, V = apply(op, inp)
  723. return U, sigma, V
  724. def _has_inf(inp: Tensor) -> Tensor:
  725. """
  726. Check whether input contains infinite value.
  727. :param inp: a tensor to be checked.
  728. :return: a int32 scalar tensor, 0 for False and 1 for True.
  729. """
  730. op = builtin.CheckHasInf()
  731. (oup,) = apply(op, inp.reshape(-1).astype("float32"))
  732. oup._setscalar()
  733. return oup

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台