You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

elemwise.py 5.2 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244
  1. # -*- coding: utf-8 -*-
  2. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  3. #
  4. # Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
  5. #
  6. # Unless required by applicable law or agreed to in writing,
  7. # software distributed under the License is distributed on an
  8. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  9. # pylint: disable=unused-argument,invalid-name,redefined-builtin,arguments-out-of-order
  10. import functools
  11. import megengine._internal as mgb
  12. from ..core.graph import _use_default_if_none
  13. from ..core.tensor import Tensor, wrap_io_tensor
  14. __all__ = [
  15. "abs",
  16. "arccos",
  17. "add",
  18. "arcsin",
  19. "ceil",
  20. "cos",
  21. "divide",
  22. "equal",
  23. "exp",
  24. "greater",
  25. "greater_equal",
  26. "floor",
  27. "less",
  28. "less_equal",
  29. "log",
  30. "maximum",
  31. "minimum",
  32. "mod",
  33. "multiply",
  34. "power",
  35. "relu",
  36. "round",
  37. "sigmoid",
  38. "sin",
  39. "subtract",
  40. "tanh",
  41. ]
  42. def _elemwise(mode): # DONT export
  43. """Decorator helps to wrap megbrain element-wise oprs"""
  44. def elemwise_decorator(func):
  45. @functools.wraps(func)
  46. @wrap_io_tensor
  47. def elemwise_func(*inputs) -> Tensor:
  48. if all(isinstance(i, (int, float)) for i in inputs):
  49. device, comp_graph = _use_default_if_none(None, None)
  50. ret = mgb.opr.elemwise(
  51. *inputs, mode=mode, comp_node=device, comp_graph=comp_graph
  52. )
  53. return ret.inferred_value[0]
  54. return mgb.opr.elemwise(*inputs, mode=mode)
  55. return elemwise_func
  56. return elemwise_decorator
  57. @_elemwise("ABS")
  58. def abs(x):
  59. """Calculate the absolute value element-wise."""
  60. @_elemwise("ACOS")
  61. def arccos(x):
  62. """Inverse cosine, element-wise."""
  63. @_elemwise("ADD")
  64. def add(x, y):
  65. """Element-wise addition."""
  66. @_elemwise("ASIN")
  67. def arcsin(x):
  68. """Inverse sine, element-wise."""
  69. @_elemwise("CEIL")
  70. def ceil(x):
  71. """Return the ceil of the input, element-wise."""
  72. @_elemwise("COS")
  73. def cos(x):
  74. """Cosine, element-wise."""
  75. @_elemwise("TRUE_DIV")
  76. def divide(x, y):
  77. """Return (x / y) element-wise."""
  78. @_elemwise("EQ")
  79. def equal(x, y):
  80. """Return (x == y) element-wise."""
  81. @_elemwise("EXP")
  82. def exp(x):
  83. """Calculate the exponential element-wise"""
  84. @_elemwise("FLOOR")
  85. def floor(x):
  86. """Return the floor of the input, element-wise"""
  87. def greater(x, y):
  88. """Return (x > y) element-wise."""
  89. return less(y, x)
  90. def greater_equal(x, y):
  91. """Return (x >= y) element-wise"""
  92. return less_equal(y, x)
  93. @_elemwise("LT")
  94. def less(x, y):
  95. """Return (x < y) element-wise."""
  96. @_elemwise("LEQ")
  97. def less_equal(x, y):
  98. """Return (x =< y) element-wise."""
  99. @_elemwise("LOG")
  100. def log(x):
  101. """Natural logarithm (base `e`), element-wise."""
  102. @_elemwise("MAX")
  103. def maximum(x, y):
  104. """Element-wise maximum of array elements."""
  105. @_elemwise("MIN")
  106. def minimum(x, y):
  107. """Element-wise minimum of array elements."""
  108. @_elemwise("MOD")
  109. def mod(x, y):
  110. """Return element-wise remainder of division."""
  111. @_elemwise("MUL")
  112. def multiply(x, y):
  113. """Element-wise multiplication."""
  114. @_elemwise("POW")
  115. def power(x, y):
  116. """First tensor elements raised to powers from second tensor (x ** y), element-wise."""
  117. @_elemwise("RELU")
  118. def relu(x):
  119. """Return `max(x, 0)` element-wise."""
  120. @_elemwise("ROUND")
  121. def round(x):
  122. """Round tensor to int element-wise."""
  123. @_elemwise("SIGMOID")
  124. def sigmoid(x):
  125. """Return 1 / ( 1 + exp( -x ) ) element-wise."""
  126. @_elemwise("SIN")
  127. def sin(x):
  128. """Sine, element-wise."""
  129. @_elemwise("SUB")
  130. def subtract(x, y):
  131. """Subtract arguments element-wise"""
  132. @_elemwise("TANH")
  133. def tanh(x):
  134. """Compute hyperbolic tangent element-wise."""
  135. @wrap_io_tensor
  136. def clamp(inp: Tensor, lower=None, upper=None) -> Tensor:
  137. r"""
  138. Clamp all elements in :attr:`inp` into the range `[` :attr:`lower`, :attr:`upper` `]` and return
  139. a resulting tensor:
  140. .. math::
  141. y_i = \begin{cases}
  142. \text{lower} & \text{if } x_i < \text{lower} \\
  143. x_i & \text{if } \text{lower} \leq x_i \leq \text{upper} \\
  144. \text{upper} & \text{if } x_i > \text{upper}
  145. \end{cases}
  146. :param inp: the input tensor.
  147. :param lower: lower-bound of the range to be clamped to
  148. :param upper: upper-bound of the range to be clamped to
  149. Example:
  150. .. testcode::
  151. import numpy as np
  152. from megengine import tensor
  153. import megengine.functional as F
  154. a = tensor(np.arange(5).astype(np.int32))
  155. print(F.clamp(a, 2, 4).numpy())
  156. print(F.clamp(a, lower=3).numpy())
  157. print(F.clamp(a, upper=3).numpy())
  158. .. testoutput::
  159. [2 2 2 3 4]
  160. [3 3 3 3 4]
  161. [0 1 2 3 3]
  162. """
  163. assert lower or upper, "At least one of 'lower' or 'upper' must not be None"
  164. if lower:
  165. if upper:
  166. assert lower <= upper, "clamp lower bound is bigger that upper bound"
  167. return minimum(maximum(inp, lower), upper)
  168. else:
  169. return maximum(inp, lower)
  170. else:
  171. return minimum(inp, upper)

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台