You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

math.py 22 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719
  1. # -*- coding: utf-8 -*-
  2. import collections
  3. import math
  4. from typing import Iterable, Optional, Sequence, Tuple, Union
  5. from ..core._imperative_rt.core2 import Const, apply
  6. from ..core._imperative_rt.ops import SubgraphBuilder as _SubgraphBuilder
  7. from ..core.ops import builtin
  8. from ..core.tensor.array_method import _matmul
  9. from ..core.tensor.utils import _normalize_axis
  10. from ..tensor import Tensor
  11. from ..utils.deprecation import deprecated_kwargs_default
  12. from .elemwise import _elemwise_multi_type, clip
  13. from .tensor import expand_dims, squeeze
  14. __all__ = [
  15. "argmax",
  16. "argmin",
  17. "argsort",
  18. "dot",
  19. "isinf",
  20. "isnan",
  21. "matinv",
  22. "matmul",
  23. "max",
  24. "mean",
  25. "min",
  26. "norm",
  27. "normalize",
  28. "prod",
  29. "sign",
  30. "sort",
  31. "std",
  32. "sum",
  33. "svd",
  34. "topk",
  35. "var",
  36. ]
  37. def isnan(inp: Tensor) -> Tensor:
  38. r"""Returns a new tensor representing if each element is ``NaN`` or not.
  39. Args:
  40. inp: input tensor.
  41. Returns:
  42. result tensor.
  43. Examples:
  44. >>> x = Tensor([1, float("nan"), 0])
  45. >>> F.isnan(x).numpy()
  46. array([False, True, False])
  47. """
  48. return _elemwise_multi_type(inp, mode="isnan", dtype="bool")
  49. def isinf(inp: Tensor) -> Tensor:
  50. r"""Returns a new tensor representing if each element is ``Inf`` or not.
  51. Args:
  52. inp: input tensor.
  53. Returns:
  54. result tensor.
  55. Examples:
  56. >>> x = Tensor([1, float("inf"), 0])
  57. >>> F.isinf(x).numpy()
  58. array([False, True, False])
  59. """
  60. return _elemwise_multi_type(inp, mode="isinf", dtype="bool")
  61. def sign(inp: Tensor):
  62. r"""Returns a new tensor representing the sign of each element in input tensor.
  63. Args:
  64. inp: Tensor:
  65. Returns:
  66. the sign of input tensor.
  67. Examples:
  68. >>> x = Tensor([1, -1, 0])
  69. >>> F.sign(x)
  70. Tensor([ 1 -1 0], dtype=int32, device=xpux:0)
  71. """
  72. return (inp > 0).astype(inp.dtype) - (inp < 0).astype(inp.dtype)
  73. def sum(
  74. inp: Tensor,
  75. axis: Optional[Union[int, Sequence[int]]] = None,
  76. keepdims: bool = False,
  77. ) -> Tensor:
  78. r"""Returns the sum of input tensor along given axis. If axis is a list of dimensions,
  79. reduce over all of them.
  80. Args:
  81. inp: input tensor.
  82. axis: dimension to reduce. If None, all dimensions will be reduced.
  83. Default: None
  84. keepdims: whether the output tensor has axis retained or not.
  85. Default: False
  86. Returns:
  87. output tensor.
  88. Examples:
  89. >>> import numpy as np
  90. >>> x = Tensor(np.arange(1, 7, dtype=np.int32).reshape(2, 3))
  91. >>> F.sum(x)
  92. Tensor(21, dtype=int32, device=xpux:0)
  93. """
  94. return inp.sum(axis=axis, keepdims=keepdims)
  95. def prod(
  96. inp: Tensor, axis: Optional[Union[int, Sequence[int]]] = None, keepdims=False
  97. ) -> Tensor:
  98. r"""Returns the product of input tensor along given axis. If axis is a list of dimensions,
  99. reduce over all of them.
  100. Args:
  101. inp: input tensor.
  102. axis: dimension to reduce. If None, all dimensions will be reduced. Default: None
  103. keepdims: whether the output tensor has axis retained or not. Default: False
  104. Returns:
  105. output tensor.
  106. Examples:
  107. >>> import numpy as np
  108. >>> x = Tensor(np.arange(1, 7, dtype=np.int32).reshape(2, 3))
  109. >>> F.prod(x)
  110. Tensor(720, dtype=int32, device=xpux:0)
  111. """
  112. return inp.prod(axis=axis, keepdims=keepdims)
  113. def mean(
  114. inp: Tensor,
  115. axis: Optional[Union[int, Sequence[int]]] = None,
  116. keepdims: bool = False,
  117. ) -> Tensor:
  118. r"""Returns the mean value of input tensor along
  119. given axis. If axis is a list of dimensions,
  120. reduce over all of them.
  121. Args:
  122. inp: input tensor.
  123. axis: dimension to reduce. If None, all dimensions will be reduced. Default: None
  124. keepdims: whether the output tensor has axis retained or not. Default: False
  125. Returns:
  126. output tensor.
  127. Examples:
  128. >>> import numpy as np
  129. >>> x = Tensor(np.arange(1, 7, dtype=np.int32).reshape(2, 3))
  130. >>> out = F.mean(x)
  131. >>> out.numpy()
  132. array(3.5, dtype=float32)
  133. """
  134. return inp.mean(axis=axis, keepdims=keepdims)
  135. def var(
  136. inp: Tensor,
  137. axis: Optional[Union[int, Sequence[int]]] = None,
  138. keepdims: bool = False,
  139. ) -> Tensor:
  140. r"""Returns the variance value of input tensor along
  141. given axis. If axis is a list of dimensions,
  142. reduce over all of them.
  143. Args:
  144. inp: input tensor.
  145. axis: dimension to reduce. If None, all dimensions will be reduced. Default: None
  146. keepdims: whether the output tensor has axis retained or not. Default: False
  147. Returns:
  148. output tensor.
  149. Examples:
  150. >>> import numpy as np
  151. >>> data = Tensor(np.arange(1, 7, dtype=np.float32).reshape(2, 3))
  152. >>> out = F.var(data)
  153. >>> out.numpy().round(decimals=4)
  154. 2.9167
  155. """
  156. if axis is None:
  157. m = mean(inp, axis=axis, keepdims=False)
  158. else:
  159. m = mean(inp, axis=axis, keepdims=True)
  160. v = inp - m
  161. return mean(v ** 2, axis=axis, keepdims=keepdims)
  162. def std(
  163. inp: Tensor,
  164. axis: Optional[Union[int, Sequence[int]]] = None,
  165. keepdims: bool = False,
  166. ) -> Tensor:
  167. r"""Returns the standard deviation of input tensor along
  168. given axis. If axis is a list of dimensions,
  169. reduce over all of them.
  170. Args:
  171. inp: input tensor.
  172. axis: dimension to reduce. If None, all dimensions will be reduced. Default: None
  173. keepdims: whether the output tensor has axis retained or not. Default: False
  174. Returns:
  175. output tensor.
  176. Examples:
  177. >>> import numpy as np
  178. >>> data = Tensor(np.arange(1, 7, dtype=np.float32).reshape(2, 3))
  179. >>> out = F.std(data, axis=1)
  180. >>> out.numpy().round(decimals=4)
  181. array([0.8165, 0.8165], dtype=float32)
  182. """
  183. return var(inp, axis=axis, keepdims=keepdims) ** 0.5
  184. def min(
  185. inp: Tensor,
  186. axis: Optional[Union[int, Sequence[int]]] = None,
  187. keepdims: bool = False,
  188. ) -> Tensor:
  189. r"""Returns the min value of input tensor along
  190. given axis. If axis is a list of dimensions,
  191. reduce over all of them.
  192. Args:
  193. inp: input tensor.
  194. axis: dimension to reduce. If None, all dimensions will be reduced. Default: None
  195. keepdims: whether the output tensor has axis retained or not. Default: False
  196. Returns:
  197. output tensor.
  198. Examples:
  199. >>> import numpy as np
  200. >>> x = Tensor(np.arange(1, 7, dtype=np.int32).reshape(2,3))
  201. >>> F.min(x)
  202. Tensor(1, dtype=int32, device=xpux:0)
  203. """
  204. return inp.min(axis=axis, keepdims=keepdims)
  205. def max(
  206. inp: Tensor,
  207. axis: Optional[Union[int, Sequence[int]]] = None,
  208. keepdims: bool = False,
  209. ) -> Tensor:
  210. r"""Returns the max value of the input tensor along
  211. given axis. If axis is a list of dimensions,
  212. reduce over all of them.
  213. Args:
  214. inp: input tensor.
  215. axis: dimension to reduce. If None, all dimensions will be reduced. Default: None
  216. keepdims: whether the output tensor has axis retained or not. Default: False
  217. Returns:
  218. output tensor.
  219. Examples:
  220. >>> import numpy as np
  221. >>> x = Tensor(np.arange(1, 7, dtype=np.int32).reshape(2,3))
  222. >>> F.max(x)
  223. Tensor(6, dtype=int32, device=xpux:0)
  224. """
  225. return inp.max(axis=axis, keepdims=keepdims)
  226. def norm(
  227. inp: Tensor, ord: float = None, axis: int = None, keepdims=False,
  228. ):
  229. r"""Calculates ``p``-norm of input tensor along
  230. given axis.
  231. Args:
  232. inp: input tensor.
  233. ord: power of value applied to inp. Default: 2
  234. axis: dimension to reduce. If None, input must be a vector. Default: None
  235. keepdims: whether the output tensor has axis retained or not. Default: False
  236. Returns:
  237. output tensor.
  238. Examples:
  239. >>> import numpy as np
  240. >>> x = Tensor(np.arange(-3, 3, dtype=np.float32))
  241. >>> out = F.norm(x)
  242. >>> out.numpy().round(decimals=4)
  243. 4.3589
  244. """
  245. if axis is None:
  246. if inp.ndim != 1:
  247. raise TypeError("axis is required unless input is a vector")
  248. if ord is None:
  249. ord = 2
  250. if ord == 0:
  251. return sum(inp != 0, axis=axis, keepdims=keepdims)
  252. if ord == math.inf:
  253. return max(abs(inp))
  254. if ord == -math.inf:
  255. return min(abs(inp))
  256. return sum(abs(inp) ** ord, axis=axis, keepdims=keepdims) ** (1.0 / ord)
  257. def argmin(
  258. inp: Tensor,
  259. axis: Optional[Union[int, Sequence[int]]] = None,
  260. keepdims: bool = False,
  261. ) -> Tensor:
  262. r"""Returns the indices of the minimum values along
  263. given axis. If axis is a list of dimensions,
  264. reduce over all of them.
  265. Args:
  266. inp: input tensor.
  267. axis: dimension to reduce. If None, all dimensions will be reduced. Default: None
  268. keepdims: whether the output tensor has axis retained or not. Default: False
  269. Returns:
  270. output tensor.
  271. Examples:
  272. >>> import numpy as np
  273. >>> x = Tensor(np.arange(1, 7, dtype=np.int32).reshape(2,3))
  274. >>> F.argmin(x)
  275. Tensor(0, dtype=int32, device=xpux:0)
  276. """
  277. if axis is None:
  278. assert not keepdims, "can not set axis=None and keepdims=True"
  279. inp = inp.flatten()
  280. axis = 0
  281. axis = _normalize_axis(inp.ndim, axis, reverse=True)
  282. if isinstance(axis, collections.abc.Iterable):
  283. for ai in axis:
  284. op = builtin.Argmin(axis=ai)
  285. (inp,) = apply(op, inp)
  286. if not keepdims:
  287. inp = squeeze(inp, ai)
  288. return inp
  289. op = builtin.Argmin(axis=axis)
  290. (result,) = apply(op, inp)
  291. if not keepdims:
  292. result = squeeze(result, axis)
  293. return result
  294. def argmax(
  295. inp: Tensor,
  296. axis: Optional[Union[int, Sequence[int]]] = None,
  297. keepdims: bool = False,
  298. ) -> Tensor:
  299. r"""Returns the indices of the maximum values along
  300. given axis. If axis is a list of dimensions,
  301. reduce over all of them.
  302. Args:
  303. inp: input tensor.
  304. axis: dimension to reduce. If None, all dimensions will be reduced. Default: None
  305. keepdims: whether the output tensor has axis retained or not. Default: False
  306. Returns:
  307. output tensor.
  308. Examples:
  309. >>> import numpy as np
  310. >>> x = Tensor(np.arange(1, 7, dtype=np.int32).reshape(2,3))
  311. >>> F.argmax(x)
  312. Tensor(5, dtype=int32, device=xpux:0)
  313. """
  314. if axis is None:
  315. assert not keepdims, "can not set axis=None and keepdims=True"
  316. inp = inp.flatten()
  317. axis = 0
  318. axis = _normalize_axis(inp.ndim, axis, reverse=True)
  319. if isinstance(axis, collections.abc.Iterable):
  320. for ai in axis:
  321. op = builtin.Argmax(axis=ai)
  322. (inp,) = apply(op, inp)
  323. if not keepdims:
  324. inp = squeeze(inp, ai)
  325. return inp
  326. op = builtin.Argmax(axis=axis)
  327. (result,) = apply(op, inp)
  328. if not keepdims:
  329. result = squeeze(result, axis)
  330. return result
  331. def normalize(
  332. inp: Tensor, ord: float = None, axis: int = None, eps: float = 1e-12,
  333. ) -> Tensor:
  334. r"""Performs :math:`L_p` normalization of input tensor along
  335. given axis.
  336. For a tensor of shape :math:`(n_0, ..., n_{dim}, ..., n_k)`, each
  337. :math:`n_{dim}` -element vector :math:`v` along dimension :attr:`axis` is transformed as:
  338. .. math::
  339. v = \frac{v}{\max(\lVert v \rVert_p, \epsilon)}.
  340. Args:
  341. inp: input tensor.
  342. ord: power of value applied to input tensor. Default: 2
  343. axis: dimension to reduce.If None, input must be a vector. Default: None
  344. eps: a small value to avoid division by zero. Default: 1e-12
  345. Returns:
  346. normalized output tensor.
  347. """
  348. if axis is None:
  349. return inp / clip(norm(inp, ord, axis), lower=eps)
  350. else:
  351. return inp / clip(norm(inp, ord, axis, keepdims=True), lower=eps)
  352. def argsort(inp: Tensor, descending: bool = False) -> Tensor:
  353. r"""Returns the indices that would sort the input tensor.
  354. Args:
  355. inp: input tensor. If it's 2d, the result would be array of indices show how to sort each row in the input tensor.
  356. descending: sort in descending order, where the largest comes first. Default: False
  357. inp: Tensor:
  358. descending: bool:
  359. Returns:
  360. indices of int32 indicates how to sort the input.
  361. Examples:
  362. >>> import numpy as np
  363. >>> x = Tensor(np.array([1,2], dtype=np.float32))
  364. >>> F.argsort(x)
  365. Tensor([0 1], dtype=int32, device=xpux:0)
  366. """
  367. assert len(inp.shape) <= 2, "Input should be 1d or 2d"
  368. if descending:
  369. order = "descending"
  370. else:
  371. order = "ascending"
  372. op = builtin.Argsort(order=order)
  373. if len(inp.shape) == 1:
  374. inp = inp.reshape(1, -1)
  375. _, result = apply(op, inp)
  376. return result[0]
  377. _, result = apply(op, inp)
  378. return result
  379. def sort(inp: Tensor, descending: bool = False) -> Tuple[Tensor, Tensor]:
  380. r"""Returns sorted tensor and the indices would sort the input tensor.
  381. Args:
  382. inp: input tensor. If it's 2d, the result would be sorted by row.
  383. descending: sort in descending order, where the largest comes first. Default: False
  384. Returns:
  385. tuple of two tensors `(sorted_tensor, indices_of_int32)`.
  386. Examples:
  387. >>> import numpy as np
  388. >>> x = Tensor(np.array([1,2], dtype=np.float32))
  389. >>> out, indices = F.sort(x)
  390. >>> out.numpy()
  391. array([1., 2.], dtype=float32)
  392. """
  393. assert len(inp.shape) <= 2, "Input should be 1d or 2d"
  394. if descending:
  395. order = "descending"
  396. else:
  397. order = "ascending"
  398. op = builtin.Argsort(order=order)
  399. if len(inp.shape) == 1:
  400. inp = inp.reshape(1, -1)
  401. tns, ind = apply(op, inp)
  402. return tns[0], ind[0]
  403. tns, ind = apply(op, inp)
  404. return tns, ind
  405. @deprecated_kwargs_default("1.12", "descending", 3)
  406. def topk(
  407. inp: Tensor,
  408. k: int,
  409. descending: bool = False,
  410. kth_only: bool = False,
  411. no_sort: bool = False,
  412. ) -> Tuple[Tensor, Tensor]:
  413. r"""Selects the ``Top-K`` (by default) smallest elements of 2d matrix by row.
  414. Args:
  415. inp: input tensor. If input tensor is 2d, each row will be sorted.
  416. k: number of elements needed.
  417. descending: if True, return the largest elements instead. Default: False
  418. kth_only: if True, only the k-th element will be returned. Default: False
  419. no_sort: if True, the returned elements can be unordered. Default: False
  420. Returns:
  421. tuple of two tensors ``(topk_tensor, indices_of_int32)``
  422. Examples:
  423. >>> import numpy as np
  424. >>> x = Tensor(np.array([2, 4, 6, 8, 7, 5, 3, 1], dtype=np.float32))
  425. >>> top, indices = F.topk(x, 5, descending=False)
  426. >>> print(top.numpy(), indices.numpy())
  427. [1. 2. 3. 4. 5.] [7 0 6 1 5]
  428. """
  429. if descending:
  430. k = -k
  431. if kth_only:
  432. mode = "kth_only"
  433. elif no_sort:
  434. mode = "value_idx_nosort"
  435. else:
  436. mode = "value_idx_sorted"
  437. op = builtin.TopK(mode=mode)
  438. if not isinstance(k, Tensor):
  439. k = Const(k, "int32", inp.device)
  440. if len(inp.shape) == 1:
  441. if kth_only:
  442. (tns,) = apply(op, expand_dims(inp, 0), k)
  443. # FIXME:
  444. # could use a dedicated kernel
  445. # gradient may be routed to other indices if k-th value is not unique
  446. ind = argmax((tns == inp).astype("int8"))
  447. tns = squeeze(tns, 0)
  448. else:
  449. tns, ind = apply(op, expand_dims(inp, 0), k)
  450. tns = squeeze(tns, 0)
  451. ind = squeeze(ind, 0)
  452. else:
  453. if kth_only:
  454. (tns,) = apply(op, inp, k)
  455. # FIXME: same as above
  456. ind = argmax((expand_dims(tns, 1) == inp).astype("int8"), 1)
  457. else:
  458. tns, ind = apply(op, inp, k)
  459. return tns, ind
  460. def matinv(inp: Tensor) -> Tensor:
  461. r"""Computes the inverse of a batch of matrices; input must has shape [..., n, n].
  462. Args:
  463. inp: input tensor.
  464. Returns:
  465. output tensor.
  466. Examples:
  467. >>> import numpy as np
  468. >>> data = Tensor([[1.0, 0.0], [1.0, 1.0]])
  469. >>> out = F.matinv(data)
  470. >>> out.numpy()
  471. array([[ 1., 0.],
  472. [-1., 1.]], dtype=float32)
  473. """
  474. (result,) = apply(builtin.MatrixInverse(), inp)
  475. return result
  476. def matmul(
  477. inp1: Tensor,
  478. inp2: Tensor,
  479. transpose_a=False,
  480. transpose_b=False,
  481. compute_mode="default",
  482. ) -> Tensor:
  483. r"""Performs a matrix multiplication of the matrices ``inp1`` and ``inp2``.
  484. With different inputs dim, this function behaves differently:
  485. * Both 1-D tensor, simply forward to ``dot``.
  486. * Both 2-D tensor, normal matrix multiplication.
  487. * If one input tensor is 1-D, matrix vector multiplication.
  488. * If at least one tensor are 3-dimensional or >3-dimensional, the other tensor should have dim >= 2,
  489. the batched matrix-matrix is returned, and the tensor with smaller dimension will be broadcasted.
  490. For example:
  491. * inp1: `(n, k, m)`, inp2: `(n, m, p)`, return: `(n, k, p)`
  492. * inp1: `(n, k, m)`, inp2: `(m, p)`, return: `(n, k, p)`
  493. * inp1: `(n, j, k, m)`, inp2: `(n, j, m, p)`, return: `(n, j, k, p)`
  494. Args:
  495. inp1: first matrix to be multiplied.
  496. inp2: second matrix to be multiplied.
  497. Returns:
  498. output tensor.
  499. Examples:
  500. >>> import numpy as np
  501. >>> data1 = Tensor(np.arange(0, 6, dtype=np.float32).reshape(2, 3))
  502. >>> data2 = Tensor(np.arange(0, 6, dtype=np.float32).reshape(3, 2))
  503. >>> out = F.matmul(data1, data2)
  504. >>> out.numpy()
  505. array([[10., 13.],
  506. [28., 40.]], dtype=float32)
  507. """
  508. return _matmul(inp1, inp2, transpose_a, transpose_b, compute_mode)
  509. def dot(inp1: Tensor, inp2: Tensor) -> Tensor:
  510. r"""Computes dot-product of two vectors ``inp1`` and ``inp2``.
  511. inputs must be 1-dimensional or scalar. A scalar input is automatically broadcasted.
  512. Refer to :func:`~.matmul` for more general usage.
  513. Args:
  514. inp1: first vector.
  515. inp2: second vector.
  516. Returns:
  517. output value.
  518. Examples:
  519. >>> import numpy as np
  520. >>> data1 = Tensor(np.arange(0, 6, dtype=np.float32))
  521. >>> data2 = Tensor(np.arange(0, 6, dtype=np.float32))
  522. >>> out = F.dot(data1, data2)
  523. >>> out.numpy()
  524. array(55., dtype=float32)
  525. """
  526. op = builtin.Dot()
  527. assert (
  528. inp1.ndim <= 1 and inp2.ndim <= 1
  529. ), "Input tensors for dot must be 1-dimensional or scalar"
  530. (result,) = apply(op, inp1, inp2)
  531. return result
  532. def svd(inp: Tensor, full_matrices=False, compute_uv=True) -> Tensor:
  533. r"""Computes the singular value decomposition of a matrix (or a stack of matrices) ``inp``.
  534. Let :math:`X` be the input matrix (or a stack of input matrices), the output should satisfies:
  535. .. math::
  536. X = U * diag(S) * Vh
  537. where ``U`` is a matrix (or stack of vectors) with orthonormal columns, ``S`` is a vector of
  538. non-negative numbers (or stack of vectors), and ``Vh`` is a matrix (or a stack of matrices)
  539. with orthonormal rows.
  540. Args:
  541. inp (Tensor): A input real tensor having the shape ``(..., M, N)`` with ``inp.ndim >= 2`` .
  542. full_matrices (bool, optional): If ``False`` , ``U`` and ``Vh`` have the shapes ``(..., M, K)``
  543. and ``(..., K, N)`` , respectively, where ``K = min(M, N)`` . If ``True`` , the shapes
  544. are ``(..., M, M)`` and ``(..., N, N)`` , respectively. Default: ``False`` .
  545. compute_uv (bool, optional): Whether or not to compute ``U`` and ``Vh`` in addition to ``S`` . Default: ``True`` .
  546. Note:
  547. * naive does not support ``full_matrices`` and ``compute_uv`` as ``True`` .
  548. Returns:
  549. Returns a tuple ( ``U`` , ``S`` , ``Vh`` ), which are SVD factors ``U`` , ``S``, ``Vh`` of input matrix ``inp``.
  550. ( ``U`` , ``Vh`` only returned when ``compute_uv`` is True). ``U`` contains matrices orthonormal columns
  551. (i.e., the columns are left singular vectors). If ``full_matrices`` is ``True`` , the array must have shape
  552. ``(..., M, M)`` . If ``full_matrices`` is ``False`` , the array must have shape ``(..., M, K)`` , where ``K = min(M, N)`` .
  553. Examples:
  554. >>> import numpy as np
  555. >>> x = Tensor(np.random.randn(9, 6))
  556. >>> y = Tensor(np.random.randn(2, 7, 8, 3))
  557. >>> U, S, Vh = F.svd(x, full_matrices=False)
  558. >>> print(U._tuple_shape, S._tuple_shape, Vh._tuple_shape)
  559. (9, 6) (6,) (6, 6)
  560. >>> u, s, vh = F.svd(y, full_matrices=False)
  561. >>> print(u._tuple_shape, s._tuple_shape, vh._tuple_shape)
  562. (2, 7, 8, 3) (2, 7, 3) (2, 7, 3, 3)
  563. """
  564. op = builtin.SVD(full_matrices=full_matrices, compute_uv=compute_uv)
  565. U, S, Vh = apply(op, inp)
  566. return U, S, Vh
  567. def _check_non_finite(inps: Iterable[Tensor], scale=1.0) -> Tensor:
  568. r"""Check whether input contains infinite or nan value.
  569. Args:
  570. inps: tensors to be checked.
  571. Returns:
  572. a int32 scalar tensor, 0 for False and 1 for True.
  573. """
  574. if isinstance(inps, Tensor):
  575. inps = [inps]
  576. op = builtin.CheckNonFinite(scale=scale)
  577. oups = apply(op, *inps)
  578. out = oups[-1]
  579. for i in range(len(inps)):
  580. inps[i]._reset(oups[i])
  581. return out