You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

elemwise.py 12 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604
  1. # -*- coding: utf-8 -*-
  2. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  3. #
  4. # Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
  5. #
  6. # Unless required by applicable law or agreed to in writing,
  7. # software distributed under the License is distributed on an
  8. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  9. # pylint: disable=unused-argument,invalid-name,redefined-builtin,arguments-out-of-order
  10. import functools
  11. from ..core.ops import builtin
  12. from ..core.tensor import megbrain_graph, utils
  13. from ..core.tensor.core import apply
  14. from ..device import get_default_device
  15. from ..tensor import Tensor
  16. __all__ = [
  17. "abs",
  18. "add",
  19. "acos",
  20. "asin",
  21. "atan",
  22. "atan2",
  23. "asinh",
  24. "acosh",
  25. "atanh",
  26. "ceil",
  27. "clamp",
  28. "cos",
  29. "cosh",
  30. "div",
  31. "eq",
  32. "exp",
  33. "expm1",
  34. "fast_tanh",
  35. "floor",
  36. "floor_div",
  37. "gt",
  38. "ge",
  39. "hswish",
  40. "hsigmoid",
  41. "left_shift",
  42. "lt",
  43. "le",
  44. "log",
  45. "log1p",
  46. "logical_and",
  47. "logical_not",
  48. "logical_or",
  49. "logical_xor",
  50. "maximum",
  51. "minimum",
  52. "mod",
  53. "mul",
  54. "neg",
  55. "ne",
  56. "pow",
  57. "relu",
  58. "relu6",
  59. "right_shift",
  60. "round",
  61. "sigmoid",
  62. "sin",
  63. "sinh",
  64. "sqrt",
  65. "square",
  66. "sub",
  67. "tan",
  68. "tanh",
  69. ]
  70. def _elwise(*args, mode):
  71. op = builtin.Elemwise(mode=mode)
  72. tensor_args = list(
  73. filter(lambda x: isinstance(x, (Tensor, megbrain_graph.VarNode)), args)
  74. )
  75. if len(tensor_args) == 0:
  76. dtype = utils.dtype_promotion(args)
  77. first_arg = Tensor(args[0], dtype=dtype, device=get_default_device())
  78. args = utils.convert_inputs(first_arg, *args[1:])
  79. else:
  80. args = utils.convert_inputs(*args)
  81. if mode in ("true_div", "exp", "pow", "log", "expm1", "log1p"):
  82. args = tuple(map(lambda x: x.astype("float32"), args))
  83. (result,) = apply(op, *args)
  84. return result
  85. def _logical(*args, mode):
  86. op = builtin.CondExecPredLogical(mode=mode)
  87. args = utils.convert_inputs(*args)
  88. (result,) = apply(op, *args)
  89. return result
  90. def _elemwise_multi_type(*args, mode, **kwargs):
  91. op = builtin.ElemwiseMultiType(mode=mode, **kwargs)
  92. args = utils.convert_inputs(*args)
  93. (result,) = apply(op, *args)
  94. return result
  95. # math operations
  96. def add(x, y):
  97. """Element-wise addition.
  98. At least one operand should be tensor.
  99. Same for sub/mul/div/floor_div/pow/mod/atan2/eq/ne/lt/le/gt/ge/maximum/minmium.
  100. :param x: input tensor.
  101. :return: computed tensor.
  102. Examples:
  103. .. testcode::
  104. import numpy as np
  105. from megengine import tensor
  106. import megengine.functional as F
  107. x = tensor(np.arange(0, 6, dtype=np.float32).reshape(2, 3))
  108. y = tensor(np.arange(0, 6, dtype=np.float32).reshape(2, 3))
  109. out = F.add(x, y)
  110. print(out.numpy())
  111. Outputs:
  112. .. testoutput::
  113. [[ 0. 2. 4.]
  114. [ 6. 8. 10.]]
  115. """
  116. return _elwise(x, y, mode="add")
  117. def sub(x, y):
  118. """Element-wise subtraction."""
  119. return _elwise(x, y, mode="sub")
  120. def mul(x, y):
  121. """Element-wise multiplication."""
  122. return _elwise(x, y, mode="mul")
  123. def div(x, y):
  124. """Element-wise (x / y)."""
  125. return _elwise(x, y, mode="true_div")
  126. def floor_div(x, y):
  127. """Element-wise floor(x / y)."""
  128. return _elwise(x, y, mode="floor_divide")
  129. def neg(x):
  130. """Element-wise negation."""
  131. return _elwise(x, mode="negate")
  132. def pow(x, y):
  133. """Element-wise power."""
  134. return _elwise(x, y, mode="pow")
  135. def mod(x, y):
  136. """Element-wise remainder of division."""
  137. return _elwise(x, y, mode="mod")
  138. def abs(x):
  139. """Element-wise absolute value."""
  140. return _elwise(x, mode="abs")
  141. def exp(x):
  142. """Element-wise exponential."""
  143. return _elwise(x, mode="exp")
  144. def expm1(x):
  145. """Element-wise exp(x)-1."""
  146. return _elwise(x, mode="expm1")
  147. def log(x):
  148. """Element-wise logarithm (base `e`)."""
  149. return _elwise(x, mode="log")
  150. def log1p(x):
  151. """Element-wise log(x+1) (base `e`)."""
  152. return _elwise(x, mode="log1p")
  153. def sqrt(x: Tensor) -> Tensor:
  154. """Element-wise sqrt.
  155. For negative input value, return ``NaN``.
  156. :param x: input tensor.
  157. :return: computed tensor.
  158. Examples:
  159. .. testcode::
  160. import numpy as np
  161. from megengine import tensor
  162. import megengine.functional as F
  163. x = tensor(np.arange(0, 6, dtype=np.float32).reshape(2, 3))
  164. out = F.sqrt(x)
  165. print(out.numpy())
  166. Outputs:
  167. .. testoutput::
  168. [[0. 1. 1.4142]
  169. [1.7321 2. 2.2361]]
  170. """
  171. return x ** 0.5
  172. def square(x: Tensor) -> Tensor:
  173. """
  174. Return a new tensor with the square of the elements of input tensor.
  175. :param inp: The input tensor
  176. :return: The computed tensor
  177. Examples:
  178. .. testcode::
  179. import numpy as np
  180. import megengine as mge
  181. import megengine.functional as F
  182. data = mge.tensor(np.arange(0, 6, dtype=np.float32).reshape(2, 3))
  183. out = F.square(data)
  184. print(out.numpy())
  185. Outputs:
  186. .. testoutput::
  187. [[ 0. 1. 4.]
  188. [ 9. 16. 25.]]
  189. """
  190. return x ** 2
  191. def round(x):
  192. """Element-wise rounding to int."""
  193. return _elwise(x, mode="round")
  194. def ceil(x):
  195. """Element-wise ceiling."""
  196. return _elwise(x, mode="ceil")
  197. def floor(x):
  198. """Element-wise floor."""
  199. return _elwise(x, mode="floor")
  200. def maximum(x, y):
  201. """Element-wise maximum of array elements."""
  202. return _elwise(x, y, mode="max")
  203. def minimum(x, y):
  204. """Element-wise minimum of array elements."""
  205. return _elwise(x, y, mode="min")
  206. # trigonometric functions
  207. def cos(x):
  208. """Element-wise cosine.
  209. :param x: input tensor.
  210. :return: computed tensor.
  211. Examples:
  212. .. testcode::
  213. import numpy as np
  214. from megengine import tensor
  215. import megengine.functional as F
  216. x = tensor(np.arange(0, 6, dtype=np.float32).reshape(2, 3))
  217. out = F.cos(x)
  218. print(out.numpy())
  219. Outputs:
  220. .. testoutput::
  221. [[ 1. 0.5403 -0.4161]
  222. [-0.99 -0.6536 0.2837]]
  223. """
  224. return _elwise(x, mode="cos")
  225. def sin(x):
  226. """Element-wise sine."""
  227. return _elwise(x, mode="sin")
  228. def tan(x):
  229. """Element-wise tangent."""
  230. return sin(x) / cos(x)
  231. def acos(x):
  232. """Element-wise inverse cosine."""
  233. return _elwise(x, mode="acos")
  234. def asin(x):
  235. """Element-wise inverse sine."""
  236. return _elwise(x, mode="asin")
  237. def atan(x):
  238. """Element-wise inverse tangent."""
  239. return _elwise(x, 1, mode="atan2")
  240. def atan2(y, x):
  241. """Element-wise 2-argument arctangent."""
  242. return _elwise(y, x, mode="atan2")
  243. def cosh(x):
  244. r"""Element-wise hyperbolic cosine."""
  245. return 0.5 * (exp(x) + exp(-x))
  246. def sinh(x):
  247. r"""Element-wise hyperbolic sine."""
  248. u = expm1(x)
  249. return 0.5 * u / (u + 1) * (u + 2)
  250. def tanh(x):
  251. r"""Element-wise hyperbolic tangent."""
  252. return _elwise(x, mode="tanh")
  253. def asinh(x):
  254. r"""Element-wise inverse hyperbolic sine."""
  255. return log(x + (x ** 2 + 1) ** 0.5)
  256. def acosh(x):
  257. r"""Element-wise inverse hyperbolic cosine."""
  258. return log(x + (x ** 2 - 1) ** 0.5)
  259. def atanh(x):
  260. r"""Element-wise inverse hyperbolic tangent."""
  261. return log1p(2 * x / (1 - x)) / 2
  262. def fast_tanh(x):
  263. r"""Element-wise fast tanh; this is an approximation:
  264. .. math::
  265. \text{fast_tanh}(x) = x * (27. + x * x) / (27. + 9. * x * x)
  266. """
  267. return _elwise(x, mode="fast_tanh")
  268. # bit-twiddling functions
  269. def left_shift(x, y):
  270. """Element-wise bitwise binary: x << y.
  271. :param x: input tensor, should be int.
  272. :param y: how many bits to be left-shifted.
  273. :return: computed tensor.
  274. Examples:
  275. .. testcode::
  276. import numpy as np
  277. from megengine import tensor
  278. import megengine.functional as F
  279. x = tensor(np.arange(0, 6, dtype=np.int32).reshape(2, 3))
  280. out = F.left_shift(x, 2)
  281. print(out.numpy())
  282. Outputs:
  283. .. testoutput::
  284. [[ 0 4 8]
  285. [12 16 20]]
  286. """
  287. return _elwise(x, y, mode="shl")
  288. def right_shift(x, y):
  289. """Element-wise bitwise binary: x >> y."""
  290. return _elwise(x, y, mode="shr")
  291. # logical functions
  292. def logical_and(x, y):
  293. """Element-wise logical and: x && y."""
  294. return _elwise(x, y, mode="AND")
  295. def logical_not(x):
  296. """Element-wise logical not: ~x."""
  297. return _elwise(x, mode="NOT")
  298. def logical_or(x, y):
  299. """Element-wise logical or: x || y."""
  300. return _elwise(x, y, mode="OR")
  301. def logical_xor(x, y):
  302. """Element-wise logical xor: x ^ y."""
  303. return _elwise(x, y, mode="XOR")
  304. # comparison functions
  305. def eq(x, y):
  306. """Element-wise (x == y).
  307. :param x: input tensor 1.
  308. :param y: input tensor 2.
  309. :return: computed tensor.
  310. Examples:
  311. .. testcode::
  312. import numpy as np
  313. from megengine import tensor
  314. import megengine.functional as F
  315. x = tensor(np.arange(0, 6, dtype=np.float32).reshape(2, 3))
  316. y = tensor(np.arange(0, 6, dtype=np.float32).reshape(2, 3))
  317. out = F.eq(x, y)
  318. print(out.numpy())
  319. Outputs:
  320. .. testoutput::
  321. [[1. 1. 1.]
  322. [1. 1. 1.]]
  323. """
  324. return _elwise(x, y, mode="eq")
  325. def ne(x, y):
  326. """Element-wise (x != y)."""
  327. return x != y
  328. def lt(x, y):
  329. """Element-wise (x < y)."""
  330. return _elwise(x, y, mode="lt")
  331. def le(x, y):
  332. """Element-wise (x <= y)."""
  333. return _elwise(x, y, mode="leq")
  334. def gt(x, y):
  335. """Element-wise (x > y)."""
  336. return _elwise(y, x, mode="lt")
  337. def ge(x, y):
  338. """Element-wise (x >= y)."""
  339. return _elwise(y, x, mode="leq")
  340. # other functions
  341. def hswish(x):
  342. """Element-wise x * relu6(x + 3) / 6.
  343. :param x: input tensor.
  344. :return: computed tensor.
  345. Example:
  346. .. testcode::
  347. import numpy as np
  348. from megengine import tensor
  349. import megengine.functional as F
  350. x = tensor(np.arange(5).astype(np.float32))
  351. out = F.hswish(x)
  352. print(out.numpy())
  353. .. testoutput::
  354. [0. 0.6667 1.6667 3. 4. ]
  355. """
  356. return _elwise(x, mode="h_swish")
  357. def hsigmoid(x):
  358. """Element-wise relu6(x + 3) / 6."""
  359. return relu6(x + 3) / 6
  360. def relu(x):
  361. """Element-wise `max(x, 0)`."""
  362. return _elwise(x, mode="relu")
  363. def relu6(x):
  364. """Element-wise min(max(x, 0), 6)."""
  365. return minimum(maximum(x, 0), 6)
  366. def sigmoid(x):
  367. """Element-wise 1 / ( 1 + exp( -x ) )."""
  368. return _elwise(x, mode="sigmoid")
  369. def clamp(x: Tensor, lower=None, upper=None) -> Tensor:
  370. r"""Clamps all elements in input tensor into the range `[` :attr:`lower`, :attr:`upper` `]` and returns
  371. a resulting tensor:
  372. .. math::
  373. y_i = \begin{cases}
  374. \text{lower} & \text{if } x_i < \text{lower} \\
  375. x_i & \text{if } \text{lower} \leq x_i \leq \text{upper} \\
  376. \text{upper} & \text{if } x_i > \text{upper}
  377. \end{cases}
  378. :param x: input tensor.
  379. :param lower: lower-bound of the range to be clamped to.
  380. :param upper: upper-bound of the range to be clamped to.
  381. :return: output clamped tensor.
  382. Examples:
  383. .. testcode::
  384. import numpy as np
  385. from megengine import tensor
  386. import megengine.functional as F
  387. a = tensor(np.arange(5).astype(np.int32))
  388. print(F.clamp(a, 2, 4).numpy())
  389. print(F.clamp(a, lower=3).numpy())
  390. print(F.clamp(a, upper=3).numpy())
  391. Outputs:
  392. .. testoutput::
  393. [2 2 2 3 4]
  394. [3 3 3 3 4]
  395. [0 1 2 3 3]
  396. """
  397. assert (
  398. lower is not None or upper is not None
  399. ), "At least one of 'lower' or 'upper' must not be None"
  400. if lower is not None:
  401. if upper is not None:
  402. assert lower <= upper, "clamp lower bound is bigger that upper bound"
  403. return minimum(maximum(x, lower), upper)
  404. else:
  405. return maximum(x, lower)
  406. else:
  407. return minimum(x, upper)

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台