You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

math.py 22 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712
  1. # -*- coding: utf-8 -*-
  2. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  3. #
  4. # Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  5. #
  6. # Unless required by applicable law or agreed to in writing,
  7. # software distributed under the License is distributed on an
  8. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  9. import collections
  10. import math
  11. from typing import Iterable, Optional, Sequence, Tuple, Union
  12. from ..core._imperative_rt.core2 import Const, apply
  13. from ..core._imperative_rt.ops import SubgraphBuilder as _SubgraphBuilder
  14. from ..core.ops import builtin
  15. from ..core.tensor.array_method import _matmul
  16. from ..core.tensor.utils import _normalize_axis
  17. from ..tensor import Tensor
  18. from ..utils.deprecation import deprecated_kwargs_default
  19. from .elemwise import clip
  20. from .tensor import expand_dims, squeeze
  21. __all__ = [
  22. "argmax",
  23. "argmin",
  24. "argsort",
  25. "dot",
  26. "isinf",
  27. "isnan",
  28. "matinv",
  29. "matmul",
  30. "max",
  31. "mean",
  32. "min",
  33. "norm",
  34. "normalize",
  35. "prod",
  36. "sign",
  37. "sort",
  38. "std",
  39. "sum",
  40. "svd",
  41. "topk",
  42. "var",
  43. ]
  44. def isnan(inp: Tensor) -> Tensor:
  45. r"""Returns a new tensor representing if each element is ``NaN`` or not.
  46. Args:
  47. inp: input tensor.
  48. Returns:
  49. result tensor.
  50. Examples:
  51. >>> x = Tensor([1, float("nan"), 0])
  52. >>> F.isnan(x).numpy()
  53. array([False, True, False])
  54. """
  55. return inp != inp
  56. def isinf(inp: Tensor) -> Tensor:
  57. r"""Returns a new tensor representing if each element is ``Inf`` or not.
  58. Args:
  59. inp: input tensor.
  60. Returns:
  61. result tensor.
  62. Examples:
  63. >>> x = Tensor([1, float("inf"), 0])
  64. >>> F.isinf(x).numpy()
  65. array([False, True, False])
  66. """
  67. return abs(inp).astype("float32") == float("inf")
  68. def sign(inp: Tensor):
  69. r"""Returns a new tensor representing the sign of each element in input tensor.
  70. Args:
  71. inp: Tensor:
  72. Returns:
  73. the sign of input tensor.
  74. Examples:
  75. >>> x = Tensor([1, -1, 0])
  76. >>> F.sign(x)
  77. Tensor([ 1 -1 0], dtype=int32, device=xpux:0)
  78. """
  79. return (inp > 0).astype(inp.dtype) - (inp < 0).astype(inp.dtype)
  80. def sum(
  81. inp: Tensor,
  82. axis: Optional[Union[int, Sequence[int]]] = None,
  83. keepdims: bool = False,
  84. ) -> Tensor:
  85. r"""Returns the sum of input tensor along given axis. If axis is a list of dimensions,
  86. reduce over all of them.
  87. Args:
  88. inp: input tensor.
  89. axis: dimension to reduce. If None, all dimensions will be reduced.
  90. Default: None
  91. keepdims: whether the output tensor has axis retained or not.
  92. Default: False
  93. Returns:
  94. output tensor.
  95. Examples:
  96. >>> import numpy as np
  97. >>> x = Tensor(np.arange(1, 7, dtype=np.int32).reshape(2, 3))
  98. >>> F.sum(x)
  99. Tensor(21, dtype=int32, device=xpux:0)
  100. """
  101. return inp.sum(axis=axis, keepdims=keepdims)
  102. def prod(
  103. inp: Tensor, axis: Optional[Union[int, Sequence[int]]] = None, keepdims=False
  104. ) -> Tensor:
  105. r"""Returns the product of input tensor along given axis. If axis is a list of dimensions,
  106. reduce over all of them.
  107. Args:
  108. inp: input tensor.
  109. axis: dimension to reduce. If None, all dimensions will be reduced. Default: None
  110. keepdims: whether the output tensor has axis retained or not. Default: False
  111. Returns:
  112. output tensor.
  113. Examples:
  114. >>> import numpy as np
  115. >>> x = Tensor(np.arange(1, 7, dtype=np.int32).reshape(2, 3))
  116. >>> F.prod(x)
  117. Tensor(720, dtype=int32, device=xpux:0)
  118. """
  119. return inp.prod(axis=axis, keepdims=keepdims)
  120. def mean(
  121. inp: Tensor,
  122. axis: Optional[Union[int, Sequence[int]]] = None,
  123. keepdims: bool = False,
  124. ) -> Tensor:
  125. r"""Returns the mean value of input tensor along
  126. given axis. If axis is a list of dimensions,
  127. reduce over all of them.
  128. Args:
  129. inp: input tensor.
  130. axis: dimension to reduce. If None, all dimensions will be reduced. Default: None
  131. keepdims: whether the output tensor has axis retained or not. Default: False
  132. Returns:
  133. output tensor.
  134. Examples:
  135. >>> import numpy as np
  136. >>> x = Tensor(np.arange(1, 7, dtype=np.int32).reshape(2, 3))
  137. >>> out = F.mean(x)
  138. >>> out.numpy()
  139. array(3.5, dtype=float32)
  140. """
  141. return inp.mean(axis=axis, keepdims=keepdims)
  142. def var(
  143. inp: Tensor,
  144. axis: Optional[Union[int, Sequence[int]]] = None,
  145. keepdims: bool = False,
  146. ) -> Tensor:
  147. r"""Returns the variance value of input tensor along
  148. given axis. If axis is a list of dimensions,
  149. reduce over all of them.
  150. Args:
  151. inp: input tensor.
  152. axis: dimension to reduce. If None, all dimensions will be reduced. Default: None
  153. keepdims: whether the output tensor has axis retained or not. Default: False
  154. Returns:
  155. output tensor.
  156. Examples:
  157. >>> import numpy as np
  158. >>> data = Tensor(np.arange(1, 7, dtype=np.float32).reshape(2, 3))
  159. >>> out = F.var(data)
  160. >>> out.numpy().round(decimals=4)
  161. 2.9167
  162. """
  163. if axis is None:
  164. m = mean(inp, axis=axis, keepdims=False)
  165. else:
  166. m = mean(inp, axis=axis, keepdims=True)
  167. v = inp - m
  168. return mean(v ** 2, axis=axis, keepdims=keepdims)
  169. def std(
  170. inp: Tensor,
  171. axis: Optional[Union[int, Sequence[int]]] = None,
  172. keepdims: bool = False,
  173. ) -> Tensor:
  174. r"""Returns the standard deviation of input tensor along
  175. given axis. If axis is a list of dimensions,
  176. reduce over all of them.
  177. Args:
  178. inp: input tensor.
  179. axis: dimension to reduce. If None, all dimensions will be reduced. Default: None
  180. keepdims: whether the output tensor has axis retained or not. Default: False
  181. Returns:
  182. output tensor.
  183. Examples:
  184. >>> import numpy as np
  185. >>> data = Tensor(np.arange(1, 7, dtype=np.float32).reshape(2, 3))
  186. >>> out = F.std(data, axis=1)
  187. >>> out.numpy().round(decimals=4)
  188. array([0.8165, 0.8165], dtype=float32)
  189. """
  190. return var(inp, axis=axis, keepdims=keepdims) ** 0.5
  191. def min(
  192. inp: Tensor,
  193. axis: Optional[Union[int, Sequence[int]]] = None,
  194. keepdims: bool = False,
  195. ) -> Tensor:
  196. r"""Returns the min value of input tensor along
  197. given axis. If axis is a list of dimensions,
  198. reduce over all of them.
  199. Args:
  200. inp: input tensor.
  201. axis: dimension to reduce. If None, all dimensions will be reduced. Default: None
  202. keepdims: whether the output tensor has axis retained or not. Default: False
  203. Returns:
  204. output tensor.
  205. Examples:
  206. >>> import numpy as np
  207. >>> x = Tensor(np.arange(1, 7, dtype=np.int32).reshape(2,3))
  208. >>> F.min(x)
  209. Tensor(1, dtype=int32, device=xpux:0)
  210. """
  211. return inp.min(axis=axis, keepdims=keepdims)
  212. def max(
  213. inp: Tensor,
  214. axis: Optional[Union[int, Sequence[int]]] = None,
  215. keepdims: bool = False,
  216. ) -> Tensor:
  217. r"""Returns the max value of the input tensor along
  218. given axis. If axis is a list of dimensions,
  219. reduce over all of them.
  220. Args:
  221. inp: input tensor.
  222. axis: dimension to reduce. If None, all dimensions will be reduced. Default: None
  223. keepdims: whether the output tensor has axis retained or not. Default: False
  224. Returns:
  225. output tensor.
  226. Examples:
  227. >>> import numpy as np
  228. >>> x = Tensor(np.arange(1, 7, dtype=np.int32).reshape(2,3))
  229. >>> F.max(x)
  230. Tensor(6, dtype=int32, device=xpux:0)
  231. """
  232. return inp.max(axis=axis, keepdims=keepdims)
  233. def norm(
  234. inp: Tensor, ord: float = None, axis: int = None, keepdims=False,
  235. ):
  236. r"""Calculates ``p``-norm of input tensor along
  237. given axis.
  238. Args:
  239. inp: input tensor.
  240. ord: power of value applied to inp. Default: 2
  241. axis: dimension to reduce. If None, input must be a vector. Default: None
  242. keepdims: whether the output tensor has axis retained or not. Default: False
  243. Returns:
  244. output tensor.
  245. Examples:
  246. >>> import numpy as np
  247. >>> x = Tensor(np.arange(-3, 3, dtype=np.float32))
  248. >>> out = F.norm(x)
  249. >>> out.numpy().round(decimals=4)
  250. 4.3589
  251. """
  252. if axis is None:
  253. if inp.ndim != 1:
  254. raise TypeError("axis is required unless input is a vector")
  255. if ord is None:
  256. ord = 2
  257. if ord == 0:
  258. return sum(inp != 0, axis=axis, keepdims=keepdims)
  259. if ord == math.inf:
  260. return max(abs(inp))
  261. if ord == -math.inf:
  262. return min(abs(inp))
  263. return sum(abs(inp) ** ord, axis=axis, keepdims=keepdims) ** (1.0 / ord)
  264. def argmin(
  265. inp: Tensor,
  266. axis: Optional[Union[int, Sequence[int]]] = None,
  267. keepdims: bool = False,
  268. ) -> Tensor:
  269. r"""Returns the indices of the minimum values along
  270. given axis. If axis is a list of dimensions,
  271. reduce over all of them.
  272. Args:
  273. inp: input tensor.
  274. axis: dimension to reduce. If None, all dimensions will be reduced. Default: None
  275. keepdims: whether the output tensor has axis retained or not. Default: False
  276. Returns:
  277. output tensor.
  278. Examples:
  279. >>> import numpy as np
  280. >>> x = Tensor(np.arange(1, 7, dtype=np.int32).reshape(2,3))
  281. >>> F.argmin(x)
  282. Tensor(0, dtype=int32, device=xpux:0)
  283. """
  284. if axis is None:
  285. assert not keepdims, "can not set axis=None and keepdims=True"
  286. inp = inp.flatten()
  287. axis = 0
  288. axis = _normalize_axis(inp.ndim, axis, reverse=True)
  289. if isinstance(axis, collections.abc.Iterable):
  290. for ai in axis:
  291. op = builtin.Argmin(axis=ai)
  292. (inp,) = apply(op, inp)
  293. if not keepdims:
  294. inp = squeeze(inp, ai)
  295. return inp
  296. op = builtin.Argmin(axis=axis)
  297. (result,) = apply(op, inp)
  298. if not keepdims:
  299. result = squeeze(result, axis)
  300. return result
  301. def argmax(
  302. inp: Tensor,
  303. axis: Optional[Union[int, Sequence[int]]] = None,
  304. keepdims: bool = False,
  305. ) -> Tensor:
  306. r"""Returns the indices of the maximum values along
  307. given axis. If axis is a list of dimensions,
  308. reduce over all of them.
  309. Args:
  310. inp: input tensor.
  311. axis: dimension to reduce. If None, all dimensions will be reduced. Default: None
  312. keepdims: whether the output tensor has axis retained or not. Default: False
  313. Returns:
  314. output tensor.
  315. Examples:
  316. >>> import numpy as np
  317. >>> x = Tensor(np.arange(1, 7, dtype=np.int32).reshape(2,3))
  318. >>> F.argmax(x)
  319. Tensor(5, dtype=int32, device=xpux:0)
  320. """
  321. if axis is None:
  322. assert not keepdims, "can not set axis=None and keepdims=True"
  323. inp = inp.flatten()
  324. axis = 0
  325. axis = _normalize_axis(inp.ndim, axis, reverse=True)
  326. if isinstance(axis, collections.abc.Iterable):
  327. for ai in axis:
  328. op = builtin.Argmax(axis=ai)
  329. (inp,) = apply(op, inp)
  330. if not keepdims:
  331. inp = squeeze(inp, ai)
  332. return inp
  333. op = builtin.Argmax(axis=axis)
  334. (result,) = apply(op, inp)
  335. if not keepdims:
  336. result = squeeze(result, axis)
  337. return result
  338. def normalize(
  339. inp: Tensor, ord: float = None, axis: int = None, eps: float = 1e-12,
  340. ) -> Tensor:
  341. r"""Performs :math:`L_p` normalization of input tensor along
  342. given axis.
  343. For a tensor of shape :math:`(n_0, ..., n_{dim}, ..., n_k)`, each
  344. :math:`n_{dim}` -element vector :math:`v` along dimension :attr:`axis` is transformed as:
  345. .. math::
  346. v = \frac{v}{\max(\lVert v \rVert_p, \epsilon)}.
  347. Args:
  348. inp: input tensor.
  349. ord: power of value applied to input tensor. Default: 2
  350. axis: dimension to reduce.If None, input must be a vector. Default: None
  351. eps: a small value to avoid division by zero. Default: 1e-12
  352. Returns:
  353. normalized output tensor.
  354. """
  355. if axis is None:
  356. return inp / clip(norm(inp, ord, axis), lower=eps)
  357. else:
  358. return inp / clip(norm(inp, ord, axis, keepdims=True), lower=eps)
  359. def argsort(inp: Tensor, descending: bool = False) -> Tensor:
  360. r"""Returns the indices that would sort the input tensor.
  361. Args:
  362. inp: input tensor. If it's 2d, the result would be array of indices show how to sort each row in the input tensor.
  363. descending: sort in descending order, where the largest comes first. Default: False
  364. inp: Tensor:
  365. descending: bool:
  366. Returns:
  367. indices of int32 indicates how to sort the input.
  368. Examples:
  369. >>> import numpy as np
  370. >>> x = Tensor(np.array([1,2], dtype=np.float32))
  371. >>> F.argsort(x)
  372. Tensor([0 1], dtype=int32, device=xpux:0)
  373. """
  374. assert len(inp.shape) <= 2, "Input should be 1d or 2d"
  375. if descending:
  376. order = "descending"
  377. else:
  378. order = "ascending"
  379. op = builtin.Argsort(order=order)
  380. if len(inp.shape) == 1:
  381. inp = inp.reshape(1, -1)
  382. _, result = apply(op, inp)
  383. return result[0]
  384. _, result = apply(op, inp)
  385. return result
  386. def sort(inp: Tensor, descending: bool = False) -> Tuple[Tensor, Tensor]:
  387. r"""Returns sorted tensor and the indices would sort the input tensor.
  388. Args:
  389. inp: input tensor. If it's 2d, the result would be sorted by row.
  390. descending: sort in descending order, where the largest comes first. Default: False
  391. Returns:
  392. tuple of two tensors `(sorted_tensor, indices_of_int32)`.
  393. Examples:
  394. >>> import numpy as np
  395. >>> x = Tensor(np.array([1,2], dtype=np.float32))
  396. >>> out, indices = F.sort(x)
  397. >>> out.numpy()
  398. array([1., 2.], dtype=float32)
  399. """
  400. assert len(inp.shape) <= 2, "Input should be 1d or 2d"
  401. if descending:
  402. order = "descending"
  403. else:
  404. order = "ascending"
  405. op = builtin.Argsort(order=order)
  406. if len(inp.shape) == 1:
  407. inp = inp.reshape(1, -1)
  408. tns, ind = apply(op, inp)
  409. return tns[0], ind[0]
  410. tns, ind = apply(op, inp)
  411. return tns, ind
  412. @deprecated_kwargs_default("1.12", "descending", 3)
  413. def topk(
  414. inp: Tensor,
  415. k: int,
  416. descending: bool = False,
  417. kth_only: bool = False,
  418. no_sort: bool = False,
  419. ) -> Tuple[Tensor, Tensor]:
  420. r"""Selects the ``Top-K`` (by default) smallest elements of 2d matrix by row.
  421. Args:
  422. inp: input tensor. If input tensor is 2d, each row will be sorted.
  423. k: number of elements needed.
  424. descending: if True, return the largest elements instead. Default: False
  425. kth_only: if True, only the k-th element will be returned. Default: False
  426. no_sort: if True, the returned elements can be unordered. Default: False
  427. Returns:
  428. tuple of two tensors ``(topk_tensor, indices_of_int32)``
  429. Examples:
  430. >>> import numpy as np
  431. >>> x = Tensor(np.array([2, 4, 6, 8, 7, 5, 3, 1], dtype=np.float32))
  432. >>> top, indices = F.topk(x, 5, descending=False)
  433. >>> print(top.numpy(), indices.numpy())
  434. [1. 2. 3. 4. 5.] [7 0 6 1 5]
  435. """
  436. if descending:
  437. k = -k
  438. if kth_only:
  439. mode = "kth_only"
  440. elif no_sort:
  441. mode = "value_idx_nosort"
  442. else:
  443. mode = "value_idx_sorted"
  444. op = builtin.TopK(mode=mode)
  445. if not isinstance(k, Tensor):
  446. k = Const(k, "int32", inp.device, None)
  447. if len(inp.shape) == 1:
  448. if kth_only:
  449. (tns,) = apply(op, expand_dims(inp, 0), k)
  450. # FIXME:
  451. # could use a dedicated kernel
  452. # gradient may be routed to other indices if k-th value is not unique
  453. ind = argmax((tns == inp).astype("int8"))
  454. tns = squeeze(tns, 0)
  455. else:
  456. tns, ind = apply(op, expand_dims(inp, 0), k)
  457. tns = squeeze(tns, 0)
  458. ind = squeeze(ind, 0)
  459. else:
  460. if kth_only:
  461. (tns,) = apply(op, inp, k)
  462. # FIXME: same as above
  463. ind = argmax((expand_dims(tns, 1) == inp).astype("int8"), 1)
  464. else:
  465. tns, ind = apply(op, inp, k)
  466. return tns, ind
  467. def matinv(inp: Tensor) -> Tensor:
  468. r"""Computes the inverse of a batch of matrices; input must has shape [..., n, n].
  469. Args:
  470. inp: input tensor.
  471. Returns:
  472. output tensor.
  473. Examples:
  474. >>> import numpy as np
  475. >>> data = Tensor([[1.0, 0.0], [1.0, 1.0]])
  476. >>> out = F.matinv(data)
  477. >>> out.numpy()
  478. array([[ 1., 0.],
  479. [-1., 1.]], dtype=float32)
  480. """
  481. (result,) = apply(builtin.MatrixInverse(), inp)
  482. return result
  483. def matmul(
  484. inp1: Tensor,
  485. inp2: Tensor,
  486. transpose_a=False,
  487. transpose_b=False,
  488. compute_mode="default",
  489. format="default",
  490. ) -> Tensor:
  491. r"""Performs a matrix multiplication of the matrices ``inp1`` and ``inp2``.
  492. With different inputs dim, this function behaves differently:
  493. * Both 1-D tensor, simply forward to ``dot``.
  494. * Both 2-D tensor, normal matrix multiplication.
  495. * If one input tensor is 1-D, matrix vector multiplication.
  496. * If at least one tensor are 3-dimensional or >3-dimensional, the other tensor should have dim >= 2,
  497. the batched matrix-matrix is returned, and the tensor with smaller dimension will be broadcasted.
  498. For example:
  499. * inp1: `(n, k, m)`, inp2: `(n, m, p)`, return: `(n, k, p)`
  500. * inp1: `(n, k, m)`, inp2: `(m, p)`, return: `(n, k, p)`
  501. * inp1: `(n, j, k, m)`, inp2: `(n, j, m, p)`, return: `(n, j, k, p)`
  502. Args:
  503. inp1: first matrix to be multiplied.
  504. inp2: second matrix to be multiplied.
  505. Returns:
  506. output tensor.
  507. Examples:
  508. >>> import numpy as np
  509. >>> data1 = Tensor(np.arange(0, 6, dtype=np.float32).reshape(2, 3))
  510. >>> data2 = Tensor(np.arange(0, 6, dtype=np.float32).reshape(3, 2))
  511. >>> out = F.matmul(data1, data2)
  512. >>> out.numpy()
  513. array([[10., 13.],
  514. [28., 40.]], dtype=float32)
  515. """
  516. return _matmul(inp1, inp2, transpose_a, transpose_b, compute_mode, format)
  517. def dot(inp1: Tensor, inp2: Tensor) -> Tensor:
  518. r"""Computes dot-product of two vectors ``inp1`` and ``inp2``.
  519. inputs must be 1-dimensional or scalar. A scalar input is automatically broadcasted.
  520. Refer to :func:`~.matmul` for more general usage.
  521. Args:
  522. inp1: first vector.
  523. inp2: second vector.
  524. Returns:
  525. output value.
  526. Examples:
  527. >>> import numpy as np
  528. >>> data1 = Tensor(np.arange(0, 6, dtype=np.float32))
  529. >>> data2 = Tensor(np.arange(0, 6, dtype=np.float32))
  530. >>> out = F.dot(data1, data2)
  531. >>> out.numpy()
  532. array(55., dtype=float32)
  533. """
  534. op = builtin.Dot()
  535. assert (
  536. inp1.ndim <= 1 and inp2.ndim <= 1
  537. ), "Input tensors for dot must be 1-dimensional or scalar"
  538. (result,) = apply(op, inp1, inp2)
  539. return result
  540. def svd(inp: Tensor, full_matrices=False, compute_uv=True) -> Tensor:
  541. r"""Returns a singular value decomposition ``A = USVh`` of a matrix (or a stack of matrices) ``x`` , where ``U`` is a matrix (or a stack of matrices) with orthonormal columns, ``S`` is a vector of non-negative numbers (or stack of vectors), and ``Vh`` is a matrix (or a stack of matrices) with orthonormal rows.
  542. Args:
  543. x (Tensor): A input real tensor having the shape ``(..., M, N)`` with ``x.ndim >= 2`` .
  544. full_matrices (bool, optional): If ``False`` , ``U`` and ``Vh`` have the shapes ``(..., M, K)`` and ``(..., K, N)`` , respectively, where ``K = min(M, N)`` . If ``True`` , the shapes are ``(..., M, M)`` and ``(..., N, N)`` , respectively. Default: ``False`` .
  545. compute_uv (bool, optional): Whether or not to compute ``U`` and ``Vh`` in addition to ``S`` . Default: ``True`` .
  546. Note:
  547. * naive does not support ``full_matrices`` and ``compute_uv`` as ``True`` .
  548. Returns:
  549. Returns a tuple ( ``U`` , ``S`` , ``Vh`` ), which are SVD factors ``U`` , ``S``, ``Vh`` of input matrix ``x``. ( ``U`` , ``Vh`` only returned when ``compute_uv`` is True).
  550. ``U`` contains matrices orthonormal columns (i.e., the columns are left singular vectors). If ``full_matrices`` is ``True`` , the array must have shape ``(..., M, M)`` . If ``full_matrices`` is ``False`` , the array must have shape ``(..., M, K)`` , where ``K = min(M, N)`` .
  551. Examples:
  552. >>> import numpy as np
  553. >>> x = Tensor(np.random.randn(9, 6))
  554. >>> y = Tensor(np.random.randn(2, 7, 8, 3))
  555. >>> U, S, Vh = F.svd(x, full_matrices=False)
  556. >>> print(U._tuple_shape, S._tuple_shape, Vh._tuple_shape)
  557. (9, 6) (6,) (6, 6)
  558. >>> u, s, vh = F.svd(y, full_matrices=False)
  559. >>> print(u._tuple_shape, s._tuple_shape, vh._tuple_shape)
  560. (2, 7, 8, 3) (2, 7, 3) (2, 7, 3, 3)
  561. """
  562. op = builtin.SVD(full_matrices=full_matrices, compute_uv=compute_uv)
  563. U, S, Vh = apply(op, inp)
  564. return U, S, Vh
  565. def _check_non_finite(inps: Iterable[Tensor], scale=1.0) -> Tensor:
  566. r"""Check whether input contains infinite or nan value.
  567. Args:
  568. inp: a tensor to be checked.
  569. Returns:
  570. a int32 scalar tensor, 0 for False and 1 for True.
  571. """
  572. op = builtin.CheckNonFinite(scale=scale)
  573. oups = apply(op, *inps)
  574. out = oups[-1]
  575. for i in range(len(inps)):
  576. inps[i]._reset(oups[i])
  577. return out