You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

math.py 18 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726
  1. # -*- coding: utf-8 -*-
  2. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  3. #
  4. # Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
  5. #
  6. # Unless required by applicable law or agreed to in writing,
  7. # software distributed under the License is distributed on an
  8. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  9. import collections
  10. import functools
  11. import math
  12. import numbers
  13. from typing import Optional, Sequence, Tuple, Union
  14. from ..core.ops import builtin
  15. from ..core.ops._internal import param_defs as P
  16. from ..core.ops.special import Const
  17. from ..core.tensor import utils
  18. from ..core.tensor.core import TensorBase, TensorWrapperBase, apply
  19. from ..tensor import Tensor
  20. from .elemwise import clip, exp, log, log1p
  21. from .tensor import reshape, squeeze
  22. __all__ = [
  23. "argmax",
  24. "argmin",
  25. "argsort",
  26. "isinf",
  27. "isnan",
  28. "max",
  29. "mean",
  30. "min",
  31. "norm",
  32. "normalize",
  33. "prod",
  34. "sign",
  35. "sort",
  36. "std",
  37. "sum",
  38. "topk",
  39. "var",
  40. ]
  41. def isnan(inp: Tensor) -> Tensor:
  42. r"""
  43. Returns a new tensor representing if each element is ``NaN`` or not.
  44. :param inp: input tensor.
  45. :return: result tensor.
  46. Examples:
  47. .. testcode::
  48. from megengine import tensor
  49. import megengine.functional as F
  50. x = tensor([1, float("nan"), 0])
  51. print(F.isnan(x).numpy())
  52. Outputs:
  53. .. testoutput::
  54. [False True False]
  55. """
  56. return inp != inp
  57. def isinf(inp: Tensor) -> Tensor:
  58. r"""
  59. Returns a new tensor representing if each element is ``Inf`` or not.
  60. :param inp: input tensor.
  61. :return: result tensor.
  62. Examples:
  63. .. testcode::
  64. from megengine import tensor
  65. import megengine.functional as F
  66. x = tensor([1, float("inf"), 0])
  67. print(F.isinf(x).numpy())
  68. Outputs:
  69. .. testoutput::
  70. [False True False]
  71. """
  72. return abs(inp).astype("float32") == float("inf")
  73. def sign(inp: Tensor):
  74. r"""
  75. Returns a new tensor representing the sign of each element in input tensor.
  76. :param: input tensor.
  77. :return: the sign of input tensor.
  78. Examples:
  79. .. testcode::
  80. from megengine import tensor
  81. import megengine.functional as F
  82. x = tensor([1, -1, 0])
  83. print(F.sign(x).numpy())
  84. Outputs:
  85. .. testoutput::
  86. [ 1 -1 0]
  87. """
  88. return (inp > 0).astype(inp.dtype) - (inp < 0).astype(inp.dtype)
  89. def sum(
  90. inp: Tensor,
  91. axis: Optional[Union[int, Sequence[int]]] = None,
  92. keepdims: bool = False,
  93. ) -> Tensor:
  94. r"""
  95. Returns the sum of input tensor along given axis. If axis is a list of dimensions,
  96. reduce over all of them.
  97. :param inp: input tensor.
  98. :param axis: dimension to reduce. If None, all dimensions will be reduced.
  99. Default: None
  100. :param keepdims: whether the output tensor has axis retained or not.
  101. Default: False
  102. :return: output tensor.
  103. Examples:
  104. .. testcode::
  105. import numpy as np
  106. from megengine import tensor
  107. import megengine.functional as F
  108. x = tensor(np.arange(1, 7, dtype=np.int32).reshape(2, 3))
  109. out = F.sum(x)
  110. print(out.numpy())
  111. Outputs:
  112. .. testoutput::
  113. [21]
  114. """
  115. return inp.sum(axis=axis, keepdims=keepdims)
  116. def prod(
  117. inp: Tensor, axis: Optional[Union[int, Sequence[int]]] = None, keepdims=False
  118. ) -> Tensor:
  119. r"""
  120. Returns the product of input tensor along given axis. If axis is a list of dimensions,
  121. reduce over all of them.
  122. :param inp: input tensor.
  123. :param axis: dimension to reduce. If None, all dimensions will be reduced. Default: None
  124. :param keepdims: whether the output tensor has axis retained or not. Default: False
  125. :return: output tensor.
  126. Examples:
  127. .. testcode::
  128. import numpy as np
  129. from megengine import tensor
  130. import megengine.functional as F
  131. x = tensor(np.arange(1, 7, dtype=np.int32).reshape(2, 3))
  132. out = F.prod(x)
  133. print(out.numpy())
  134. Outputs:
  135. .. testoutput::
  136. [720]
  137. """
  138. return inp.prod(axis=axis, keepdims=keepdims)
  139. def mean(
  140. inp: Tensor,
  141. axis: Optional[Union[int, Sequence[int]]] = None,
  142. keepdims: bool = False,
  143. ) -> Tensor:
  144. """
  145. Returns the mean value of input tensor along
  146. given axis. If axis is a list of dimensions,
  147. reduce over all of them.
  148. :param inp: input tensor.
  149. :param axis: dimension to reduce. If None, all dimensions will be reduced. Default: None
  150. :param keepdims: whether the output tensor has axis retained or not. Default: False
  151. :return: output tensor.
  152. Examples:
  153. .. testcode::
  154. import numpy as np
  155. from megengine import tensor
  156. import megengine.functional as F
  157. x = tensor(np.arange(1, 7, dtype=np.int32).reshape(2, 3))
  158. out = F.mean(x)
  159. print(out.numpy())
  160. Outputs:
  161. .. testoutput::
  162. [3.5]
  163. """
  164. return inp.mean(axis=axis, keepdims=keepdims)
  165. def var(
  166. inp: Tensor,
  167. axis: Optional[Union[int, Sequence[int]]] = None,
  168. keepdims: bool = False,
  169. ) -> Tensor:
  170. """
  171. Returns the variance value of input tensor along
  172. given axis. If axis is a list of dimensions,
  173. reduce over all of them.
  174. :param inp: input tensor.
  175. :param axis: dimension to reduce. If None, all dimensions will be reduced. Default: None
  176. :param keepdims: whether the output tensor has axis retained or not. Default: False
  177. :return: output tensor.
  178. Examples:
  179. .. testcode::
  180. import numpy as np
  181. from megengine import tensor
  182. import megengine.functional as F
  183. data = tensor(np.arange(1, 7, dtype=np.float32).reshape(2, 3))
  184. out = F.var(data)
  185. print(out.numpy().round(decimals=4))
  186. Outputs:
  187. .. testoutput::
  188. [2.9167]
  189. """
  190. if axis is None:
  191. m = mean(inp, axis=axis, keepdims=False)
  192. else:
  193. m = mean(inp, axis=axis, keepdims=True)
  194. v = inp - m
  195. return mean(v ** 2, axis=axis, keepdims=keepdims)
  196. def std(
  197. inp: Tensor,
  198. axis: Optional[Union[int, Sequence[int]]] = None,
  199. keepdims: bool = False,
  200. ) -> Tensor:
  201. """
  202. Returns the standard deviation of input tensor along
  203. given axis. If axis is a list of dimensions,
  204. reduce over all of them.
  205. :param inp: input tensor.
  206. :param axis: dimension to reduce. If None, all dimensions will be reduced. Default: None
  207. :param keepdims: whether the output tensor has axis retained or not. Default: False
  208. :return: output tensor.
  209. Examples:
  210. .. testcode::
  211. import numpy as np
  212. from megengine import tensor
  213. import megengine.functional as F
  214. data = tensor(np.arange(1, 7, dtype=np.float32).reshape(2, 3))
  215. out = F.std(data, axis=1)
  216. print(out.numpy().round(decimals=4))
  217. Outputs:
  218. .. testoutput::
  219. [0.8165 0.8165]
  220. """
  221. return var(inp, axis=axis, keepdims=keepdims) ** 0.5
  222. def min(
  223. inp: Tensor,
  224. axis: Optional[Union[int, Sequence[int]]] = None,
  225. keepdims: bool = False,
  226. ) -> Tensor:
  227. r"""
  228. Returns the min value of input tensor along
  229. given axis. If axis is a list of dimensions,
  230. reduce over all of them.
  231. :param inp: input tensor.
  232. :param axis: dimension to reduce. If None, all dimensions will be reduced. Default: None
  233. :param keepdims: whether the output tensor has axis retained or not. Default: False
  234. :return: output tensor.
  235. Examples:
  236. .. testcode::
  237. import numpy as np
  238. from megengine import tensor
  239. import megengine.functional as F
  240. x = tensor(np.arange(1, 7, dtype=np.int32).reshape(2,3))
  241. out = F.min(x)
  242. print(out.numpy())
  243. Outputs:
  244. .. testoutput::
  245. [1]
  246. """
  247. return inp.min(axis=axis, keepdims=keepdims)
  248. def max(
  249. inp: Tensor,
  250. axis: Optional[Union[int, Sequence[int]]] = None,
  251. keepdims: bool = False,
  252. ) -> Tensor:
  253. r"""
  254. Returns the max value of the input tensor along
  255. given axis. If axis is a list of dimensions,
  256. reduce over all of them.
  257. :param inp: input tensor.
  258. :param axis: dimension to reduce. If None, all dimensions will be reduced. Default: None
  259. :param keepdims: whether the output tensor has axis retained or not. Default: False
  260. :return: output tensor.
  261. Examples:
  262. .. testcode::
  263. import numpy as np
  264. from megengine import tensor
  265. import megengine.functional as F
  266. x = tensor(np.arange(1, 7, dtype=np.int32).reshape(2,3))
  267. out = F.max(x)
  268. print(out.numpy())
  269. Outputs:
  270. .. testoutput::
  271. [6]
  272. """
  273. return inp.max(axis=axis, keepdims=keepdims)
  274. def norm(
  275. inp: Tensor, ord: float = None, axis: int = None, keepdims=False,
  276. ):
  277. """
  278. Calculates ``p``-norm of input tensor along
  279. given axis.
  280. :param inp: input tensor.
  281. :param ord: power of value applied to inp. Default: 2
  282. :param axis: dimension to reduce. If None, input must be a vector. Default: None
  283. :param keepdims: whether the output tensor has axis retained or not. Default: False
  284. :return: output tensor.
  285. Examples:
  286. .. testcode::
  287. import numpy as np
  288. from megengine import tensor
  289. import megengine.functional as F
  290. x = tensor(np.arange(-3, 3, dtype=np.float32))
  291. out = F.norm(x)
  292. print(out.numpy().round(decimals=4))
  293. Outputs:
  294. .. testoutput::
  295. [4.3589]
  296. """
  297. if axis is None:
  298. if inp.ndim != 1:
  299. raise TypeError("axis is required unless input is a vector")
  300. if ord is None:
  301. ord = 2
  302. if ord == 0:
  303. return sum(inp != 0, axis=axis, keepdims=keepdims)
  304. if ord == math.inf:
  305. return max(abs(inp))
  306. if ord == -math.inf:
  307. return min(abs(inp))
  308. return sum(abs(inp) ** ord, axis=axis, keepdims=keepdims) ** (1.0 / ord)
  309. def argmin(
  310. inp: Tensor,
  311. axis: Optional[Union[int, Sequence[int]]] = None,
  312. keepdims: bool = False,
  313. ) -> Tensor:
  314. r"""
  315. Returns the indices of the minimum values along
  316. given axis. If axis is a list of dimensions,
  317. reduce over all of them.
  318. :param inp: input tensor.
  319. :param axis: dimension to reduce. If None, all dimensions will be reduced. Default: None
  320. :param keepdims: whether the output tensor has axis retained or not. Default: False
  321. :return: output tensor.
  322. Examples:
  323. .. testcode::
  324. import numpy as np
  325. from megengine import tensor
  326. import megengine.functional as F
  327. x = tensor(np.arange(1, 7, dtype=np.int32).reshape(2,3))
  328. out = F.argmin(x)
  329. print(out.numpy())
  330. Outputs:
  331. .. testoutput::
  332. [0]
  333. """
  334. if isinstance(axis, collections.abc.Iterable):
  335. axis = list(axis)
  336. axis.sort(reverse=True)
  337. for ai in axis:
  338. op = builtin.Argmin(axis=ai)
  339. (inp,) = apply(op, inp)
  340. if not keepdims:
  341. inp = squeeze(inp, ai)
  342. return inp
  343. if axis is None:
  344. assert not keepdims, "can not set axis=None and keepdims=True"
  345. inp = inp.flatten()
  346. axis = 0
  347. op = builtin.Argmin(axis=axis)
  348. (result,) = apply(op, inp)
  349. if not keepdims:
  350. result = squeeze(result, axis)
  351. return result
  352. def argmax(
  353. inp: Tensor,
  354. axis: Optional[Union[int, Sequence[int]]] = None,
  355. keepdims: bool = False,
  356. ) -> Tensor:
  357. r"""
  358. Returns the indices of the maximum values along
  359. given axis. If axis is a list of dimensions,
  360. reduce over all of them.
  361. :param inp: input tensor.
  362. :param axis: dimension to reduce. If None, all dimensions will be reduced. Default: None
  363. :param keepdims: whether the output tensor has axis retained or not. Default: False
  364. :return: output tensor.
  365. Examples:
  366. .. testcode::
  367. import numpy as np
  368. from megengine import tensor
  369. import megengine.functional as F
  370. x = tensor(np.arange(1, 7, dtype=np.int32).reshape(2,3))
  371. out = F.argmax(x)
  372. print(out.numpy())
  373. Outputs:
  374. .. testoutput::
  375. [5]
  376. """
  377. if isinstance(axis, collections.abc.Iterable):
  378. axis = list(axis)
  379. axis.sort(reverse=True)
  380. for ai in axis:
  381. op = builtin.Argmax(axis=ai)
  382. (inp,) = apply(op, inp)
  383. if not keepdims:
  384. inp = squeeze(inp, ai)
  385. return inp
  386. if axis is None:
  387. assert not keepdims, "can not set axis=None and keepdims=True"
  388. inp = inp.flatten()
  389. axis = 0
  390. op = builtin.Argmax(axis=axis)
  391. (result,) = apply(op, inp)
  392. if not keepdims:
  393. result = squeeze(result, axis)
  394. return result
  395. def normalize(
  396. inp: Tensor, ord: float = None, axis: int = None, eps: float = 1e-12,
  397. ) -> Tensor:
  398. r"""
  399. Performs :math:`L_p` normalization of input tensor along
  400. given axis.
  401. For a tensor of shape :math:`(n_0, ..., n_{dim}, ..., n_k)`, each
  402. :math:`n_{dim}` -element vector :math:`v` along dimension :attr:`axis` is transformed as:
  403. .. math::
  404. v = \frac{v}{\max(\lVert v \rVert_p, \epsilon)}.
  405. :param inp: input tensor.
  406. :param ord: power of value applied to input tensor. Default: 2
  407. :param axis: dimension to reduce.If None, input must be a vector. Default: None
  408. :param eps: a small value to avoid division by zero. Default: 1e-12
  409. :return: normalized output tensor.
  410. """
  411. if axis is None:
  412. return inp / clip(norm(inp, ord, axis), lower=eps)
  413. else:
  414. return inp / clip(norm(inp, ord, axis, keepdims=True), lower=eps)
  415. def argsort(inp: Tensor, descending: bool = False) -> Tensor:
  416. r"""
  417. Returns the indices that would sort the input tensor.
  418. :param inp: input tensor. If it's 2d, the result would be array of indices show how to sort each row in the input tensor.
  419. :param descending: sort in descending order, where the largest comes first. Default: False
  420. :return: indices of int32 indicates how to sort the input.
  421. Examples:
  422. .. testcode::
  423. import numpy as np
  424. from megengine import tensor
  425. import megengine.functional as F
  426. x = tensor(np.array([1,2], dtype=np.float32))
  427. indices = F.argsort(x)
  428. print(indices.numpy())
  429. Outputs:
  430. .. testoutput::
  431. [0 1]
  432. """
  433. assert len(inp.shape) <= 2, "Input should be 1d or 2d"
  434. if descending:
  435. order = P.Argsort.Order.DESCENDING
  436. else:
  437. order = P.Argsort.Order.ASCENDING
  438. op = builtin.Argsort(order=order)
  439. if len(inp.shape) == 1:
  440. inp = inp.reshape(1, -1)
  441. _, result = apply(op, inp)
  442. return result[0]
  443. _, result = apply(op, inp)
  444. return result
  445. def sort(inp: Tensor, descending: bool = False) -> Tuple[Tensor, Tensor]:
  446. r"""
  447. Returns sorted tensor and the indices would sort the input tensor.
  448. :param inp: input tensor. If it's 2d, the result would be sorted by row.
  449. :param descending: sort in descending order, where the largest comes first. Default: False
  450. :return: tuple of two tensors `(sorted_tensor, indices_of_int32)`.
  451. Examples:
  452. .. testcode::
  453. import numpy as np
  454. from megengine import tensor
  455. import megengine.functional as F
  456. x = tensor(np.array([1,2], dtype=np.float32))
  457. out, indices = F.sort(x)
  458. print(out.numpy())
  459. Outputs:
  460. .. testoutput::
  461. [1. 2.]
  462. """
  463. assert len(inp.shape) <= 2, "Input should be 1d or 2d"
  464. if descending:
  465. order = P.Argsort.Order.DESCENDING
  466. else:
  467. order = P.Argsort.Order.ASCENDING
  468. op = builtin.Argsort(order=order)
  469. if len(inp.shape) == 1:
  470. inp = inp.reshape(1, -1)
  471. tns, ind = apply(op, inp)
  472. return tns[0], ind[0]
  473. tns, ind = apply(op, inp)
  474. return tns, ind
  475. def topk(
  476. inp: Tensor,
  477. k: int,
  478. descending: bool = False,
  479. kth_only: bool = False,
  480. no_sort: bool = False,
  481. ) -> Tuple[Tensor, Tensor]:
  482. r"""
  483. Selects the ``Top-K``(by default) smallest elements of 2d matrix by row.
  484. :param inp: input tensor. If input tensor is 2d, each row will be sorted.
  485. :param k: number of elements needed.
  486. :param descending: if True, return the largest elements instead. Default: False
  487. :param kth_only: if True, only the k-th element will be returned. Default: False
  488. :param no_sort: if True, the returned elements can be unordered. Default: False
  489. :return: tuple of two tensors `(topk_tensor, indices_of_int32)`.
  490. Examples:
  491. .. testcode::
  492. import numpy as np
  493. from megengine import tensor
  494. import megengine.functional as F
  495. x = tensor(np.array([2, 4, 6, 8, 7, 5, 3, 1], dtype=np.float32))
  496. top, indices = F.topk(x, 5)
  497. print(top.numpy(), indices.numpy())
  498. Outputs:
  499. .. testoutput::
  500. [1. 2. 3. 4. 5.] [7 0 6 1 5]
  501. """
  502. if descending:
  503. inp = -inp
  504. Mode = P.TopK.Mode
  505. if kth_only:
  506. mode = Mode.KTH_ONLY
  507. elif no_sort:
  508. mode = Mode.VALUE_IDX_NOSORT
  509. else:
  510. mode = Mode.VALUE_IDX_SORTED
  511. op = builtin.TopK(mode=mode)
  512. if not isinstance(k, (TensorBase, TensorWrapperBase)):
  513. (k,) = Const(k, dtype="int32", device=inp.device)(inp)
  514. if len(inp.shape) == 1:
  515. inp = inp.reshape(1, -1)
  516. res = apply(op, inp, k)
  517. if kth_only:
  518. tns = res[0]
  519. else:
  520. tns, ind = res[0][0], res[1][0]
  521. else:
  522. res = apply(op, inp, k)
  523. if kth_only:
  524. tns = res
  525. else:
  526. tns, ind = res[0], res[1]
  527. if descending:
  528. tns = -tns
  529. return tns, ind

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台