You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

elemwise.py 6.3 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299
  1. # -*- coding: utf-8 -*-
  2. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  3. #
  4. # Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
  5. #
  6. # Unless required by applicable law or agreed to in writing,
  7. # software distributed under the License is distributed on an
  8. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  9. # pylint: disable=unused-argument,invalid-name,redefined-builtin,arguments-out-of-order
  10. import functools
  11. import megengine._internal as mgb
  12. from ..core.graph import _use_default_if_none
  13. from ..core.tensor import Tensor, wrap_io_tensor
  14. __all__ = [
  15. "abs",
  16. "arccos",
  17. "add",
  18. "arcsin",
  19. "clamp",
  20. "ceil",
  21. "cos",
  22. "divide",
  23. "equal",
  24. "exp",
  25. "greater",
  26. "greater_equal",
  27. "floor",
  28. "isinf",
  29. "isnan",
  30. "less",
  31. "less_equal",
  32. "log",
  33. "maximum",
  34. "minimum",
  35. "mod",
  36. "multiply",
  37. "power",
  38. "relu",
  39. "round",
  40. "sigmoid",
  41. "sin",
  42. "subtract",
  43. "tanh",
  44. ]
  45. def _elemwise(mode): # DONT export
  46. """Decorator helps to wrap megbrain element-wise oprs"""
  47. def elemwise_decorator(func):
  48. @functools.wraps(func)
  49. @wrap_io_tensor
  50. def elemwise_func(*inputs) -> Tensor:
  51. if all(isinstance(i, (int, float)) for i in inputs):
  52. device, comp_graph = _use_default_if_none(None, None)
  53. ret = mgb.opr.elemwise(
  54. *inputs, mode=mode, comp_node=device, comp_graph=comp_graph
  55. )
  56. return ret.inferred_value[0]
  57. return mgb.opr.elemwise(*inputs, mode=mode)
  58. return elemwise_func
  59. return elemwise_decorator
  60. @_elemwise("ABS")
  61. def abs(x):
  62. """Calculate the absolute value element-wise."""
  63. @_elemwise("ACOS")
  64. def arccos(x):
  65. """Inverse cosine, element-wise."""
  66. @_elemwise("ADD")
  67. def add(x, y):
  68. """Element-wise addition."""
  69. @_elemwise("ASIN")
  70. def arcsin(x):
  71. """Inverse sine, element-wise."""
  72. @_elemwise("CEIL")
  73. def ceil(x):
  74. """Return the ceil of the input, element-wise."""
  75. @_elemwise("COS")
  76. def cos(x):
  77. """Cosine, element-wise."""
  78. @_elemwise("TRUE_DIV")
  79. def divide(x, y):
  80. """Return (x / y) element-wise."""
  81. @_elemwise("EQ")
  82. def equal(x, y):
  83. """Return (x == y) element-wise."""
  84. @_elemwise("EXP")
  85. def exp(x):
  86. """Calculate the exponential element-wise"""
  87. @_elemwise("FLOOR")
  88. def floor(x):
  89. """Return the floor of the input, element-wise"""
  90. def greater(x, y):
  91. """Return (x > y) element-wise."""
  92. return less(y, x)
  93. def greater_equal(x, y):
  94. """Return (x >= y) element-wise"""
  95. return less_equal(y, x)
  96. @_elemwise("LT")
  97. def less(x, y):
  98. """Return (x < y) element-wise."""
  99. @_elemwise("LEQ")
  100. def less_equal(x, y):
  101. """Return (x =< y) element-wise."""
  102. @_elemwise("LOG")
  103. def log(x):
  104. """Natural logarithm (base `e`), element-wise."""
  105. @_elemwise("MAX")
  106. def maximum(x, y):
  107. """Element-wise maximum of array elements."""
  108. @_elemwise("MIN")
  109. def minimum(x, y):
  110. """Element-wise minimum of array elements."""
  111. @_elemwise("MOD")
  112. def mod(x, y):
  113. """Return element-wise remainder of division."""
  114. @_elemwise("MUL")
  115. def multiply(x, y):
  116. """Element-wise multiplication."""
  117. @_elemwise("POW")
  118. def power(x, y):
  119. """First tensor elements raised to powers from second tensor (x ** y), element-wise."""
  120. @_elemwise("RELU")
  121. def relu(x):
  122. """Return `max(x, 0)` element-wise."""
  123. @_elemwise("ROUND")
  124. def round(x):
  125. """Round tensor to int element-wise."""
  126. @_elemwise("SIGMOID")
  127. def sigmoid(x):
  128. """Return 1 / ( 1 + exp( -x ) ) element-wise."""
  129. @_elemwise("SIN")
  130. def sin(x):
  131. """Sine, element-wise."""
  132. @_elemwise("SUB")
  133. def subtract(x, y):
  134. """Subtract arguments element-wise"""
  135. @_elemwise("TANH")
  136. def tanh(x):
  137. """Compute hyperbolic tangent element-wise."""
  138. @wrap_io_tensor
  139. def clamp(inp: Tensor, lower=None, upper=None) -> Tensor:
  140. r"""
  141. Clamp all elements in :attr:`inp` into the range `[` :attr:`lower`, :attr:`upper` `]` and return
  142. a resulting tensor:
  143. .. math::
  144. y_i = \begin{cases}
  145. \text{lower} & \text{if } x_i < \text{lower} \\
  146. x_i & \text{if } \text{lower} \leq x_i \leq \text{upper} \\
  147. \text{upper} & \text{if } x_i > \text{upper}
  148. \end{cases}
  149. :param inp: the input tensor.
  150. :param lower: lower-bound of the range to be clamped to
  151. :param upper: upper-bound of the range to be clamped to
  152. Example:
  153. .. testcode::
  154. import numpy as np
  155. from megengine import tensor
  156. import megengine.functional as F
  157. a = tensor(np.arange(5).astype(np.int32))
  158. print(F.clamp(a, 2, 4).numpy())
  159. print(F.clamp(a, lower=3).numpy())
  160. print(F.clamp(a, upper=3).numpy())
  161. .. testoutput::
  162. [2 2 2 3 4]
  163. [3 3 3 3 4]
  164. [0 1 2 3 3]
  165. """
  166. assert (
  167. lower is not None or upper is not None
  168. ), "At least one of 'lower' or 'upper' must not be None"
  169. if lower is not None:
  170. if upper is not None:
  171. assert lower <= upper, "clamp lower bound is bigger that upper bound"
  172. return minimum(maximum(inp, lower), upper)
  173. else:
  174. return maximum(inp, lower)
  175. else:
  176. return minimum(inp, upper)
  177. def isnan(inp: Tensor) -> Tensor:
  178. r"""Returns a new tensor representing if each element is NaN or not.
  179. :param: inp
  180. :return: a new tensor representing if each element in :attr:`inp` is NaN or not.
  181. Examples:
  182. .. testcode::
  183. from megengine import tensor
  184. import megengine.functional as F
  185. x = tensor([1, float("nan"), 0])
  186. print(F.isnan(x))
  187. .. testoutput::
  188. Tensor([0 1 0], dtype=uint8)
  189. """
  190. return (inp != inp).astype("uint8")
  191. def isinf(inp: Tensor) -> Tensor:
  192. r"""Returns a new tensor representing if each element is Inf or not.
  193. :param: inp
  194. :return: a new tensor representing if each element in :attr:`inp` is Inf or not.
  195. Examples:
  196. .. testcode::
  197. from megengine import tensor
  198. import megengine.functional as F
  199. x = tensor([1, float("inf"), 0])
  200. print(F.isinf(x))
  201. .. testoutput::
  202. Tensor([0 1 0], dtype=uint8)
  203. """
  204. return (abs(inp) == float("inf")).astype("uint8")

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台