You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

math.py 17 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681
  1. # -*- coding: utf-8 -*-
  2. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  3. #
  4. # Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
  5. #
  6. # Unless required by applicable law or agreed to in writing,
  7. # software distributed under the License is distributed on an
  8. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  9. import collections
  10. import functools
  11. import math
  12. import numbers
  13. from typing import Optional, Sequence, Tuple, Union
  14. from ..core.ops import builtin
  15. from ..core.ops._internal import param_defs as P
  16. from ..core.tensor import utils
  17. from ..core.tensor.core import apply
  18. from ..tensor import Tensor
  19. from .elemwise import clamp, exp, log, log1p
  20. from .tensor import remove_axis, reshape
  21. __all__ = [
  22. "argmax",
  23. "argmin",
  24. "argsort",
  25. "isinf",
  26. "isnan",
  27. "max",
  28. "mean",
  29. "min",
  30. "norm",
  31. "normalize",
  32. "prod",
  33. "sign",
  34. "sort",
  35. "std",
  36. "sum",
  37. "topk",
  38. "var",
  39. ]
  40. def isnan(inp: Tensor) -> Tensor:
  41. r"""Returns a new tensor representing if each element is NaN or not.
  42. :param: inp
  43. :return: a new tensor representing if each element in :attr:`inp` is NaN or not.
  44. Examples:
  45. .. testcode::
  46. from megengine import tensor
  47. import megengine.functional as F
  48. x = tensor([1, float("nan"), 0])
  49. print(F.isnan(x).numpy())
  50. Outputs:
  51. .. testoutput::
  52. [False True False]
  53. """
  54. return inp != inp
  55. def isinf(inp: Tensor) -> Tensor:
  56. r"""Returns a new tensor representing if each element is Inf or not.
  57. :param: inp
  58. :return: a new tensor representing if each element in :attr:`inp` is Inf or not.
  59. Examples:
  60. .. testcode::
  61. from megengine import tensor
  62. import megengine.functional as F
  63. x = tensor([1, float("inf"), 0])
  64. print(F.isinf(x).numpy())
  65. Outputs:
  66. .. testoutput::
  67. [False True False]
  68. """
  69. return abs(inp).astype("float32") == float("inf")
  70. def sign(inp: Tensor):
  71. r"""Returns sign of each element in the input tensor.
  72. :param: inp
  73. :return: a sign tensor.
  74. Examples:
  75. .. testcode::
  76. from megengine import tensor
  77. import megengine.functional as F
  78. x = tensor([1, -1, 0])
  79. print(F.sign(x).numpy())
  80. .. testoutput::
  81. [ 1 -1 0]
  82. """
  83. return (inp > 0).astype(inp.dtype) - (inp < 0).astype(inp.dtype)
  84. def sum(
  85. inp: Tensor,
  86. axis: Optional[Union[int, Sequence[int]]] = None,
  87. keepdims: bool = False,
  88. ) -> Tensor:
  89. r"""Returns the sum of each row of the ``inp`` tensor in the given ``axis``.
  90. :param inp: The input tensor.
  91. :param axis: The dimension to reduce. If None, all the dimensions will be reduced.
  92. Default: None
  93. :param keepdims: Whether the output tensor has ``axis`` retained or not.
  94. Default: False
  95. :return: The output tensor
  96. Examples:
  97. .. testcode::
  98. import numpy as np
  99. from megengine import tensor
  100. import megengine.functional as F
  101. data = tensor(np.arange(1, 7, dtype=np.int32).reshape(2, 3))
  102. out = F.sum(data)
  103. print(out.numpy())
  104. Outputs:
  105. .. testoutput::
  106. [21]
  107. """
  108. return inp.sum(axis=axis, keepdims=keepdims)
  109. def prod(
  110. inp: Tensor, axis: Optional[Union[int, Sequence[int]]] = None, keepdims=False
  111. ) -> Tensor:
  112. r"""
  113. Returns the element product of input tensor along given *axis*.
  114. :param inp: The input tensor
  115. :param axis: The dimension to reduce. If None, all the dimensions will be reduced. Default: ``None``
  116. :param keepdims: Whether the output tensor has *axis* retained or not. Default: ``False``
  117. :return: The output tensor
  118. Examples:
  119. .. testcode::
  120. import numpy as np
  121. from megengine import tensor
  122. import megengine.functional as F
  123. data = tensor(np.arange(1, 7, dtype=np.int32).reshape(2, 3))
  124. out = F.prod(data)
  125. print(out.numpy())
  126. Outputs:
  127. .. testoutput::
  128. [720]
  129. """
  130. return inp.prod(axis=axis, keepdims=keepdims)
  131. def mean(
  132. inp: Tensor,
  133. axis: Optional[Union[int, Sequence[int]]] = None,
  134. keepdims: bool = False,
  135. ) -> Tensor:
  136. """Returns the mean value of each row of the ``inp`` tensor in
  137. the given ``axis``. If axis is a list of dimensions,
  138. reduce over all of them.
  139. :param inp: The input tensor
  140. :param axis: The dimension to reduce. If None, all the dimensions will be reduced. Default: None
  141. :param keepdims: Whether the output tensor has ``axis`` retained or not. Default: False
  142. Examples:
  143. .. testcode::
  144. import numpy as np
  145. from megengine import tensor
  146. import megengine.functional as F
  147. data = tensor(np.arange(1, 7, dtype=np.int32).reshape(2, 3))
  148. out = F.mean(data)
  149. print(out.numpy())
  150. Outputs:
  151. .. testoutput::
  152. [3.5]
  153. """
  154. return inp.astype("float32").mean(axis=axis, keepdims=keepdims)
  155. def median(
  156. inp: Tensor,
  157. axis: Optional[Union[int, Sequence[int]]] = None,
  158. keepdims: bool = False,
  159. ) -> Tensor:
  160. raise NotImplementedError
  161. def var(
  162. inp: Tensor,
  163. axis: Optional[Union[int, Sequence[int]]] = None,
  164. keepdims: bool = False,
  165. ) -> Tensor:
  166. """Returns the variance value of input tensor along
  167. given ``axis``. If axis is a list of dimensions,
  168. reduce over all of them.
  169. :param inp: The input tensor.
  170. :param axis: The dimension to reduce. If None, all the dimensions will be reduced. Default: ``None``.
  171. :param keepdims: Whether the output tensor has ``axis`` retained or not. Default: ``False``.
  172. :return: The output tensor.
  173. Examples:
  174. .. testcode::
  175. import numpy as np
  176. from megengine import tensor
  177. import megengine.functional as F
  178. data = tensor(np.arange(1, 7, dtype=np.float32).reshape(2, 3))
  179. out = F.var(data)
  180. print(out.numpy())
  181. Outputs:
  182. .. testoutput::
  183. [2.9167]
  184. """
  185. if axis is None:
  186. m = mean(inp, axis=axis, keepdims=False)
  187. else:
  188. m = mean(inp, axis=axis, keepdims=True)
  189. v = inp - m
  190. return mean(v ** 2, axis=axis, keepdims=keepdims)
  191. def std(
  192. inp: Tensor,
  193. axis: Optional[Union[int, Sequence[int]]] = None,
  194. keepdims: bool = False,
  195. ) -> Tensor:
  196. """Returns the standard deviation of input tensor along
  197. given ``axis``. If axis is a list of dimensions,
  198. reduce over all of them.
  199. :param inp: The input tensor.
  200. :param axis: The dimension to reduce. If None, all the dimensions will be reduced. Default: ``None``.
  201. :param keepdims: Whether the output tensor has ``axis`` retained or not. Default: ``False``.
  202. :return: The output tensor.
  203. Examples:
  204. .. testcode::
  205. import numpy as np
  206. from megengine import tensor
  207. import megengine.functional as F
  208. data = tensor(np.arange(1, 7, dtype=np.float32).reshape(2, 3))
  209. out = F.std(data, axis=1)
  210. print(out.numpy())
  211. Outputs:
  212. .. testoutput::
  213. [0.8165 0.8165]
  214. """
  215. return var(inp, axis=axis, keepdims=keepdims) ** 0.5
  216. def min(
  217. inp: Tensor,
  218. axis: Optional[Union[int, Sequence[int]]] = None,
  219. keepdims: bool = False,
  220. ) -> Tensor:
  221. r"""
  222. Returns the min value of input tensor along given *axis*.
  223. :param inp: The input tensor
  224. :param axis: The dimension to reduce. If None, all the dimensions will be reduced. Default: None
  225. :param keepdims: Whether the output tensor has *axis* retained or not. Default: False
  226. :return: The output tensor
  227. Examples:
  228. .. testcode::
  229. import numpy as np
  230. from megengine import tensor
  231. import megengine.functional as F
  232. x = tensor(np.arange(1, 7, dtype=np.int32).reshape(2,3))
  233. y = F.min(x)
  234. print(y.numpy())
  235. Outputs:
  236. .. testoutput::
  237. [1]
  238. """
  239. return inp.min(axis=axis, keepdims=keepdims)
  240. def max(
  241. inp: Tensor,
  242. axis: Optional[Union[int, Sequence[int]]] = None,
  243. keepdims: bool = False,
  244. ) -> Tensor:
  245. r"""Returns the max value of the input tensor along given *axis*.
  246. :param inp: The input tensor
  247. :param axis: The dimension to reduce. If None, all the dimensions will be reduced. Default: None
  248. :param keepdims: Whether the output tensor has *axis* retained or not. Default: False
  249. :return: The output tensor
  250. Examples:
  251. .. testcode::
  252. import numpy as np
  253. from megengine import tensor
  254. import megengine.functional as F
  255. x = tensor(np.arange(1, 7, dtype=np.int32).reshape(2,3))
  256. y = F.max(x)
  257. print(y.numpy())
  258. Outputs:
  259. .. testoutput::
  260. [6]
  261. """
  262. return inp.max(axis=axis, keepdims=keepdims)
  263. def norm(
  264. inp: Tensor,
  265. p: int = 2,
  266. axis: Optional[Union[int, Sequence[int]]] = None,
  267. keepdims=False,
  268. ):
  269. """Calculate ``p``-norm of input tensor along certain axis.
  270. :param inp: The input tensor
  271. :param p: power of value ``p`` applied to ``inp``. Default: 2
  272. :param axis: The dimension to reduce. If None, all the dimensions will be reduced. Default: None
  273. :param keepdims: Whether the output tensor has ``axis`` retained or not. Default: False
  274. :return: The output tensor
  275. Examples:
  276. .. testcode::
  277. import numpy as np
  278. from megengine import tensor
  279. import megengine.functional as F
  280. x = tensor(np.arange(-3, 3, dtype=np.float32).reshape(2,3))
  281. y = F.norm(x)
  282. print(y.numpy())
  283. Outputs:
  284. .. testoutput::
  285. [4.3589]
  286. """
  287. if p == 0:
  288. return sum(inp != 0, axis=axis, keepdims=keepdims)
  289. if p == math.inf:
  290. return max(abs(inp))
  291. if p == -math.inf:
  292. return min(abs(inp))
  293. return sum(abs(inp) ** p, axis=axis, keepdims=keepdims) ** (1.0 / p)
  294. def argmin(
  295. inp: Tensor,
  296. axis: Optional[Union[int, Sequence[int]]] = None,
  297. keepdims: bool = False,
  298. ) -> Tensor:
  299. r"""Returns the indices of the minimum values along an axis
  300. :param inp: The input tensor
  301. :param axis: The dimension to reduce. If None, all the dimensions will be reduced. Default: None
  302. :param keepdims: Whether the output tensor has *axis* retained or not. Default: False
  303. :return: The output tensor
  304. Examples:
  305. .. testcode::
  306. import numpy as np
  307. from megengine import tensor
  308. import megengine.functional as F
  309. x = tensor(np.arange(1, 7, dtype=np.int32).reshape(2,3))
  310. y = F.argmin(x)
  311. print(y.numpy())
  312. Outputs:
  313. .. testoutput::
  314. [0]
  315. """
  316. if isinstance(axis, collections.abc.Iterable):
  317. axis = list(axis)
  318. axis.sort(reverse=True)
  319. for ai in axis:
  320. op = builtin.Argmin(axis=ai)
  321. (inp,) = apply(op, inp)
  322. if not keepdims:
  323. inp = remove_axis(inp, ai)
  324. return inp
  325. if axis is None:
  326. assert not keepdims, "can not set axis=None and keepdims=True"
  327. inp = inp.flatten()
  328. axis = 0
  329. op = builtin.Argmin(axis=axis)
  330. (result,) = apply(op, inp)
  331. if not keepdims:
  332. result = remove_axis(result, axis)
  333. return result
  334. def argmax(
  335. inp: Tensor,
  336. axis: Optional[Union[int, Sequence[int]]] = None,
  337. keepdims: bool = False,
  338. ) -> Tensor:
  339. r"""Returns the indices of the maximum values along an axis
  340. :param inp: The input tensor
  341. :param axis: The dimension to reduce. If None, all the dimensions will be reduced. Default: None
  342. :param keepdims: Whether the output tensor has *axis* retained or not. Default: False
  343. :return: The output tensor
  344. Examples:
  345. .. testcode::
  346. import numpy as np
  347. from megengine import tensor
  348. import megengine.functional as F
  349. x = tensor(np.arange(1, 7, dtype=np.int32).reshape(2,3))
  350. y = F.argmax(x)
  351. print(y.numpy())
  352. Outputs:
  353. .. testoutput::
  354. [5]
  355. """
  356. if isinstance(axis, collections.abc.Iterable):
  357. axis = list(axis)
  358. axis.sort(reverse=True)
  359. for ai in axis:
  360. op = builtin.Argmax(axis=ai)
  361. (inp,) = apply(op, inp)
  362. if not keepdims:
  363. inp = remove_axis(inp, ai)
  364. return inp
  365. if axis is None:
  366. assert not keepdims, "can not set axis=None and keepdims=True"
  367. inp = inp.flatten()
  368. axis = 0
  369. op = builtin.Argmax(axis=axis)
  370. (result,) = apply(op, inp)
  371. if not keepdims:
  372. result = remove_axis(result, axis)
  373. return result
  374. def normalize(
  375. inp: Tensor,
  376. p: int = 2,
  377. axis: Optional[Union[int, Sequence[int]]] = None,
  378. eps: float = 1e-12,
  379. ) -> Tensor:
  380. r"""Perform :math:`L_p` normalization of input tensor along certain axis.
  381. For a tensor :attr:`inp` of shape :math:`(n_0, ..., n_{dim}, ..., n_k)`, each
  382. :math:`n_{dim}` -element vector :math:`v` along dimension :attr:`axis` is transformed as:
  383. .. math::
  384. v = \frac{v}{\max(\lVert v \rVert_p, \epsilon)}.
  385. :param inp: the input tensor
  386. :param p: power of value ``p`` applied to ``inp``. Default: 2
  387. :param axis: The dimension to reduce. If None, all the dimensions will be reduced
  388. to calculate the norm. Default: None
  389. :param eps: a small value to avoid division by zero. Default: 1e-12
  390. :return: the normalized output tensor
  391. """
  392. if axis is None:
  393. return inp / clamp(norm(inp, p, axis), lower=eps)
  394. else:
  395. return inp / clamp(norm(inp, p, axis, keepdims=True), lower=eps)
  396. def argsort(inp: Tensor, descending: bool = False) -> Tensor:
  397. r"""
  398. Sort the target 2d matrix by row, return both the sorted tensor and indices.
  399. :param inp: The input tensor, if 2d, each row will be sorted
  400. :param descending: Sort in descending order, where the largest comes first. Default: ``False``
  401. :return: Tuple of two tensors (sorted_tensor, indices_of_int32)
  402. Examples:
  403. .. testcode::
  404. import numpy as np
  405. from megengine import tensor
  406. import megengine.functional as F
  407. data = tensor(np.array([1,2], dtype=np.float32))
  408. indices = F.argsort(data)
  409. print(indices.numpy())
  410. Outputs:
  411. .. testoutput::
  412. [0 1]
  413. """
  414. assert len(inp.shape) <= 2, "Input should be 1d or 2d"
  415. if descending:
  416. order = P.Argsort.Order.DESCENDING
  417. else:
  418. order = P.Argsort.Order.ASCENDING
  419. op = builtin.Argsort(order=order)
  420. if len(inp.shape) == 1:
  421. inp = inp.reshape(1, -1)
  422. _, result = apply(op, inp)
  423. return result[0]
  424. _, result = apply(op, inp)
  425. return result
  426. def sort(inp: Tensor, descending: bool = False) -> Tuple[Tensor, Tensor]:
  427. assert len(inp.shape) <= 2, "Input should be 1d or 2d"
  428. if descending:
  429. order = P.Argsort.Order.DESCENDING
  430. else:
  431. order = P.Argsort.Order.ASCENDING
  432. op = builtin.Argsort(order=order)
  433. if len(inp.shape) == 1:
  434. inp = inp.reshape(1, -1)
  435. tns, ind = apply(op, inp)
  436. return tns[0], ind[0]
  437. tns, ind = apply(op, inp)
  438. return tns, ind
  439. def topk(
  440. inp: Tensor,
  441. k: int,
  442. descending: bool = False,
  443. kth_only: bool = False,
  444. no_sort: bool = False,
  445. ) -> Tuple[Tensor, Tensor]:
  446. r"""
  447. Selected the Top-K (by default) smallest elements of 2d matrix by row.
  448. :param inp: The input tensor, if 2d, each row will be sorted
  449. :param k: The number of elements needed
  450. :param descending: If true, return the largest elements instead. Default: ``False``
  451. :param kth_only: If true, only the k-th element will be returned. Default: ``False``
  452. :param no_sort: If true, the returned elements can be unordered. Default: ``False``
  453. :return: Tuple of two tensors (topk_tensor, indices_of_int32)
  454. Examples:
  455. .. testcode::
  456. import numpy as np
  457. from megengine import tensor
  458. import megengine.functional as F
  459. data = tensor(np.array([2, 4, 6, 8, 7, 5, 3, 1], dtype=np.float32))
  460. top, indices = F.topk(data, 5)
  461. print(top.numpy(), indices.numpy())
  462. Outputs:
  463. .. testoutput::
  464. [1. 2. 3. 4. 5.] [7 0 6 1 5]
  465. """
  466. if descending:
  467. inp = -inp
  468. Mode = P.TopK.Mode
  469. if kth_only:
  470. mode = Mode.KTH_ONLY
  471. elif no_sort:
  472. mode = Mode.VALUE_IDX_NOSORT
  473. else:
  474. mode = Mode.VALUE_IDX_SORTED
  475. op = builtin.TopK(mode=mode)
  476. if len(inp.shape) == 1:
  477. inp = inp.reshape(1, -1)
  478. res = apply(op, inp, Tensor(k, dtype="int32"))
  479. if kth_only:
  480. tns = res[0]
  481. else:
  482. tns, ind = res[0][0], res[1][0]
  483. else:
  484. res = apply(op, inp, Tensor(k, dtype="int32"))
  485. if kth_only:
  486. tns = res
  487. else:
  488. tns, ind = res[0], res[1]
  489. if descending:
  490. tns = -tns
  491. return tns, ind

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台