You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

elemwise.py 13 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626
  1. # -*- coding: utf-8 -*-
  2. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  3. #
  4. # Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
  5. #
  6. # Unless required by applicable law or agreed to in writing,
  7. # software distributed under the License is distributed on an
  8. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  9. # pylint: disable=unused-argument,invalid-name,redefined-builtin,arguments-out-of-order
  10. import functools
  11. from ..core.ops import builtin
  12. from ..core.ops.builtin import Elemwise
  13. from ..core.tensor import megbrain_graph, utils
  14. from ..core.tensor.core import apply
  15. from ..core.tensor.utils import isscalar, setscalar
  16. from ..device import get_default_device
  17. from ..jit.tracing import is_tracing
  18. from ..tensor import Tensor
  19. __all__ = [
  20. "abs",
  21. "add",
  22. "acos",
  23. "asin",
  24. "atan",
  25. "atan2",
  26. "asinh",
  27. "acosh",
  28. "atanh",
  29. "ceil",
  30. "clip",
  31. "cos",
  32. "cosh",
  33. "div",
  34. "equal",
  35. "exp",
  36. "expm1",
  37. "floor",
  38. "floor_div",
  39. "greater",
  40. "greater_equal",
  41. "hswish",
  42. "hsigmoid",
  43. "left_shift",
  44. "less",
  45. "less_equal",
  46. "log",
  47. "log1p",
  48. "logical_and",
  49. "logical_not",
  50. "logical_or",
  51. "logical_xor",
  52. "maximum",
  53. "minimum",
  54. "mod",
  55. "mul",
  56. "neg",
  57. "not_equal",
  58. "pow",
  59. "relu",
  60. "relu6",
  61. "right_shift",
  62. "round",
  63. "sigmoid",
  64. "sin",
  65. "sinh",
  66. "sqrt",
  67. "square",
  68. "sub",
  69. "tan",
  70. "tanh",
  71. ]
  72. class _ElemwiseMode(Elemwise.Mode):
  73. @classmethod
  74. def __normalize(cls, val):
  75. if isinstance(val, str):
  76. if not hasattr(cls, "__member_upper_dict__"):
  77. cls.__member_upper_dict__ = {
  78. k.upper(): v for k, v in cls.__members__.items()
  79. }
  80. val = cls.__member_upper_dict__.get(val.upper(), val)
  81. return val
  82. @classmethod
  83. def convert(cls, val):
  84. val = cls.__normalize(val)
  85. if isinstance(val, cls):
  86. return val
  87. return cls(val)
  88. def _elwise(*args, mode):
  89. mode = _ElemwiseMode.convert(mode)
  90. op = builtin.Elemwise(mode)
  91. tensor_args = list(
  92. filter(lambda x: isinstance(x, (Tensor, megbrain_graph.VarNode)), args)
  93. )
  94. if len(tensor_args) == 0:
  95. dtype = utils.dtype_promotion(args)
  96. first_arg = Tensor(args[0], dtype=dtype, device=get_default_device())
  97. args = utils.convert_inputs(first_arg, *args[1:])
  98. else:
  99. args = utils.convert_inputs(*args)
  100. if mode in ("true_div", "exp", "pow", "log", "expm1", "log1p"):
  101. args = tuple(map(lambda x: x.astype("float32"), args))
  102. _isscalar = True
  103. for i in args:
  104. if isscalar(i) == False:
  105. _isscalar = False
  106. break
  107. (result,) = apply(op, *args)
  108. if _isscalar:
  109. setscalar(result)
  110. return result
  111. def _elemwise_multi_type(*args, mode, **kwargs):
  112. op = builtin.ElemwiseMultiType(mode=mode, **kwargs)
  113. args = utils.convert_inputs(*args)
  114. (result,) = apply(op, *args)
  115. return result
  116. # math operations
  117. def add(x, y):
  118. """
  119. Element-wise `addition`.
  120. At least one operand should be tensor.
  121. Same for sub/mul/div/floor_div/pow/mod/atan2/equal/not_equal/less/less_equal/greater/greater_equal/maximum/minmium.
  122. :param x: input tensor.
  123. :return: computed tensor.
  124. Examples:
  125. .. testcode::
  126. import numpy as np
  127. from megengine import tensor
  128. import megengine.functional as F
  129. x = tensor(np.arange(0, 6, dtype=np.float32).reshape(2, 3))
  130. y = tensor(np.arange(0, 6, dtype=np.float32).reshape(2, 3))
  131. out = F.add(x, y)
  132. print(out.numpy())
  133. Outputs:
  134. .. testoutput::
  135. [[ 0. 2. 4.]
  136. [ 6. 8. 10.]]
  137. """
  138. return _elwise(x, y, mode=Elemwise.Mode.ADD)
  139. def sub(x, y):
  140. """Element-wise `subtraction`."""
  141. return _elwise(x, y, mode=Elemwise.Mode.SUB)
  142. def mul(x, y):
  143. """Element-wise `multiplication`."""
  144. return _elwise(x, y, mode=Elemwise.Mode.MUL)
  145. def div(x, y):
  146. """Element-wise `(x / y)`."""
  147. return _elwise(x, y, mode=Elemwise.Mode.TRUE_DIV)
  148. def floor_div(x, y):
  149. """Element-wise `floor(x / y)`."""
  150. return _elwise(x, y, mode=Elemwise.Mode.FLOOR_DIVIDE)
  151. def neg(x):
  152. """Element-wise `negation`."""
  153. return _elwise(x, mode=Elemwise.Mode.NEGATE)
  154. def pow(x, y):
  155. """Element-wise `power`."""
  156. return _elwise(x, y, mode=Elemwise.Mode.POW)
  157. def mod(x, y):
  158. """Element-wise `remainder of division`."""
  159. return _elwise(x, y, mode=Elemwise.Mode.MOD)
  160. def abs(x):
  161. """Element-wise `absolute value`."""
  162. return _elwise(x, mode=Elemwise.Mode.ABS)
  163. def exp(x):
  164. """Element-wise `exponential`."""
  165. return _elwise(x, mode=Elemwise.Mode.EXP)
  166. def expm1(x):
  167. """Element-wise `exp(x)-1`."""
  168. return _elwise(x, mode=Elemwise.Mode.EXPM1)
  169. def log(x):
  170. """Element-wise `logarithm (base e)`."""
  171. return _elwise(x, mode=Elemwise.Mode.LOG)
  172. def log1p(x):
  173. """Element-wise `log(x+1) (base e)`."""
  174. return _elwise(x, mode=Elemwise.Mode.LOG1P)
  175. def sqrt(x: Tensor) -> Tensor:
  176. """
  177. Element-wise `sqrt`.
  178. Returns ``NaN`` for negative input value.
  179. :param x: input tensor.
  180. :return: computed tensor.
  181. Examples:
  182. .. testcode::
  183. import numpy as np
  184. from megengine import tensor
  185. import megengine.functional as F
  186. x = tensor(np.arange(0, 6, dtype=np.float32).reshape(2, 3))
  187. out = F.sqrt(x)
  188. print(out.numpy().round(decimals=4))
  189. Outputs:
  190. .. testoutput::
  191. [[0. 1. 1.4142]
  192. [1.7321 2. 2.2361]]
  193. """
  194. return x ** 0.5
  195. def square(x: Tensor) -> Tensor:
  196. """
  197. Returns a new tensor with the square of the elements of input tensor.
  198. :param inp: input tensor.
  199. :return: computed tensor.
  200. Examples:
  201. .. testcode::
  202. import numpy as np
  203. import megengine as mge
  204. import megengine.functional as F
  205. data = mge.tensor(np.arange(0, 6, dtype=np.float32).reshape(2, 3))
  206. out = F.square(data)
  207. print(out.numpy().round(decimals=4))
  208. Outputs:
  209. .. testoutput::
  210. [[ 0. 1. 4.]
  211. [ 9. 16. 25.]]
  212. """
  213. return x ** 2
  214. def round(x):
  215. """Element-wise `rounding to int`."""
  216. return _elwise(x, mode=Elemwise.Mode.ROUND)
  217. def ceil(x):
  218. """Element-wise `ceiling`."""
  219. return _elwise(x, mode=Elemwise.Mode.CEIL)
  220. def floor(x):
  221. """Element-wise `floor`."""
  222. return _elwise(x, mode=Elemwise.Mode.FLOOR)
  223. def maximum(x, y):
  224. """Element-wise `maximum of array elements`."""
  225. return _elwise(x, y, mode=Elemwise.Mode.MAX)
  226. def minimum(x, y):
  227. """Element-wise `minimum of array elements`."""
  228. return _elwise(x, y, mode=Elemwise.Mode.MIN)
  229. # trigonometric functions
  230. def cos(x):
  231. """
  232. Element-wise `cosine`.
  233. :param x: input tensor.
  234. :return: computed tensor.
  235. Examples:
  236. .. testcode::
  237. import numpy as np
  238. from megengine import tensor
  239. import megengine.functional as F
  240. x = tensor(np.arange(0, 6, dtype=np.float32).reshape(2, 3))
  241. out = F.cos(x)
  242. print(out.numpy().round(decimals=4))
  243. Outputs:
  244. .. testoutput::
  245. [[ 1. 0.5403 -0.4161]
  246. [-0.99 -0.6536 0.2837]]
  247. """
  248. return _elwise(x, mode=Elemwise.Mode.COS)
  249. def sin(x):
  250. """Element-wise `sine`."""
  251. return _elwise(x, mode=Elemwise.Mode.SIN)
  252. def tan(x):
  253. """Element-wise `tangent`."""
  254. return sin(x) / cos(x)
  255. def acos(x):
  256. """Element-wise `inverse cosine`."""
  257. return _elwise(x, mode=Elemwise.Mode.ACOS)
  258. def asin(x):
  259. """Element-wise `inverse sine`."""
  260. return _elwise(x, mode=Elemwise.Mode.ASIN)
  261. def atan(x):
  262. """Element-wise `inverse tangent`."""
  263. return _elwise(x, 1, mode=Elemwise.Mode.ATAN2)
  264. def atan2(y, x):
  265. """Element-wise `2-argument arctangent`."""
  266. return _elwise(y, x, mode=Elemwise.Mode.ATAN2)
  267. def cosh(x):
  268. r"""Element-wise `hyperbolic cosine`."""
  269. return 0.5 * (exp(x) + exp(-x))
  270. def sinh(x):
  271. r"""Element-wise `hyperbolic sine`."""
  272. u = expm1(x)
  273. return 0.5 * u / (u + 1) * (u + 2)
  274. def tanh(x):
  275. r"""Element-wise `hyperbolic tangent`."""
  276. return _elwise(x, mode=Elemwise.Mode.TANH)
  277. def asinh(x):
  278. r"""Element-wise `inverse hyperbolic sine`."""
  279. return log(x + (x ** 2 + 1) ** 0.5)
  280. def acosh(x):
  281. r"""Element-wise `inverse hyperbolic cosine`."""
  282. return log(x + (x ** 2 - 1) ** 0.5)
  283. def atanh(x):
  284. r"""Element-wise `inverse hyperbolic tangent`."""
  285. return log1p(2 * x / (1 - x)) / 2
  286. # bit-twiddling functions
  287. def left_shift(x, y):
  288. """
  289. Element-wise `bitwise binary: x << y`.
  290. :param x: input tensor, should be int.
  291. :param y: how many bits to be left-shifted.
  292. :return: computed tensor.
  293. Examples:
  294. .. testcode::
  295. import numpy as np
  296. from megengine import tensor
  297. import megengine.functional as F
  298. x = tensor(np.arange(0, 6, dtype=np.int32).reshape(2, 3))
  299. out = F.left_shift(x, 2)
  300. print(out.numpy())
  301. Outputs:
  302. .. testoutput::
  303. [[ 0 4 8]
  304. [12 16 20]]
  305. """
  306. return _elwise(x, y, mode=Elemwise.Mode.SHL)
  307. def right_shift(x, y):
  308. """Element-wise `bitwise binary: x >> y`."""
  309. return _elwise(x, y, mode=Elemwise.Mode.SHR)
  310. # logical functions
  311. def logical_and(x, y):
  312. """Element-wise `logical and: x && y`."""
  313. return _elwise(x, y, mode=Elemwise.Mode.AND)
  314. def logical_not(x):
  315. """Element-wise `logical not: ~x`."""
  316. return _elwise(x, mode=Elemwise.Mode.NOT)
  317. def logical_or(x, y):
  318. """Element-wise `logical or: x || y`."""
  319. return _elwise(x, y, mode=Elemwise.Mode.OR)
  320. def logical_xor(x, y):
  321. """Element-wise `logical xor: x ^ y`."""
  322. return _elwise(x, y, mode=Elemwise.Mode.XOR)
  323. # comparison functions
  324. def equal(x, y):
  325. """
  326. Element-wise `(x == y)`.
  327. :param x: input tensor 1.
  328. :param y: input tensor 2.
  329. :return: computed tensor.
  330. Examples:
  331. .. testcode::
  332. import numpy as np
  333. from megengine import tensor
  334. import megengine.functional as F
  335. x = tensor(np.arange(0, 6, dtype=np.float32).reshape(2, 3))
  336. y = tensor(np.arange(0, 6, dtype=np.float32).reshape(2, 3))
  337. out = F.equal(x, y)
  338. print(out.numpy())
  339. Outputs:
  340. .. testoutput::
  341. [[1. 1. 1.]
  342. [1. 1. 1.]]
  343. """
  344. return _elwise(x, y, mode=Elemwise.Mode.EQ)
  345. def not_equal(x, y):
  346. """Element-wise `(x != y)`."""
  347. return x != y
  348. def less(x, y):
  349. """Element-wise `(x < y)`."""
  350. return _elwise(x, y, mode=Elemwise.Mode.LT)
  351. def less_equal(x, y):
  352. """Element-wise `(x <= y)`."""
  353. return _elwise(x, y, mode=Elemwise.Mode.LEQ)
  354. def greater(x, y):
  355. """Element-wise `(x > y)`."""
  356. return _elwise(y, x, mode=Elemwise.Mode.LT)
  357. def greater_equal(x, y):
  358. """Element-wise `(x >= y)`."""
  359. return _elwise(y, x, mode=Elemwise.Mode.LEQ)
  360. # other functions
  361. def hswish(x):
  362. """
  363. Element-wise `x * relu6(x + 3) / 6`.
  364. :param x: input tensor.
  365. :return: computed tensor.
  366. Example:
  367. .. testcode::
  368. import numpy as np
  369. from megengine import tensor
  370. import megengine.functional as F
  371. x = tensor(np.arange(5).astype(np.float32))
  372. out = F.hswish(x)
  373. print(out.numpy().round(decimals=4))
  374. .. testoutput::
  375. [0. 0.6667 1.6667 3. 4. ]
  376. """
  377. return _elwise(x, mode=Elemwise.Mode.H_SWISH)
  378. def hsigmoid(x):
  379. """Element-wise `relu6(x + 3) / 6`."""
  380. return relu6(x + 3) / 6
  381. def relu(x):
  382. """Element-wise `max(x, 0)`."""
  383. return _elwise(x, mode=Elemwise.Mode.RELU)
  384. def relu6(x):
  385. """Element-wise `min(max(x, 0), 6)`."""
  386. return minimum(maximum(x, 0), 6)
  387. def sigmoid(x):
  388. """Element-wise `1 / ( 1 + exp( -x ) )`."""
  389. return _elwise(x, mode=Elemwise.Mode.SIGMOID)
  390. def clip(x: Tensor, lower=None, upper=None) -> Tensor:
  391. r"""
  392. Clamps all elements in input tensor into the range `[` :attr:`lower`, :attr:`upper` `]` and returns
  393. a resulting tensor:
  394. .. math::
  395. y_i = \begin{cases}
  396. \text{lower} & \text{if } x_i < \text{lower} \\
  397. x_i & \text{if } \text{lower} \leq x_i \leq \text{upper} \\
  398. \text{upper} & \text{if } x_i > \text{upper}
  399. \end{cases}
  400. :param x: input tensor.
  401. :param lower: lower-bound of the range to be clamped to.
  402. :param upper: upper-bound of the range to be clamped to.
  403. :return: output clamped tensor.
  404. Examples:
  405. .. testcode::
  406. import numpy as np
  407. from megengine import tensor
  408. import megengine.functional as F
  409. a = tensor(np.arange(5).astype(np.int32))
  410. print(F.clip(a, 2, 4).numpy())
  411. print(F.clip(a, lower=3).numpy())
  412. print(F.clip(a, upper=3).numpy())
  413. Outputs:
  414. .. testoutput::
  415. [2 2 2 3 4]
  416. [3 3 3 3 4]
  417. [0 1 2 3 3]
  418. """
  419. assert (
  420. lower is not None or upper is not None
  421. ), "At least one of 'lower' or 'upper' must not be None"
  422. if lower is not None:
  423. if upper is not None:
  424. if not is_tracing():
  425. assert lower <= upper, "clip lower bound is bigger that upper bound"
  426. return minimum(maximum(x, lower), upper)
  427. else:
  428. return maximum(x, lower)
  429. else:
  430. return minimum(x, upper)

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台