You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

elemwise.py 5.3 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244
  1. # -*- coding: utf-8 -*-
  2. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  3. #
  4. # Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
  5. #
  6. # Unless required by applicable law or agreed to in writing,
  7. # software distributed under the License is distributed on an
  8. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  9. # pylint: disable=unused-argument,invalid-name,redefined-builtin,arguments-out-of-order
  10. import functools
  11. import megengine._internal as mgb
  12. from ..core.graph import _use_default_if_none
  13. from ..core.tensor import Tensor, wrap_io_tensor
  14. __all__ = [
  15. "abs",
  16. "arccos",
  17. "add",
  18. "arcsin",
  19. "ceil",
  20. "cos",
  21. "divide",
  22. "equal",
  23. "exp",
  24. "greater",
  25. "greater_equal",
  26. "floor",
  27. "less",
  28. "less_equal",
  29. "log",
  30. "maximum",
  31. "minimum",
  32. "mod",
  33. "multiply",
  34. "power",
  35. "relu",
  36. "round",
  37. "sigmoid",
  38. "sin",
  39. "subtract",
  40. "tanh",
  41. ]
  42. def _elemwise(mode): # DONT export
  43. """Decorator helps to wrap megbrain element-wise oprs"""
  44. def elemwise_decorator(func):
  45. @functools.wraps(func)
  46. @wrap_io_tensor
  47. def elemwise_func(*inputs) -> Tensor:
  48. if all(isinstance(i, (int,float)) for i in inputs):
  49. device, comp_graph = _use_default_if_none(None, None)
  50. ret = mgb.opr.elemwise(*inputs,
  51. mode=mode,
  52. comp_node=device,
  53. comp_graph=comp_graph)
  54. return ret.inferred_value[0]
  55. return mgb.opr.elemwise(*inputs, mode=mode)
  56. return elemwise_func
  57. return elemwise_decorator
  58. @_elemwise("ABS")
  59. def abs(x):
  60. """Calculate the absolute value element-wise."""
  61. @_elemwise("ACOS")
  62. def arccos(x):
  63. """Inverse cosine, element-wise."""
  64. @_elemwise("ADD")
  65. def add(x, y):
  66. """Element-wise addition."""
  67. @_elemwise("ASIN")
  68. def arcsin(x):
  69. """Inverse sine, element-wise."""
  70. @_elemwise("CEIL")
  71. def ceil(x):
  72. """Return the ceil of the input, element-wise."""
  73. @_elemwise("COS")
  74. def cos(x):
  75. """Cosine, element-wise."""
  76. @_elemwise("TRUE_DIV")
  77. def divide(x, y):
  78. """Return (x / y) element-wise."""
  79. @_elemwise("EQ")
  80. def equal(x, y):
  81. """Return (x == y) element-wise."""
  82. @_elemwise("EXP")
  83. def exp(x):
  84. """Calculate the exponential element-wise"""
  85. @_elemwise("FLOOR")
  86. def floor(x):
  87. """Return the floor of the input, element-wise"""
  88. def greater(x, y):
  89. """Return (x > y) element-wise."""
  90. return less(y, x)
  91. def greater_equal(x, y):
  92. """Return (x >= y) element-wise"""
  93. return less_equal(y, x)
  94. @_elemwise("LT")
  95. def less(x, y):
  96. """Return (x < y) element-wise."""
  97. @_elemwise("LEQ")
  98. def less_equal(x, y):
  99. """Return (x =< y) element-wise."""
  100. @_elemwise("LOG")
  101. def log(x):
  102. """Natural logarithm (base `e`), element-wise."""
  103. @_elemwise("MAX")
  104. def maximum(x, y):
  105. """Element-wise maximum of array elements."""
  106. @_elemwise("MIN")
  107. def minimum(x, y):
  108. """Element-wise minimum of array elements."""
  109. @_elemwise("MOD")
  110. def mod(x, y):
  111. """Return element-wise remainder of division."""
  112. @_elemwise("MUL")
  113. def multiply(x, y):
  114. """Element-wise multiplication."""
  115. @_elemwise("POW")
  116. def power(x, y):
  117. """First tensor elements raised to powers from second tensor (x ** y), element-wise."""
  118. @_elemwise("RELU")
  119. def relu(x):
  120. """Return `max(x, 0)` element-wise."""
  121. @_elemwise("ROUND")
  122. def round(x):
  123. """Round tensor to int element-wise."""
  124. @_elemwise("SIGMOID")
  125. def sigmoid(x):
  126. """Return 1 / ( 1 + exp( -x ) ) element-wise."""
  127. @_elemwise("SIN")
  128. def sin(x):
  129. """Sine, element-wise."""
  130. @_elemwise("SUB")
  131. def subtract(x, y):
  132. """Subtract arguments element-wise"""
  133. @_elemwise("TANH")
  134. def tanh(x):
  135. """Compute hyperbolic tangent element-wise."""
  136. @wrap_io_tensor
  137. def clamp(inp: Tensor, lower=None, upper=None) -> Tensor:
  138. r"""
  139. Clamp all elements in :attr:`inp` into the range `[` :attr:`lower`, :attr:`upper` `]` and return
  140. a resulting tensor:
  141. .. math::
  142. y_i = \begin{cases}
  143. \text{lower} & \text{if } x_i < \text{lower} \\
  144. x_i & \text{if } \text{lower} \leq x_i \leq \text{upper} \\
  145. \text{upper} & \text{if } x_i > \text{upper}
  146. \end{cases}
  147. :param inp: the input tensor.
  148. :param lower: lower-bound of the range to be clamped to
  149. :param upper: upper-bound of the range to be clamped to
  150. Example:
  151. .. testcode::
  152. import numpy as np
  153. from megengine import tensor
  154. import megengine.functional as F
  155. a = tensor(np.arange(5).astype(np.int32))
  156. print(F.clamp(a, 2, 4).numpy())
  157. print(F.clamp(a, lower=3).numpy())
  158. print(F.clamp(a, upper=3).numpy())
  159. .. testoutput::
  160. [2 2 2 3 4]
  161. [3 3 3 3 4]
  162. [0 1 2 3 3]
  163. """
  164. assert lower or upper, "At least one of 'lower' or 'upper' must not be None"
  165. if lower:
  166. if upper:
  167. assert lower <= upper, "clamp lower bound is bigger that upper bound"
  168. return minimum(maximum(inp, lower), upper)
  169. else:
  170. return maximum(inp, lower)
  171. else:
  172. return minimum(inp, upper)

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台

Contributors (1)