You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

elemwise.py 13 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576
  1. # -*- coding: utf-8 -*-
  2. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  3. #
  4. # Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  5. #
  6. # Unless required by applicable law or agreed to in writing,
  7. # software distributed under the License is distributed on an
  8. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  9. # pylint: disable=unused-argument,invalid-name,redefined-builtin,arguments-out-of-order
  10. import numpy as np
  11. from ..core._imperative_rt.core2 import SymbolVar, apply
  12. from ..core.ops import builtin
  13. from ..core.ops.builtin import Elemwise
  14. from ..core.tensor import utils
  15. from ..core.tensor.array_method import _elwise_apply
  16. from ..core.tensor.utils import astype
  17. from ..device import get_default_device
  18. from ..jit.tracing import is_tracing
  19. from ..tensor import Tensor
  20. from ..utils.deprecation import deprecated_func
  21. __all__ = [
  22. "abs",
  23. "add",
  24. "acos",
  25. "asin",
  26. "atan",
  27. "atan2",
  28. "asinh",
  29. "acosh",
  30. "atanh",
  31. "ceil",
  32. "clip",
  33. "cos",
  34. "cosh",
  35. "div",
  36. "equal",
  37. "exp",
  38. "expm1",
  39. "floor",
  40. "floor_div",
  41. "greater",
  42. "greater_equal",
  43. "left_shift",
  44. "less",
  45. "less_equal",
  46. "log",
  47. "log1p",
  48. "logical_and",
  49. "logical_not",
  50. "logical_or",
  51. "logical_xor",
  52. "maximum",
  53. "minimum",
  54. "mod",
  55. "mul",
  56. "neg",
  57. "not_equal",
  58. "pow",
  59. "right_shift",
  60. "round",
  61. "sin",
  62. "sinh",
  63. "sqrt",
  64. "square",
  65. "sub",
  66. "tan",
  67. "tanh",
  68. ]
  69. def _elwise(*args, mode):
  70. tensor_args = list(filter(lambda x: isinstance(x, (Tensor, SymbolVar)), args))
  71. if len(tensor_args) == 0:
  72. dtype = utils.dtype_promotion(args)
  73. first_arg = Tensor(args[0], dtype=dtype, device=get_default_device())
  74. args = utils.convert_inputs(first_arg, *args[1:])
  75. else:
  76. args = utils.convert_inputs(*args)
  77. if mode in (
  78. Elemwise.Mode.TRUE_DIV,
  79. Elemwise.Mode.EXP,
  80. Elemwise.Mode.POW,
  81. Elemwise.Mode.LOG,
  82. Elemwise.Mode.EXPM1,
  83. Elemwise.Mode.LOG1P,
  84. Elemwise.Mode.TANH,
  85. Elemwise.Mode.ACOS,
  86. Elemwise.Mode.ASIN,
  87. Elemwise.Mode.ATAN2,
  88. Elemwise.Mode.CEIL,
  89. Elemwise.Mode.COS,
  90. Elemwise.Mode.FLOOR,
  91. Elemwise.Mode.H_SWISH,
  92. Elemwise.Mode.ROUND,
  93. Elemwise.Mode.SIGMOID,
  94. Elemwise.Mode.SIN,
  95. ):
  96. if mode in (
  97. Elemwise.Mode.CEIL,
  98. Elemwise.Mode.FLOOR,
  99. Elemwise.Mode.ROUND,
  100. ) and np.issubdtype(args[0].dtype, np.integer):
  101. return args[0]
  102. args = tuple(map(lambda x: astype(x, "float32"), args))
  103. return _elwise_apply(args, mode)
  104. def _elemwise_multi_type(*args, mode, **kwargs):
  105. op = builtin.ElemwiseMultiType(mode=mode, **kwargs)
  106. args = utils.convert_inputs(*args)
  107. (result,) = apply(op, *args)
  108. return result
  109. # math operations
  110. def add(x, y):
  111. """
  112. Element-wise `addition`.
  113. At least one operand should be tensor.
  114. Same for sub/mul/div/floor_div/pow/mod/atan2/equal/not_equal/less/less_equal/greater/greater_equal/maximum/minmium.
  115. :param x: input tensor.
  116. :return: computed tensor.
  117. Examples:
  118. .. testcode::
  119. import numpy as np
  120. from megengine import tensor
  121. import megengine.functional as F
  122. x = tensor(np.arange(0, 6, dtype=np.float32).reshape(2, 3))
  123. y = tensor(np.arange(0, 6, dtype=np.float32).reshape(2, 3))
  124. out = F.add(x, y)
  125. print(out.numpy())
  126. Outputs:
  127. .. testoutput::
  128. [[ 0. 2. 4.]
  129. [ 6. 8. 10.]]
  130. """
  131. return _elwise(x, y, mode=Elemwise.Mode.ADD)
  132. def sub(x, y):
  133. """Element-wise `subtraction`."""
  134. return _elwise(x, y, mode=Elemwise.Mode.SUB)
  135. def mul(x, y):
  136. """Element-wise `multiplication`."""
  137. return _elwise(x, y, mode=Elemwise.Mode.MUL)
  138. def div(x, y):
  139. """Element-wise `(x / y)`."""
  140. return _elwise(x, y, mode=Elemwise.Mode.TRUE_DIV)
  141. def floor_div(x, y):
  142. """Element-wise `floor(x / y)`."""
  143. return _elwise(x, y, mode=Elemwise.Mode.FLOOR_DIV)
  144. def neg(x):
  145. """Element-wise `negation`."""
  146. return _elwise(x, mode=Elemwise.Mode.NEGATE)
  147. def pow(x, y):
  148. """Element-wise `power`."""
  149. return _elwise(x, y, mode=Elemwise.Mode.POW)
  150. def mod(x, y):
  151. """Element-wise `remainder of division`."""
  152. return _elwise(x, y, mode=Elemwise.Mode.MOD)
  153. def abs(x):
  154. """Element-wise `absolute value`."""
  155. return _elwise(x, mode=Elemwise.Mode.ABS)
  156. def exp(x):
  157. """Element-wise `exponential`."""
  158. return _elwise(x, mode=Elemwise.Mode.EXP)
  159. def expm1(x):
  160. """Element-wise `exp(x)-1`."""
  161. return _elwise(x, mode=Elemwise.Mode.EXPM1)
  162. def log(x):
  163. """Element-wise `logarithm (base e)`."""
  164. return _elwise(x, mode=Elemwise.Mode.LOG)
  165. def log1p(x):
  166. """Element-wise `log(x+1) (base e)`."""
  167. return _elwise(x, mode=Elemwise.Mode.LOG1P)
  168. def sqrt(x: Tensor) -> Tensor:
  169. """
  170. Element-wise `sqrt`.
  171. Returns ``NaN`` for negative input value.
  172. :param x: input tensor.
  173. :return: computed tensor.
  174. Examples:
  175. .. testcode::
  176. import numpy as np
  177. from megengine import tensor
  178. import megengine.functional as F
  179. x = tensor(np.arange(0, 6, dtype=np.float32).reshape(2, 3))
  180. out = F.sqrt(x)
  181. print(out.numpy().round(decimals=4))
  182. Outputs:
  183. .. testoutput::
  184. [[0. 1. 1.4142]
  185. [1.7321 2. 2.2361]]
  186. """
  187. return x ** 0.5
  188. def square(x: Tensor) -> Tensor:
  189. """
  190. Returns a new tensor with the square of the elements of input tensor.
  191. :param inp: input tensor.
  192. :return: computed tensor.
  193. Examples:
  194. .. testcode::
  195. import numpy as np
  196. import megengine as mge
  197. import megengine.functional as F
  198. data = mge.tensor(np.arange(0, 6, dtype=np.float32).reshape(2, 3))
  199. out = F.square(data)
  200. print(out.numpy().round(decimals=4))
  201. Outputs:
  202. .. testoutput::
  203. [[ 0. 1. 4.]
  204. [ 9. 16. 25.]]
  205. """
  206. return x ** 2
  207. def round(x):
  208. """Element-wise `rounding to int`."""
  209. return _elwise(x, mode=Elemwise.Mode.ROUND)
  210. def ceil(x):
  211. """Element-wise `ceiling`."""
  212. return _elwise(x, mode=Elemwise.Mode.CEIL)
  213. def floor(x):
  214. """Element-wise `floor`."""
  215. return _elwise(x, mode=Elemwise.Mode.FLOOR)
  216. def maximum(x, y):
  217. """Element-wise `maximum of array elements`."""
  218. return _elwise(x, y, mode=Elemwise.Mode.MAX)
  219. def minimum(x, y):
  220. """Element-wise `minimum of array elements`."""
  221. return _elwise(x, y, mode=Elemwise.Mode.MIN)
  222. # trigonometric functions
  223. def cos(x):
  224. """
  225. Element-wise `cosine`.
  226. :param x: input tensor.
  227. :return: computed tensor.
  228. Examples:
  229. .. testcode::
  230. import numpy as np
  231. from megengine import tensor
  232. import megengine.functional as F
  233. x = tensor(np.arange(0, 6, dtype=np.float32).reshape(2, 3))
  234. out = F.cos(x)
  235. print(out.numpy().round(decimals=4))
  236. Outputs:
  237. .. testoutput::
  238. [[ 1. 0.5403 -0.4161]
  239. [-0.99 -0.6536 0.2837]]
  240. """
  241. return _elwise(x, mode=Elemwise.Mode.COS)
  242. def sin(x):
  243. """Element-wise `sine`."""
  244. return _elwise(x, mode=Elemwise.Mode.SIN)
  245. def tan(x):
  246. """Element-wise `tangent`."""
  247. return sin(x) / cos(x)
  248. def acos(x):
  249. """Element-wise `inverse cosine`."""
  250. return _elwise(x, mode=Elemwise.Mode.ACOS)
  251. def asin(x):
  252. """Element-wise `inverse sine`."""
  253. return _elwise(x, mode=Elemwise.Mode.ASIN)
  254. def atan(x):
  255. """Element-wise `inverse tangent`."""
  256. return _elwise(x, 1, mode=Elemwise.Mode.ATAN2)
  257. def atan2(y, x):
  258. """Element-wise `2-argument arctangent`."""
  259. return _elwise(y, x, mode=Elemwise.Mode.ATAN2)
  260. def cosh(x):
  261. r"""Element-wise `hyperbolic cosine`."""
  262. return 0.5 * (exp(x) + exp(-x))
  263. def sinh(x):
  264. r"""Element-wise `hyperbolic sine`."""
  265. u = expm1(x)
  266. return 0.5 * u / (u + 1) * (u + 2)
  267. def tanh(x):
  268. r"""Element-wise `hyperbolic tangent`."""
  269. return _elwise(x, mode=Elemwise.Mode.TANH)
  270. def asinh(x):
  271. r"""Element-wise `inverse hyperbolic sine`."""
  272. return log(x + (x ** 2 + 1) ** 0.5)
  273. def acosh(x):
  274. r"""Element-wise `inverse hyperbolic cosine`."""
  275. return log(x + (x ** 2 - 1) ** 0.5)
  276. def atanh(x):
  277. r"""Element-wise `inverse hyperbolic tangent`."""
  278. return log1p(2 * x / (1 - x)) / 2
  279. # bit-twiddling functions
  280. def left_shift(x, y):
  281. """
  282. Element-wise `bitwise binary: x << y`.
  283. :param x: input tensor, should be int.
  284. :param y: how many bits to be left-shifted.
  285. :return: computed tensor.
  286. Examples:
  287. .. testcode::
  288. import numpy as np
  289. from megengine import tensor
  290. import megengine.functional as F
  291. x = tensor(np.arange(0, 6, dtype=np.int32).reshape(2, 3))
  292. out = F.left_shift(x, 2)
  293. print(out.numpy())
  294. Outputs:
  295. .. testoutput::
  296. [[ 0 4 8]
  297. [12 16 20]]
  298. """
  299. return _elwise(x, y, mode=Elemwise.Mode.SHL)
  300. def right_shift(x, y):
  301. """Element-wise `bitwise binary: x >> y`."""
  302. return _elwise(x, y, mode=Elemwise.Mode.SHR)
  303. # logical functions
  304. def logical_and(x, y):
  305. """Element-wise `logical and: x && y`."""
  306. return _elwise(x, y, mode=Elemwise.Mode.AND)
  307. def logical_not(x):
  308. """Element-wise `logical not: ~x`."""
  309. return _elwise(x, mode=Elemwise.Mode.NOT)
  310. def logical_or(x, y):
  311. """Element-wise `logical or: x || y`."""
  312. return _elwise(x, y, mode=Elemwise.Mode.OR)
  313. def logical_xor(x, y):
  314. """Element-wise `logical xor: x ^ y`."""
  315. return _elwise(x, y, mode=Elemwise.Mode.XOR)
  316. # comparison functions
  317. def equal(x, y):
  318. """
  319. Element-wise `(x == y)`.
  320. :param x: input tensor 1.
  321. :param y: input tensor 2.
  322. :return: computed tensor.
  323. Examples:
  324. .. testcode::
  325. import numpy as np
  326. from megengine import tensor
  327. import megengine.functional as F
  328. x = tensor(np.arange(0, 6, dtype=np.float32).reshape(2, 3))
  329. y = tensor(np.arange(0, 6, dtype=np.float32).reshape(2, 3))
  330. out = F.equal(x, y)
  331. print(out.numpy())
  332. Outputs:
  333. .. testoutput::
  334. [[1. 1. 1.]
  335. [1. 1. 1.]]
  336. """
  337. return _elwise(x, y, mode=Elemwise.Mode.EQ)
  338. def not_equal(x, y):
  339. """Element-wise `(x != y)`."""
  340. return x != y
  341. def less(x, y):
  342. """Element-wise `(x < y)`."""
  343. return _elwise(x, y, mode=Elemwise.Mode.LT)
  344. def less_equal(x, y):
  345. """Element-wise `(x <= y)`."""
  346. return _elwise(x, y, mode=Elemwise.Mode.LEQ)
  347. def greater(x, y):
  348. """Element-wise `(x > y)`."""
  349. return _elwise(y, x, mode=Elemwise.Mode.LT)
  350. def greater_equal(x, y):
  351. """Element-wise `(x >= y)`."""
  352. return _elwise(y, x, mode=Elemwise.Mode.LEQ)
  353. # other functions
  354. def clip(x: Tensor, lower=None, upper=None) -> Tensor:
  355. r"""
  356. Clamps all elements in input tensor into the range `[` :attr:`lower`, :attr:`upper` `]` and returns
  357. a resulting tensor:
  358. .. math::
  359. y_i = \begin{cases}
  360. \text{lower} & \text{if } x_i < \text{lower} \\
  361. x_i & \text{if } \text{lower} \leq x_i \leq \text{upper} \\
  362. \text{upper} & \text{if } x_i > \text{upper}
  363. \end{cases}
  364. :param x: input tensor.
  365. :param lower: lower-bound of the range to be clamped to.
  366. :param upper: upper-bound of the range to be clamped to.
  367. :return: output clamped tensor.
  368. Examples:
  369. .. testcode::
  370. import numpy as np
  371. from megengine import tensor
  372. import megengine.functional as F
  373. a = tensor(np.arange(5).astype(np.int32))
  374. print(F.clip(a, 2, 4).numpy())
  375. print(F.clip(a, lower=3).numpy())
  376. print(F.clip(a, upper=3).numpy())
  377. Outputs:
  378. .. testoutput::
  379. [2 2 2 3 4]
  380. [3 3 3 3 4]
  381. [0 1 2 3 3]
  382. """
  383. assert (
  384. lower is not None or upper is not None
  385. ), "At least one of 'lower' or 'upper' must not be None"
  386. if lower is not None:
  387. if upper is not None:
  388. if not is_tracing():
  389. assert lower <= upper, "clip lower bound is bigger that upper bound"
  390. return minimum(maximum(x, lower), upper)
  391. else:
  392. return maximum(x, lower)
  393. else:
  394. return minimum(x, upper)
  395. sigmoid = deprecated_func("1.3", "megengine.functional.nn", "sigmoid", True)
  396. hsigmoid = deprecated_func("1.3", "megengine.functional.nn", "hsigmoid", True)
  397. relu = deprecated_func("1.3", "megengine.functional.nn", "relu", True)
  398. relu6 = deprecated_func("1.3", "megengine.functional.nn", "relu6", True)
  399. hswish = deprecated_func("1.3", "megengine.functional.nn", "hswish", True)

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台