You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_elemwise.py 10 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318
  1. # -*- coding: utf-8 -*-
  2. import numpy as np
  3. import pytest
  4. import megengine.autodiff as ad
  5. import megengine.functional as F
  6. import megengine.functional.elemwise as elemwise
  7. from megengine import tensor
  8. from megengine.core.tensor import dtype
  9. from megengine.functional.elemwise import Elemwise
  10. from megengine.jit import trace
  11. def test_abs():
  12. np.testing.assert_allclose(
  13. F.abs(tensor([-3.0, -4.0, -5.0])).numpy(),
  14. np.abs(np.array([-3.0, -4.0, -5.0], dtype=np.float32)),
  15. )
  16. np.testing.assert_allclose(F.abs(-3.0).numpy(), np.abs(np.float32(-3.0)))
  17. def test_elemwise_mode_string():
  18. for key, mode in vars(Elemwise.Mode).items():
  19. if isinstance(mode, Elemwise.Mode):
  20. assert key == mode
  21. assert Elemwise(mode=key) == Elemwise(mode=mode)
  22. def test_multiply():
  23. np.testing.assert_allclose(
  24. F.mul(-3.0, -4.0).numpy(), np.multiply(np.float32(-3.0), np.float32(-4.0))
  25. )
  26. np.testing.assert_allclose(
  27. F.mul(tensor([3.0, 4.0]), 4.0).numpy(),
  28. np.multiply(np.array([3.0, 4.0], dtype=np.float32), 4.0),
  29. )
  30. np.testing.assert_allclose(
  31. F.mul(4.0, tensor([3.0, 4.0])).numpy(),
  32. np.multiply(4.0, np.array([3.0, 4.0], dtype=np.float32)),
  33. )
  34. np.testing.assert_allclose(
  35. F.mul(tensor([3.0, 4.0]), tensor([3.0, 4.0])).numpy(),
  36. np.multiply(
  37. np.array([3.0, 4.0], dtype=np.float32),
  38. np.array([3.0, 4.0], dtype=np.float32),
  39. ),
  40. )
  41. def test_div():
  42. np.testing.assert_allclose(
  43. F.div(tensor([3.0, 4.0]), 2).numpy(),
  44. np.divide(np.array([3, 4], dtype=np.float32), 2),
  45. )
  46. np.testing.assert_allclose(
  47. (tensor([3, 4]) / 2).numpy(), np.divide(np.array([3, 4], dtype=np.float32), 2),
  48. )
  49. np.testing.assert_allclose(
  50. F.floor_div(tensor([-5.0, -7.0]), 2).numpy(),
  51. np.floor_divide(np.array([-5.0, -7.0], dtype=np.float32), 2),
  52. )
  53. np.testing.assert_allclose(
  54. (tensor([-5, -7]) // 2).numpy(),
  55. np.floor_divide(np.array([-5, -7], dtype=np.int32), 2),
  56. )
  57. np.testing.assert_allclose(
  58. (tensor([[5, 4, 3], [4, 2, 6]]) // [1, 2, 1]).numpy(),
  59. np.floor_divide(np.array([[5, 4, 3], [4, 2, 6]], dtype=np.int32), [1, 2, 1]),
  60. )
  61. def test_clamp():
  62. """Fix an issue when `lower` or `upper` is 0, it will be recognized as `False` and
  63. `F.clip` will fall into wrong conditions unexpectedly.
  64. """
  65. x = np.linspace(-6, 6, dtype="float32")
  66. np.testing.assert_allclose(
  67. F.clip(tensor(x) + 3, 0, 6).numpy(), np.clip(x + 3, 0, 6)
  68. )
  69. np.testing.assert_allclose(
  70. F.clip(tensor(x) - 3, -6, 0).numpy(), np.clip(x - 3, -6, 0)
  71. )
  72. def test_isnan():
  73. for case in [[1, float("nan"), 0]]:
  74. np.testing.assert_allclose(F.isnan(tensor(case)).numpy(), np.isnan(case))
  75. def test_isinf():
  76. for case in [[1, float("inf"), 0]]:
  77. np.testing.assert_allclose(F.isinf(tensor(case)).numpy(), np.isinf(case))
  78. def test_sign():
  79. for case in [[1, -1, 0]]:
  80. x = tensor(case)
  81. np.testing.assert_allclose(F.sign(x).numpy(), np.sign(case).astype(x.dtype))
  82. def test_cosh():
  83. np.random.seed(42)
  84. x = np.random.randn(100).astype("float32")
  85. y_np = np.cosh(x)
  86. y_mge = F.cosh(tensor(x)).numpy()
  87. np.testing.assert_allclose(y_np, y_mge, rtol=1e-5)
  88. def test_sinh():
  89. np.random.seed(42)
  90. x = np.random.randn(100).astype("float32")
  91. y_np = np.sinh(x)
  92. y_mge = F.sinh(tensor(x)).numpy()
  93. np.testing.assert_allclose(y_np, y_mge, rtol=1e-5)
  94. def test_asinh():
  95. np.random.seed(42)
  96. x = np.random.randn(100).astype("float32")
  97. y_np = np.arcsinh(x)
  98. y_mge = F.asinh(tensor(x)).numpy()
  99. np.testing.assert_almost_equal(y_np, y_mge, decimal=5)
  100. def test_acosh():
  101. x = np.arange(0, 10000).astype("float32") / 100 + 1
  102. y_np = np.arccosh(x)
  103. y_mge = F.acosh(tensor(x)).numpy()
  104. np.testing.assert_almost_equal(y_np, y_mge, decimal=6)
  105. def test_atanh():
  106. np.random.seed(42)
  107. x = np.random.rand(100).astype("float32") * 2 - 1
  108. y_np = np.arctanh(x)
  109. y_mge = F.atanh(tensor(x)).numpy()
  110. np.testing.assert_almost_equal(y_np, y_mge, decimal=5)
  111. def test_hswish():
  112. np.random.seed(42)
  113. x = np.random.randn(100).astype("float32")
  114. y_np = x * np.minimum(np.maximum(x + 3, 0), 6) / 6
  115. y_mge = F.hswish(tensor(x)).numpy()
  116. np.testing.assert_almost_equal(y_np, y_mge, decimal=6)
  117. def test_silu():
  118. x = np.array([-1.5, 0.0, 1.0, 1.5]).astype("float32")
  119. y_np = x / (1 + np.exp(-x))
  120. y_mge = F.silu(tensor(x)).numpy()
  121. np.testing.assert_almost_equal(y_np, y_mge, decimal=6)
  122. def test_hsigmoid():
  123. np.random.seed(42)
  124. x = np.random.randn(100).astype("float32")
  125. y_np = np.minimum(np.maximum(x + 3, 0), 6) / 6
  126. y_mge = F.hsigmoid(tensor(x)).numpy()
  127. np.testing.assert_almost_equal(y_np, y_mge, decimal=6)
  128. def test_logical_oprs():
  129. x = np.array([[True, False], [False, True]])
  130. y = np.array([[True, True], [False, False]])
  131. xx = tensor(x)
  132. yy = tensor(y)
  133. np.testing.assert_equal(~x, (F.logical_not(xx)).numpy())
  134. np.testing.assert_equal(x & y, F.logical_and(xx, yy).numpy())
  135. np.testing.assert_equal(x | y, F.logical_or(xx, yy).numpy())
  136. np.testing.assert_equal(x ^ y, F.logical_xor(xx, yy).numpy())
  137. def test_logaddexp():
  138. x = np.random.randn(2, 100)
  139. y = np.random.randn(2, 100)
  140. xx = tensor(x)
  141. yy = tensor(y)
  142. out_np = np.log(np.exp(x) + np.exp(y))
  143. out_mge = F.logaddexp(xx, yy)
  144. np.testing.assert_almost_equal(out_np, out_mge.numpy(), decimal=6)
  145. def test_qadd():
  146. inp_scale = 0.5
  147. outp_scale = 0.2
  148. x = np.arange(6).reshape(2, 3).astype("float32")
  149. y = np.arange(6).reshape(2, 3).astype("float32")
  150. x = tensor(x, dtype=dtype.qint8(inp_scale))
  151. y = tensor(y, dtype=dtype.qint8(inp_scale))
  152. result_mge = F.elemwise._elemwise_multi_type(
  153. x, y, mode="qadd", dtype=dtype.qint8(outp_scale)
  154. )
  155. result_mge = result_mge.astype("float32").numpy()
  156. result_expect = x.astype("float32").numpy() + y.astype("float32").numpy()
  157. np.testing.assert_almost_equal(result_mge, result_expect, decimal=6)
  158. def test_int32_input():
  159. x = tensor(np.array([1, 2, 3, 4, 5]), dtype="int32")
  160. for op_name in elemwise.__all__:
  161. op = getattr(elemwise, op_name)
  162. nargs = op.__code__.co_argcount
  163. if op_name == "clip":
  164. inp = (x, 0, 1)
  165. elif op_name.endswith("_shift"):
  166. inp = (x, 1)
  167. elif op_name.startswith("logical_"):
  168. continue
  169. else:
  170. inp = (x,) * nargs
  171. y = op(*inp)
  172. y.numpy()
  173. @pytest.mark.parametrize("is_trace", [True, False])
  174. def test_empty_tensor(is_trace):
  175. binary_func = []
  176. unary_func = []
  177. for op_name in elemwise.__all__:
  178. op = getattr(elemwise, op_name)
  179. nargs = op.__code__.co_argcount
  180. if op_name == "clip":
  181. unary_func.append(["clip", lambda x, f=op: f(x, lower=0, upper=1)])
  182. elif op_name.endswith("_shift"):
  183. unary_func.append(
  184. [op_name, lambda x, f=op: f(tensor(x.numpy(), dtype="int32"), 1)]
  185. )
  186. elif op_name.startswith("logical_"): # logical_xxx op only accept boolean type
  187. if nargs == 1:
  188. unary_func.append(
  189. [op_name, lambda x, f=op: f(tensor(x.numpy(), dtype="bool"))]
  190. )
  191. else:
  192. assert nargs == 2
  193. binary_func.append(
  194. [
  195. op_name,
  196. lambda x, y, f=op: f(
  197. tensor(x.numpy(), dtype="bool"),
  198. tensor(y.numpy(), dtype="bool"),
  199. ),
  200. ]
  201. )
  202. elif nargs == 1:
  203. unary_func.append([op_name, op])
  204. elif nargs == 2:
  205. binary_func.append([op_name, op])
  206. else:
  207. raise NotImplementedError("nargs {}".format(nargs))
  208. def run_test(func, args, ref_shape, is_trace, sym=False):
  209. args = [tensor(t, dtype="float32") for t in args]
  210. if is_trace:
  211. func = trace(symbolic=sym)(func)
  212. for _ in range(3):
  213. out = func(*args)
  214. assert out.numpy().shape == ref_shape
  215. else:
  216. out = func(*args)
  217. assert out.numpy().shape == ref_shape, out.numpy().shape
  218. inps = [
  219. np.array([]).astype("float32"),
  220. np.random.randn(2, 0, 3).astype("float32"),
  221. 123,
  222. ]
  223. for op_name, op in unary_func:
  224. if is_trace:
  225. for sym in [True, False]:
  226. run_test(op, [inps[0],], inps[0].shape, True, sym)
  227. run_test(op, [inps[1],], inps[1].shape, True, sym)
  228. else:
  229. run_test(op, [inps[0],], inps[0].shape, False)
  230. run_test(op, [inps[1],], inps[1].shape, False)
  231. for op_name, op in binary_func:
  232. if is_trace:
  233. for sym in [True, False]:
  234. run_test(op, [inps[0], inps[0]], (inps[0] + inps[0]).shape, True, sym)
  235. run_test(op, [inps[1], inps[1]], (inps[1] + inps[1]).shape, True, sym)
  236. run_test(op, [inps[0], inps[2]], (inps[0] + inps[2]).shape, True, sym)
  237. run_test(op, [inps[1], inps[2]], (inps[1] + inps[2]).shape, True, sym)
  238. else:
  239. run_test(op, [inps[0], inps[0]], (inps[0] + inps[0]).shape, False)
  240. run_test(op, [inps[1], inps[1]], (inps[1] + inps[1]).shape, False)
  241. run_test(op, [inps[0], inps[2]], (inps[0] + inps[2]).shape, False)
  242. run_test(op, [inps[1], inps[2]], (inps[1] + inps[2]).shape, False)
  243. @pytest.mark.parametrize("is_trace", [True, False])
  244. def test_maximum_grad_consistency(is_trace):
  245. def f(x):
  246. with ad.GradManager() as gm:
  247. gm.attach(x)
  248. gm.backward(F.maximum(x, x))
  249. dx = x.grad
  250. x.grad = None
  251. return dx
  252. def run(f):
  253. x = F.arange(10)
  254. for i in range(3):
  255. np.testing.assert_equal(f(x).numpy(), np.ones(10))
  256. if is_trace:
  257. for symbolic in [False, True]:
  258. run(trace(symbolic=symbolic)(f))
  259. else:
  260. run(f)