You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_functional.py 16 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512
  1. # -*- coding: utf-8 -*-
  2. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  3. #
  4. # Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
  5. #
  6. # Unless required by applicable law or agreed to in writing,
  7. # software distributed under the License is distributed on an
  8. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  9. import itertools
  10. from functools import partial
  11. import numpy as np
  12. import pytest
  13. from utils import opr_test
  14. import megengine.core.ops.builtin as builtin
  15. import megengine.core.tensor.dtype as dtype
  16. import megengine.functional as F
  17. from megengine import Parameter, Tensor, is_cuda_available, tensor
  18. from megengine.core._trace_option import use_symbolic_shape
  19. from megengine.core.autodiff.grad import Grad
  20. from megengine.core.tensor.utils import make_shape_tuple
  21. def test_where():
  22. maskv0 = np.array([[1, 0], [0, 1]], dtype=np.bool_)
  23. xv0 = np.array([[1, np.inf], [np.nan, 4]], dtype=np.float32)
  24. yv0 = np.array([[5, 6], [7, 8]], dtype=np.float32)
  25. maskv1 = np.array([[1, 0, 1], [1, 0, 0], [1, 1, 0]], dtype=np.bool_)
  26. xv1 = np.array([[1, np.inf, 2], [0, np.nan, 4], [1, 5, 7]], dtype=np.float32)
  27. yv1 = np.array([[5, 6, 9], [2, 7, 8], [2, 1, 9]], dtype=np.float32)
  28. cases = [
  29. {"input": [maskv0, xv0, yv0]},
  30. {"input": [maskv1, xv1, yv1]},
  31. ]
  32. opr_test(cases, F.where, ref_fn=np.where)
  33. maskv2 = np.array([1, 1, 1], dtype=np.bool_)
  34. xv2 = np.array([1, 3, 2], dtype=np.float32)
  35. yv2 = np.array([5, 6, 9], dtype=np.float32)
  36. maskv3 = np.array([0, 0, 0], dtype=np.bool_)
  37. xv3 = np.array([1, 3, 2], dtype=np.float32)
  38. yv3 = np.array([5, 6, 9], dtype=np.float32)
  39. cases = [
  40. {"input": [maskv2, xv2, yv2]},
  41. {"input": [maskv3, xv3, yv3]},
  42. ]
  43. opr_test(cases, F.where, ref_fn=np.where)
  44. def test_dropout():
  45. data = tensor(np.ones(10, dtype=np.float32))
  46. out = F.dropout(data, 1.0 / 3.0, training=False)
  47. assert out.numpy().sum() >= 0.0
  48. def test_matmul():
  49. shape1 = 3
  50. shape2 = 3
  51. shape3 = (3, 5)
  52. shape4 = (5, 6)
  53. data1 = np.random.random(shape1).astype("float32")
  54. data2 = np.random.random(shape2).astype("float32")
  55. data3 = np.random.random(shape3).astype("float32")
  56. data4 = np.random.random(shape4).astype("float32")
  57. cases = [
  58. {"input": [data1, data2]},
  59. {"input": [data2, data3]},
  60. {"input": [data3, data4]},
  61. ]
  62. opr_test(cases, F.matmul, ref_fn=np.matmul)
  63. batch_size = 10
  64. shape1 = (batch_size, 2, 3)
  65. shape2 = (batch_size, 3, 4)
  66. shape3 = (batch_size, 10, 4, 5)
  67. data1 = np.random.random(shape1).astype("float32")
  68. data2 = np.random.random(shape2).astype("float32")
  69. data3 = np.random.random(shape3).astype("float32")
  70. cases = [{"input": [data1, data2]}, {"input": [data2, data3]}]
  71. for i in range(0, batch_size):
  72. def compare_fn(x, y):
  73. x.numpy()[i, ...] == y
  74. opr_test(
  75. cases,
  76. F.matmul,
  77. compare_fn=compare_fn,
  78. ref_fn=lambda x, y: np.matmul(x[i, ...], y[i, ...]),
  79. )
  80. def test_interpolate():
  81. def linear_interpolate():
  82. inp = tensor(np.arange(1, 3, dtype=np.float32).reshape(1, 1, 2))
  83. out = F.nn.interpolate(inp, scale_factor=2.0, mode="LINEAR")
  84. out2 = F.nn.interpolate(inp, 4, mode="LINEAR")
  85. np.testing.assert_allclose(
  86. out.numpy(), np.array([[[1.0, 1.25, 1.75, 2.0]]], dtype=np.float32)
  87. )
  88. np.testing.assert_allclose(
  89. out2.numpy(), np.array([[[1.0, 1.25, 1.75, 2.0]]], dtype=np.float32)
  90. )
  91. def many_batch_interpolate():
  92. inp = tensor(np.arange(1, 9, dtype=np.float32).reshape(2, 1, 2, 2))
  93. out = F.nn.interpolate(inp, [4, 4])
  94. out2 = F.nn.interpolate(inp, scale_factor=2.0)
  95. np.testing.assert_allclose(out.numpy(), out2.numpy())
  96. def assign_corner_interpolate():
  97. inp = tensor(np.arange(1, 5, dtype=np.float32).reshape(1, 1, 2, 2))
  98. out = F.nn.interpolate(inp, [4, 4], align_corners=True)
  99. out2 = F.nn.interpolate(inp, scale_factor=2.0, align_corners=True)
  100. np.testing.assert_allclose(out.numpy(), out2.numpy())
  101. def error_shape_linear_interpolate():
  102. inp = tensor(np.arange(1, 5, dtype=np.float32).reshape(1, 1, 2, 2))
  103. with pytest.raises(ValueError):
  104. F.nn.interpolate(inp, scale_factor=2.0, mode="LINEAR")
  105. def inappropriate_scale_linear_interpolate():
  106. inp = tensor(np.arange(1, 3, dtype=np.float32).reshape(1, 1, 2))
  107. with pytest.raises(ValueError):
  108. F.nn.interpolate(inp, scale_factor=[2.0, 3.0], mode="LINEAR")
  109. linear_interpolate()
  110. many_batch_interpolate()
  111. assign_corner_interpolate()
  112. error_shape_linear_interpolate()
  113. inappropriate_scale_linear_interpolate()
  114. def _save_to(self, name="grad"):
  115. def callback(tensor, grad):
  116. setattr(self, name, grad)
  117. return callback
  118. def _gen_roi_inp():
  119. inp_feat = np.random.randn(2, 32, 256, 256)
  120. rois = np.zeros((4, 5))
  121. rois[:, 0] = [0, 0, 1, 1]
  122. rois[:, 1:3] = np.random.rand(4, 2) * 100
  123. rois[:, 3:] = np.random.rand(4, 2) * 100 + 150
  124. inp_feat = tensor(inp_feat)
  125. rois = tensor(rois)
  126. return inp_feat, rois
  127. def test_roi_align():
  128. inp_feat, rois = _gen_roi_inp()
  129. grad = Grad().wrt(inp_feat, callback=_save_to(inp_feat))
  130. output_shape = (7, 7)
  131. out_feat = F.nn.roi_align(
  132. inp_feat,
  133. rois,
  134. output_shape=output_shape,
  135. mode="average",
  136. spatial_scale=1.0 / 4,
  137. sample_points=2,
  138. aligned=True,
  139. )
  140. assert make_shape_tuple(out_feat.shape) == (
  141. rois.shape[0],
  142. inp_feat.shape[1],
  143. *output_shape,
  144. )
  145. grad(out_feat, tensor(F.ones_like(out_feat)))
  146. assert make_shape_tuple(inp_feat.grad.shape) == make_shape_tuple(inp_feat.shape)
  147. def test_roi_pooling():
  148. inp_feat, rois = _gen_roi_inp()
  149. grad = Grad().wrt(inp_feat, callback=_save_to(inp_feat))
  150. output_shape = (7, 7)
  151. out_feat = F.nn.roi_pooling(
  152. inp_feat, rois, output_shape=output_shape, mode="max", scale=1.0 / 4,
  153. )
  154. assert make_shape_tuple(out_feat.shape) == (
  155. rois.shape[0],
  156. inp_feat.shape[1],
  157. *output_shape,
  158. )
  159. grad(out_feat, tensor(F.ones_like(out_feat)))
  160. assert make_shape_tuple(inp_feat.grad.shape) == make_shape_tuple(inp_feat.shape)
  161. def test_adaptive_avg_pool2d():
  162. inp = tensor(np.arange(0, 16, dtype=np.float32).reshape(1, 1, 4, 4))
  163. oshp = (2, 2)
  164. grad = Grad().wrt(inp, callback=_save_to(inp))
  165. outp = F.adaptive_avg_pool2d(inp, oshp,)
  166. assert make_shape_tuple(outp.shape) == (inp.shape[0], inp.shape[1], *oshp,)
  167. np.testing.assert_equal(
  168. outp.numpy(), np.array([[[[2.5, 4.5], [10.5, 12.5]]]], dtype=np.float32)
  169. )
  170. grad(outp, tensor(F.ones_like(outp)))
  171. assert make_shape_tuple(inp.grad.shape) == make_shape_tuple(inp.shape)
  172. np.testing.assert_equal(
  173. inp.grad.numpy(),
  174. np.array(
  175. [
  176. [
  177. [
  178. [0.25, 0.25, 0.25, 0.25],
  179. [0.25, 0.25, 0.25, 0.25],
  180. [0.25, 0.25, 0.25, 0.25],
  181. [0.25, 0.25, 0.25, 0.25],
  182. ]
  183. ]
  184. ],
  185. dtype=np.float32,
  186. ),
  187. )
  188. def test_adaptive_max_pool2d():
  189. inp = tensor(np.arange(0, 16, dtype=np.float32).reshape(1, 1, 4, 4))
  190. oshp = (2, 2)
  191. grad = Grad().wrt(inp, callback=_save_to(inp))
  192. outp = F.adaptive_max_pool2d(inp, oshp,)
  193. assert make_shape_tuple(outp.shape) == (inp.shape[0], inp.shape[1], *oshp,)
  194. np.testing.assert_equal(
  195. outp.numpy(), np.array([[[[5, 7], [13, 15]]]], dtype=np.float32)
  196. )
  197. grad(outp, tensor(F.ones_like(outp)))
  198. assert make_shape_tuple(inp.grad.shape) == make_shape_tuple(inp.shape)
  199. np.testing.assert_equal(
  200. inp.grad.numpy(),
  201. np.array(
  202. [
  203. [
  204. [
  205. [0.0, 0.0, 0.0, 0.0],
  206. [0.0, 1.0, 0.0, 1.0],
  207. [0.0, 0.0, 0.0, 0.0],
  208. [0.0, 1.0, 0.0, 1.0],
  209. ]
  210. ]
  211. ],
  212. dtype=np.float32,
  213. ),
  214. )
  215. def test_one_hot():
  216. def onehot_low_dimension():
  217. inp = tensor(np.arange(1, 4, dtype=np.int32))
  218. out = F.one_hot(inp, num_classes=4)
  219. np.testing.assert_allclose(
  220. out.numpy(), np.eye(4, dtype=np.int32)[np.arange(1, 4, dtype=np.int32)]
  221. )
  222. def onehot_high_dimension():
  223. arr = np.array(
  224. [[3, 2, 4, 4, 2, 4, 0, 4, 4, 1], [4, 1, 1, 3, 2, 2, 4, 2, 4, 3]],
  225. dtype=np.int32,
  226. )
  227. inp = tensor(arr)
  228. out = F.one_hot(inp, 10)
  229. np.testing.assert_allclose(out.numpy(), np.eye(10, dtype=np.int32)[arr])
  230. onehot_low_dimension()
  231. onehot_high_dimension()
  232. def test_binary_cross_entropy():
  233. data1_shape = (2, 2)
  234. label1_shape = (2, 2)
  235. data2_shape = (2, 3)
  236. label2_shape = (2, 3)
  237. def sigmoid(x):
  238. return 1 / (1 + np.exp(-x))
  239. def compare_fn(x, y):
  240. np.testing.assert_allclose(x.numpy(), y, atol=5e-4)
  241. np.random.seed(123)
  242. data1 = np.random.uniform(size=data1_shape).astype(np.float32)
  243. label1 = np.random.uniform(size=label1_shape).astype(np.float32)
  244. expect1 = np.array([0.6361], dtype=np.float32)
  245. np.random.seed(123)
  246. data2 = np.random.uniform(size=data2_shape).astype(np.float32)
  247. label2 = np.random.uniform(size=label2_shape).astype(np.float32)
  248. expect2 = np.array([0.6750], dtype=np.float32)
  249. cases = [
  250. {"input": [data1, label1], "output": expect1,},
  251. {"input": [data2, label2], "output": expect2,},
  252. ]
  253. opr_test(cases, F.nn.binary_cross_entropy, compare_fn=compare_fn)
  254. cases = [
  255. {"input": [sigmoid(data1), label1], "output": expect1,},
  256. {"input": [sigmoid(data2), label2], "output": expect2,},
  257. ]
  258. opr_test(
  259. cases,
  260. partial(F.nn.binary_cross_entropy, with_logits=False),
  261. compare_fn=compare_fn,
  262. )
  263. def test_hinge_loss():
  264. np.random.seed(123)
  265. # case with L1 norm
  266. cases = []
  267. for shape in [(2, 2), (2, 3)]:
  268. data = np.random.uniform(size=shape).astype(np.float32)
  269. label = 2 * np.random.randint(0, 1, size=shape).astype(np.float32) - 1
  270. expect = np.clip(0, np.inf, 1 - data * label).sum(axis=1).mean()
  271. cases.append({"input": [data, label], "output": expect})
  272. opr_test(cases, F.nn.hinge_loss)
  273. # cases with L2 norm
  274. cases = []
  275. for shape in [(2, 2), (2, 3)]:
  276. data = np.random.uniform(size=shape).astype(np.float32)
  277. label = 2 * np.random.randint(0, 1, size=shape).astype(np.float32) - 1
  278. expect = ((np.clip(0, np.inf, 1 - data * label) ** 2).sum(axis=1)).mean()
  279. cases.append({"input": [data, label], "output": expect})
  280. def hinge_loss_with_l2_norm(pred, label):
  281. return F.nn.hinge_loss(pred, label, "L2")
  282. opr_test(cases, hinge_loss_with_l2_norm)
  283. def test_nms():
  284. x = np.array(
  285. [
  286. [0, 0, 100, 100],
  287. [10, 10, 100, 100],
  288. [50, 50, 100, 100],
  289. [100, 100, 150, 150],
  290. ],
  291. dtype=np.float32,
  292. )
  293. inp = tensor(x)
  294. scores = tensor([0.5, 0.8, 0.9, 0.6], dtype=np.float32)
  295. result = F.nn.nms(inp, scores=scores, iou_thresh=0.5)
  296. np.testing.assert_equal(result.numpy(), np.array([2, 1, 3], dtype=np.int32))
  297. @pytest.mark.skip(reason="cuda does not support nchw int8")
  298. def test_conv_bias():
  299. inp_scale = 1.5
  300. w_scale = 2.5
  301. outp_scale = 1.5
  302. inp_dtype = dtype.qint8(inp_scale)
  303. w_dtype = dtype.qint8(w_scale)
  304. b_dtype = dtype.qint32(inp_scale * w_scale)
  305. out_dtype = dtype.qint8(outp_scale)
  306. def run(
  307. N,
  308. IC,
  309. OC,
  310. IH,
  311. IW,
  312. KH,
  313. KW,
  314. PH,
  315. PW,
  316. SH,
  317. SW,
  318. has_bias=True,
  319. nonlinear_mode="IDENTITY",
  320. ):
  321. inp_v = np.random.normal(size=(N, IC, IH, IW))
  322. w_v = np.random.normal(size=(OC, IC, KW, KW))
  323. b_v = np.random.normal(size=(1, OC, 1, 1))
  324. inp_scale = dtype.get_scale(inp_dtype)
  325. w_scale = dtype.get_scale(w_dtype)
  326. b_scale = dtype.get_scale(b_dtype)
  327. inpv = dtype.convert_to_qint8(inp_v * inp_scale, inp_dtype)
  328. wv = dtype.convert_to_qint8(w_v * w_scale, w_dtype)
  329. bv = dtype.convert_to_qint32(b_v * b_scale, b_dtype)
  330. inp_int8 = tensor(inpv, dtype=inp_dtype)
  331. w_int8 = Parameter(wv, dtype=w_dtype)
  332. b_int32 = Parameter(bv, dtype=b_dtype)
  333. inp_fp32 = inp_int8.astype("float32")
  334. w_fp32 = w_int8.astype("float32")
  335. b_fp32 = b_int32.astype("float32")
  336. def convert_to_nchw4(var):
  337. var = F.reshape(
  338. var, (var.shape[0], var.shape[1] // 4, 4, var.shape[2], var.shape[3])
  339. )
  340. var = F.transpose(var, (0, 1, 3, 4, 2))
  341. return var
  342. def run_conv2d(inp, w, b):
  343. O = F.conv2d(
  344. inp, w, b if has_bias else None, stride=(SH, SW), padding=(PH, PW),
  345. )
  346. if nonlinear_mode == "RELU":
  347. return F.relu(O)
  348. else:
  349. return O
  350. def run_conv_bias(inp, w, b, format="NCHW"):
  351. b = b if has_bias else Parameter(np.zeros_like(b.numpy()))
  352. if format == "NCHW4":
  353. inp = convert_to_nchw4(inp)
  354. w = convert_to_nchw4(w)
  355. b = convert_to_nchw4(b)
  356. return F.nn.conv_bias_activation(
  357. inp,
  358. w,
  359. b,
  360. stride=(SH, SW),
  361. padding=(PH, PW),
  362. format=format,
  363. dtype=out_dtype,
  364. nonlinear_mode=nonlinear_mode,
  365. )
  366. format = "NCHW4" if is_cuda_available() else "NCHW"
  367. expected = run_conv2d(inp_fp32, w_fp32, b_fp32)
  368. expected = expected.astype(out_dtype).astype("float32")
  369. result = run_conv_bias(inp_int8, w_int8, b_int32, format=format).astype(
  370. "float32"
  371. )
  372. if format == "NCHW4":
  373. result = F.transpose(result, (0, 1, 4, 2, 3))
  374. expected = F.flatten(expected)
  375. result = F.flatten(result)
  376. np.testing.assert_allclose(result.numpy(), expected.numpy(), atol=outp_scale)
  377. run(1, 4, 4, 24, 33, 1, 1, 2, 3, 1, 1, False)
  378. run(10, 12, 24, 46, 46, 1, 1, 2, 1, 3, 1, False)
  379. run(10, 36, 8, 46, 26, 2, 2, 2, 1, 1, 2, False)
  380. run(1, 4, 4, 24, 33, 1, 1, 2, 3, 1, 1)
  381. run(10, 12, 24, 46, 46, 1, 1, 2, 1, 3, 1)
  382. run(10, 36, 8, 46, 26, 2, 2, 2, 1, 1, 2)
  383. run(10, 36, 8, 46, 26, 2, 2, 2, 1, 1, 2, False, "RELU")
  384. run(10, 36, 8, 46, 26, 2, 2, 2, 1, 1, 2, True, "RELU")
  385. def test_zero_stride_numpy_array():
  386. inp = np.random.randn(3, 224, 224).astype(np.float32)
  387. inp = inp[np.newaxis, :]
  388. inp = tensor(inp, dtype=np.float32)
  389. weight = tensor(np.random.randn(16, 3, 3, 3), dtype=np.float32)
  390. out = F.conv2d(inp, weight, None, (2, 2), (3, 3), (1, 1), 1)
  391. def test_condtake():
  392. x = np.array([[1, 2, 3], [4, 5, 6]])
  393. y = np.array([[True, False, True], [False, True, True]])
  394. xx = tensor(x)
  395. yy = tensor(y)
  396. val, idx = F.cond_take(yy, xx)
  397. np.testing.assert_equal(val.numpy(), x[y])
  398. np.testing.assert_equal(idx.numpy(), np.where(y.reshape(-1))[0])
  399. def test_condtake_is_same():
  400. op1 = builtin.CondTake()
  401. op2 = builtin.CondTake()
  402. assert op1 == op2
  403. def test_nms_is_same():
  404. op1 = builtin.NMSKeep(0.7, 100)
  405. op2 = builtin.NMSKeep(0.7, 100)
  406. op3 = builtin.NMSKeep(0.8, 100)
  407. op4 = builtin.NMSKeep(0.7, 200)
  408. assert op1 == op2
  409. assert op1 != op3
  410. assert op1 != op4
  411. assert op3 != op4

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台