You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_functional.py 17 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555
  1. # -*- coding: utf-8 -*-
  2. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  3. #
  4. # Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
  5. #
  6. # Unless required by applicable law or agreed to in writing,
  7. # software distributed under the License is distributed on an
  8. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  9. import itertools
  10. import numpy as np
  11. import pytest
  12. from utils import opr_test
  13. import megengine.core.ops.builtin as builtin
  14. import megengine.core.tensor.dtype as dtype
  15. import megengine.functional as F
  16. from megengine import Parameter, Tensor, is_cuda_available, tensor
  17. from megengine.core._trace_option import use_tensor_shape
  18. from megengine.core.autodiff.grad import Grad
  19. from megengine.core.tensor.utils import make_shape_tuple
  20. def test_where():
  21. maskv0 = np.array([[1, 0], [0, 1]], dtype=np.bool_)
  22. xv0 = np.array([[1, np.inf], [np.nan, 4]], dtype=np.float32)
  23. yv0 = np.array([[5, 6], [7, 8]], dtype=np.float32)
  24. maskv1 = np.array([[1, 0, 1], [1, 0, 0], [1, 1, 0]], dtype=np.bool_)
  25. xv1 = np.array([[1, np.inf, 2], [0, np.nan, 4], [1, 5, 7]], dtype=np.float32)
  26. yv1 = np.array([[5, 6, 9], [2, 7, 8], [2, 1, 9]], dtype=np.float32)
  27. cases = [
  28. {"input": [maskv0, xv0, yv0]},
  29. {"input": [maskv1, xv1, yv1]},
  30. ]
  31. opr_test(cases, F.where, ref_fn=np.where)
  32. maskv2 = np.array([1, 1, 1], dtype=np.bool_)
  33. xv2 = np.array([1, 3, 2], dtype=np.float32)
  34. yv2 = np.array([5, 6, 9], dtype=np.float32)
  35. maskv3 = np.array([0, 0, 0], dtype=np.bool_)
  36. xv3 = np.array([1, 3, 2], dtype=np.float32)
  37. yv3 = np.array([5, 6, 9], dtype=np.float32)
  38. cases = [
  39. {"input": [maskv2, xv2, yv2]},
  40. {"input": [maskv3, xv3, yv3]},
  41. ]
  42. opr_test(cases, F.where, ref_fn=np.where)
  43. def test_dropout():
  44. data = tensor(np.ones(10, dtype=np.float32))
  45. out = F.dropout(data, 1.0 / 3.0, training=False)
  46. assert out.numpy().sum() >= 0.0
  47. def test_matmul():
  48. shape1 = 3
  49. shape2 = 3
  50. shape3 = (3, 5)
  51. shape4 = (5, 6)
  52. data1 = np.random.random(shape1).astype("float32")
  53. data2 = np.random.random(shape2).astype("float32")
  54. data3 = np.random.random(shape3).astype("float32")
  55. data4 = np.random.random(shape4).astype("float32")
  56. cases = [
  57. {"input": [data1, data2]},
  58. {"input": [data2, data3]},
  59. {"input": [data3, data4]},
  60. ]
  61. opr_test(cases, F.matmul, ref_fn=np.matmul)
  62. batch_size = 10
  63. shape1 = (batch_size, 2, 3)
  64. shape2 = (batch_size, 3, 4)
  65. shape3 = (batch_size, 10, 4, 5)
  66. data1 = np.random.random(shape1).astype("float32")
  67. data2 = np.random.random(shape2).astype("float32")
  68. data3 = np.random.random(shape3).astype("float32")
  69. cases = [{"input": [data1, data2]}, {"input": [data2, data3]}]
  70. for i in range(0, batch_size):
  71. def compare_fn(x, y):
  72. x.numpy()[i, ...] == y
  73. opr_test(
  74. cases,
  75. F.matmul,
  76. compare_fn=compare_fn,
  77. ref_fn=lambda x, y: np.matmul(x[i, ...], y[i, ...]),
  78. )
  79. def test_interpolate():
  80. def linear_interpolate():
  81. inp = tensor(np.arange(1, 3, dtype=np.float32).reshape(1, 1, 2))
  82. out = F.interpolate(inp, scale_factor=2.0, mode="LINEAR")
  83. out2 = F.interpolate(inp, 4, mode="LINEAR")
  84. np.testing.assert_allclose(
  85. out.numpy(), np.array([[[1.0, 1.25, 1.75, 2.0]]], dtype=np.float32)
  86. )
  87. np.testing.assert_allclose(
  88. out2.numpy(), np.array([[[1.0, 1.25, 1.75, 2.0]]], dtype=np.float32)
  89. )
  90. def many_batch_interpolate():
  91. inp = tensor(np.arange(1, 9, dtype=np.float32).reshape(2, 1, 2, 2))
  92. out = F.interpolate(inp, [4, 4])
  93. out2 = F.interpolate(inp, scale_factor=2.0)
  94. np.testing.assert_allclose(out.numpy(), out2.numpy())
  95. def assign_corner_interpolate():
  96. inp = tensor(np.arange(1, 5, dtype=np.float32).reshape(1, 1, 2, 2))
  97. out = F.interpolate(inp, [4, 4], align_corners=True)
  98. out2 = F.interpolate(inp, scale_factor=2.0, align_corners=True)
  99. np.testing.assert_allclose(out.numpy(), out2.numpy())
  100. def error_shape_linear_interpolate():
  101. inp = tensor(np.arange(1, 5, dtype=np.float32).reshape(1, 1, 2, 2))
  102. with pytest.raises(ValueError):
  103. F.interpolate(inp, scale_factor=2.0, mode="LINEAR")
  104. def inappropriate_scale_linear_interpolate():
  105. inp = tensor(np.arange(1, 3, dtype=np.float32).reshape(1, 1, 2))
  106. with pytest.raises(ValueError):
  107. F.interpolate(inp, scale_factor=[2.0, 3.0], mode="LINEAR")
  108. linear_interpolate()
  109. many_batch_interpolate()
  110. assign_corner_interpolate()
  111. error_shape_linear_interpolate()
  112. inappropriate_scale_linear_interpolate()
  113. def _save_to(self, name="grad"):
  114. def callback(tensor, grad):
  115. setattr(self, name, grad)
  116. return callback
  117. def _gen_roi_inp():
  118. inp_feat = np.random.randn(2, 32, 256, 256)
  119. rois = np.zeros((4, 5))
  120. rois[:, 0] = [0, 0, 1, 1]
  121. rois[:, 1:3] = np.random.rand(4, 2) * 100
  122. rois[:, 3:] = np.random.rand(4, 2) * 100 + 150
  123. inp_feat = tensor(inp_feat)
  124. rois = tensor(rois)
  125. return inp_feat, rois
  126. def test_roi_align():
  127. inp_feat, rois = _gen_roi_inp()
  128. grad = Grad().wrt(inp_feat, callback=_save_to(inp_feat))
  129. output_shape = (7, 7)
  130. out_feat = F.roi_align(
  131. inp_feat,
  132. rois,
  133. output_shape=output_shape,
  134. mode="average",
  135. spatial_scale=1.0 / 4,
  136. sample_points=2,
  137. aligned=True,
  138. )
  139. assert make_shape_tuple(out_feat.shape) == (
  140. rois.shape[0],
  141. inp_feat.shape[1],
  142. *output_shape,
  143. )
  144. grad(out_feat, tensor(F.ones_like(out_feat)))
  145. assert make_shape_tuple(inp_feat.grad.shape) == make_shape_tuple(inp_feat.shape)
  146. def test_roi_pooling():
  147. inp_feat, rois = _gen_roi_inp()
  148. grad = Grad().wrt(inp_feat, callback=_save_to(inp_feat))
  149. output_shape = (7, 7)
  150. out_feat = F.roi_pooling(
  151. inp_feat, rois, output_shape=output_shape, mode="max", scale=1.0 / 4,
  152. )
  153. assert make_shape_tuple(out_feat.shape) == (
  154. rois.shape[0],
  155. inp_feat.shape[1],
  156. *output_shape,
  157. )
  158. grad(out_feat, tensor(F.ones_like(out_feat)))
  159. assert make_shape_tuple(inp_feat.grad.shape) == make_shape_tuple(inp_feat.shape)
  160. def test_adaptive_avg_pool2d():
  161. inp = tensor(np.arange(0, 16, dtype=np.float32).reshape(1, 1, 4, 4))
  162. oshp = (2, 2)
  163. grad = Grad().wrt(inp, callback=_save_to(inp))
  164. outp = F.adaptive_avg_pool2d(inp, oshp,)
  165. assert make_shape_tuple(outp.shape) == (inp.shape[0], inp.shape[1], *oshp,)
  166. np.testing.assert_equal(
  167. outp.numpy(), np.array([[[[2.5, 4.5], [10.5, 12.5]]]], dtype=np.float32)
  168. )
  169. grad(outp, tensor(F.ones_like(outp)))
  170. assert make_shape_tuple(inp.grad.shape) == make_shape_tuple(inp.shape)
  171. np.testing.assert_equal(
  172. inp.grad.numpy(),
  173. np.array(
  174. [
  175. [
  176. [
  177. [0.25, 0.25, 0.25, 0.25],
  178. [0.25, 0.25, 0.25, 0.25],
  179. [0.25, 0.25, 0.25, 0.25],
  180. [0.25, 0.25, 0.25, 0.25],
  181. ]
  182. ]
  183. ],
  184. dtype=np.float32,
  185. ),
  186. )
  187. def test_adaptive_max_pool2d():
  188. inp = tensor(np.arange(0, 16, dtype=np.float32).reshape(1, 1, 4, 4))
  189. oshp = (2, 2)
  190. grad = Grad().wrt(inp, callback=_save_to(inp))
  191. outp = F.adaptive_max_pool2d(inp, oshp,)
  192. assert make_shape_tuple(outp.shape) == (inp.shape[0], inp.shape[1], *oshp,)
  193. np.testing.assert_equal(
  194. outp.numpy(), np.array([[[[5, 7], [13, 15]]]], dtype=np.float32)
  195. )
  196. grad(outp, tensor(F.ones_like(outp)))
  197. assert make_shape_tuple(inp.grad.shape) == make_shape_tuple(inp.shape)
  198. np.testing.assert_equal(
  199. inp.grad.numpy(),
  200. np.array(
  201. [
  202. [
  203. [
  204. [0.0, 0.0, 0.0, 0.0],
  205. [0.0, 1.0, 0.0, 1.0],
  206. [0.0, 0.0, 0.0, 0.0],
  207. [0.0, 1.0, 0.0, 1.0],
  208. ]
  209. ]
  210. ],
  211. dtype=np.float32,
  212. ),
  213. )
  214. def test_one_hot():
  215. def onehot_low_dimension():
  216. inp = tensor(np.arange(1, 4, dtype=np.int32))
  217. out = F.one_hot(inp, num_classes=4)
  218. np.testing.assert_allclose(
  219. out.numpy(), np.eye(4, dtype=np.int32)[np.arange(1, 4, dtype=np.int32)]
  220. )
  221. def onehot_high_dimension():
  222. arr = np.array(
  223. [[3, 2, 4, 4, 2, 4, 0, 4, 4, 1], [4, 1, 1, 3, 2, 2, 4, 2, 4, 3]],
  224. dtype=np.int32,
  225. )
  226. inp = tensor(arr)
  227. out = F.one_hot(inp, 10)
  228. np.testing.assert_allclose(out.numpy(), np.eye(10, dtype=np.int32)[arr])
  229. onehot_low_dimension()
  230. onehot_high_dimension()
  231. def test_add_update():
  232. shape = (2, 3)
  233. v = np.random.random(shape).astype(np.float32)
  234. b = Tensor(v)
  235. u = F.add_update(b, 1)
  236. np.testing.assert_allclose(u.numpy(), v + 1, atol=1e-6)
  237. u = F.add_update(b, 1)
  238. np.testing.assert_allclose(u.numpy(), v + 2, atol=1e-6)
  239. x = np.ones((2, 2), dtype=np.float32)
  240. y = x * 0.5
  241. dest = tensor(x)
  242. delta = tensor(y)
  243. r = F.add_update(dest, delta, alpha=0.9, beta=0.1, bias=0.1)
  244. np.testing.assert_allclose(r.numpy(), x * 0.9 + y * 0.1 + 0.1, atol=1e-6)
  245. def test_add_update_params():
  246. b = np.random.random((2, 3)).astype(np.float32)
  247. y = Tensor(b)
  248. # @jit.trace
  249. def f(x):
  250. return F.add_update(y, x)
  251. f(np.zeros((2, 3)).astype(np.float32))
  252. z = Tensor(np.zeros((2, 3)).astype(np.float32))
  253. F.add_update(y, z, beta=0.1)
  254. res = f(np.ones((2, 3)).astype(np.float32))
  255. np.testing.assert_allclose(res.numpy(), b + 1)
  256. def test_binary_cross_entropy():
  257. data1_shape = (2, 2)
  258. label1_shape = (2, 2)
  259. data2_shape = (2, 3)
  260. label2_shape = (2, 3)
  261. def sigmoid(x):
  262. return 1 / (1 + np.exp(-x))
  263. def compare_fn(x, y):
  264. np.testing.assert_allclose(x.numpy(), y, atol=5e-4)
  265. np.random.seed(123)
  266. data1 = sigmoid(np.random.uniform(size=data1_shape).astype(np.float32))
  267. label1 = np.random.uniform(size=label1_shape).astype(np.float32)
  268. expect1 = np.array([0.6361], dtype=np.float32)
  269. np.random.seed(123)
  270. data2 = sigmoid(np.random.uniform(size=data2_shape).astype(np.float32))
  271. label2 = np.random.uniform(size=label2_shape).astype(np.float32)
  272. expect2 = np.array([0.6750], dtype=np.float32)
  273. cases = [
  274. {"input": [data1, label1], "output": expect1,},
  275. {"input": [data2, label2], "output": expect2,},
  276. ]
  277. opr_test(cases, F.binary_cross_entropy, compare_fn=compare_fn)
  278. def test_hinge_loss():
  279. np.random.seed(123)
  280. # case with L1 norm
  281. cases = []
  282. for shape in [(2, 2), (2, 3)]:
  283. data = np.random.uniform(size=shape).astype(np.float32)
  284. label = 2 * np.random.randint(0, 1, size=shape).astype(np.float32) - 1
  285. expect = np.clip(0, np.inf, 1 - data * label).sum(axis=1).mean()
  286. cases.append({"input": [data, label], "output": expect})
  287. opr_test(cases, F.hinge_loss)
  288. # cases with L2 norm
  289. cases = []
  290. for shape in [(2, 2), (2, 3)]:
  291. data = np.random.uniform(size=shape).astype(np.float32)
  292. label = 2 * np.random.randint(0, 1, size=shape).astype(np.float32) - 1
  293. expect = ((np.clip(0, np.inf, 1 - data * label) ** 2).sum(axis=1)).mean()
  294. cases.append({"input": [data, label], "output": expect})
  295. def hinge_loss_with_l2_norm(pred, label):
  296. return F.hinge_loss(pred, label, "L2")
  297. opr_test(cases, hinge_loss_with_l2_norm)
  298. def test_nms():
  299. x = np.array(
  300. [
  301. [0, 0, 100, 100],
  302. [10, 10, 100, 100],
  303. [50, 50, 100, 100],
  304. [100, 100, 150, 150],
  305. ],
  306. dtype=np.float32,
  307. )
  308. inp = tensor(x)
  309. scores = tensor([0.5, 0.8, 0.9, 0.6], dtype=np.float32)
  310. result = F.nms(inp, scores=scores, iou_thresh=0.5)
  311. np.testing.assert_equal(result.numpy(), np.array([2, 1, 3], dtype=np.int32))
  312. def test_batched_nms():
  313. x = np.array(
  314. [
  315. [0, 0, 100, 100],
  316. [0.5, 0.5, 1.5, 1.5],
  317. [20, 20, 100, 100],
  318. [0.5, 0.5, 1.0, 1.0],
  319. [10, 10, 100, 100],
  320. [0.5, 0.5, 1.0, 1.0],
  321. ],
  322. dtype=np.float32,
  323. )
  324. inp = tensor(x)
  325. scores = tensor([0.6, 0.9, 0.5, 0.6, 0.8, 0.7], dtype=np.float32)
  326. idxs = tensor([0, 1, 0, 1, 0, 1], dtype=np.int32)
  327. results = F.batched_nms(inp, scores=scores, idxs=idxs, iou_thresh=0.5)
  328. np.testing.assert_equal(results.numpy(), np.array([1, 4, 5], dtype=np.int32))
  329. @pytest.mark.skip(reason="cuda does not support nchw int8")
  330. def test_conv_bias():
  331. inp_scale = 1.5
  332. w_scale = 2.5
  333. outp_scale = 1.5
  334. inp_dtype = dtype.qint8(inp_scale)
  335. w_dtype = dtype.qint8(w_scale)
  336. b_dtype = dtype.qint32(inp_scale * w_scale)
  337. out_dtype = dtype.qint8(outp_scale)
  338. def run(
  339. N,
  340. IC,
  341. OC,
  342. IH,
  343. IW,
  344. KH,
  345. KW,
  346. PH,
  347. PW,
  348. SH,
  349. SW,
  350. has_bias=True,
  351. nonlinear_mode="IDENTITY",
  352. ):
  353. inp_v = np.random.normal(size=(N, IC, IH, IW))
  354. w_v = np.random.normal(size=(OC, IC, KW, KW))
  355. b_v = np.random.normal(size=(1, OC, 1, 1))
  356. inp_scale = dtype.get_scale(inp_dtype)
  357. w_scale = dtype.get_scale(w_dtype)
  358. b_scale = dtype.get_scale(b_dtype)
  359. inpv = dtype.convert_to_qint8(inp_v * inp_scale, inp_dtype)
  360. wv = dtype.convert_to_qint8(w_v * w_scale, w_dtype)
  361. bv = dtype.convert_to_qint32(b_v * b_scale, b_dtype)
  362. inp_int8 = tensor(inpv, dtype=inp_dtype)
  363. w_int8 = Parameter(wv, dtype=w_dtype)
  364. b_int32 = Parameter(bv, dtype=b_dtype)
  365. inp_fp32 = inp_int8.astype("float32")
  366. w_fp32 = w_int8.astype("float32")
  367. b_fp32 = b_int32.astype("float32")
  368. def convert_to_nchw4(var):
  369. var = F.reshape(
  370. var, (var.shape[0], var.shape[1] // 4, 4, var.shape[2], var.shape[3])
  371. )
  372. var = F.transpose(var, (0, 1, 3, 4, 2))
  373. return var
  374. def run_conv2d(inp, w, b):
  375. O = F.conv2d(
  376. inp, w, b if has_bias else None, stride=(SH, SW), padding=(PH, PW),
  377. )
  378. if nonlinear_mode == "RELU":
  379. return F.relu(O)
  380. else:
  381. return O
  382. def run_conv_bias(inp, w, b, format="NCHW"):
  383. b = b if has_bias else Parameter(np.zeros_like(b.numpy()))
  384. if format == "NCHW4":
  385. inp = convert_to_nchw4(inp)
  386. w = convert_to_nchw4(w)
  387. b = convert_to_nchw4(b)
  388. return F.conv_bias_activation(
  389. inp,
  390. w,
  391. b,
  392. stride=(SH, SW),
  393. padding=(PH, PW),
  394. format=format,
  395. dtype=out_dtype,
  396. nonlinear_mode=nonlinear_mode,
  397. )
  398. format = "NCHW4" if is_cuda_available() else "NCHW"
  399. expected = run_conv2d(inp_fp32, w_fp32, b_fp32)
  400. expected = expected.astype(out_dtype).astype("float32")
  401. result = run_conv_bias(inp_int8, w_int8, b_int32, format=format).astype(
  402. "float32"
  403. )
  404. if format == "NCHW4":
  405. result = F.transpose(result, (0, 1, 4, 2, 3))
  406. expected = F.flatten(expected)
  407. result = F.flatten(result)
  408. np.testing.assert_allclose(result.numpy(), expected.numpy(), atol=outp_scale)
  409. run(1, 4, 4, 24, 33, 1, 1, 2, 3, 1, 1, False)
  410. run(10, 12, 24, 46, 46, 1, 1, 2, 1, 3, 1, False)
  411. run(10, 36, 8, 46, 26, 2, 2, 2, 1, 1, 2, False)
  412. run(1, 4, 4, 24, 33, 1, 1, 2, 3, 1, 1)
  413. run(10, 12, 24, 46, 46, 1, 1, 2, 1, 3, 1)
  414. run(10, 36, 8, 46, 26, 2, 2, 2, 1, 1, 2)
  415. run(10, 36, 8, 46, 26, 2, 2, 2, 1, 1, 2, False, "RELU")
  416. run(10, 36, 8, 46, 26, 2, 2, 2, 1, 1, 2, True, "RELU")
  417. def test_zero_stride_numpy_array():
  418. inp = np.random.randn(3, 224, 224).astype(np.float32)
  419. inp = inp[np.newaxis, :]
  420. inp = tensor(inp, dtype=np.float32)
  421. weight = tensor(np.random.randn(16, 3, 3, 3), dtype=np.float32)
  422. out = F.conv2d(inp, weight, None, (2, 2), (3, 3), (1, 1), 1)
  423. def test_condtake():
  424. x = np.array([[1, 2, 3], [4, 5, 6]])
  425. y = np.array([[True, False, True], [False, True, True]])
  426. xx = tensor(x)
  427. yy = tensor(y)
  428. val, idx = F.cond_take(yy, xx)
  429. np.testing.assert_equal(val.numpy(), x[y])
  430. np.testing.assert_equal(idx.numpy(), np.where(y.reshape(-1))[0])
  431. def test_condtake_is_same():
  432. op1 = builtin.CondTake()
  433. op2 = builtin.CondTake()
  434. assert op1 == op2
  435. def test_nms_is_same():
  436. op1 = builtin.NMSKeep(0.7, 100)
  437. op2 = builtin.NMSKeep(0.7, 100)
  438. op3 = builtin.NMSKeep(0.8, 100)
  439. op4 = builtin.NMSKeep(0.7, 200)
  440. assert op1 == op2
  441. assert op1 != op3
  442. assert op1 != op4
  443. assert op3 != op4

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台