You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

elemwise.py 15 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586
  1. # -*- coding: utf-8 -*-
  2. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  3. #
  4. # Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  5. #
  6. # Unless required by applicable law or agreed to in writing,
  7. # software distributed under the License is distributed on an
  8. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  9. # pylint: disable=unused-argument,invalid-name,redefined-builtin,arguments-out-of-order
  10. import numpy as np
  11. from ..core._imperative_rt.core2 import SymbolVar, apply
  12. from ..core.ops import builtin
  13. from ..core.ops.builtin import Elemwise
  14. from ..core.tensor.array_method import _elwise
  15. from ..core.tensor.utils import convert_inputs
  16. from ..tensor import Tensor
  17. from ..utils.deprecation import deprecated_func
  18. from .tensor_cache import get_scalar_one
  19. __all__ = [
  20. "abs",
  21. "add",
  22. "acos",
  23. "asin",
  24. "atan",
  25. "atan2",
  26. "asinh",
  27. "acosh",
  28. "atanh",
  29. "ceil",
  30. "clip",
  31. "cos",
  32. "cosh",
  33. "div",
  34. "equal",
  35. "exp",
  36. "expm1",
  37. "floor",
  38. "floor_div",
  39. "greater",
  40. "greater_equal",
  41. "left_shift",
  42. "less",
  43. "less_equal",
  44. "log",
  45. "log1p",
  46. "logical_and",
  47. "logical_not",
  48. "logical_or",
  49. "logical_xor",
  50. "logaddexp",
  51. "maximum",
  52. "minimum",
  53. "mod",
  54. "mul",
  55. "neg",
  56. "not_equal",
  57. "pow",
  58. "right_shift",
  59. "round",
  60. "sin",
  61. "sinh",
  62. "sqrt",
  63. "square",
  64. "sub",
  65. "tan",
  66. "tanh",
  67. ]
  68. def _elemwise_multi_type(*args, mode, **kwargs):
  69. op = builtin.ElemwiseMultiType(mode=mode, **kwargs)
  70. args = convert_inputs(*args)
  71. (result,) = apply(op, *args)
  72. return result
  73. # math operations
  74. def add(x, y):
  75. r"""Element-wise `addition`.
  76. Examples:
  77. .. testcode::
  78. import numpy as np
  79. from megengine import tensor
  80. import megengine.functional as F
  81. x = tensor(np.arange(0, 6, dtype=np.float32).reshape(2, 3))
  82. y = tensor(np.arange(0, 6, dtype=np.float32).reshape(2, 3))
  83. out = F.add(x, y)
  84. print(out.numpy())
  85. Outputs:
  86. .. testoutput::
  87. [[ 0. 2. 4.]
  88. [ 6. 8. 10.]]
  89. """
  90. return _elwise(x, y, mode=Elemwise.Mode.ADD)
  91. def sub(x, y):
  92. r"""Element-wise `sub`.
  93. Examples:
  94. .. testcode::
  95. import numpy as np
  96. from megengine import tensor
  97. import megengine.functional as F
  98. x = tensor(np.arange(1, 7, dtype=np.float32).reshape(2, 3))
  99. y = tensor(np.arange(0, 6, dtype=np.float32).reshape(2, 3))
  100. out = F.sub(x, y)
  101. print(out.numpy())
  102. Outputs:
  103. .. testoutput::
  104. [[1. 1. 1.]
  105. [1. 1. 1.]]
  106. """
  107. return _elwise(x, y, mode=Elemwise.Mode.SUB)
  108. def mul(x: Tensor, y: Tensor) -> Tensor:
  109. r"""Calculates the product for each element :math:`x_i` of the input tensor `x` with the respective element :math:`y_i` of the input tensor `y`.
  110. Note:
  111. * If either :math:`x_i` or :math:`y_i` is `NaN`, the result is `NaN`.
  112. * If :math:`x_i` is either `+infinity` or `-infinity` and :math:`y_i` is either `+0` or `-0`, the result is `NaN`.
  113. * If :math:`x_i` is either `+0` or `-0` and :math:`y_i` is either `+infinity` or `-infinity`, the result is `NaN`.
  114. * If :math:`x_i` and :math:`y_i` have different mathematical signs, the result has a negative mathematical sign, unless the result is `NaN`.
  115. * If :math:`x_i` is either `+infinity` or `-infinity` and :math:`y_i` is either `+infinity` or `-infinity`,
  116. the result is a signed infinity with the mathematical sign determined by the rule already stated above.
  117. * If :math:`x_i` is either `+infinity` or `-infinity` and :math:`y_i` is a nonzero finite number,
  118. the result is a signed infinity with the mathematical sign determined by the rule already stated above.
  119. * If :math:`x_i` is a nonzero finite number and :math:`y_i` is either `+infinity` or `-infinity`,
  120. the result is a signed infinity with the mathematical sign determined by the rule already stated above.
  121. * In the remaining cases, where neither `infinity` nor `NaN` is involved,
  122. the product must be computed and rounded to the nearest representable value according to IEEE 754-2019 and a supported rounding mode.
  123. If the magnitude is too large to represent, the result is an `infinity` of appropriate mathematical sign.
  124. If the magnitude is too small to represent, the result is a zero of appropriate mathematical sign.
  125. * Floating-point multiplication is not always associative due to finite precision.
  126. Args:
  127. x: first input tensor. Should have a numeric data type.
  128. y: second input tensor. Must be compatible with `x` (see :ref:`broadcasting-rule` ). Should have a numeric data type.
  129. Returns:
  130. A tensor containing the element-wise products. The returned array must have a data type determined by :ref:`dtype-promotion`.
  131. Examples:
  132. >>> F.mul(2, 3)
  133. Tensor(6, dtype=int32, device=xpux:0)
  134. >>> F.mul(2.0, 3.0)
  135. Tensor(6.0, device=xpux:0)
  136. >>> x = F.arange(6.0).reshape(2, 3)
  137. >>> y = F.arange(3.0)
  138. >>> F.mul(x, y)
  139. Tensor([[ 0. 1. 4.]
  140. [ 0. 4. 10.]], device=xpux:0)
  141. The `*` operator can be used as a shorthand for :func:`~.functional.mul` on tensors.
  142. >>> x = F.arange(6.0).reshape(2, 3)
  143. >>> y = F.arange(3.0)
  144. >>> x * y
  145. Tensor([[ 0. 1. 4.]
  146. [ 0. 4. 10.]], device=xpux:0)
  147. """
  148. return _elwise(x, y, mode=Elemwise.Mode.MUL)
  149. def div(x, y):
  150. r"""Element-wise `(x / y)`."""
  151. return _elwise(x, y, mode=Elemwise.Mode.TRUE_DIV)
  152. def floor_div(x, y):
  153. r"""Element-wise `floor(x / y)`."""
  154. return _elwise(x, y, mode=Elemwise.Mode.FLOOR_DIV)
  155. def neg(x):
  156. r"""Element-wise `negation`."""
  157. return _elwise(x, mode=Elemwise.Mode.NEGATE)
  158. def pow(x, y):
  159. r"""Element-wise `power`."""
  160. return _elwise(x, y, mode=Elemwise.Mode.POW)
  161. def mod(x, y):
  162. r"""Element-wise `remainder of division`."""
  163. return _elwise(x, y, mode=Elemwise.Mode.MOD)
  164. def abs(x):
  165. r"""Element-wise `absolute value`."""
  166. return _elwise(x, mode=Elemwise.Mode.ABS)
  167. def exp(x):
  168. r"""Element-wise `exponential`."""
  169. return _elwise(x, mode=Elemwise.Mode.EXP)
  170. def expm1(x):
  171. r"""Element-wise `exp(x)-1`."""
  172. return _elwise(x, mode=Elemwise.Mode.EXPM1)
  173. def log(x):
  174. r"""Element-wise `logarithm (base e)`."""
  175. return _elwise(x, mode=Elemwise.Mode.LOG)
  176. def log1p(x):
  177. r"""Element-wise `log(x+1) (base e)`."""
  178. return _elwise(x, mode=Elemwise.Mode.LOG1P)
  179. def sqrt(x: Tensor) -> Tensor:
  180. r"""Element-wise `sqrt`.
  181. Examples:
  182. .. testcode::
  183. import numpy as np
  184. from megengine import tensor
  185. import megengine.functional as F
  186. x = tensor(np.arange(0, 6, dtype=np.float32).reshape(2, 3))
  187. out = F.sqrt(x)
  188. print(out.numpy().round(decimals=4))
  189. Outputs:
  190. .. testoutput::
  191. [[0. 1. 1.4142]
  192. [1.7321 2. 2.2361]]
  193. """
  194. return x ** 0.5
  195. def square(x: Tensor) -> Tensor:
  196. r"""Element-wise `square`.
  197. Examples:
  198. .. testcode::
  199. import numpy as np
  200. import megengine as mge
  201. import megengine.functional as F
  202. data = mge.tensor(np.arange(0, 6, dtype=np.float32).reshape(2, 3))
  203. out = F.square(data)
  204. print(out.numpy().round(decimals=4))
  205. Outputs:
  206. .. testoutput::
  207. [[ 0. 1. 4.]
  208. [ 9. 16. 25.]]
  209. """
  210. return x ** 2
  211. def round(x):
  212. r"""Element-wise `rounding to int`."""
  213. return _elwise(x, mode=Elemwise.Mode.ROUND)
  214. def ceil(x):
  215. r"""Element-wise `ceiling`."""
  216. return _elwise(x, mode=Elemwise.Mode.CEIL)
  217. def floor(x):
  218. r"""Element-wise `floor`."""
  219. return _elwise(x, mode=Elemwise.Mode.FLOOR)
  220. def maximum(x, y):
  221. r"""Element-wise `maximum of array elements`."""
  222. return _elwise(x, y, mode=Elemwise.Mode.MAX)
  223. def minimum(x, y):
  224. r"""Element-wise `minimum of array elements`."""
  225. return _elwise(x, y, mode=Elemwise.Mode.MIN)
  226. # trigonometric functions
  227. def cos(x):
  228. r"""Element-wise `cosine`.
  229. Examples:
  230. .. testcode::
  231. import numpy as np
  232. from megengine import tensor
  233. import megengine.functional as F
  234. x = tensor(np.arange(0, 6, dtype=np.float32).reshape(2, 3))
  235. out = F.cos(x)
  236. print(out.numpy().round(decimals=4))
  237. Outputs:
  238. .. testoutput::
  239. [[ 1. 0.5403 -0.4161]
  240. [-0.99 -0.6536 0.2837]]
  241. """
  242. return _elwise(x, mode=Elemwise.Mode.COS)
  243. def sin(x):
  244. r"""Element-wise `sine`."""
  245. return _elwise(x, mode=Elemwise.Mode.SIN)
  246. def tan(x):
  247. r"""Element-wise `tangent`."""
  248. return sin(x) / cos(x)
  249. def acos(x):
  250. r"""Element-wise `inverse cosine`."""
  251. return _elwise(x, mode=Elemwise.Mode.ACOS)
  252. def asin(x):
  253. r"""Element-wise `inverse sine`."""
  254. return _elwise(x, mode=Elemwise.Mode.ASIN)
  255. def atan(x):
  256. r"""Element-wise `inverse tangent`."""
  257. return _elwise(
  258. x,
  259. get_scalar_one("float32", x.device if isinstance(x, Tensor) else None),
  260. mode=Elemwise.Mode.ATAN2,
  261. )
  262. def atan2(y, x):
  263. r"""Element-wise `2-argument arctangent`."""
  264. return _elwise(y, x, mode=Elemwise.Mode.ATAN2)
  265. def cosh(x):
  266. r"""Element-wise `hyperbolic cosine`."""
  267. return 0.5 * (exp(x) + exp(-x))
  268. def sinh(x):
  269. r"""Element-wise `hyperbolic sine`."""
  270. u = expm1(x)
  271. return 0.5 * u / (u + 1) * (u + 2)
  272. def tanh(x):
  273. r"""Element-wise `hyperbolic tangent`."""
  274. return _elwise(x, mode=Elemwise.Mode.TANH)
  275. def asinh(x):
  276. r"""Element-wise `inverse hyperbolic sine`."""
  277. return log(x + (x ** 2 + 1) ** 0.5)
  278. def acosh(x):
  279. r"""Element-wise `inverse hyperbolic cosine`."""
  280. return log(x + (x ** 2 - 1) ** 0.5)
  281. def atanh(x):
  282. r"""Element-wise `inverse hyperbolic tangent`."""
  283. return log1p(2 * x / (1 - x)) / 2
  284. # bit-twiddling functions
  285. def left_shift(x, y):
  286. r"""Element-wise `bitwise binary: x << y`.
  287. Examples:
  288. .. testcode::
  289. import numpy as np
  290. from megengine import tensor
  291. import megengine.functional as F
  292. x = tensor(np.arange(0, 6, dtype=np.int32).reshape(2, 3))
  293. out = F.left_shift(x, 2)
  294. print(out.numpy())
  295. Outputs:
  296. .. testoutput::
  297. [[ 0 4 8]
  298. [12 16 20]]
  299. """
  300. return _elwise(x, y, mode=Elemwise.Mode.SHL)
  301. def right_shift(x, y):
  302. r"""Element-wise `bitwise binary: x >> y`."""
  303. return _elwise(x, y, mode=Elemwise.Mode.SHR)
  304. # logical functions
  305. def logical_and(x, y):
  306. r"""Element-wise `logical and: x && y`."""
  307. return _elwise(x, y, mode=Elemwise.Mode.AND)
  308. def logical_not(x):
  309. r"""Element-wise `logical not: ~x`."""
  310. return _elwise(x, mode=Elemwise.Mode.NOT)
  311. def logical_or(x, y):
  312. r"""Element-wise `logical or: x || y`."""
  313. return _elwise(x, y, mode=Elemwise.Mode.OR)
  314. def logical_xor(x, y):
  315. r"""Element-wise `logical xor: x ^ y`."""
  316. return _elwise(x, y, mode=Elemwise.Mode.XOR)
  317. def logaddexp(x: Tensor, y: Tensor) -> Tensor:
  318. r"""Element-wise `numerically stable log(exp(x) + exp(y)`
  319. """
  320. return _elwise(x, y, mode=Elemwise.Mode.LOG_SUM_EXP)
  321. # comparison functions
  322. def equal(x, y):
  323. r"""Element-wise `(x == y)`.
  324. Examples:
  325. .. testcode::
  326. import numpy as np
  327. from megengine import tensor
  328. import megengine.functional as F
  329. x = tensor(np.arange(0, 6, dtype=np.float32).reshape(2, 3))
  330. y = tensor(np.arange(0, 6, dtype=np.float32).reshape(2, 3))
  331. out = F.equal(x, y)
  332. print(out.numpy())
  333. Outputs:
  334. .. testoutput::
  335. [[1. 1. 1.]
  336. [1. 1. 1.]]
  337. """
  338. return _elwise(x, y, mode=Elemwise.Mode.EQ)
  339. def not_equal(x, y):
  340. r"""Element-wise `(x != y)`."""
  341. return x != y
  342. def less(x, y):
  343. r"""Element-wise `(x < y)`."""
  344. return _elwise(x, y, mode=Elemwise.Mode.LT)
  345. def less_equal(x, y):
  346. r"""Element-wise `(x <= y)`."""
  347. return _elwise(x, y, mode=Elemwise.Mode.LEQ)
  348. def greater(x, y):
  349. r"""Element-wise `(x > y)`."""
  350. return _elwise(y, x, mode=Elemwise.Mode.LT)
  351. def greater_equal(x, y):
  352. r"""Element-wise `(x >= y)`."""
  353. return _elwise(y, x, mode=Elemwise.Mode.LEQ)
  354. # other functions
  355. def clip(x: Tensor, lower=None, upper=None) -> Tensor:
  356. r"""Clamps all elements in input tensor into the range ``[ lower, upper ]`` and returns
  357. a resulting tensor:
  358. .. math::
  359. y_i = \begin{cases}
  360. \text{lower} & \text{if } x_i < \text{lower} \\
  361. x_i & \text{if } \text{lower} \leq x_i \leq \text{upper} \\
  362. \text{upper} & \text{if } x_i > \text{upper}
  363. \end{cases}
  364. Args:
  365. x: (Tensor): The input tensor.
  366. lower: (Numberic,optional): lower-bound of the range to be clamped to.
  367. upper: (Numberic,optional): upper-bound of the range to be clamped to.
  368. Note:
  369. * If both `lower` and `upper` are None, raises an AssertionError.
  370. * If `lower` is bigger than `upper`, the result is same as `clip(Tensor(), upper, upper)`.
  371. Returns:
  372. output clamped tensor. The result must have a data type determined by :ref:`dtype-promotion`.
  373. Examples:
  374. >>> import numpy as np
  375. >>> x = Tensor([0,1,2,3,4])
  376. >>> F.clip(x, 2, 4)
  377. Tensor([2 2 2 3 4], dtype=int32, device=xpux:0)
  378. >>> x = Tensor([0,1,2,3,4])
  379. >>> F.clip(x, 4, 3)
  380. Tensor([3 3 3 3 3], dtype=int32, device=xpux:0)
  381. >>> x = F.arange(5)
  382. >>> F.clip(x, lower=3)
  383. Tensor([3. 3. 3. 3. 4.], device=xpux:0)
  384. >>> x = F.arange(5, dtype=np.int32)
  385. >>> F.clip(x, upper=2.1)
  386. Tensor([0. 1. 2. 2.1 2.1], device=xpux:0)
  387. """
  388. assert (
  389. lower is not None or upper is not None
  390. ), "At least one of 'lower' or 'upper' must not be None"
  391. if lower is not None:
  392. if upper is not None:
  393. return minimum(maximum(x, lower), upper)
  394. else:
  395. return maximum(x, lower)
  396. else:
  397. return minimum(x, upper)
  398. sigmoid = deprecated_func("1.3", "megengine.functional.nn", "sigmoid", True)
  399. hsigmoid = deprecated_func("1.3", "megengine.functional.nn", "hsigmoid", True)
  400. relu = deprecated_func("1.3", "megengine.functional.nn", "relu", True)
  401. relu6 = deprecated_func("1.3", "megengine.functional.nn", "relu6", True)
  402. hswish = deprecated_func("1.3", "megengine.functional.nn", "hswish", True)