You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_math.py 11 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333
  1. # -*- coding: utf-8 -*-
  2. from functools import partial
  3. import numpy as np
  4. import pytest
  5. from utils import opr_test
  6. import megengine.functional as F
  7. from megengine import Tensor, jit, tensor
  8. from megengine.core._imperative_rt.core2 import apply
  9. from megengine.core.ops import builtin
  10. def common_test_reduce(opr, ref_opr):
  11. data1_shape = (5, 6, 7)
  12. data2_shape = (2, 9, 12)
  13. data1 = np.random.random(data1_shape).astype(np.float32)
  14. data2 = np.random.random(data2_shape).astype(np.float32)
  15. cases = [
  16. {"input": data1},
  17. {"input": data2},
  18. {"input": np.array([[[1, 2, np.nan, 4], [8, 6, 5, 2], [2, 3, 4, 5]]])},
  19. ]
  20. if opr not in (F.argmin, F.argmax):
  21. # test default axis
  22. opr_test(cases, opr, ref_fn=ref_opr)
  23. # test all axises in range of input shape
  24. for axis in range(-3, 3):
  25. # test keepdims False
  26. opr_test(cases, opr, ref_fn=lambda x: ref_opr(x, axis=axis), axis=axis)
  27. # test keepdims True
  28. opr_test(
  29. cases,
  30. opr,
  31. ref_fn=lambda x: ref_opr(x, axis=axis, keepdims=True),
  32. axis=axis,
  33. keepdims=True,
  34. )
  35. else:
  36. # test defaut axis
  37. opr_test(cases, opr, ref_fn=lambda x: ref_opr(x).astype(np.int32))
  38. # test all axises in range of input shape
  39. for axis in range(0, 3):
  40. opr_test(
  41. cases,
  42. opr,
  43. ref_fn=lambda x: ref_opr(x, axis=axis).astype(np.int32),
  44. axis=axis,
  45. )
  46. # test negative axis
  47. axis = axis - len(data1_shape)
  48. opr_test(
  49. cases,
  50. opr,
  51. ref_fn=lambda x: ref_opr(x, axis=axis).astype(np.int32),
  52. axis=axis,
  53. )
  54. def test_sum():
  55. common_test_reduce(opr=F.sum, ref_opr=np.sum)
  56. x = Tensor(np.arange(1, 7, dtype=np.int32).reshape(2, 3))
  57. y = F.sum(x, axis=-1)
  58. np.testing.assert_equal(y.numpy(), np.array([6, 15]).astype(np.int32))
  59. def test_prod():
  60. common_test_reduce(opr=F.prod, ref_opr=np.prod)
  61. x = Tensor(np.arange(1, 7, dtype=np.int32).reshape(2, 3))
  62. y = F.prod(x, axis=-2)
  63. np.testing.assert_equal(y.numpy(), np.array([4, 10, 18]).astype(np.int32))
  64. def test_mean():
  65. common_test_reduce(opr=F.mean, ref_opr=np.mean)
  66. x = Tensor(np.arange(1, 7, dtype=np.int32).reshape(2, 3))
  67. y = F.mean(x, axis=-2)
  68. np.testing.assert_equal(y.numpy(), np.array([2.5, 3.5, 4.5]).astype(np.float32))
  69. def test_var():
  70. common_test_reduce(opr=F.var, ref_opr=np.var)
  71. x = Tensor(np.arange(1, 7, dtype=np.int32).reshape(2, 3))
  72. y = F.var(x, axis=-2)
  73. np.testing.assert_equal(y.numpy(), np.array([2.25, 2.25, 2.25]).astype(np.float32))
  74. def test_std():
  75. common_test_reduce(opr=F.std, ref_opr=np.std)
  76. x = Tensor(np.arange(1, 7, dtype=np.int32).reshape(2, 3))
  77. y = F.std(x, axis=-2)
  78. np.testing.assert_equal(y.numpy(), np.array([1.5, 1.5, 1.5]).astype(np.float32))
  79. x = Tensor(np.arange(1, 7, dtype=np.int32).reshape(2, 3))
  80. y = F.std(x, axis=-2)
  81. np.testing.assert_equal(y.numpy(), np.array([1.5, 1.5, 1.5]).astype(np.float32))
  82. def test_min():
  83. common_test_reduce(opr=F.min, ref_opr=np.min)
  84. x = Tensor(np.arange(1, 7, dtype=np.int32).reshape(2, 3))
  85. y = F.min(x, axis=-1)
  86. np.testing.assert_equal(y.numpy(), np.array([1, 4]).astype(np.int32))
  87. def test_max():
  88. common_test_reduce(opr=F.max, ref_opr=np.max)
  89. x = Tensor(np.arange(1, 7, dtype=np.int32).reshape(2, 3))
  90. y = F.max(x, axis=-1)
  91. np.testing.assert_equal(y.numpy(), np.array([3, 6]).astype(np.int32))
  92. def test_argmin():
  93. common_test_reduce(opr=F.argmin, ref_opr=np.argmin)
  94. x = Tensor(np.arange(1, 7, dtype=np.int32).reshape(2, 3))
  95. y = F.argmin(x, axis=-1)
  96. np.testing.assert_equal(y.numpy(), np.array([0, 0]).astype(np.int32))
  97. def test_argmax():
  98. common_test_reduce(opr=F.argmax, ref_opr=np.argmax)
  99. x = Tensor(np.arange(1, 7, dtype=np.int32).reshape(2, 3))
  100. y = F.argmax(x, axis=-2)
  101. np.testing.assert_equal(y.numpy(), np.array([1, 1, 1]).astype(np.int32))
  102. def test_norm():
  103. x = Tensor(np.arange(1, 7, dtype=np.int32).reshape(2, 3))
  104. y = F.norm(x, axis=-1)
  105. np.testing.assert_equal(
  106. y.numpy().round(decimals=3), np.array([3.742, 8.775]).astype(np.float32)
  107. )
  108. def test_sqrt():
  109. d1_shape = (15,)
  110. d2_shape = (25,)
  111. d1 = np.random.random(d1_shape).astype(np.float32)
  112. d2 = np.random.random(d2_shape).astype(np.float32)
  113. cases = [{"input": d1}, {"input": d2}]
  114. opr_test(cases, F.sqrt, ref_fn=np.sqrt)
  115. def test_sort():
  116. data1_shape = (10, 3)
  117. data2_shape = (12, 2)
  118. data1 = np.random.random(data1_shape).astype(np.float32)
  119. data2 = np.random.random(data2_shape).astype(np.float32)
  120. output1 = [np.sort(data1), np.argsort(data1).astype(np.int32)]
  121. output2 = [np.sort(data2), np.argsort(data2).astype(np.int32)]
  122. cases = [
  123. {"input": data1, "output": output1},
  124. {"input": data2, "output": output2},
  125. ]
  126. opr_test(cases, F.sort)
  127. @pytest.mark.parametrize("is_symbolic", [None, False, True])
  128. def test_sort_empty(is_symbolic):
  129. data_shapes = [
  130. (0,),
  131. (10, 0),
  132. ]
  133. def fn(x):
  134. return F.sort(x)
  135. for shape in data_shapes:
  136. if is_symbolic is not None:
  137. fn_ = jit.trace(symbolic=is_symbolic)(fn)
  138. else:
  139. fn_ = fn
  140. data = np.random.random(shape).astype(np.float32)
  141. for _ in range(3):
  142. outs = fn_(Tensor(data))
  143. ref_outs = (np.sort(data), np.argsort(data))
  144. assert len(ref_outs) == len(outs)
  145. for i in range(len(outs)):
  146. np.testing.assert_equal(outs[i].numpy(), ref_outs[i])
  147. if is_symbolic is None:
  148. break
  149. def test_normalize():
  150. x = Tensor(np.arange(1, 7, dtype=np.int32).reshape(2, 3))
  151. y = F.normalize(x, axis=-1)
  152. np.testing.assert_equal(
  153. y.numpy().round(decimals=1),
  154. np.array([[0.3, 0.5, 0.8], [0.5, 0.6, 0.7]]).astype(np.float32),
  155. )
  156. cases = [
  157. {"input": np.random.random((2, 3, 12, 12)).astype(np.float32)} for i in range(2)
  158. ]
  159. def np_normalize(x, p=2, axis=None, eps=1e-12):
  160. if axis is None:
  161. norm = np.sum(x ** p) ** (1.0 / p)
  162. else:
  163. norm = np.sum(x ** p, axis=axis, keepdims=True) ** (1.0 / p)
  164. return x / np.clip(norm, a_min=eps, a_max=np.inf)
  165. # # Test L-2 norm along all dimensions
  166. # opr_test(cases, F.normalize, ref_fn=np_normalize)
  167. # # Test L-1 norm along all dimensions
  168. # opr_test(cases, partial(F.normalize, p=1), ref_fn=partial(np_normalize, p=1))
  169. # Test L-2 norm along the second dimension
  170. opr_test(cases, partial(F.normalize, axis=1), ref_fn=partial(np_normalize, axis=1))
  171. # Test some norm == 0
  172. cases[0]["input"][0, 0, 0, :] = 0
  173. cases[1]["input"][0, 0, 0, :] = 0
  174. opr_test(cases, partial(F.normalize, axis=3), ref_fn=partial(np_normalize, axis=3))
  175. def test_sum_neg_axis():
  176. shape = (2, 3)
  177. data = np.random.random(shape).astype(np.float32)
  178. for axis in (-1, -2, (-2, 1), (-1, 0)):
  179. get = F.sum(Tensor(data), axis=axis)
  180. ref = np.sum(data, axis=axis)
  181. np.testing.assert_allclose(get.numpy(), ref, rtol=1e-6)
  182. with pytest.raises(AssertionError):
  183. F.sum(Tensor(data), axis=(-1, 1))
  184. def test_builtin_reduce():
  185. shape = (2, 3, 3, 2)
  186. data = np.random.random(shape).astype(np.float32)
  187. for axis in (-1, -2, 0, 1):
  188. for keepdims in (True, False):
  189. op = builtin.Reduce(mode="sum", axis=axis, keepdim=keepdims)
  190. get = apply(op, tensor(data))[0]
  191. def_op = builtin.Reduce(mode="sum", axis=axis)
  192. def_get = apply(def_op, tensor(data))[0]
  193. ref = np.sum(data, axis=axis, keepdims=keepdims)
  194. np.testing.assert_allclose(get.numpy(), ref, rtol=1e-6)
  195. if keepdims == True:
  196. np.testing.assert_allclose(def_get.numpy(), ref, rtol=1e-6)
  197. def test_non_finite():
  198. shape = (32, 3, 32, 32)
  199. data = []
  200. for i in range(2):
  201. data.append(np.random.random(shape).astype(np.float32))
  202. tensorList = [Tensor(x) for x in data]
  203. rst = F.math._check_non_finite(tensorList, 0.7)
  204. np.testing.assert_equal(rst.numpy(), [0])
  205. for i in range(len(tensorList)):
  206. np.testing.assert_allclose(tensorList[i].numpy() / 0.7, data[i], rtol=1e-6)
  207. data[1][0][0][0][0] = float("inf")
  208. rst = F.math._check_non_finite([Tensor(x) for x in data], 0.7)
  209. np.testing.assert_equal(rst.numpy(), [1])
  210. data[1][0][0][0][0] = float("nan")
  211. rst = F.math._check_non_finite([Tensor(x) for x in data], 0.7)
  212. np.testing.assert_equal(rst.numpy(), [1])
  213. @pytest.mark.parametrize("descending", [True, False])
  214. @pytest.mark.parametrize("sorted", [True, False])
  215. @pytest.mark.parametrize("inp1d", [True, False])
  216. @pytest.mark.parametrize("kth_only", [True, False])
  217. def test_topk(descending, sorted, inp1d, kth_only):
  218. k = 3
  219. if inp1d:
  220. data = np.random.permutation(7)
  221. else:
  222. data = np.random.permutation(5 * 7).reshape(5, 7)
  223. data = data.astype(np.int32)
  224. def np_sort(x):
  225. if descending:
  226. return np.sort(x)[..., ::-1]
  227. return np.sort(x)
  228. res = F.topk(
  229. Tensor(data), k, descending=descending, no_sort=(not sorted), kth_only=kth_only
  230. )
  231. values, indices = res
  232. values = values.numpy()
  233. indices = indices.numpy()
  234. if kth_only:
  235. np.testing.assert_equal(
  236. values, np.take_along_axis(data, indices[..., None], -1).squeeze(-1)
  237. )
  238. np.testing.assert_equal(values, np_sort(data)[..., k - 1])
  239. else:
  240. np.testing.assert_equal(values, np.take_along_axis(data, indices, -1))
  241. if not sorted:
  242. values = np_sort(values)
  243. np.testing.assert_equal(values, np_sort(data)[..., :k])
  244. @pytest.mark.parametrize("is_trace", [True, False])
  245. def test_reduce_on_empty_tensor(is_trace):
  246. dtypes = [np.float32, np.int32, np.bool]
  247. inputs = [
  248. (np.random.random((0,)), None),
  249. (np.random.random((3, 0, 2)), 1),
  250. (np.random.random((10, 10, 0, 10)), 0),
  251. ]
  252. def run_test(fn, ref_fn, input, dtype, axis=None, symbolic=False):
  253. if is_trace:
  254. fn = jit.trace(symbolic=symbolic)(fn)
  255. for i in range(3):
  256. out = fn(Tensor(input, dtype=dtype), axis=axis).numpy()
  257. out_ref = ref_fn(input.astype(dtype), axis=axis)
  258. np.testing.assert_equal(out, out_ref)
  259. for dtype in dtypes:
  260. for inp, axis in inputs:
  261. run_test(F.sum, np.sum, inp, dtype, axis, True)
  262. run_test(F.sum, np.sum, inp, dtype, axis, False)
  263. run_test(F.prod, np.prod, inp, dtype, axis, True)
  264. run_test(F.prod, np.prod, inp, dtype, axis, False)