You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

elemwise.py 6.3 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298
  1. # -*- coding: utf-8 -*-
  2. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  3. #
  4. # Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
  5. #
  6. # Unless required by applicable law or agreed to in writing,
  7. # software distributed under the License is distributed on an
  8. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  9. # pylint: disable=unused-argument,invalid-name,redefined-builtin,arguments-out-of-order
  10. import functools
  11. import megengine._internal as mgb
  12. from ..core.graph import _use_default_if_none
  13. from ..core.tensor import Tensor, wrap_io_tensor
  14. __all__ = [
  15. "abs",
  16. "arccos",
  17. "add",
  18. "arcsin",
  19. "ceil",
  20. "cos",
  21. "divide",
  22. "equal",
  23. "exp",
  24. "greater",
  25. "greater_equal",
  26. "floor",
  27. "isinf",
  28. "isnan",
  29. "less",
  30. "less_equal",
  31. "log",
  32. "maximum",
  33. "minimum",
  34. "mod",
  35. "multiply",
  36. "power",
  37. "relu",
  38. "round",
  39. "sigmoid",
  40. "sin",
  41. "subtract",
  42. "tanh",
  43. ]
  44. def _elemwise(mode): # DONT export
  45. """Decorator helps to wrap megbrain element-wise oprs"""
  46. def elemwise_decorator(func):
  47. @functools.wraps(func)
  48. @wrap_io_tensor
  49. def elemwise_func(*inputs) -> Tensor:
  50. if all(isinstance(i, (int, float)) for i in inputs):
  51. device, comp_graph = _use_default_if_none(None, None)
  52. ret = mgb.opr.elemwise(
  53. *inputs, mode=mode, comp_node=device, comp_graph=comp_graph
  54. )
  55. return ret.inferred_value[0]
  56. return mgb.opr.elemwise(*inputs, mode=mode)
  57. return elemwise_func
  58. return elemwise_decorator
  59. @_elemwise("ABS")
  60. def abs(x):
  61. """Calculate the absolute value element-wise."""
  62. @_elemwise("ACOS")
  63. def arccos(x):
  64. """Inverse cosine, element-wise."""
  65. @_elemwise("ADD")
  66. def add(x, y):
  67. """Element-wise addition."""
  68. @_elemwise("ASIN")
  69. def arcsin(x):
  70. """Inverse sine, element-wise."""
  71. @_elemwise("CEIL")
  72. def ceil(x):
  73. """Return the ceil of the input, element-wise."""
  74. @_elemwise("COS")
  75. def cos(x):
  76. """Cosine, element-wise."""
  77. @_elemwise("TRUE_DIV")
  78. def divide(x, y):
  79. """Return (x / y) element-wise."""
  80. @_elemwise("EQ")
  81. def equal(x, y):
  82. """Return (x == y) element-wise."""
  83. @_elemwise("EXP")
  84. def exp(x):
  85. """Calculate the exponential element-wise"""
  86. @_elemwise("FLOOR")
  87. def floor(x):
  88. """Return the floor of the input, element-wise"""
  89. def greater(x, y):
  90. """Return (x > y) element-wise."""
  91. return less(y, x)
  92. def greater_equal(x, y):
  93. """Return (x >= y) element-wise"""
  94. return less_equal(y, x)
  95. @_elemwise("LT")
  96. def less(x, y):
  97. """Return (x < y) element-wise."""
  98. @_elemwise("LEQ")
  99. def less_equal(x, y):
  100. """Return (x =< y) element-wise."""
  101. @_elemwise("LOG")
  102. def log(x):
  103. """Natural logarithm (base `e`), element-wise."""
  104. @_elemwise("MAX")
  105. def maximum(x, y):
  106. """Element-wise maximum of array elements."""
  107. @_elemwise("MIN")
  108. def minimum(x, y):
  109. """Element-wise minimum of array elements."""
  110. @_elemwise("MOD")
  111. def mod(x, y):
  112. """Return element-wise remainder of division."""
  113. @_elemwise("MUL")
  114. def multiply(x, y):
  115. """Element-wise multiplication."""
  116. @_elemwise("POW")
  117. def power(x, y):
  118. """First tensor elements raised to powers from second tensor (x ** y), element-wise."""
  119. @_elemwise("RELU")
  120. def relu(x):
  121. """Return `max(x, 0)` element-wise."""
  122. @_elemwise("ROUND")
  123. def round(x):
  124. """Round tensor to int element-wise."""
  125. @_elemwise("SIGMOID")
  126. def sigmoid(x):
  127. """Return 1 / ( 1 + exp( -x ) ) element-wise."""
  128. @_elemwise("SIN")
  129. def sin(x):
  130. """Sine, element-wise."""
  131. @_elemwise("SUB")
  132. def subtract(x, y):
  133. """Subtract arguments element-wise"""
  134. @_elemwise("TANH")
  135. def tanh(x):
  136. """Compute hyperbolic tangent element-wise."""
  137. @wrap_io_tensor
  138. def clamp(inp: Tensor, lower=None, upper=None) -> Tensor:
  139. r"""
  140. Clamp all elements in :attr:`inp` into the range `[` :attr:`lower`, :attr:`upper` `]` and return
  141. a resulting tensor:
  142. .. math::
  143. y_i = \begin{cases}
  144. \text{lower} & \text{if } x_i < \text{lower} \\
  145. x_i & \text{if } \text{lower} \leq x_i \leq \text{upper} \\
  146. \text{upper} & \text{if } x_i > \text{upper}
  147. \end{cases}
  148. :param inp: the input tensor.
  149. :param lower: lower-bound of the range to be clamped to
  150. :param upper: upper-bound of the range to be clamped to
  151. Example:
  152. .. testcode::
  153. import numpy as np
  154. from megengine import tensor
  155. import megengine.functional as F
  156. a = tensor(np.arange(5).astype(np.int32))
  157. print(F.clamp(a, 2, 4).numpy())
  158. print(F.clamp(a, lower=3).numpy())
  159. print(F.clamp(a, upper=3).numpy())
  160. .. testoutput::
  161. [2 2 2 3 4]
  162. [3 3 3 3 4]
  163. [0 1 2 3 3]
  164. """
  165. assert (
  166. lower is not None or upper is not None
  167. ), "At least one of 'lower' or 'upper' must not be None"
  168. if lower is not None:
  169. if upper is not None:
  170. assert lower <= upper, "clamp lower bound is bigger that upper bound"
  171. return minimum(maximum(inp, lower), upper)
  172. else:
  173. return maximum(inp, lower)
  174. else:
  175. return minimum(inp, upper)
  176. def isnan(inp: Tensor) -> Tensor:
  177. r"""Returns a new tensor representing if each element is NaN or not.
  178. :param: inp
  179. :return: a new tensor representing if each element in :attr:`inp` is NaN or not.
  180. Examples:
  181. .. testcode::
  182. from megengine import tensor
  183. import megengine.functional as F
  184. x = tensor([1, float("nan"), 0])
  185. print(F.isnan(x))
  186. .. testoutput::
  187. Tensor([0 1 0], dtype=uint8)
  188. """
  189. return (inp != inp).astype("uint8")
  190. def isinf(inp: Tensor) -> Tensor:
  191. r"""Returns a new tensor representing if each element is Inf or not.
  192. :param: inp
  193. :return: a new tensor representing if each element in :attr:`inp` is Inf or not.
  194. Examples:
  195. .. testcode::
  196. from megengine import tensor
  197. import megengine.functional as F
  198. x = tensor([1, float("inf"), 0])
  199. print(F.isinf(x))
  200. .. testoutput::
  201. Tensor([0 1 0], dtype=uint8)
  202. """
  203. return (abs(inp) == float("inf")).astype("uint8")

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台