You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

math.py 9.4 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333
  1. # -*- coding: utf-8 -*-
  2. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  3. #
  4. # Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
  5. #
  6. # Unless required by applicable law or agreed to in writing,
  7. # software distributed under the License is distributed on an
  8. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  9. import math
  10. import numbers
  11. from typing import Optional, Sequence, Union
  12. import megengine._internal as mgb
  13. from ..core import Tensor, wrap_io_tensor
  14. from .elemwise import clamp, exp, isinf, log
  15. from .tensor import remove_axis, where, zeros_like
  16. @wrap_io_tensor
  17. def sum(inp: Tensor, axis: Optional[int] = None, keepdims: bool = False) -> Tensor:
  18. r"""Returns the sum of each row of the ``inp`` tensor in the given ``axis``.
  19. :param inp: The input tensor.
  20. :param axis: The dimension to reduce. If None, all the dimensions will be reduced.
  21. Default: None
  22. :param keepdims: Whether the output tensor has ``axis`` retained or not.
  23. Default: False
  24. :return: The output tensor
  25. Examples:
  26. .. testcode::
  27. import numpy as np
  28. from megengine import tensor
  29. import megengine.functional as F
  30. data = tensor(np.arange(1, 7, dtype=np.int32).reshape(2, 3))
  31. out = F.sum(data)
  32. print(out.numpy())
  33. .. testoutput::
  34. [21]
  35. """
  36. return mgb.opr.reduce_(inp, "SUM", axis, keepdims)
  37. @wrap_io_tensor
  38. def prod(inp: Tensor, axis: Optional[int] = None, keepdims=False) -> Tensor:
  39. r"""
  40. Returns the element product of input tensor along given *axis*.
  41. :param inp: The input tensor
  42. :param axis: The dimension to reduce. If None, all the dimensions will be reduced. Default: ``None``
  43. :param keepdims: Whether the output tensor has *axis* retained or not. Default: ``False``
  44. :return: The output tensor
  45. Examples:
  46. .. testcode::
  47. import numpy as np
  48. from megengine import tensor
  49. import megengine.functional as F
  50. data = tensor(np.arange(1, 7, dtype=np.int32).reshape(2, 3))
  51. out = F.prod(data)
  52. print(out.numpy())
  53. Outputs:
  54. .. testoutput::
  55. [720]
  56. """
  57. return mgb.opr.reduce_(inp, "PRODUCT", axis, keepdims)
  58. @wrap_io_tensor
  59. def mean(inp: Tensor, axis: Optional[int] = None, keepdims: bool = False) -> Tensor:
  60. """Returns the mean value of each row of the ``inp`` tensor in
  61. the given ``axis``. If axis is a list of dimensions,
  62. reduce over all of them.
  63. :param inp: The input tensor
  64. :param axis: The dimension to reduce. If None, all the dimensions will be reduced. Default: None
  65. :param keepdims: Whether the output tensor has ``axis`` retained or not. Default: False
  66. Examples:
  67. .. testcode::
  68. import numpy as np
  69. from megengine import tensor
  70. import megengine.functional as F
  71. data = tensor(np.arange(1, 7, dtype=np.int32).reshape(2, 3))
  72. out = F.mean(data)
  73. print(out.numpy())
  74. .. testoutput::
  75. [3.5]
  76. """
  77. return mgb.opr.mean(inp, axis, keepdims)
  78. @wrap_io_tensor
  79. def min(inp: Tensor, axis: Optional[int] = None, keepdims: bool = False) -> Tensor:
  80. r"""
  81. Returns the min value of input tensor along given *axis*.
  82. :param inp: The input tensor
  83. :param axis: The dimension to reduce. If None, all the dimensions will be reduced. Default: None
  84. :param keepdims: Whether the output tensor has *axis* retained or not. Default: False
  85. :return: The output tensor
  86. Examples:
  87. .. testcode::
  88. import numpy as np
  89. from megengine import tensor
  90. import megengine.functional as F
  91. x = tensor(np.arange(1, 7, dtype=np.int32).reshape(2,3))
  92. y = F.min(x)
  93. print(y.numpy())
  94. Outputs:
  95. .. testoutput::
  96. [1]
  97. """
  98. return mgb.opr.reduce_(inp, "MIN", axis, keepdims)
  99. @wrap_io_tensor
  100. def max(inp: Tensor, axis: Optional[int] = None, keepdims: bool = False) -> Tensor:
  101. r"""Returns the max value of the input tensor along given *axis*.
  102. :param inp: The input tensor
  103. :param axis: The dimension to reduce. If None, all the dimensions will be reduced. Default: None
  104. :param keepdims: Whether the output tensor has *axis* retained or not. Default: False
  105. :return: The output tensor
  106. Examples:
  107. .. testcode::
  108. import numpy as np
  109. from megengine import tensor
  110. import megengine.functional as F
  111. x = tensor(np.arange(1, 7, dtype=np.int32).reshape(2,3))
  112. y = F.max(x)
  113. print(y.numpy())
  114. .. testoutput::
  115. [6]
  116. """
  117. return mgb.opr.reduce_(inp, "MAX", axis, keepdims)
  118. @wrap_io_tensor
  119. def sqrt(inp: Tensor) -> Tensor:
  120. """
  121. Return a new tensor with the square-root of the elements of ``inp``
  122. :param inp: The input tensor
  123. :return: The computed tensor
  124. Examples:
  125. .. testcode::
  126. import numpy as np
  127. import megengine as mge
  128. import megengine.functional as F
  129. data = mge.tensor(np.arange(0, 6, dtype=np.float32).reshape(2, 3))
  130. out = F.sqrt(data)
  131. print(out.numpy())
  132. Outputs:
  133. .. testoutput::
  134. [[0. 1. 1.4142]
  135. [1.7321 2. 2.2361 ]]
  136. """
  137. return mgb.opr.sqrt(inp)
  138. def norm(inp: Tensor, p: int = 2, axis: Optional[int] = None, keepdims=False):
  139. """Calculate ``p``-norm of input tensor along certain axis.
  140. :param inp: The input tensor
  141. :param p: power of value ``p`` applied to ``inp``. Default: 2
  142. :param axis: The dimension to reduce. If None, all the dimensions will be reduced. Default: None
  143. :param keepdims: Whether the output tensor has ``axis`` retained or not. Default: False
  144. :return: The output tensor
  145. """
  146. if axis is None:
  147. inp = inp.reshape(-1)
  148. return (inp ** p).sum(axis=axis, keepdims=keepdims) ** (1.0 / p)
  149. @wrap_io_tensor
  150. def argmin(inp: Tensor, axis: Optional[int] = None, keepdims: bool = False) -> Tensor:
  151. r"""Returns the indices of the minimum values along an axis
  152. :param inp: The input tensor
  153. :param axis: The dimension to reduce. If None, all the dimensions will be reduced. Default: None
  154. :param keepdims: Whether the output tensor has *axis* retained or not. Default: False
  155. :return: The output tensor
  156. Examples:
  157. .. testcode::
  158. import numpy as np
  159. from megengine import tensor
  160. import megengine.functional as F
  161. x = tensor(np.arange(1, 7, dtype=np.int32).reshape(2,3))
  162. y = F.argmin(x)
  163. print(y.numpy())
  164. .. testoutput::
  165. [0]
  166. """
  167. return mgb.opr.argmin(inp, axis, keepdims)
  168. @wrap_io_tensor
  169. def argmax(inp: Tensor, axis: Optional[int] = None, keepdims: bool = False) -> Tensor:
  170. r"""Returns the indices of the maximum values along an axis
  171. :param inp: The input tensor
  172. :param axis: The dimension to reduce. If None, all the dimensions will be reduced. Default: None
  173. :param keepdims: Whether the output tensor has *axis* retained or not. Default: False
  174. :return: The output tensor
  175. Examples:
  176. .. testcode::
  177. import numpy as np
  178. from megengine import tensor
  179. import megengine.functional as F
  180. x = tensor(np.arange(1, 7, dtype=np.int32).reshape(2,3))
  181. y = F.argmax(x)
  182. print(y.numpy())
  183. .. testoutput::
  184. [5]
  185. """
  186. return mgb.opr.argmax(inp, axis, keepdims)
  187. def normalize(
  188. inp: Tensor, p: int = 2, axis: Optional[int] = None, eps: float = 1e-12
  189. ) -> Tensor:
  190. r"""Perform :math:`L_p` normalization of input tensor along certain axis.
  191. For a tensor :attr:`inp` of shape :math:`(n_0, ..., n_{dim}, ..., n_k)`, each
  192. :math:`n_{dim}` -element vector :math:`v` along dimension :attr:`axis` is transformed as:
  193. .. math::
  194. v = \frac{v}{\max(\lVert v \rVert_p, \epsilon)}.
  195. :param inp: the input tensor
  196. :param p: power of value ``p`` applied to ``inp``. Default: 2
  197. :param axis: The dimension to reduce. If None, all the dimensions will be reduced
  198. to calculate the norm. Default: None
  199. :param eps: a small value to avoid division by zero. Default: 1e-12
  200. :return: the normalized output tensor
  201. """
  202. if axis is None:
  203. return inp / clamp(norm(inp, p), lower=eps)
  204. else:
  205. return inp / clamp(norm(inp, p, axis, keepdims=True), lower=eps)
  206. def logsumexp(inp: Tensor, axis: Union[int, Sequence[int]], keepdims: bool = False):
  207. r"""
  208. Compute the log of the sum of exponentials of inputs along the given :attr:`axis`. The computation is numerically stabilized.
  209. .. math::
  210. \mathsf{logsumexp}(x_1, \dots, x_n) = \log(\exp(x_1) + \cdots + \exp(x_n))
  211. :param inp: The input tensor.
  212. :param axis: Axis over which the sum is taken. It can be a single axis or a list of axes.
  213. :param keepdims: whether to retain :attr:`axis` or not for the output tensor.
  214. """
  215. if isinstance(axis, numbers.Integral):
  216. axis = (axis,)
  217. max_value = inp
  218. for dim in axis:
  219. max_value = max_value.max(axis=dim, keepdims=True)
  220. max_value = where(
  221. isinf(max_value).astype("int32"), zeros_like(max_value), max_value
  222. )
  223. x = exp(inp - max_value)
  224. for dim in axis:
  225. x = x.sum(axis=dim, keepdims=True)
  226. x = max_value + log(x)
  227. if not keepdims:
  228. axis = sorted(axis, reverse=True)
  229. for i in axis:
  230. x = remove_axis(x, axis=i)
  231. return x

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台