You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

math.py 17 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684
  1. # -*- coding: utf-8 -*-
  2. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  3. #
  4. # Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
  5. #
  6. # Unless required by applicable law or agreed to in writing,
  7. # software distributed under the License is distributed on an
  8. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  9. import collections
  10. import functools
  11. import math
  12. import numbers
  13. from typing import Optional, Sequence, Tuple, Union
  14. from ..core.ops import builtin
  15. from ..core.ops._internal import param_defs as P
  16. from ..core.tensor import utils
  17. from ..core.tensor.core import apply
  18. from ..tensor import Tensor
  19. from .elemwise import clamp, exp, log, log1p
  20. from .tensor import add_axis, remove_axis, reshape
  21. __all__ = [
  22. "argmax",
  23. "argmin",
  24. "argsort",
  25. "isinf",
  26. "isnan",
  27. "max",
  28. "mean",
  29. "min",
  30. "norm",
  31. "normalize",
  32. "prod",
  33. "sign",
  34. "sort",
  35. "std",
  36. "sum",
  37. "topk",
  38. "var",
  39. ]
  40. def isnan(inp: Tensor) -> Tensor:
  41. r"""Returns a new tensor representing if each element is ``NaN`` or not.
  42. :param inp: input tensor.
  43. :return: a new tensor representing if each element in inp is NaN or not.
  44. Examples:
  45. .. testcode::
  46. from megengine import tensor
  47. import megengine.functional as F
  48. x = tensor([1, float("nan"), 0])
  49. print(F.isnan(x).numpy())
  50. Outputs:
  51. .. testoutput::
  52. [False True False]
  53. """
  54. return inp != inp
  55. def isinf(inp: Tensor) -> Tensor:
  56. r"""Returns a new tensor representing if each element is ``Inf`` or not.
  57. :param inp: input tensor.
  58. :return: a new tensor representing if each element in inp is Inf or not.
  59. Examples:
  60. .. testcode::
  61. from megengine import tensor
  62. import megengine.functional as F
  63. x = tensor([1, float("inf"), 0])
  64. print(F.isinf(x).numpy())
  65. Outputs:
  66. .. testoutput::
  67. [False True False]
  68. """
  69. return abs(inp).astype("float32") == float("inf")
  70. def sign(inp: Tensor):
  71. r"""Returns a new tensor representing the sign of each element in input tensor.
  72. :param: input tensor.
  73. :return: the sign of input tensor.
  74. Examples:
  75. .. testcode::
  76. from megengine import tensor
  77. import megengine.functional as F
  78. x = tensor([1, -1, 0])
  79. print(F.sign(x).numpy())
  80. Outputs:
  81. .. testoutput::
  82. [ 1 -1 0]
  83. """
  84. return (inp > 0).astype(inp.dtype) - (inp < 0).astype(inp.dtype)
  85. def sum(
  86. inp: Tensor,
  87. axis: Optional[Union[int, Sequence[int]]] = None,
  88. keepdims: bool = False,
  89. ) -> Tensor:
  90. r"""Returns the sum of input tensor along given axis. If axis is a list of dimensions,
  91. reduce over all of them.
  92. :param inp: input tensor.
  93. :param axis: dimension to reduce. If None, all the dimensions will be reduced.
  94. Default: None
  95. :param keepdims: whether the output tensor has axis retained or not.
  96. Default: False
  97. :return: output tensor.
  98. Examples:
  99. .. testcode::
  100. import numpy as np
  101. from megengine import tensor
  102. import megengine.functional as F
  103. x = tensor(np.arange(1, 7, dtype=np.int32).reshape(2, 3))
  104. out = F.sum(x)
  105. print(out.numpy())
  106. Outputs:
  107. .. testoutput::
  108. [21]
  109. """
  110. return inp.sum(axis=axis, keepdims=keepdims)
  111. def prod(
  112. inp: Tensor, axis: Optional[Union[int, Sequence[int]]] = None, keepdims=False
  113. ) -> Tensor:
  114. r"""Returns the product of input tensor along given axis. If axis is a list of dimensions,
  115. reduce over all of them.
  116. :param inp: input tensor.
  117. :param axis: dimension to reduce. If None, all the dimensions will be reduced. Default: None
  118. :param keepdims: whether the output tensor has axis retained or not. Default: False
  119. :return: output tensor.
  120. Examples:
  121. .. testcode::
  122. import numpy as np
  123. from megengine import tensor
  124. import megengine.functional as F
  125. x = tensor(np.arange(1, 7, dtype=np.int32).reshape(2, 3))
  126. out = F.prod(x)
  127. print(out.numpy())
  128. Outputs:
  129. .. testoutput::
  130. [720]
  131. """
  132. return inp.prod(axis=axis, keepdims=keepdims)
  133. def mean(
  134. inp: Tensor,
  135. axis: Optional[Union[int, Sequence[int]]] = None,
  136. keepdims: bool = False,
  137. ) -> Tensor:
  138. """Returns the mean value of input tensor along
  139. given axis. If axis is a list of dimensions,
  140. reduce over all of them.
  141. :param inp: input tensor.
  142. :param axis: dimension to reduce. If None, all the dimensions will be reduced. Default: None
  143. :param keepdims: whether the output tensor has axis retained or not. Default: False
  144. :return: output tensor.
  145. Examples:
  146. .. testcode::
  147. import numpy as np
  148. from megengine import tensor
  149. import megengine.functional as F
  150. x = tensor(np.arange(1, 7, dtype=np.int32).reshape(2, 3))
  151. out = F.mean(x)
  152. print(out.numpy())
  153. Outputs:
  154. .. testoutput::
  155. [3.5]
  156. """
  157. return inp.astype("float32").mean(axis=axis, keepdims=keepdims)
  158. def var(
  159. inp: Tensor,
  160. axis: Optional[Union[int, Sequence[int]]] = None,
  161. keepdims: bool = False,
  162. ) -> Tensor:
  163. """Returns the variance value of input tensor along
  164. given axis. If axis is a list of dimensions,
  165. reduce over all of them.
  166. :param inp: input tensor.
  167. :param axis: dimension to reduce. If None, all the dimensions will be reduced. Default: None
  168. :param keepdims: whether the output tensor has axis retained or not. Default: False
  169. :return: output tensor.
  170. Examples:
  171. .. testcode::
  172. import numpy as np
  173. from megengine import tensor
  174. import megengine.functional as F
  175. data = tensor(np.arange(1, 7, dtype=np.float32).reshape(2, 3))
  176. out = F.var(data)
  177. print(out.numpy())
  178. Outputs:
  179. .. testoutput::
  180. [2.9167]
  181. """
  182. if axis is None:
  183. m = mean(inp, axis=axis, keepdims=False)
  184. else:
  185. m = mean(inp, axis=axis, keepdims=True)
  186. v = inp - m
  187. return mean(v ** 2, axis=axis, keepdims=keepdims)
  188. def std(
  189. inp: Tensor,
  190. axis: Optional[Union[int, Sequence[int]]] = None,
  191. keepdims: bool = False,
  192. ) -> Tensor:
  193. """Returns the standard deviation of input tensor along
  194. given axis. If axis is a list of dimensions,
  195. reduce over all of them.
  196. :param inp: input tensor.
  197. :param axis: dimension to reduce. If None, all the dimensions will be reduced. Default: None
  198. :param keepdims: whether the output tensor has axis retained or not. Default: False
  199. :return: output tensor.
  200. Examples:
  201. .. testcode::
  202. import numpy as np
  203. from megengine import tensor
  204. import megengine.functional as F
  205. data = tensor(np.arange(1, 7, dtype=np.float32).reshape(2, 3))
  206. out = F.std(data, axis=1)
  207. print(out.numpy())
  208. Outputs:
  209. .. testoutput::
  210. [0.8165 0.8165]
  211. """
  212. return var(inp, axis=axis, keepdims=keepdims) ** 0.5
  213. def min(
  214. inp: Tensor,
  215. axis: Optional[Union[int, Sequence[int]]] = None,
  216. keepdims: bool = False,
  217. ) -> Tensor:
  218. r"""Returns the min value of input tensor along
  219. given axis. If axis is a list of dimensions,
  220. reduce over all of them.
  221. :param inp: input tensor.
  222. :param axis: dimension to reduce. If None, all the dimensions will be reduced. Default: None
  223. :param keepdims: whether the output tensor has axis retained or not. Default: False
  224. :return: output tensor.
  225. Examples:
  226. .. testcode::
  227. import numpy as np
  228. from megengine import tensor
  229. import megengine.functional as F
  230. x = tensor(np.arange(1, 7, dtype=np.int32).reshape(2,3))
  231. out = F.min(x)
  232. print(out.numpy())
  233. Outputs:
  234. .. testoutput::
  235. [1]
  236. """
  237. return inp.min(axis=axis, keepdims=keepdims)
  238. def max(
  239. inp: Tensor,
  240. axis: Optional[Union[int, Sequence[int]]] = None,
  241. keepdims: bool = False,
  242. ) -> Tensor:
  243. r"""Returns the max value of the input tensor along
  244. given axis. If axis is a list of dimensions,
  245. reduce over all of them.
  246. :param inp: input tensor.
  247. :param axis: dimension to reduce. If None, all the dimensions will be reduced. Default: None
  248. :param keepdims: whether the output tensor has axis retained or not. Default: False
  249. :return: output tensor.
  250. Examples:
  251. .. testcode::
  252. import numpy as np
  253. from megengine import tensor
  254. import megengine.functional as F
  255. x = tensor(np.arange(1, 7, dtype=np.int32).reshape(2,3))
  256. out = F.max(x)
  257. print(out.numpy())
  258. Outputs:
  259. .. testoutput::
  260. [6]
  261. """
  262. return inp.max(axis=axis, keepdims=keepdims)
  263. def norm(
  264. inp: Tensor,
  265. p: int = 2,
  266. axis: Optional[Union[int, Sequence[int]]] = None,
  267. keepdims=False,
  268. ):
  269. """Calculates ``p``-norm of input tensor along
  270. given axis. If axis is a list of dimensions,
  271. reduce over all of them.
  272. :param inp: input tensor.
  273. :param p: power of value applied to inp. Default: 2
  274. :param axis: dimension to reduce. If None, all the dimensions will be reduced. Default: None
  275. :param keepdims: whether the output tensor has axis retained or not. Default: False
  276. :return: output tensor.
  277. Examples:
  278. .. testcode::
  279. import numpy as np
  280. from megengine import tensor
  281. import megengine.functional as F
  282. x = tensor(np.arange(-3, 3, dtype=np.float32).reshape(2,3))
  283. out = F.norm(x)
  284. print(out.numpy())
  285. Outputs:
  286. .. testoutput::
  287. [4.3589]
  288. """
  289. if p == 0:
  290. return sum(inp != 0, axis=axis, keepdims=keepdims)
  291. if p == math.inf:
  292. return max(abs(inp))
  293. if p == -math.inf:
  294. return min(abs(inp))
  295. return sum(abs(inp) ** p, axis=axis, keepdims=keepdims) ** (1.0 / p)
  296. def argmin(
  297. inp: Tensor,
  298. axis: Optional[Union[int, Sequence[int]]] = None,
  299. keepdims: bool = False,
  300. ) -> Tensor:
  301. r"""Returns the indices of the minimum values along
  302. given axis. If axis is a list of dimensions,
  303. reduce over all of them.
  304. :param inp: input tensor.
  305. :param axis: dimension to reduce. If None, all the dimensions will be reduced. Default: None
  306. :param keepdims: whether the output tensor has axis retained or not. Default: False
  307. :return: output tensor.
  308. Examples:
  309. .. testcode::
  310. import numpy as np
  311. from megengine import tensor
  312. import megengine.functional as F
  313. x = tensor(np.arange(1, 7, dtype=np.int32).reshape(2,3))
  314. out = F.argmin(x)
  315. print(out.numpy())
  316. Outputs:
  317. .. testoutput::
  318. [0]
  319. """
  320. if isinstance(axis, collections.abc.Iterable):
  321. axis = list(axis)
  322. axis.sort(reverse=True)
  323. for ai in axis:
  324. op = builtin.Argmin(axis=ai)
  325. (inp,) = apply(op, inp)
  326. if not keepdims:
  327. inp = remove_axis(inp, ai)
  328. return inp
  329. if axis is None:
  330. assert not keepdims, "can not set axis=None and keepdims=True"
  331. inp = inp.flatten()
  332. axis = 0
  333. op = builtin.Argmin(axis=axis)
  334. (result,) = apply(op, inp)
  335. if not keepdims:
  336. result = remove_axis(result, axis)
  337. return result
  338. def argmax(
  339. inp: Tensor,
  340. axis: Optional[Union[int, Sequence[int]]] = None,
  341. keepdims: bool = False,
  342. ) -> Tensor:
  343. r"""Returns the indices of the maximum values along
  344. given axis. If axis is a list of dimensions,
  345. reduce over all of them.
  346. :param inp: input tensor.
  347. :param axis: dimension to reduce. If None, all the dimensions will be reduced. Default: None
  348. :param keepdims: whether the output tensor has axis retained or not. Default: False
  349. :return: output tensor.
  350. Examples:
  351. .. testcode::
  352. import numpy as np
  353. from megengine import tensor
  354. import megengine.functional as F
  355. x = tensor(np.arange(1, 7, dtype=np.int32).reshape(2,3))
  356. out = F.argmax(x)
  357. print(out.numpy())
  358. Outputs:
  359. .. testoutput::
  360. [5]
  361. """
  362. if isinstance(axis, collections.abc.Iterable):
  363. axis = list(axis)
  364. axis.sort(reverse=True)
  365. for ai in axis:
  366. op = builtin.Argmax(axis=ai)
  367. (inp,) = apply(op, inp)
  368. if not keepdims:
  369. inp = remove_axis(inp, ai)
  370. return inp
  371. if axis is None:
  372. assert not keepdims, "can not set axis=None and keepdims=True"
  373. inp = inp.flatten()
  374. axis = 0
  375. op = builtin.Argmax(axis=axis)
  376. (result,) = apply(op, inp)
  377. if not keepdims:
  378. result = remove_axis(result, axis)
  379. return result
  380. def normalize(
  381. inp: Tensor,
  382. p: int = 2,
  383. axis: Optional[Union[int, Sequence[int]]] = None,
  384. eps: float = 1e-12,
  385. ) -> Tensor:
  386. r"""Performs :math:`L_p` normalization of input tensor along
  387. given axis. If axis is a list of dimensions,
  388. reduce over all of them.
  389. For a tensor inp of shape :math:`(n_0, ..., n_{dim}, ..., n_k)`, each
  390. :math:`n_{dim}` -element vector :math:`v` along dimension :attr:`axis` is transformed as:
  391. .. math::
  392. v = \frac{v}{\max(\lVert v \rVert_p, \epsilon)}.
  393. :param inp: input tensor.
  394. :param p: power of value applied to inp. Default: 2
  395. :param axis: dimension to reduce. If None, all the dimensions will be reduced
  396. to calculate the norm. Default: None
  397. :param eps: a small value to avoid division by zero. Default: 1e-12
  398. :return: normalized output tensor.
  399. """
  400. if axis is None:
  401. return inp / clamp(norm(inp, p, axis), lower=eps)
  402. else:
  403. return inp / clamp(norm(inp, p, axis, keepdims=True), lower=eps)
  404. def argsort(inp: Tensor, descending: bool = False) -> Tensor:
  405. r"""Sorts the target 2d matrix by row, return both the sorted tensor and indices.
  406. :param inp: input tensor, if 2d, each row will be sorted.
  407. :param descending: Sort in descending order, where the largest comes first. Default: False
  408. :return: Tuple of two tensors `(sorted_tensor, indices_of_int32)`.
  409. Examples:
  410. .. testcode::
  411. import numpy as np
  412. from megengine import tensor
  413. import megengine.functional as F
  414. x = tensor(np.array([1,2], dtype=np.float32))
  415. indices = F.argsort(x)
  416. print(indices.numpy())
  417. Outputs:
  418. .. testoutput::
  419. [0 1]
  420. """
  421. assert len(inp.shape) <= 2, "Input should be 1d or 2d"
  422. if descending:
  423. order = P.Argsort.Order.DESCENDING
  424. else:
  425. order = P.Argsort.Order.ASCENDING
  426. op = builtin.Argsort(order=order)
  427. if len(inp.shape) == 1:
  428. inp = inp.reshape(1, -1)
  429. _, result = apply(op, inp)
  430. return result[0]
  431. _, result = apply(op, inp)
  432. return result
  433. def sort(inp: Tensor, descending: bool = False) -> Tuple[Tensor, Tensor]:
  434. assert len(inp.shape) <= 2, "Input should be 1d or 2d"
  435. if descending:
  436. order = P.Argsort.Order.DESCENDING
  437. else:
  438. order = P.Argsort.Order.ASCENDING
  439. op = builtin.Argsort(order=order)
  440. if len(inp.shape) == 1:
  441. inp = inp.reshape(1, -1)
  442. tns, ind = apply(op, inp)
  443. return tns[0], ind[0]
  444. tns, ind = apply(op, inp)
  445. return tns, ind
  446. def topk(
  447. inp: Tensor,
  448. k: int,
  449. descending: bool = False,
  450. kth_only: bool = False,
  451. no_sort: bool = False,
  452. ) -> Tuple[Tensor, Tensor]:
  453. r"""Selects the ``Top-K(by default)`` smallest elements of 2d matrix by row.
  454. :param inp: input tensor, if 2d, each row will be sorted.
  455. :param k: number of elements needed.
  456. :param descending: if true, return the largest elements instead. Default: False
  457. :param kth_only: if true, only the k-th element will be returned. Default: False
  458. :param no_sort: if true, the returned elements can be unordered. Default: False
  459. :return: tuple of two tensors `(topk_tensor, indices_of_int32)`.
  460. Examples:
  461. .. testcode::
  462. import numpy as np
  463. from megengine import tensor
  464. import megengine.functional as F
  465. x = tensor(np.array([2, 4, 6, 8, 7, 5, 3, 1], dtype=np.float32))
  466. top, indices = F.topk(x, 5)
  467. print(top.numpy(), indices.numpy())
  468. Outputs:
  469. .. testoutput::
  470. [1. 2. 3. 4. 5.] [7 0 6 1 5]
  471. """
  472. if descending:
  473. inp = -inp
  474. Mode = P.TopK.Mode
  475. if kth_only:
  476. mode = Mode.KTH_ONLY
  477. elif no_sort:
  478. mode = Mode.VALUE_IDX_NOSORT
  479. else:
  480. mode = Mode.VALUE_IDX_SORTED
  481. op = builtin.TopK(mode=mode)
  482. if len(inp.shape) == 1:
  483. inp = inp.reshape(1, -1)
  484. res = apply(op, inp, Tensor(k, dtype="int32"))
  485. if kth_only:
  486. tns = res[0]
  487. else:
  488. tns, ind = res[0][0], res[1][0]
  489. else:
  490. res = apply(op, inp, Tensor(k, dtype="int32"))
  491. if kth_only:
  492. tns = res
  493. else:
  494. tns, ind = res[0], res[1]
  495. if descending:
  496. tns = -tns
  497. return tns, ind

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台