You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

elemwise.py 10 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492
  1. # -*- coding: utf-8 -*-
  2. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  3. #
  4. # Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
  5. #
  6. # Unless required by applicable law or agreed to in writing,
  7. # software distributed under the License is distributed on an
  8. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  9. # pylint: disable=unused-argument,invalid-name,redefined-builtin,arguments-out-of-order
  10. import functools
  11. from ..core.ops import builtin
  12. from ..core.tensor import megbrain_graph, utils
  13. from ..core.tensor.core import apply
  14. from ..device import get_default_device
  15. from ..tensor import Tensor
  16. __all__ = [
  17. "abs",
  18. "add",
  19. "acos",
  20. "asin",
  21. "atan",
  22. "atan2",
  23. "asinh",
  24. "acosh",
  25. "atanh",
  26. "bitwise_and", # TODO
  27. "bitwise_not", # TODO
  28. "bitwise_or", # TODO
  29. "bitwise_xor", # TODO
  30. "ceil",
  31. "clamp",
  32. "cos",
  33. "cosh",
  34. "div",
  35. "eq",
  36. "exp",
  37. "expm1",
  38. "floor",
  39. "floor_div",
  40. "gt",
  41. "ge",
  42. "hswish",
  43. "hsigmoid",
  44. "left_shift",
  45. "lt",
  46. "le",
  47. "log",
  48. "log1p",
  49. "logical_and",
  50. "logical_not",
  51. "logical_or",
  52. "logical_xor",
  53. "maximum",
  54. "minimum",
  55. "mod",
  56. "mul",
  57. "neg",
  58. "ne",
  59. "pow",
  60. "relu",
  61. "relu6",
  62. "right_shift",
  63. "round",
  64. "sigmoid",
  65. "sin",
  66. "sinh",
  67. "sqrt",
  68. "square",
  69. "sub",
  70. "tan",
  71. "tanh",
  72. "fast_tanh",
  73. ]
  74. def _elwise(*args, mode):
  75. op = builtin.Elemwise(mode=mode)
  76. tensor_args = list(
  77. filter(lambda x: isinstance(x, (Tensor, megbrain_graph.VarNode)), args)
  78. )
  79. if len(tensor_args) == 0:
  80. dtype = utils.dtype_promotion(args)
  81. first_arg = Tensor(args[0], dtype=dtype, device=get_default_device())
  82. args = utils.convert_inputs(first_arg, *args[1:])
  83. else:
  84. args = utils.convert_inputs(*args)
  85. if mode in ("true_div", "exp", "pow", "log", "expm1", "log1p"):
  86. args = tuple(map(lambda x: x.astype("float32"), args))
  87. (result,) = apply(op, *args)
  88. return result
  89. def _logical(*args, mode):
  90. op = builtin.CondExecPredLogical(mode=mode)
  91. args = utils.convert_inputs(*args)
  92. (result,) = apply(op, *args)
  93. return result
  94. def _elemwise_multi_type(*args, mode, **kwargs):
  95. op = builtin.ElemwiseMultiType(mode=mode, **kwargs)
  96. args = utils.convert_inputs(*args)
  97. (result,) = apply(op, *args)
  98. return result
  99. # math operations
  100. def add(x, y):
  101. """Element-wise addition.
  102. At least one operand should be tensor.
  103. same for sub/mul/div/floor_div/pow/mod/atan2/eq/ne/lt/le/gt/ge/maximum/minmium.
  104. """
  105. return _elwise(x, y, mode="add")
  106. def sub(x, y):
  107. """Element-wise subtract."""
  108. return _elwise(x, y, mode="sub")
  109. def mul(x, y):
  110. """Element-wise multiplication."""
  111. return _elwise(x, y, mode="mul")
  112. def div(x, y):
  113. """Element-wise (x / y)."""
  114. return _elwise(x, y, mode="true_div")
  115. def floor_div(x, y):
  116. """Element-wise floor(x / y)."""
  117. return _elwise(x, y, mode="floor_divide")
  118. def neg(x):
  119. """Element-wise negation."""
  120. return _elwise(x, mode="negate")
  121. def pow(x, y):
  122. """Element-wise power."""
  123. return _elwise(x, y, mode="pow")
  124. def mod(x, y):
  125. """Element-wise remainder of division."""
  126. return _elwise(x, y, mode="mod")
  127. def abs(x):
  128. """Element-wise absolute value."""
  129. return _elwise(x, mode="abs")
  130. def exp(x):
  131. """Element-wise exponential."""
  132. return _elwise(x, mode="exp")
  133. def expm1(x):
  134. """Element-wise exp(x)-1."""
  135. return _elwise(x, mode="expm1")
  136. def log(x):
  137. """Element-wise logarithm (base `e`)."""
  138. return _elwise(x, mode="log")
  139. def log1p(x):
  140. """Element-wise log(x+1) (base `e`)."""
  141. return _elwise(x, mode="log1p")
  142. def sqrt(inp: Tensor) -> Tensor:
  143. """
  144. Return a new tensor with the square-root of the elements of ``inp``.
  145. For negative value, return nan.
  146. :param inp: The input tensor
  147. :return: The computed tensor
  148. Examples:
  149. .. testcode::
  150. import numpy as np
  151. import megengine as mge
  152. import megengine.functional as F
  153. data = mge.tensor(np.arange(0, 6, dtype=np.float32).reshape(2, 3))
  154. out = F.sqrt(data)
  155. print(out.numpy())
  156. Outputs:
  157. .. testoutput::
  158. [[0. 1. 1.4142]
  159. [1.7321 2. 2.2361 ]]
  160. """
  161. return inp ** 0.5
  162. def square(inp: Tensor) -> Tensor:
  163. """
  164. Return a new tensor with the square of the elements of ``inp``
  165. :param inp: The input tensor
  166. :return: The computed tensor
  167. Examples:
  168. .. testcode::
  169. import numpy as np
  170. import megengine as mge
  171. import megengine.functional as F
  172. data = mge.tensor(np.arange(0, 6, dtype=np.float32).reshape(2, 3))
  173. out = F.square(data)
  174. print(out.numpy())
  175. Outputs:
  176. .. testoutput::
  177. [[0. 1. 4.]
  178. [9. 16. 25.]]
  179. """
  180. return inp ** 2
  181. def round(x):
  182. """Round tensor to int element-wise."""
  183. return _elwise(x, mode="round")
  184. def ceil(x):
  185. """Return the ceil of the input, element-wise."""
  186. return _elwise(x, mode="ceil")
  187. def floor(x):
  188. """Calculate the floor element-wise"""
  189. return _elwise(x, mode="floor")
  190. # trigonometric functions
  191. def cos(x):
  192. """Cosine, element-wise."""
  193. return _elwise(x, mode="cos")
  194. def sin(x):
  195. """Sine, element-wise."""
  196. return _elwise(x, mode="sin")
  197. def tan(x):
  198. return sin(x) / cos(x)
  199. def acos(x):
  200. """Inverse cosine, element-wise."""
  201. return _elwise(x, mode="acos")
  202. def asin(x):
  203. """Inverse sine, element-wise."""
  204. return _elwise(x, mode="asin")
  205. def atan(x):
  206. return _elwise(x, 1, mode="atan2")
  207. def atan2(y, x):
  208. return _elwise(y, x, mode="atan2")
  209. def cosh(x):
  210. r"""Compute element-wise hyperbolic cosine."""
  211. return 0.5 * (exp(x) + exp(-x))
  212. def sinh(x):
  213. r"""Compute element-wise hyperbolic sine."""
  214. u = expm1(x)
  215. return 0.5 * u / (u + 1) * (u + 2)
  216. def tanh(x):
  217. r"""Compute element-wise hyperbolic tangent."""
  218. return _elwise(x, mode="tanh")
  219. def asinh(x):
  220. r"""Compute element-wise inverse hyperbolic sine."""
  221. return log(x + (x ** 2 + 1) ** 0.5)
  222. def acosh(x):
  223. r"""Compute element-wise inverse hyperbolic cosine."""
  224. return log(x + (x ** 2 - 1) ** 0.5)
  225. def atanh(x):
  226. r"""Compute element-wise inverse hyperbolic tangent."""
  227. return log1p(2 * x / (1 - x)) / 2
  228. def fast_tanh(x):
  229. r"""Compute element-wise fast tanh; this is an approximation:
  230. .. math::
  231. \text{fast_tanh}(x) = x * (27. + x * x) / (27. + 9. * x * x)
  232. """
  233. return _elwise(x, mode="fast_tanh")
  234. # bit-twiddling functions
  235. def left_shift(x, y):
  236. return _elwise(x, y, mode="shl")
  237. def right_shift(x, y):
  238. return _elwise(x, y, mode="shl")
  239. def bitwise_and(x, y):
  240. raise NotImplementedError
  241. def bitwise_not(x):
  242. raise NotImplementedError
  243. def bitwise_or(x, y):
  244. raise NotImplementedError
  245. def bitwise_xor(x, y):
  246. raise NotImplementedError
  247. # logical functions
  248. def logical_and(x, y):
  249. return _elwise(x, y, mode="AND")
  250. def logical_not(x):
  251. return _elwise(x, mode="NOT")
  252. def logical_or(x, y):
  253. return _elwise(x, y, mode="OR")
  254. def logical_xor(x, y):
  255. return _elwise(x, y, mode="XOR")
  256. # comparison functions
  257. def eq(x, y):
  258. """Return (x == y) element-wise."""
  259. return _elwise(x, y, mode="eq")
  260. def ne(x, y):
  261. return x != y
  262. def lt(x, y):
  263. """Return (x < y) element-wise."""
  264. return _elwise(x, y, mode="lt")
  265. def le(x, y):
  266. """Return (x =< y) element-wise."""
  267. return _elwise(x, y, mode="leq")
  268. def gt(x, y):
  269. """Return (x > y) element-wise."""
  270. return _elwise(y, x, mode="lt")
  271. def ge(x, y):
  272. """Return (x >= y) element-wise"""
  273. return _elwise(y, x, mode="leq")
  274. def hswish(x):
  275. """Return x * relu6(x + 3) / 6 element-wise"""
  276. return _elwise(x, mode="h_swish")
  277. def hsigmoid(x):
  278. """Return relu6(x + 3) / 6 element-wise"""
  279. return relu6(x + 3) / 6
  280. def relu(x):
  281. """Return `max(x, 0)` element-wise."""
  282. return _elwise(x, mode="relu")
  283. def relu6(x):
  284. """Return min(max(x, 0), 6) element-wise."""
  285. return minimum(maximum(x, 0), 6)
  286. def sigmoid(x):
  287. """Return 1 / ( 1 + exp( -x ) ) element-wise."""
  288. return _elwise(x, mode="sigmoid")
  289. def maximum(x, y):
  290. """Element-wise maximum of array elements."""
  291. return _elwise(x, y, mode="max")
  292. def minimum(x, y):
  293. """Element-wise minimum of array elements."""
  294. return _elwise(x, y, mode="min")
  295. def clamp(inp: Tensor, lower=None, upper=None) -> Tensor:
  296. r"""
  297. Clamp all elements in :attr:`inp` into the range `[` :attr:`lower`, :attr:`upper` `]` and return
  298. a resulting tensor:
  299. .. math::
  300. y_i = \begin{cases}
  301. \text{lower} & \text{if } x_i < \text{lower} \\
  302. x_i & \text{if } \text{lower} \leq x_i \leq \text{upper} \\
  303. \text{upper} & \text{if } x_i > \text{upper}
  304. \end{cases}
  305. :param inp: the input tensor.
  306. :param lower: lower-bound of the range to be clamped to
  307. :param upper: upper-bound of the range to be clamped to
  308. Example:
  309. .. testcode::
  310. import numpy as np
  311. from megengine import tensor
  312. import megengine.functional as F
  313. a = tensor(np.arange(5).astype(np.int32))
  314. print(F.clamp(a, 2, 4).numpy())
  315. print(F.clamp(a, lower=3).numpy())
  316. print(F.clamp(a, upper=3).numpy())
  317. .. testoutput::
  318. [2 2 2 3 4]
  319. [3 3 3 3 4]
  320. [0 1 2 3 3]
  321. """
  322. assert (
  323. lower is not None or upper is not None
  324. ), "At least one of 'lower' or 'upper' must not be None"
  325. if lower is not None:
  326. if upper is not None:
  327. assert lower <= upper, "clamp lower bound is bigger that upper bound"
  328. return minimum(maximum(inp, lower), upper)
  329. else:
  330. return maximum(inp, lower)
  331. else:
  332. return minimum(inp, upper)

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台