You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

math.py 23 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745
  1. # -*- coding: utf-8 -*-
  2. import collections
  3. import math
  4. from typing import Iterable, Optional, Sequence, Tuple, Union
  5. from ..core._imperative_rt.core2 import Const, apply
  6. from ..core._imperative_rt.ops import SubgraphBuilder as _SubgraphBuilder
  7. from ..core.ops import builtin
  8. from ..core.tensor.array_method import _matmul
  9. from ..core.tensor.utils import _normalize_axis
  10. from ..tensor import Tensor
  11. from ..utils.deprecation import deprecated_kwargs_default
  12. from .elemwise import _elemwise_multi_type, clip
  13. from .tensor import expand_dims, squeeze
  14. __all__ = [
  15. "argmax",
  16. "argmin",
  17. "argsort",
  18. "dot",
  19. "isinf",
  20. "isnan",
  21. "matinv",
  22. "matmul",
  23. "max",
  24. "mean",
  25. "min",
  26. "norm",
  27. "normalize",
  28. "prod",
  29. "sign",
  30. "sort",
  31. "std",
  32. "sum",
  33. "svd",
  34. "topk",
  35. "var",
  36. ]
  37. # TODO: Should be moved to elemwise - logical functions
  38. def isnan(inp: Tensor) -> Tensor:
  39. r"""Element-wise ``NaN`` check.
  40. Tests each element :math:`x_i` of the input tensor :math:`x` to determine whether the element is ``NaN``.
  41. Args:
  42. inp: input tensor. Should have a numeric data type.
  43. Returns:
  44. a tensor containing test results.
  45. An element out is ``True`` if :math:`x_i` is ``NaN`` and ``False`` otherwise.
  46. The returned array should have a data type of bool.
  47. Examples:
  48. >>> x = Tensor([1, float("nan"), 0])
  49. >>> F.isnan(x)
  50. Tensor([False True False], dtype=bool, device=xpux:0)
  51. """
  52. return _elemwise_multi_type(inp, mode="isnan", dtype="bool")
  53. def isinf(inp: Tensor) -> Tensor:
  54. r"""Element-wise ``infinity`` check.
  55. Tests each element :math:`x_i` of the input tensor :math:`x` to determine
  56. whether the element is if equal to positive or negative infinity.
  57. Args:
  58. inp: input tensor. Should have a numeric data type.
  59. Returns:
  60. a tensor containing test results.
  61. An element out is ``True`` if :math:`x_i` is either positive or negative infinity and ``False`` otherwise.
  62. The returned array should have a data type of bool.
  63. Examples:
  64. >>> x = Tensor([1, float("inf"), 0])
  65. >>> F.isinf(x)
  66. Tensor([False True False], dtype=bool, device=xpux:0)
  67. """
  68. return _elemwise_multi_type(inp, mode="isinf", dtype="bool")
  69. # TODO: Should be moved to elemwise - arithmetic operations
  70. def sign(x: Tensor):
  71. r"""Element-wise sign.
  72. Returns an indication of the sign of a number for each element :math:`x_i` of the input tensor :math:`x`.
  73. Args:
  74. inp: input tensor. Should have a numeric data type.
  75. Returns:
  76. a tensor containing the evaluated result for each element in :math:`x`.
  77. The returned array must have the same data type as :math:`x`.
  78. Examples:
  79. Element-wise sign:
  80. >>> x = Tensor([1, -1, 0])
  81. >>> F.sign(x)
  82. Tensor([ 1 -1 0], dtype=int32, device=xpux:0)
  83. """
  84. return (x > 0).astype(x.dtype) - (x < 0).astype(x.dtype)
  85. # statistical functions
  86. def sum(
  87. inp: Tensor,
  88. axis: Optional[Union[int, Sequence[int]]] = None,
  89. keepdims: bool = False,
  90. ) -> Tensor:
  91. r"""Returns the sum of input tensor along given axis. If axis is a list of dimensions,
  92. reduce over all of them.
  93. Args:
  94. inp: input tensor.
  95. axis: dimension to reduce. If None, all dimensions will be reduced.
  96. Default: None
  97. keepdims: whether the output tensor has axis retained or not.
  98. Default: False
  99. Returns:
  100. output tensor.
  101. Examples:
  102. >>> import numpy as np
  103. >>> x = Tensor(np.arange(1, 7, dtype=np.int32).reshape(2, 3))
  104. >>> F.sum(x)
  105. Tensor(21, dtype=int32, device=xpux:0)
  106. """
  107. return inp.sum(axis=axis, keepdims=keepdims)
  108. def prod(
  109. inp: Tensor, axis: Optional[Union[int, Sequence[int]]] = None, keepdims=False
  110. ) -> Tensor:
  111. r"""Returns the product of input tensor along given axis. If axis is a list of dimensions,
  112. reduce over all of them.
  113. Args:
  114. inp: input tensor.
  115. axis: dimension to reduce. If None, all dimensions will be reduced. Default: None
  116. keepdims: whether the output tensor has axis retained or not. Default: False
  117. Returns:
  118. output tensor.
  119. Examples:
  120. >>> import numpy as np
  121. >>> x = Tensor(np.arange(1, 7, dtype=np.int32).reshape(2, 3))
  122. >>> F.prod(x)
  123. Tensor(720, dtype=int32, device=xpux:0)
  124. """
  125. return inp.prod(axis=axis, keepdims=keepdims)
  126. def mean(
  127. inp: Tensor,
  128. axis: Optional[Union[int, Sequence[int]]] = None,
  129. keepdims: bool = False,
  130. ) -> Tensor:
  131. r"""Returns the mean value of input tensor along
  132. given axis. If axis is a list of dimensions,
  133. reduce over all of them.
  134. Args:
  135. inp: input tensor.
  136. axis: dimension to reduce. If None, all dimensions will be reduced. Default: None
  137. keepdims: whether the output tensor has axis retained or not. Default: False
  138. Returns:
  139. output tensor.
  140. Examples:
  141. >>> import numpy as np
  142. >>> x = Tensor(np.arange(1, 7, dtype=np.int32).reshape(2, 3))
  143. >>> out = F.mean(x)
  144. >>> out.numpy()
  145. array(3.5, dtype=float32)
  146. """
  147. return inp.mean(axis=axis, keepdims=keepdims)
  148. def var(
  149. inp: Tensor,
  150. axis: Optional[Union[int, Sequence[int]]] = None,
  151. keepdims: bool = False,
  152. ) -> Tensor:
  153. r"""Returns the variance value of input tensor along
  154. given axis. If axis is a list of dimensions,
  155. reduce over all of them.
  156. Args:
  157. inp: input tensor.
  158. axis: dimension to reduce. If None, all dimensions will be reduced. Default: None
  159. keepdims: whether the output tensor has axis retained or not. Default: False
  160. Returns:
  161. output tensor.
  162. Examples:
  163. >>> import numpy as np
  164. >>> data = Tensor(np.arange(1, 7, dtype=np.float32).reshape(2, 3))
  165. >>> out = F.var(data)
  166. >>> out.numpy().round(decimals=4)
  167. 2.9167
  168. """
  169. if axis is None:
  170. m = mean(inp, axis=axis, keepdims=False)
  171. else:
  172. m = mean(inp, axis=axis, keepdims=True)
  173. v = inp - m
  174. return mean(v ** 2, axis=axis, keepdims=keepdims)
  175. def std(
  176. inp: Tensor,
  177. axis: Optional[Union[int, Sequence[int]]] = None,
  178. keepdims: bool = False,
  179. ) -> Tensor:
  180. r"""Returns the standard deviation of input tensor along
  181. given axis. If axis is a list of dimensions,
  182. reduce over all of them.
  183. Args:
  184. inp: input tensor.
  185. axis: dimension to reduce. If None, all dimensions will be reduced. Default: None
  186. keepdims: whether the output tensor has axis retained or not. Default: False
  187. Returns:
  188. output tensor.
  189. Examples:
  190. >>> import numpy as np
  191. >>> data = Tensor(np.arange(1, 7, dtype=np.float32).reshape(2, 3))
  192. >>> out = F.std(data, axis=1)
  193. >>> out.numpy().round(decimals=4)
  194. array([0.8165, 0.8165], dtype=float32)
  195. """
  196. return var(inp, axis=axis, keepdims=keepdims) ** 0.5
  197. def min(
  198. inp: Tensor,
  199. axis: Optional[Union[int, Sequence[int]]] = None,
  200. keepdims: bool = False,
  201. ) -> Tensor:
  202. r"""Returns the min value of input tensor along
  203. given axis. If axis is a list of dimensions,
  204. reduce over all of them.
  205. Args:
  206. inp: input tensor.
  207. axis: dimension to reduce. If None, all dimensions will be reduced. Default: None
  208. keepdims: whether the output tensor has axis retained or not. Default: False
  209. Returns:
  210. output tensor.
  211. Examples:
  212. >>> import numpy as np
  213. >>> x = Tensor(np.arange(1, 7, dtype=np.int32).reshape(2,3))
  214. >>> F.min(x)
  215. Tensor(1, dtype=int32, device=xpux:0)
  216. """
  217. return inp.min(axis=axis, keepdims=keepdims)
  218. def max(
  219. inp: Tensor,
  220. axis: Optional[Union[int, Sequence[int]]] = None,
  221. keepdims: bool = False,
  222. ) -> Tensor:
  223. r"""Returns the max value of the input tensor along
  224. given axis. If axis is a list of dimensions,
  225. reduce over all of them.
  226. Args:
  227. inp: input tensor.
  228. axis: dimension to reduce. If None, all dimensions will be reduced. Default: None
  229. keepdims: whether the output tensor has axis retained or not. Default: False
  230. Returns:
  231. output tensor.
  232. Examples:
  233. >>> import numpy as np
  234. >>> x = Tensor(np.arange(1, 7, dtype=np.int32).reshape(2,3))
  235. >>> F.max(x)
  236. Tensor(6, dtype=int32, device=xpux:0)
  237. """
  238. return inp.max(axis=axis, keepdims=keepdims)
  239. def norm(
  240. inp: Tensor, ord: float = None, axis: int = None, keepdims=False,
  241. ):
  242. r"""Calculates ``p``-norm of input tensor along
  243. given axis.
  244. Args:
  245. inp: input tensor.
  246. ord: power of value applied to inp. Default: 2
  247. axis: dimension to reduce. If None, input must be a vector. Default: None
  248. keepdims: whether the output tensor has axis retained or not. Default: False
  249. Returns:
  250. output tensor.
  251. Examples:
  252. >>> import numpy as np
  253. >>> x = Tensor(np.arange(-3, 3, dtype=np.float32))
  254. >>> out = F.norm(x)
  255. >>> out.numpy().round(decimals=4)
  256. 4.3589
  257. """
  258. if axis is None:
  259. if inp.ndim != 1:
  260. raise TypeError("axis is required unless input is a vector")
  261. if ord is None:
  262. ord = 2
  263. if ord == 0:
  264. return sum(inp != 0, axis=axis, keepdims=keepdims)
  265. if ord == math.inf:
  266. return max(abs(inp))
  267. if ord == -math.inf:
  268. return min(abs(inp))
  269. return sum(abs(inp) ** ord, axis=axis, keepdims=keepdims) ** (1.0 / ord)
  270. def argmin(
  271. inp: Tensor,
  272. axis: Optional[Union[int, Sequence[int]]] = None,
  273. keepdims: bool = False,
  274. ) -> Tensor:
  275. r"""Returns the indices of the minimum values along
  276. given axis. If axis is a list of dimensions,
  277. reduce over all of them.
  278. Args:
  279. inp: input tensor.
  280. axis: dimension to reduce. If None, all dimensions will be reduced. Default: None
  281. keepdims: whether the output tensor has axis retained or not. Default: False
  282. Returns:
  283. output tensor.
  284. Examples:
  285. >>> import numpy as np
  286. >>> x = Tensor(np.arange(1, 7, dtype=np.int32).reshape(2,3))
  287. >>> F.argmin(x)
  288. Tensor(0, dtype=int32, device=xpux:0)
  289. """
  290. if axis is None:
  291. assert not keepdims, "can not set axis=None and keepdims=True"
  292. inp = inp.flatten()
  293. axis = 0
  294. axis = _normalize_axis(inp.ndim, axis, reverse=True)
  295. if isinstance(axis, collections.abc.Iterable):
  296. for ai in axis:
  297. op = builtin.Argmin(axis=ai)
  298. (inp,) = apply(op, inp)
  299. if not keepdims:
  300. inp = squeeze(inp, ai)
  301. return inp
  302. op = builtin.Argmin(axis=axis)
  303. (result,) = apply(op, inp)
  304. if not keepdims:
  305. result = squeeze(result, axis)
  306. return result
  307. def argmax(
  308. inp: Tensor,
  309. axis: Optional[Union[int, Sequence[int]]] = None,
  310. keepdims: bool = False,
  311. ) -> Tensor:
  312. r"""Returns the indices of the maximum values along
  313. given axis. If axis is a list of dimensions,
  314. reduce over all of them.
  315. Args:
  316. inp: input tensor.
  317. axis: dimension to reduce. If None, all dimensions will be reduced. Default: None
  318. keepdims: whether the output tensor has axis retained or not. Default: False
  319. Returns:
  320. output tensor.
  321. Examples:
  322. >>> import numpy as np
  323. >>> x = Tensor(np.arange(1, 7, dtype=np.int32).reshape(2,3))
  324. >>> F.argmax(x)
  325. Tensor(5, dtype=int32, device=xpux:0)
  326. """
  327. if axis is None:
  328. assert not keepdims, "can not set axis=None and keepdims=True"
  329. inp = inp.flatten()
  330. axis = 0
  331. axis = _normalize_axis(inp.ndim, axis, reverse=True)
  332. if isinstance(axis, collections.abc.Iterable):
  333. for ai in axis:
  334. op = builtin.Argmax(axis=ai)
  335. (inp,) = apply(op, inp)
  336. if not keepdims:
  337. inp = squeeze(inp, ai)
  338. return inp
  339. op = builtin.Argmax(axis=axis)
  340. (result,) = apply(op, inp)
  341. if not keepdims:
  342. result = squeeze(result, axis)
  343. return result
  344. def normalize(
  345. inp: Tensor, ord: float = None, axis: int = None, eps: float = 1e-12,
  346. ) -> Tensor:
  347. r"""Performs :math:`L_p` normalization of input tensor along
  348. given axis.
  349. For a tensor of shape :math:`(n_0, ..., n_{dim}, ..., n_k)`, each
  350. :math:`n_{dim}` -element vector :math:`v` along dimension :attr:`axis` is transformed as:
  351. .. math::
  352. v = \frac{v}{\max(\lVert v \rVert_p, \epsilon)}.
  353. Args:
  354. inp: input tensor.
  355. ord: power of value applied to input tensor. Default: 2
  356. axis: dimension to reduce.If None, input must be a vector. Default: None
  357. eps: a small value to avoid division by zero. Default: 1e-12
  358. Returns:
  359. normalized output tensor.
  360. """
  361. if axis is None:
  362. return inp / clip(norm(inp, ord, axis), lower=eps)
  363. else:
  364. return inp / clip(norm(inp, ord, axis, keepdims=True), lower=eps)
  365. def argsort(inp: Tensor, descending: bool = False) -> Tensor:
  366. r"""Returns the indices that would sort the input tensor.
  367. Args:
  368. inp: input tensor. If it's 2d, the result would be array of indices show how to sort each row in the input tensor.
  369. descending: sort in descending order, where the largest comes first. Default: False
  370. inp: Tensor:
  371. descending: bool:
  372. Returns:
  373. indices of int32 indicates how to sort the input.
  374. Examples:
  375. >>> import numpy as np
  376. >>> x = Tensor(np.array([1,2], dtype=np.float32))
  377. >>> F.argsort(x)
  378. Tensor([0 1], dtype=int32, device=xpux:0)
  379. """
  380. assert len(inp.shape) <= 2, "Input should be 1d or 2d"
  381. if descending:
  382. order = "descending"
  383. else:
  384. order = "ascending"
  385. op = builtin.Argsort(order=order)
  386. if len(inp.shape) == 1:
  387. inp = inp.reshape(1, -1)
  388. _, result = apply(op, inp)
  389. return result[0]
  390. _, result = apply(op, inp)
  391. return result
  392. def sort(inp: Tensor, descending: bool = False) -> Tuple[Tensor, Tensor]:
  393. r"""Returns sorted tensor and the indices would sort the input tensor.
  394. Args:
  395. inp: input tensor. If it's 2d, the result would be sorted by row.
  396. descending: sort in descending order, where the largest comes first. Default: False
  397. Returns:
  398. tuple of two tensors `(sorted_tensor, indices_of_int32)`.
  399. Examples:
  400. >>> import numpy as np
  401. >>> x = Tensor(np.array([1,2], dtype=np.float32))
  402. >>> out, indices = F.sort(x)
  403. >>> out.numpy()
  404. array([1., 2.], dtype=float32)
  405. """
  406. assert len(inp.shape) <= 2, "Input should be 1d or 2d"
  407. if descending:
  408. order = "descending"
  409. else:
  410. order = "ascending"
  411. op = builtin.Argsort(order=order)
  412. if len(inp.shape) == 1:
  413. inp = inp.reshape(1, -1)
  414. tns, ind = apply(op, inp)
  415. return tns[0], ind[0]
  416. tns, ind = apply(op, inp)
  417. return tns, ind
  418. @deprecated_kwargs_default("1.12", "descending", 3)
  419. def topk(
  420. inp: Tensor,
  421. k: int,
  422. descending: bool = False,
  423. kth_only: bool = False,
  424. no_sort: bool = False,
  425. ) -> Tuple[Tensor, Tensor]:
  426. r"""Selects the ``Top-K`` (by default) smallest elements of 2d matrix by row.
  427. Args:
  428. inp: input tensor. If input tensor is 2d, each row will be sorted.
  429. k: number of elements needed.
  430. descending: if True, return the largest elements instead. Default: False
  431. kth_only: if True, only the k-th element will be returned. Default: False
  432. no_sort: if True, the returned elements can be unordered. Default: False
  433. Returns:
  434. tuple of two tensors ``(topk_tensor, indices_of_int32)``
  435. Examples:
  436. >>> import numpy as np
  437. >>> x = Tensor(np.array([2, 4, 6, 8, 7, 5, 3, 1], dtype=np.float32))
  438. >>> top, indices = F.topk(x, 5, descending=False)
  439. >>> print(top.numpy(), indices.numpy())
  440. [1. 2. 3. 4. 5.] [7 0 6 1 5]
  441. """
  442. if descending:
  443. k = -k
  444. if kth_only:
  445. mode = "kth_only"
  446. elif no_sort:
  447. mode = "value_idx_nosort"
  448. else:
  449. mode = "value_idx_sorted"
  450. op = builtin.TopK(mode=mode)
  451. if not isinstance(k, Tensor):
  452. k = Const(k, "int32", inp.device)
  453. if len(inp.shape) == 1:
  454. if kth_only:
  455. (tns,) = apply(op, expand_dims(inp, 0), k)
  456. # FIXME:
  457. # could use a dedicated kernel
  458. # gradient may be routed to other indices if k-th value is not unique
  459. ind = argmax((tns == inp).astype("int8"))
  460. tns = squeeze(tns, 0)
  461. else:
  462. tns, ind = apply(op, expand_dims(inp, 0), k)
  463. tns = squeeze(tns, 0)
  464. ind = squeeze(ind, 0)
  465. else:
  466. if kth_only:
  467. (tns,) = apply(op, inp, k)
  468. # FIXME: same as above
  469. ind = argmax((expand_dims(tns, 1) == inp).astype("int8"), 1)
  470. else:
  471. tns, ind = apply(op, inp, k)
  472. return tns, ind
  473. def matinv(inp: Tensor) -> Tensor:
  474. r"""Computes the inverse of a batch of matrices; input must has shape [..., n, n].
  475. Args:
  476. inp: input tensor.
  477. Returns:
  478. output tensor.
  479. Examples:
  480. >>> import numpy as np
  481. >>> data = Tensor([[1.0, 0.0], [1.0, 1.0]])
  482. >>> out = F.matinv(data)
  483. >>> out.numpy()
  484. array([[ 1., 0.],
  485. [-1., 1.]], dtype=float32)
  486. """
  487. (result,) = apply(builtin.MatrixInverse(), inp)
  488. return result
  489. def matmul(
  490. inp1: Tensor,
  491. inp2: Tensor,
  492. transpose_a=False,
  493. transpose_b=False,
  494. compute_mode="default",
  495. ) -> Tensor:
  496. r"""Performs a matrix multiplication of the matrices ``inp1`` and ``inp2``.
  497. With different inputs dim, this function behaves differently:
  498. * Both 1-D tensor, simply forward to ``dot``.
  499. * Both 2-D tensor, normal matrix multiplication.
  500. * If one input tensor is 1-D, matrix vector multiplication.
  501. * If at least one tensor are 3-dimensional or >3-dimensional, the other tensor should have dim >= 2,
  502. the batched matrix-matrix is returned, and the tensor with smaller dimension will be broadcasted.
  503. For example:
  504. * inp1: `(n, k, m)`, inp2: `(n, m, p)`, return: `(n, k, p)`
  505. * inp1: `(n, k, m)`, inp2: `(m, p)`, return: `(n, k, p)`
  506. * inp1: `(n, j, k, m)`, inp2: `(n, j, m, p)`, return: `(n, j, k, p)`
  507. Args:
  508. inp1: first matrix to be multiplied.
  509. inp2: second matrix to be multiplied.
  510. Returns:
  511. output tensor.
  512. Examples:
  513. >>> import numpy as np
  514. >>> data1 = Tensor(np.arange(0, 6, dtype=np.float32).reshape(2, 3))
  515. >>> data2 = Tensor(np.arange(0, 6, dtype=np.float32).reshape(3, 2))
  516. >>> out = F.matmul(data1, data2)
  517. >>> out.numpy()
  518. array([[10., 13.],
  519. [28., 40.]], dtype=float32)
  520. """
  521. return _matmul(inp1, inp2, transpose_a, transpose_b, compute_mode)
  522. def dot(inp1: Tensor, inp2: Tensor) -> Tensor:
  523. r"""Computes dot-product of two vectors ``inp1`` and ``inp2``.
  524. inputs must be 1-dimensional or scalar. A scalar input is automatically broadcasted.
  525. Refer to :func:`~.matmul` for more general usage.
  526. Args:
  527. inp1: first vector.
  528. inp2: second vector.
  529. Returns:
  530. output value.
  531. Examples:
  532. >>> import numpy as np
  533. >>> data1 = Tensor(np.arange(0, 6, dtype=np.float32))
  534. >>> data2 = Tensor(np.arange(0, 6, dtype=np.float32))
  535. >>> out = F.dot(data1, data2)
  536. >>> out.numpy()
  537. array(55., dtype=float32)
  538. """
  539. op = builtin.Dot()
  540. assert (
  541. inp1.ndim <= 1 and inp2.ndim <= 1
  542. ), "Input tensors for dot must be 1-dimensional or scalar"
  543. (result,) = apply(op, inp1, inp2)
  544. return result
  545. def svd(inp: Tensor, full_matrices=False, compute_uv=True) -> Tensor:
  546. r"""Computes the singular value decomposition of a matrix (or a stack of matrices) ``inp``.
  547. Let :math:`X` be the input matrix (or a stack of input matrices), the output should satisfies:
  548. .. math::
  549. X = U * diag(S) * Vh
  550. where ``U`` is a matrix (or stack of vectors) with orthonormal columns, ``S`` is a vector of
  551. non-negative numbers (or stack of vectors), and ``Vh`` is a matrix (or a stack of matrices)
  552. with orthonormal rows.
  553. Args:
  554. inp (Tensor): A input real tensor having the shape ``(..., M, N)`` with ``inp.ndim >= 2`` .
  555. full_matrices (bool, optional): If ``False`` , ``U`` and ``Vh`` have the shapes ``(..., M, K)``
  556. and ``(..., K, N)`` , respectively, where ``K = min(M, N)`` . If ``True`` , the shapes
  557. are ``(..., M, M)`` and ``(..., N, N)`` , respectively. Default: ``False`` .
  558. compute_uv (bool, optional): Whether or not to compute ``U`` and ``Vh`` in addition to ``S`` . Default: ``True`` .
  559. Note:
  560. * naive does not support ``full_matrices`` and ``compute_uv`` as ``True`` .
  561. Returns:
  562. Returns a tuple ( ``U`` , ``S`` , ``Vh`` ), which are SVD factors ``U`` , ``S``, ``Vh`` of input matrix ``inp``.
  563. ( ``U`` , ``Vh`` only returned when ``compute_uv`` is True). ``U`` contains matrices orthonormal columns
  564. (i.e., the columns are left singular vectors). If ``full_matrices`` is ``True`` , the array must have shape
  565. ``(..., M, M)`` . If ``full_matrices`` is ``False`` , the array must have shape ``(..., M, K)`` , where ``K = min(M, N)`` .
  566. Examples:
  567. >>> import numpy as np
  568. >>> x = Tensor(np.random.randn(9, 6))
  569. >>> y = Tensor(np.random.randn(2, 7, 8, 3))
  570. >>> U, S, Vh = F.svd(x, full_matrices=False)
  571. >>> print(U._tuple_shape, S._tuple_shape, Vh._tuple_shape)
  572. (9, 6) (6,) (6, 6)
  573. >>> u, s, vh = F.svd(y, full_matrices=False)
  574. >>> print(u._tuple_shape, s._tuple_shape, vh._tuple_shape)
  575. (2, 7, 8, 3) (2, 7, 3) (2, 7, 3, 3)
  576. """
  577. op = builtin.SVD(full_matrices=full_matrices, compute_uv=compute_uv)
  578. U, S, Vh = apply(op, inp)
  579. return U, S, Vh
  580. def _check_non_finite(inps: Iterable[Tensor], scale=1.0) -> Tensor:
  581. r"""Check whether input contains infinite or nan value.
  582. Args:
  583. inps: tensors to be checked.
  584. Returns:
  585. a int32 scalar tensor, 0 for False and 1 for True.
  586. """
  587. if isinstance(inps, Tensor):
  588. inps = [inps]
  589. op = builtin.CheckNonFinite(scale=scale)
  590. oups = apply(op, *inps)
  591. out = oups[-1]
  592. for i in range(len(inps)):
  593. inps[i]._reset(oups[i])
  594. return out