You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

math.py 18 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696
  1. # -*- coding: utf-8 -*-
  2. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  3. #
  4. # Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
  5. #
  6. # Unless required by applicable law or agreed to in writing,
  7. # software distributed under the License is distributed on an
  8. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  9. import collections
  10. import functools
  11. import math
  12. import numbers
  13. from typing import Optional, Sequence, Tuple, Union
  14. from ..core.ops import builtin
  15. from ..core.ops._internal import param_defs as P
  16. from ..core.tensor import utils
  17. from ..core.tensor.core import apply
  18. from ..tensor import Tensor
  19. from .elemwise import clamp, exp, log, log1p
  20. from .tensor import remove_axis, reshape
  21. __all__ = [
  22. "all", # TODO
  23. "all_close", # TODO
  24. "any", # TODO
  25. "argmax",
  26. "argmin",
  27. "argsort",
  28. "isinf",
  29. "isnan", # TODO
  30. "max",
  31. "mean",
  32. "median", # TODO
  33. "min",
  34. "norm",
  35. "normalize",
  36. "prod",
  37. "sign", # TODO
  38. "sort",
  39. "std",
  40. "sum",
  41. "topk",
  42. "unique", # TODO
  43. "var",
  44. ]
  45. def all(inp):
  46. raise NotImplementedError
  47. def all_close(inp):
  48. raise NotImplementedError
  49. def any(inp):
  50. raise NotImplementedError
  51. def unique(inp):
  52. raise NotImplementedError
  53. def isnan(inp: Tensor) -> Tensor:
  54. r"""Returns a new tensor representing if each element is NaN or not.
  55. :param: inp
  56. :return: a new tensor representing if each element in :attr:`inp` is NaN or not.
  57. Examples:
  58. .. testcode::
  59. from megengine import tensor
  60. import megengine.functional as F
  61. x = tensor([1, float("nan"), 0])
  62. print(F.isnan(x))
  63. .. testoutput::
  64. Tensor([0 1 0], dtype=uint8)
  65. """
  66. raise NotImplementedError
  67. # return (inp != inp).astype("uint8")
  68. def isinf(inp: Tensor) -> Tensor:
  69. r"""Returns a new tensor representing if each element is Inf or not.
  70. :param: inp
  71. :return: a new tensor representing if each element in :attr:`inp` is Inf or not.
  72. Examples:
  73. .. testcode::
  74. from megengine import tensor
  75. import megengine.functional as F
  76. x = tensor([1, float("inf"), 0])
  77. print(F.isinf(x))
  78. .. testoutput::
  79. Tensor([0 1 0], dtype=uint8)
  80. """
  81. return (abs(inp).astype("float32") == float("inf")).astype("uint8")
  82. def sign(inp: Tensor):
  83. raise NotImplementedError
  84. def _reduce(
  85. data,
  86. *,
  87. mode,
  88. axis: Optional[Union[int, Sequence[int]]] = None,
  89. keepdims: bool = False
  90. ):
  91. (data,) = utils.convert_inputs(data)
  92. if axis is None:
  93. data = data.reshape(-1)
  94. assert not keepdims, "can not set axis=None and keepdims=True"
  95. op = builtin.Reduce(mode=mode, axis=0)
  96. (result,) = apply(op, data)
  97. elif isinstance(axis, collections.Iterable):
  98. axis = list(axis)
  99. axis.sort(reverse=True)
  100. for ai in axis:
  101. op = builtin.Reduce(mode=mode, axis=ai)
  102. (data,) = apply(op, data)
  103. if not keepdims:
  104. data = remove_axis(data, ai)
  105. result = data
  106. else:
  107. op = builtin.Reduce(mode=mode, axis=axis)
  108. (result,) = apply(op, data)
  109. if not keepdims:
  110. result = remove_axis(result, axis)
  111. return result
  112. def sum(
  113. inp: Tensor,
  114. axis: Optional[Union[int, Sequence[int]]] = None,
  115. keepdims: bool = False,
  116. ) -> Tensor:
  117. r"""Returns the sum of each row of the ``inp`` tensor in the given ``axis``.
  118. :param inp: The input tensor.
  119. :param axis: The dimension to reduce. If None, all the dimensions will be reduced.
  120. Default: None
  121. :param keepdims: Whether the output tensor has ``axis`` retained or not.
  122. Default: False
  123. :return: The output tensor
  124. Examples:
  125. .. testcode::
  126. import numpy as np
  127. from megengine import tensor
  128. import megengine.functional as F
  129. data = tensor(np.arange(1, 7, dtype=np.int32).reshape(2, 3))
  130. out = F.sum(data)
  131. print(out.numpy())
  132. .. testoutput::
  133. [21]
  134. """
  135. return _reduce(inp, mode="SUM", axis=axis, keepdims=keepdims)
  136. def prod(
  137. inp: Tensor, axis: Optional[Union[int, Sequence[int]]] = None, keepdims=False
  138. ) -> Tensor:
  139. r"""
  140. Returns the element product of input tensor along given *axis*.
  141. :param inp: The input tensor
  142. :param axis: The dimension to reduce. If None, all the dimensions will be reduced. Default: ``None``
  143. :param keepdims: Whether the output tensor has *axis* retained or not. Default: ``False``
  144. :return: The output tensor
  145. Examples:
  146. .. testcode::
  147. import numpy as np
  148. from megengine import tensor
  149. import megengine.functional as F
  150. data = tensor(np.arange(1, 7, dtype=np.int32).reshape(2, 3))
  151. out = F.prod(data)
  152. print(out.numpy())
  153. Outputs:
  154. .. testoutput::
  155. [720]
  156. """
  157. return _reduce(inp, mode="PRODUCT", axis=axis, keepdims=keepdims)
  158. def mean(
  159. inp: Tensor,
  160. axis: Optional[Union[int, Sequence[int]]] = None,
  161. keepdims: bool = False,
  162. ) -> Tensor:
  163. """Returns the mean value of each row of the ``inp`` tensor in
  164. the given ``axis``. If axis is a list of dimensions,
  165. reduce over all of them.
  166. :param inp: The input tensor
  167. :param axis: The dimension to reduce. If None, all the dimensions will be reduced. Default: None
  168. :param keepdims: Whether the output tensor has ``axis`` retained or not. Default: False
  169. Examples:
  170. .. testcode::
  171. import numpy as np
  172. from megengine import tensor
  173. import megengine.functional as F
  174. data = tensor(np.arange(1, 7, dtype=np.int32).reshape(2, 3))
  175. out = F.mean(data)
  176. print(out.numpy())
  177. .. testoutput::
  178. [3.5]
  179. """
  180. return _reduce(inp, mode="MEAN", axis=axis, keepdims=keepdims)
  181. def median(
  182. inp: Tensor,
  183. axis: Optional[Union[int, Sequence[int]]] = None,
  184. keepdims: bool = False,
  185. ) -> Tensor:
  186. raise NotImplementedError
  187. def var(
  188. inp: Tensor,
  189. axis: Optional[Union[int, Sequence[int]]] = None,
  190. keepdims: bool = False,
  191. ) -> Tensor:
  192. """Returns the variance value of input tensor along
  193. given ``axis``. If axis is a list of dimensions,
  194. reduce over all of them.
  195. :param inp: The input tensor.
  196. :param axis: The dimension to reduce. If None, all the dimensions will be reduced. Default: ``None``.
  197. :param keepdims: Whether the output tensor has ``axis`` retained or not. Default: ``False``.
  198. :return: The output tensor.
  199. Examples:
  200. .. testcode::
  201. import numpy as np
  202. from megengine import tensor
  203. import megengine.functional as F
  204. data = tensor(np.arange(1, 7, dtype=np.float32).reshape(2, 3))
  205. out = F.var(data)
  206. print(out.numpy())
  207. .. testoutput::
  208. [2.9166667]
  209. """
  210. if axis is None:
  211. m = mean(inp, axis=axis, keepdims=False)
  212. else:
  213. m = mean(inp, axis=axis, keepdims=True)
  214. v = inp - m
  215. return mean(v ** 2, axis=axis, keepdims=keepdims)
  216. def std(
  217. inp: Tensor,
  218. axis: Optional[Union[int, Sequence[int]]] = None,
  219. keepdims: bool = False,
  220. ) -> Tensor:
  221. """Returns the standard deviation of input tensor along
  222. given ``axis``. If axis is a list of dimensions,
  223. reduce over all of them.
  224. :param inp: The input tensor.
  225. :param axis: The dimension to reduce. If None, all the dimensions will be reduced. Default: ``None``.
  226. :param keepdims: Whether the output tensor has ``axis`` retained or not. Default: ``False``.
  227. :return: The output tensor.
  228. Examples:
  229. .. testcode::
  230. import numpy as np
  231. from megengine import tensor
  232. import megengine.functional as F
  233. data = tensor(np.arange(1, 7, dtype=np.float32).reshape(2, 3))
  234. out = F.std(data, axis=1)
  235. print(out.numpy())
  236. .. testoutput::
  237. [0.8164966 0.8164966]
  238. """
  239. return var(inp, axis=axis, keepdims=keepdims) ** 0.5
  240. def min(
  241. inp: Tensor,
  242. axis: Optional[Union[int, Sequence[int]]] = None,
  243. keepdims: bool = False,
  244. ) -> Tensor:
  245. r"""
  246. Returns the min value of input tensor along given *axis*.
  247. :param inp: The input tensor
  248. :param axis: The dimension to reduce. If None, all the dimensions will be reduced. Default: None
  249. :param keepdims: Whether the output tensor has *axis* retained or not. Default: False
  250. :return: The output tensor
  251. Examples:
  252. .. testcode::
  253. import numpy as np
  254. from megengine import tensor
  255. import megengine.functional as F
  256. x = tensor(np.arange(1, 7, dtype=np.int32).reshape(2,3))
  257. y = F.min(x)
  258. print(y.numpy())
  259. Outputs:
  260. .. testoutput::
  261. [1]
  262. """
  263. return _reduce(inp, mode="MIN", axis=axis, keepdims=keepdims)
  264. def max(
  265. inp: Tensor,
  266. axis: Optional[Union[int, Sequence[int]]] = None,
  267. keepdims: bool = False,
  268. ) -> Tensor:
  269. r"""Returns the max value of the input tensor along given *axis*.
  270. :param inp: The input tensor
  271. :param axis: The dimension to reduce. If None, all the dimensions will be reduced. Default: None
  272. :param keepdims: Whether the output tensor has *axis* retained or not. Default: False
  273. :return: The output tensor
  274. Examples:
  275. .. testcode::
  276. import numpy as np
  277. from megengine import tensor
  278. import megengine.functional as F
  279. x = tensor(np.arange(1, 7, dtype=np.int32).reshape(2,3))
  280. y = F.max(x)
  281. print(y.numpy())
  282. .. testoutput::
  283. [6]
  284. """
  285. return _reduce(inp, mode="MAX", axis=axis, keepdims=keepdims)
  286. def norm(
  287. inp: Tensor,
  288. p: int = 2,
  289. axis: Optional[Union[int, Sequence[int]]] = None,
  290. keepdims=False,
  291. ):
  292. """Calculate ``p``-norm of input tensor along certain axis.
  293. :param inp: The input tensor
  294. :param p: power of value ``p`` applied to ``inp``. Default: 2
  295. :param axis: The dimension to reduce. If None, all the dimensions will be reduced. Default: None
  296. :param keepdims: Whether the output tensor has ``axis`` retained or not. Default: False
  297. :return: The output tensor
  298. Examples:
  299. .. testcode::
  300. import numpy as np
  301. from megengine import tensor
  302. import megengine.functional as F
  303. x = tensor(np.arange(-3, 3, dtype=np.float32).reshape(2,3))
  304. y = F.norm(x)
  305. print(y.numpy())
  306. .. testoutput::
  307. [4.358899]
  308. """
  309. if p == 0:
  310. return sum(inp != 0, axis=axis, keepdims=keepdims)
  311. if p == math.inf:
  312. return max(abs(inp))
  313. if p == -math.inf:
  314. return min(abs(inp))
  315. return sum(abs(inp) ** p, axis=axis, keepdims=keepdims) ** (1.0 / p)
  316. def argmin(
  317. inp: Tensor,
  318. axis: Optional[Union[int, Sequence[int]]] = None,
  319. keepdims: bool = False,
  320. ) -> Tensor:
  321. r"""Returns the indices of the minimum values along an axis
  322. :param inp: The input tensor
  323. :param axis: The dimension to reduce. If None, all the dimensions will be reduced. Default: None
  324. :param keepdims: Whether the output tensor has *axis* retained or not. Default: False
  325. :return: The output tensor
  326. Examples:
  327. .. testcode::
  328. import numpy as np
  329. from megengine import tensor
  330. import megengine.functional as F
  331. x = tensor(np.arange(1, 7, dtype=np.int32).reshape(2,3))
  332. y = F.argmin(x)
  333. print(y.numpy())
  334. .. testoutput::
  335. [0]
  336. """
  337. if isinstance(axis, collections.Iterable):
  338. axis = list(axis)
  339. axis.sort(reverse=True)
  340. for ai in axis:
  341. op = builtin.Argmin(axis=ai)
  342. (inp,) = apply(op, inp)
  343. if not keepdims:
  344. inp = remove_axis(inp, ai)
  345. return inp
  346. if axis is None:
  347. assert not keepdims, "can not set axis=None and keepdims=True"
  348. inp = inp.flatten()
  349. axis = 0
  350. op = builtin.Argmin(axis=axis)
  351. (result,) = apply(op, inp)
  352. if not keepdims:
  353. result = remove_axis(result, axis)
  354. return result
  355. def argmax(
  356. inp: Tensor,
  357. axis: Optional[Union[int, Sequence[int]]] = None,
  358. keepdims: bool = False,
  359. ) -> Tensor:
  360. r"""Returns the indices of the maximum values along an axis
  361. :param inp: The input tensor
  362. :param axis: The dimension to reduce. If None, all the dimensions will be reduced. Default: None
  363. :param keepdims: Whether the output tensor has *axis* retained or not. Default: False
  364. :return: The output tensor
  365. Examples:
  366. .. testcode::
  367. import numpy as np
  368. from megengine import tensor
  369. import megengine.functional as F
  370. x = tensor(np.arange(1, 7, dtype=np.int32).reshape(2,3))
  371. y = F.argmax(x)
  372. print(y.numpy())
  373. .. testoutput::
  374. [5]
  375. """
  376. if isinstance(axis, collections.Iterable):
  377. axis = list(axis)
  378. axis.sort(reverse=True)
  379. for ai in axis:
  380. op = builtin.Argmax(axis=ai)
  381. (inp,) = apply(op, inp)
  382. if not keepdims:
  383. inp = remove_axis(inp, ai)
  384. return inp
  385. if axis is None:
  386. assert not keepdims, "can not set axis=None and keepdims=True"
  387. inp = inp.flatten()
  388. axis = 0
  389. op = builtin.Argmax(axis=axis)
  390. (result,) = apply(op, inp)
  391. if not keepdims:
  392. result = remove_axis(result, axis)
  393. return result
  394. def normalize(
  395. inp: Tensor,
  396. p: int = 2,
  397. axis: Optional[Union[int, Sequence[int]]] = None,
  398. eps: float = 1e-12,
  399. ) -> Tensor:
  400. r"""Perform :math:`L_p` normalization of input tensor along certain axis.
  401. For a tensor :attr:`inp` of shape :math:`(n_0, ..., n_{dim}, ..., n_k)`, each
  402. :math:`n_{dim}` -element vector :math:`v` along dimension :attr:`axis` is transformed as:
  403. .. math::
  404. v = \frac{v}{\max(\lVert v \rVert_p, \epsilon)}.
  405. :param inp: the input tensor
  406. :param p: power of value ``p`` applied to ``inp``. Default: 2
  407. :param axis: The dimension to reduce. If None, all the dimensions will be reduced
  408. to calculate the norm. Default: None
  409. :param eps: a small value to avoid division by zero. Default: 1e-12
  410. :return: the normalized output tensor
  411. """
  412. if axis is None:
  413. return inp / clamp(norm(inp, p, axis), lower=eps)
  414. else:
  415. return inp / clamp(norm(inp, p, axis, keepdims=True), lower=eps)
  416. def argsort(inp: Tensor, descending: bool = False) -> Tensor:
  417. r"""
  418. Sort the target 2d matrix by row, return both the sorted tensor and indices.
  419. :param inp: The input tensor, if 2d, each row will be sorted
  420. :param descending: Sort in descending order, where the largest comes first. Default: ``False``
  421. :return: Tuple of two tensors (sorted_tensor, indices_of_int32)
  422. Examples:
  423. .. testcode::
  424. import numpy as np
  425. from megengine import tensor
  426. import megengine.functional as F
  427. data = tensor(np.array([1,2], dtype=np.float32))
  428. indices = F.argsort(data)
  429. print(indices.numpy())
  430. Outputs:
  431. .. testoutput::
  432. [0 1]
  433. """
  434. assert len(inp.shape) <= 2, "Input should be 1d or 2d"
  435. if descending:
  436. order = P.Argsort.Order.DESCENDING
  437. else:
  438. order = P.Argsort.Order.ASCENDING
  439. op = builtin.Argsort(order=order)
  440. if len(inp.shape) == 1:
  441. inp = inp.reshape(1, -1)
  442. _, result = apply(op, inp)
  443. return result[0]
  444. _, result = apply(op, inp)
  445. return result
  446. def sort(inp: Tensor, descending: bool = False) -> Tuple[Tensor, Tensor]:
  447. assert len(inp.shape) <= 2, "Input should be 1d or 2d"
  448. if descending:
  449. order = P.Argsort.Order.DESCENDING
  450. else:
  451. order = P.Argsort.Order.ASCENDING
  452. op = builtin.Argsort(order=order)
  453. if len(inp.shape) == 1:
  454. inp = inp.reshape(1, -1)
  455. tns, ind = apply(op, inp)
  456. return tns[0], ind[0]
  457. tns, ind = apply(op, inp)
  458. return tns, ind
  459. def topk(
  460. inp: Tensor,
  461. k: int,
  462. descending: bool = False,
  463. kth_only: bool = False,
  464. no_sort: bool = False,
  465. ) -> Tuple[Tensor, Tensor]:
  466. r"""
  467. Selected the Top-K (by default) smallest elements of 2d matrix by row.
  468. :param inp: The input tensor, if 2d, each row will be sorted
  469. :param k: The number of elements needed
  470. :param descending: If true, return the largest elements instead. Default: ``False``
  471. :param kth_only: If true, only the k-th element will be returned. Default: ``False``
  472. :param no_sort: If true, the returned elements can be unordered. Default: ``False``
  473. :return: Tuple of two tensors (topk_tensor, indices_of_int32)
  474. Examples:
  475. .. testcode::
  476. import numpy as np
  477. from megengine import tensor
  478. import megengine.functional as F
  479. data = tensor(np.array([2, 4, 6, 8, 7, 5, 3, 1], dtype=np.float32))
  480. top, indices = F.topk(data, 5)
  481. print(top.numpy(), indices.numpy())
  482. Outputs:
  483. .. testoutput::
  484. [1. 2. 3. 4. 5.] [7 0 6 1 5]
  485. """
  486. if descending:
  487. inp = -inp
  488. Mode = P.TopK.Mode
  489. if kth_only:
  490. mode = Mode.KTH_ONLY
  491. elif no_sort:
  492. mode = Mode.VALUE_IDX_NOSORT
  493. else:
  494. mode = Mode.VALUE_IDX_SORTED
  495. op = builtin.TopK(mode=mode)
  496. if len(inp.shape) == 1:
  497. inp = inp.reshape(1, -1)
  498. res = apply(op, inp, Tensor(k, dtype="int32"))
  499. if kth_only:
  500. tns = res[0]
  501. else:
  502. tns, ind = res[0][0], res[1][0]
  503. else:
  504. res = apply(op, inp, Tensor(k, dtype="int32"))
  505. if kth_only:
  506. tns = res
  507. else:
  508. tns, ind = res[0], res[1]
  509. if descending:
  510. tns = -tns
  511. return tns, ind

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台