You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

math.py 24 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963
  1. # -*- coding: utf-8 -*-
  2. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  3. #
  4. # Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  5. #
  6. # Unless required by applicable law or agreed to in writing,
  7. # software distributed under the License is distributed on an
  8. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  9. import collections
  10. import functools
  11. import math
  12. import numbers
  13. from typing import Optional, Sequence, Tuple, Union
  14. from ..core._imperative_rt.core2 import apply
  15. from ..core._trace_option import use_symbolic_shape
  16. from ..core.ops import builtin
  17. from ..core.ops.special import Const
  18. from ..core.tensor import utils
  19. from ..tensor import Tensor
  20. from .debug_param import get_conv_execution_strategy
  21. from .elemwise import clip, exp, log, log1p
  22. from .tensor import broadcast_to, concat, expand_dims, reshape, squeeze
  23. __all__ = [
  24. "argmax",
  25. "argmin",
  26. "argsort",
  27. "dot",
  28. "isinf",
  29. "isnan",
  30. "matinv",
  31. "matmul",
  32. "max",
  33. "mean",
  34. "min",
  35. "norm",
  36. "normalize",
  37. "prod",
  38. "sign",
  39. "sort",
  40. "std",
  41. "sum",
  42. "svd",
  43. "topk",
  44. "var",
  45. ]
  46. def isnan(inp: Tensor) -> Tensor:
  47. r"""
  48. Returns a new tensor representing if each element is ``NaN`` or not.
  49. :param inp: input tensor.
  50. :return: result tensor.
  51. Examples:
  52. .. testcode::
  53. from megengine import tensor
  54. import megengine.functional as F
  55. x = tensor([1, float("nan"), 0])
  56. print(F.isnan(x).numpy())
  57. Outputs:
  58. .. testoutput::
  59. [False True False]
  60. """
  61. return inp != inp
  62. def isinf(inp: Tensor) -> Tensor:
  63. r"""
  64. Returns a new tensor representing if each element is ``Inf`` or not.
  65. :param inp: input tensor.
  66. :return: result tensor.
  67. Examples:
  68. .. testcode::
  69. from megengine import tensor
  70. import megengine.functional as F
  71. x = tensor([1, float("inf"), 0])
  72. print(F.isinf(x).numpy())
  73. Outputs:
  74. .. testoutput::
  75. [False True False]
  76. """
  77. return abs(inp).astype("float32") == float("inf")
  78. def sign(inp: Tensor):
  79. r"""
  80. Returns a new tensor representing the sign of each element in input tensor.
  81. :param: input tensor.
  82. :return: the sign of input tensor.
  83. Examples:
  84. .. testcode::
  85. from megengine import tensor
  86. import megengine.functional as F
  87. x = tensor([1, -1, 0])
  88. print(F.sign(x).numpy())
  89. Outputs:
  90. .. testoutput::
  91. [ 1 -1 0]
  92. """
  93. return (inp > 0).astype(inp.dtype) - (inp < 0).astype(inp.dtype)
  94. def sum(
  95. inp: Tensor,
  96. axis: Optional[Union[int, Sequence[int]]] = None,
  97. keepdims: bool = False,
  98. ) -> Tensor:
  99. r"""
  100. Returns the sum of input tensor along given axis. If axis is a list of dimensions,
  101. reduce over all of them.
  102. :param inp: input tensor.
  103. :param axis: dimension to reduce. If None, all dimensions will be reduced.
  104. Default: None
  105. :param keepdims: whether the output tensor has axis retained or not.
  106. Default: False
  107. :return: output tensor.
  108. Examples:
  109. .. testcode::
  110. import numpy as np
  111. from megengine import tensor
  112. import megengine.functional as F
  113. x = tensor(np.arange(1, 7, dtype=np.int32).reshape(2, 3))
  114. out = F.sum(x)
  115. print(out.numpy())
  116. Outputs:
  117. .. testoutput::
  118. 21
  119. """
  120. return inp.sum(axis=axis, keepdims=keepdims)
  121. def prod(
  122. inp: Tensor, axis: Optional[Union[int, Sequence[int]]] = None, keepdims=False
  123. ) -> Tensor:
  124. r"""
  125. Returns the product of input tensor along given axis. If axis is a list of dimensions,
  126. reduce over all of them.
  127. :param inp: input tensor.
  128. :param axis: dimension to reduce. If None, all dimensions will be reduced. Default: None
  129. :param keepdims: whether the output tensor has axis retained or not. Default: False
  130. :return: output tensor.
  131. Examples:
  132. .. testcode::
  133. import numpy as np
  134. from megengine import tensor
  135. import megengine.functional as F
  136. x = tensor(np.arange(1, 7, dtype=np.int32).reshape(2, 3))
  137. out = F.prod(x)
  138. print(out.numpy())
  139. Outputs:
  140. .. testoutput::
  141. 720
  142. """
  143. return inp.prod(axis=axis, keepdims=keepdims)
  144. def mean(
  145. inp: Tensor,
  146. axis: Optional[Union[int, Sequence[int]]] = None,
  147. keepdims: bool = False,
  148. ) -> Tensor:
  149. """
  150. Returns the mean value of input tensor along
  151. given axis. If axis is a list of dimensions,
  152. reduce over all of them.
  153. :param inp: input tensor.
  154. :param axis: dimension to reduce. If None, all dimensions will be reduced. Default: None
  155. :param keepdims: whether the output tensor has axis retained or not. Default: False
  156. :return: output tensor.
  157. Examples:
  158. .. testcode::
  159. import numpy as np
  160. from megengine import tensor
  161. import megengine.functional as F
  162. x = tensor(np.arange(1, 7, dtype=np.int32).reshape(2, 3))
  163. out = F.mean(x)
  164. print(out.numpy())
  165. Outputs:
  166. .. testoutput::
  167. 3.5
  168. """
  169. return inp.mean(axis=axis, keepdims=keepdims)
  170. def var(
  171. inp: Tensor,
  172. axis: Optional[Union[int, Sequence[int]]] = None,
  173. keepdims: bool = False,
  174. ) -> Tensor:
  175. """
  176. Returns the variance value of input tensor along
  177. given axis. If axis is a list of dimensions,
  178. reduce over all of them.
  179. :param inp: input tensor.
  180. :param axis: dimension to reduce. If None, all dimensions will be reduced. Default: None
  181. :param keepdims: whether the output tensor has axis retained or not. Default: False
  182. :return: output tensor.
  183. Examples:
  184. .. testcode::
  185. import numpy as np
  186. from megengine import tensor
  187. import megengine.functional as F
  188. data = tensor(np.arange(1, 7, dtype=np.float32).reshape(2, 3))
  189. out = F.var(data)
  190. print(out.numpy().round(decimals=4))
  191. Outputs:
  192. .. testoutput::
  193. 2.9167
  194. """
  195. if axis is None:
  196. m = mean(inp, axis=axis, keepdims=False)
  197. else:
  198. m = mean(inp, axis=axis, keepdims=True)
  199. v = inp - m
  200. return mean(v ** 2, axis=axis, keepdims=keepdims)
  201. def std(
  202. inp: Tensor,
  203. axis: Optional[Union[int, Sequence[int]]] = None,
  204. keepdims: bool = False,
  205. ) -> Tensor:
  206. """
  207. Returns the standard deviation of input tensor along
  208. given axis. If axis is a list of dimensions,
  209. reduce over all of them.
  210. :param inp: input tensor.
  211. :param axis: dimension to reduce. If None, all dimensions will be reduced. Default: None
  212. :param keepdims: whether the output tensor has axis retained or not. Default: False
  213. :return: output tensor.
  214. Examples:
  215. .. testcode::
  216. import numpy as np
  217. from megengine import tensor
  218. import megengine.functional as F
  219. data = tensor(np.arange(1, 7, dtype=np.float32).reshape(2, 3))
  220. out = F.std(data, axis=1)
  221. print(out.numpy().round(decimals=4))
  222. Outputs:
  223. .. testoutput::
  224. [0.8165 0.8165]
  225. """
  226. return var(inp, axis=axis, keepdims=keepdims) ** 0.5
  227. def min(
  228. inp: Tensor,
  229. axis: Optional[Union[int, Sequence[int]]] = None,
  230. keepdims: bool = False,
  231. ) -> Tensor:
  232. r"""
  233. Returns the min value of input tensor along
  234. given axis. If axis is a list of dimensions,
  235. reduce over all of them.
  236. :param inp: input tensor.
  237. :param axis: dimension to reduce. If None, all dimensions will be reduced. Default: None
  238. :param keepdims: whether the output tensor has axis retained or not. Default: False
  239. :return: output tensor.
  240. Examples:
  241. .. testcode::
  242. import numpy as np
  243. from megengine import tensor
  244. import megengine.functional as F
  245. x = tensor(np.arange(1, 7, dtype=np.int32).reshape(2,3))
  246. out = F.min(x)
  247. print(out.numpy())
  248. Outputs:
  249. .. testoutput::
  250. 1
  251. """
  252. return inp.min(axis=axis, keepdims=keepdims)
  253. def max(
  254. inp: Tensor,
  255. axis: Optional[Union[int, Sequence[int]]] = None,
  256. keepdims: bool = False,
  257. ) -> Tensor:
  258. r"""
  259. Returns the max value of the input tensor along
  260. given axis. If axis is a list of dimensions,
  261. reduce over all of them.
  262. :param inp: input tensor.
  263. :param axis: dimension to reduce. If None, all dimensions will be reduced. Default: None
  264. :param keepdims: whether the output tensor has axis retained or not. Default: False
  265. :return: output tensor.
  266. Examples:
  267. .. testcode::
  268. import numpy as np
  269. from megengine import tensor
  270. import megengine.functional as F
  271. x = tensor(np.arange(1, 7, dtype=np.int32).reshape(2,3))
  272. out = F.max(x)
  273. print(out.numpy())
  274. Outputs:
  275. .. testoutput::
  276. 6
  277. """
  278. return inp.max(axis=axis, keepdims=keepdims)
  279. def norm(
  280. inp: Tensor, ord: float = None, axis: int = None, keepdims=False,
  281. ):
  282. """
  283. Calculates ``p``-norm of input tensor along
  284. given axis.
  285. :param inp: input tensor.
  286. :param ord: power of value applied to inp. Default: 2
  287. :param axis: dimension to reduce. If None, input must be a vector. Default: None
  288. :param keepdims: whether the output tensor has axis retained or not. Default: False
  289. :return: output tensor.
  290. Examples:
  291. .. testcode::
  292. import numpy as np
  293. from megengine import tensor
  294. import megengine.functional as F
  295. x = tensor(np.arange(-3, 3, dtype=np.float32))
  296. out = F.norm(x)
  297. print(out.numpy().round(decimals=4))
  298. Outputs:
  299. .. testoutput::
  300. 4.3589
  301. """
  302. if axis is None:
  303. if inp.ndim != 1:
  304. raise TypeError("axis is required unless input is a vector")
  305. if ord is None:
  306. ord = 2
  307. if ord == 0:
  308. return sum(inp != 0, axis=axis, keepdims=keepdims)
  309. if ord == math.inf:
  310. return max(abs(inp))
  311. if ord == -math.inf:
  312. return min(abs(inp))
  313. return sum(abs(inp) ** ord, axis=axis, keepdims=keepdims) ** (1.0 / ord)
  314. def argmin(
  315. inp: Tensor,
  316. axis: Optional[Union[int, Sequence[int]]] = None,
  317. keepdims: bool = False,
  318. ) -> Tensor:
  319. r"""
  320. Returns the indices of the minimum values along
  321. given axis. If axis is a list of dimensions,
  322. reduce over all of them.
  323. :param inp: input tensor.
  324. :param axis: dimension to reduce. If None, all dimensions will be reduced. Default: None
  325. :param keepdims: whether the output tensor has axis retained or not. Default: False
  326. :return: output tensor.
  327. Examples:
  328. .. testcode::
  329. import numpy as np
  330. from megengine import tensor
  331. import megengine.functional as F
  332. x = tensor(np.arange(1, 7, dtype=np.int32).reshape(2,3))
  333. out = F.argmin(x)
  334. print(out.numpy())
  335. Outputs:
  336. .. testoutput::
  337. 0
  338. """
  339. if isinstance(axis, collections.abc.Iterable):
  340. axis = list(axis)
  341. axis.sort(reverse=True)
  342. for ai in axis:
  343. op = builtin.Argmin(axis=ai)
  344. (inp,) = apply(op, inp)
  345. if not keepdims:
  346. inp = squeeze(inp, ai)
  347. return inp
  348. if axis is None:
  349. assert not keepdims, "can not set axis=None and keepdims=True"
  350. inp = inp.flatten()
  351. axis = 0
  352. op = builtin.Argmin(axis=axis)
  353. (result,) = apply(op, inp)
  354. if not keepdims:
  355. result = squeeze(result, axis)
  356. return result
  357. def argmax(
  358. inp: Tensor,
  359. axis: Optional[Union[int, Sequence[int]]] = None,
  360. keepdims: bool = False,
  361. ) -> Tensor:
  362. r"""
  363. Returns the indices of the maximum values along
  364. given axis. If axis is a list of dimensions,
  365. reduce over all of them.
  366. :param inp: input tensor.
  367. :param axis: dimension to reduce. If None, all dimensions will be reduced. Default: None
  368. :param keepdims: whether the output tensor has axis retained or not. Default: False
  369. :return: output tensor.
  370. Examples:
  371. .. testcode::
  372. import numpy as np
  373. from megengine import tensor
  374. import megengine.functional as F
  375. x = tensor(np.arange(1, 7, dtype=np.int32).reshape(2,3))
  376. out = F.argmax(x)
  377. print(out.numpy())
  378. Outputs:
  379. .. testoutput::
  380. 5
  381. """
  382. if isinstance(axis, collections.abc.Iterable):
  383. axis = list(axis)
  384. axis.sort(reverse=True)
  385. for ai in axis:
  386. op = builtin.Argmax(axis=ai)
  387. (inp,) = apply(op, inp)
  388. if not keepdims:
  389. inp = squeeze(inp, ai)
  390. return inp
  391. if axis is None:
  392. assert not keepdims, "can not set axis=None and keepdims=True"
  393. inp = inp.flatten()
  394. axis = 0
  395. op = builtin.Argmax(axis=axis)
  396. (result,) = apply(op, inp)
  397. if not keepdims:
  398. result = squeeze(result, axis)
  399. return result
  400. def normalize(
  401. inp: Tensor, ord: float = None, axis: int = None, eps: float = 1e-12,
  402. ) -> Tensor:
  403. r"""
  404. Performs :math:`L_p` normalization of input tensor along
  405. given axis.
  406. For a tensor of shape :math:`(n_0, ..., n_{dim}, ..., n_k)`, each
  407. :math:`n_{dim}` -element vector :math:`v` along dimension :attr:`axis` is transformed as:
  408. .. math::
  409. v = \frac{v}{\max(\lVert v \rVert_p, \epsilon)}.
  410. :param inp: input tensor.
  411. :param ord: power of value applied to input tensor. Default: 2
  412. :param axis: dimension to reduce.If None, input must be a vector. Default: None
  413. :param eps: a small value to avoid division by zero. Default: 1e-12
  414. :return: normalized output tensor.
  415. """
  416. if axis is None:
  417. return inp / clip(norm(inp, ord, axis), lower=eps)
  418. else:
  419. return inp / clip(norm(inp, ord, axis, keepdims=True), lower=eps)
  420. def argsort(inp: Tensor, descending: bool = False) -> Tensor:
  421. r"""
  422. Returns the indices that would sort the input tensor.
  423. :param inp: input tensor. If it's 2d, the result would be array of indices show how to sort each row in the input tensor.
  424. :param descending: sort in descending order, where the largest comes first. Default: False
  425. :return: indices of int32 indicates how to sort the input.
  426. Examples:
  427. .. testcode::
  428. import numpy as np
  429. from megengine import tensor
  430. import megengine.functional as F
  431. x = tensor(np.array([1,2], dtype=np.float32))
  432. indices = F.argsort(x)
  433. print(indices.numpy())
  434. Outputs:
  435. .. testoutput::
  436. [0 1]
  437. """
  438. assert len(inp.shape) <= 2, "Input should be 1d or 2d"
  439. if descending:
  440. order = "DESCENDING"
  441. else:
  442. order = "ASCENDING"
  443. op = builtin.Argsort(order=order)
  444. if len(inp.shape) == 1:
  445. inp = inp.reshape(1, -1)
  446. _, result = apply(op, inp)
  447. return result[0]
  448. _, result = apply(op, inp)
  449. return result
  450. def sort(inp: Tensor, descending: bool = False) -> Tuple[Tensor, Tensor]:
  451. r"""
  452. Returns sorted tensor and the indices would sort the input tensor.
  453. :param inp: input tensor. If it's 2d, the result would be sorted by row.
  454. :param descending: sort in descending order, where the largest comes first. Default: False
  455. :return: tuple of two tensors `(sorted_tensor, indices_of_int32)`.
  456. Examples:
  457. .. testcode::
  458. import numpy as np
  459. from megengine import tensor
  460. import megengine.functional as F
  461. x = tensor(np.array([1,2], dtype=np.float32))
  462. out, indices = F.sort(x)
  463. print(out.numpy())
  464. Outputs:
  465. .. testoutput::
  466. [1. 2.]
  467. """
  468. assert len(inp.shape) <= 2, "Input should be 1d or 2d"
  469. if descending:
  470. order = "DESCENDING"
  471. else:
  472. order = "ASCENDING"
  473. op = builtin.Argsort(order=order)
  474. if len(inp.shape) == 1:
  475. inp = inp.reshape(1, -1)
  476. tns, ind = apply(op, inp)
  477. return tns[0], ind[0]
  478. tns, ind = apply(op, inp)
  479. return tns, ind
  480. def topk(
  481. inp: Tensor,
  482. k: int,
  483. descending: bool = False,
  484. kth_only: bool = False,
  485. no_sort: bool = False,
  486. ) -> Tuple[Tensor, Tensor]:
  487. r"""
  488. Selects the ``Top-K`` (by default) smallest elements of 2d matrix by row.
  489. :param inp: input tensor. If input tensor is 2d, each row will be sorted.
  490. :param k: number of elements needed.
  491. :param descending: if True, return the largest elements instead. Default: False
  492. :param kth_only: if True, only the k-th element will be returned. Default: False
  493. :param no_sort: if True, the returned elements can be unordered. Default: False
  494. :return: tuple of two tensors `(topk_tensor, indices_of_int32)`.
  495. Examples:
  496. .. testcode::
  497. import numpy as np
  498. from megengine import tensor
  499. import megengine.functional as F
  500. x = tensor(np.array([2, 4, 6, 8, 7, 5, 3, 1], dtype=np.float32))
  501. top, indices = F.topk(x, 5)
  502. print(top.numpy(), indices.numpy())
  503. Outputs:
  504. .. testoutput::
  505. [1. 2. 3. 4. 5.] [7 0 6 1 5]
  506. """
  507. if descending:
  508. inp = -inp
  509. if kth_only:
  510. mode = "KTH_ONLY"
  511. elif no_sort:
  512. mode = "VALUE_IDX_NOSORT"
  513. else:
  514. mode = "VALUE_IDX_SORTED"
  515. op = builtin.TopK(mode=mode)
  516. if not isinstance(k, Tensor):
  517. (k,) = Const(k, dtype="int32", device=inp.device)()
  518. if len(inp.shape) == 1:
  519. inp = inp.reshape(1, -1)
  520. res = apply(op, inp, k)
  521. if kth_only:
  522. tns = res[0]
  523. else:
  524. tns, ind = res[0][0], res[1][0]
  525. else:
  526. res = apply(op, inp, k)
  527. if kth_only:
  528. tns = res
  529. else:
  530. tns, ind = res[0], res[1]
  531. if descending:
  532. tns = -tns
  533. return tns, ind
  534. def matinv(inp: Tensor) -> Tensor:
  535. """
  536. Computes the inverse of a batch of matrices; input must has shape [..., n, n].
  537. :param inp: input tensor.
  538. :return: output tensor.
  539. Examples:
  540. .. testcode::
  541. import numpy as np
  542. from megengine import tensor
  543. import megengine.functional as F
  544. data = tensor([[1.0, 0.0], [1.0, 1.0]])
  545. out = F.matinv(data)
  546. print(out.numpy())
  547. Outputs:
  548. .. testoutput::
  549. [[ 1. 0.]
  550. [-1. 1.]]
  551. """
  552. (result,) = apply(builtin.MatrixInverse(), inp)
  553. return result
  554. def matmul(
  555. inp1: Tensor,
  556. inp2: Tensor,
  557. transpose_a=False,
  558. transpose_b=False,
  559. compute_mode="DEFAULT",
  560. format="DEFAULT",
  561. ) -> Tensor:
  562. """
  563. Performs a matrix multiplication of the matrices ``inp1`` and ``inp2``.
  564. With different inputs dim, this function behaves differently:
  565. - Both 1-D tensor, simply forward to ``dot``.
  566. - Both 2-D tensor, normal matrix multiplication.
  567. - If one input tensor is 1-D, matrix vector multiplication.
  568. - If at least one tensor are 3-dimensional or >3-dimensional, the other tensor should have dim >= 2, the batched matrix-matrix is returned, and the tensor with smaller dimension will be broadcasted. For example:
  569. - inp1: `(n, k, m)`, inp2: `(n, m, p)`, return: `(n, k, p)`
  570. - inp1: `(n, k, m)`, inp2: `(m, p)`, return: `(n, k, p)`
  571. - inp1: `(n, j, k, m)`, inp2: `(n, j, m, p)`, return: `(n, j, k, p)`
  572. :param inp1: first matrix to be multiplied.
  573. :param inp2: second matrix to be multiplied.
  574. :return: output tensor.
  575. Examples:
  576. .. testcode::
  577. import numpy as np
  578. from megengine import tensor
  579. import megengine.functional as F
  580. data1 = tensor(np.arange(0, 6, dtype=np.float32).reshape(2, 3))
  581. data2 = tensor(np.arange(0, 6, dtype=np.float32).reshape(3, 2))
  582. out = F.matmul(data1, data2)
  583. print(out.numpy())
  584. Outputs:
  585. .. testoutput::
  586. [[10. 13.]
  587. [28. 40.]]
  588. """
  589. remove_row, remove_col = False, False
  590. inp1, inp2 = utils.convert_inputs(inp1, inp2)
  591. dim1, dim2 = inp1.ndim, inp2.ndim
  592. # handle dim=1 cases, dot and matrix-vector multiplication
  593. if dim1 == 1 and dim2 == 1:
  594. return dot(inp1, inp2)
  595. # the underlying matmul op requires input dims to be at least 2
  596. if dim1 == 1:
  597. inp1 = expand_dims(inp1, 0)
  598. dim1 = 2
  599. remove_row = True
  600. if dim2 == 1:
  601. inp2 = expand_dims(inp2, 1)
  602. dim2 = 2
  603. remove_col = True
  604. batch_shape = None
  605. shape1 = inp1.shape
  606. shape2 = inp2.shape
  607. maxdim = dim1 if dim1 > dim2 else dim2
  608. if dim1 >= 3 or dim2 >= 3:
  609. if use_symbolic_shape():
  610. if dim1 > dim2:
  611. shape2 = concat([shape1[:-2], shape2[-2:]])
  612. inp2 = broadcast_to(inp2, shape2)
  613. if dim1 < dim2:
  614. shape1 = concat([shape2[:-2], shape1[-2:]])
  615. inp1 = broadcast_to(inp1, shape1)
  616. if maxdim > 3:
  617. batch_shape = shape1[:-2]
  618. # compress inputs to 3d
  619. (inp1,) = apply(
  620. builtin.Reshape(), inp1, concat([prod(shape1[:-2]), shape1[-2:]])
  621. )
  622. (inp2,) = apply(
  623. builtin.Reshape(), inp2, concat([prod(shape2[:-2]), shape2[-2:]])
  624. )
  625. else:
  626. if dim1 > dim2:
  627. shape2 = shape1[:-2] + shape2[-2:]
  628. inp2 = broadcast_to(inp2, shape2)
  629. if dim1 < dim2:
  630. shape1 = shape2[:-2] + shape1[-2:]
  631. inp1 = broadcast_to(inp1, shape1)
  632. if maxdim > 3:
  633. batch_shape = shape1[:-2]
  634. # compress inputs to 3d
  635. inp1 = inp1.reshape((-1, shape1[-2], shape1[-1]))
  636. inp2 = inp2.reshape((-1, shape2[-2], shape2[-1]))
  637. op = builtin.BatchedMatrixMul(
  638. transposeA=transpose_a,
  639. transposeB=transpose_b,
  640. compute_mode=compute_mode,
  641. format=format,
  642. strategy=get_conv_execution_strategy(),
  643. )
  644. else:
  645. op = builtin.MatrixMul(
  646. transposeA=transpose_a,
  647. transposeB=transpose_b,
  648. compute_mode=compute_mode,
  649. format=format,
  650. strategy=get_conv_execution_strategy(),
  651. )
  652. (result,) = apply(op, inp1, inp2)
  653. if maxdim > 3:
  654. if use_symbolic_shape():
  655. (result,) = apply(
  656. builtin.Reshape(), result, concat([batch_shape, result.shape[-2:]])
  657. )
  658. else:
  659. result = result.reshape(batch_shape + result.shape[-2:])
  660. if remove_row:
  661. result = squeeze(result, axis=-2)
  662. if remove_col:
  663. result = squeeze(result, axis=-1)
  664. return result
  665. def dot(inp1: Tensor, inp2: Tensor) -> Tensor:
  666. """
  667. Computes dot-product of two vectors ``inp1`` and ``inp2``.
  668. inputs must be 1-dimensional or scalar. A scalar input is automatically broadcasted.
  669. Refer to :func:`~.matmul` for more general usage.
  670. :param inp1: first vector.
  671. :param inp2: second vector.
  672. :return: output value.
  673. Examples:
  674. .. testcode::
  675. import numpy as np
  676. from megengine import tensor
  677. import megengine.functional as F
  678. data1 = tensor(np.arange(0, 6, dtype=np.float32))
  679. data2 = tensor(np.arange(0, 6, dtype=np.float32))
  680. out = F.dot(data1, data2)
  681. print(out.numpy())
  682. Outputs:
  683. .. testoutput::
  684. 55.
  685. """
  686. op = builtin.Dot()
  687. inp1, inp2 = utils.convert_inputs(inp1, inp2)
  688. assert (
  689. inp1.ndim <= 1 and inp2.ndim <= 1
  690. ), "Input tensors for dot must be 1-dimensional or scalar"
  691. (result,) = apply(op, inp1, inp2)
  692. utils.setscalar(result)
  693. return result
  694. def svd(inp: Tensor, full_matrices=False, compute_uv=True) -> Tensor:
  695. """
  696. Computes the singular value decompositions of input matrix.
  697. :param inp: input matrix, must has shape `[..., M, N]`.
  698. :return: output matrices, `(U, sigma, V)`.
  699. Examples:
  700. .. testcode::
  701. import numpy as np
  702. from megengine import tensor
  703. import megengine.functional as F
  704. x = tensor(np.arange(0, 6, dtype=np.float32).reshape(2,3))
  705. _, y, _ = F.svd(x)
  706. print(y.numpy().round(decimals=3))
  707. Outputs:
  708. .. testoutput::
  709. [7.348 1. ]
  710. """
  711. op = builtin.SVD(full_matrices=full_matrices, compute_uv=compute_uv)
  712. U, sigma, V = apply(op, inp)
  713. return U, sigma, V

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台