You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_functional.py 28 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929
  1. # -*- coding: utf-8 -*-
  2. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  3. #
  4. # Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  5. #
  6. # Unless required by applicable law or agreed to in writing,
  7. # software distributed under the License is distributed on an
  8. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  9. import itertools
  10. from functools import partial
  11. import numpy as np
  12. import pytest
  13. from utils import opr_test
  14. import megengine.core.ops.builtin as builtin
  15. import megengine.core.tensor.dtype as dtype
  16. import megengine.functional as F
  17. from megengine import Parameter, Tensor, is_cuda_available, tensor
  18. from megengine.core._trace_option import use_symbolic_shape
  19. from megengine.core.autodiff.grad import Grad
  20. from megengine.core.tensor.utils import make_shape_tuple
  21. from megengine.distributed.helper import get_device_count_by_fork
  22. from megengine.jit import trace
  23. def test_where():
  24. maskv0 = np.array([[1, 0], [0, 1]], dtype=np.bool_)
  25. xv0 = np.array([[1, np.inf], [np.nan, 4]], dtype=np.float32)
  26. yv0 = np.array([[5, 6], [7, 8]], dtype=np.float32)
  27. maskv1 = np.array([[1, 0, 1], [1, 0, 0], [1, 1, 0]], dtype=np.bool_)
  28. xv1 = np.array([[1, np.inf, 2], [0, np.nan, 4], [1, 5, 7]], dtype=np.float32)
  29. yv1 = np.array([[5, 6, 9], [2, 7, 8], [2, 1, 9]], dtype=np.float32)
  30. cases = [
  31. {"input": [maskv0, xv0, yv0]},
  32. {"input": [maskv1, xv1, yv1]},
  33. ]
  34. opr_test(cases, F.where, ref_fn=np.where, test_trace=False)
  35. maskv2 = np.array([1, 1, 1], dtype=np.bool_)
  36. xv2 = np.array([1, 3, 2], dtype=np.float32)
  37. yv2 = np.array([5, 6, 9], dtype=np.float32)
  38. maskv3 = np.array([0, 0, 0], dtype=np.bool_)
  39. xv3 = np.array([1, 3, 2], dtype=np.float32)
  40. yv3 = np.array([5, 6, 9], dtype=np.float32)
  41. cases = [
  42. {"input": [maskv2, xv2, yv2]},
  43. {"input": [maskv3, xv3, yv3]},
  44. ]
  45. opr_test(cases, F.where, ref_fn=np.where, test_trace=False)
  46. def test_dropout():
  47. data = tensor(np.ones(10, dtype=np.float32))
  48. out = F.dropout(data, 1.0 / 3.0, training=False)
  49. assert out.numpy().sum() >= 0.0
  50. def test_matinv():
  51. shape1 = (5, 5)
  52. shape2 = (3, 9, 9)
  53. data1 = np.random.random(shape1).astype("float32")
  54. data2 = np.random.random(shape2).astype("float32")
  55. # make matrix diagonally dominant for numerical stability
  56. data1 += (np.eye(shape1[0]) * shape1[0]).astype("float32")
  57. data2 += np.broadcast_to((np.eye(shape2[1]) * shape2[1]).astype("float32"), shape2)
  58. cases = [
  59. {"input": data1},
  60. {"input": data2},
  61. ]
  62. opr_test(
  63. cases,
  64. F.matinv,
  65. compare_fn=lambda x, y: np.testing.assert_allclose(x.numpy(), y, rtol=1e-5),
  66. ref_fn=np.linalg.inv,
  67. )
  68. def test_matmul():
  69. shape1 = 3
  70. shape2 = 3
  71. shape3 = (3, 5)
  72. shape4 = (5, 6)
  73. data1 = np.random.random(shape1).astype("float32")
  74. data2 = np.random.random(shape2).astype("float32")
  75. data3 = np.random.random(shape3).astype("float32")
  76. data4 = np.random.random(shape4).astype("float32")
  77. cases = [
  78. {"input": [data1, data2]},
  79. {"input": [data2, data3]},
  80. {"input": [data3, data4]},
  81. ]
  82. opr_test(cases, F.matmul, ref_fn=np.matmul)
  83. batch_size = 10
  84. shape1 = (2,)
  85. shape2 = (batch_size, 2, 3)
  86. shape3 = (batch_size, 3, 4)
  87. shape4 = (batch_size, 10, 4, 2)
  88. shape5 = (batch_size, 10, 2, 4)
  89. data1 = np.random.random(shape1).astype("float32")
  90. data2 = np.random.random(shape2).astype("float32")
  91. data3 = np.random.random(shape3).astype("float32")
  92. data4 = np.random.random(shape4).astype("float32")
  93. data5 = np.random.random(shape5).astype("float32")
  94. cases = [
  95. {"input": [data1, data2]},
  96. {"input": [data2, data3]},
  97. {"input": [data3, data4]},
  98. {"input": [data4, data5]},
  99. ]
  100. opr_test(cases, F.matmul, ref_fn=np.matmul)
  101. opr_test(
  102. [{"input": [data1, data4]}],
  103. F.matmul,
  104. ref_fn=lambda x, y: np.matmul(x, y.transpose(0, 1, 3, 2)),
  105. transpose_b=True,
  106. )
  107. opr_test(
  108. [{"input": [data3, data2]}],
  109. F.matmul,
  110. ref_fn=lambda x, y: np.matmul(x.transpose(0, 2, 1), y.transpose(0, 2, 1)),
  111. transpose_a=True,
  112. transpose_b=True,
  113. )
  114. def test_interpolate():
  115. def linear_interpolate():
  116. inp = tensor(np.arange(1, 3, dtype=np.float32).reshape(1, 1, 2))
  117. out = F.vision.interpolate(inp, scale_factor=2.0, mode="linear")
  118. out2 = F.vision.interpolate(inp, 4, mode="linear")
  119. np.testing.assert_allclose(
  120. out.numpy(), np.array([[[1.0, 1.25, 1.75, 2.0]]], dtype=np.float32)
  121. )
  122. np.testing.assert_allclose(
  123. out2.numpy(), np.array([[[1.0, 1.25, 1.75, 2.0]]], dtype=np.float32)
  124. )
  125. def many_batch_interpolate():
  126. inp = tensor(np.arange(1, 9, dtype=np.float32).reshape(2, 1, 2, 2))
  127. out = F.vision.interpolate(inp, [4, 4])
  128. out2 = F.vision.interpolate(inp, scale_factor=2.0)
  129. np.testing.assert_allclose(out.numpy(), out2.numpy())
  130. def assign_corner_interpolate():
  131. inp = tensor(np.arange(1, 5, dtype=np.float32).reshape(1, 1, 2, 2))
  132. out = F.vision.interpolate(inp, [4, 4], align_corners=True)
  133. out2 = F.vision.interpolate(inp, scale_factor=2.0, align_corners=True)
  134. np.testing.assert_allclose(out.numpy(), out2.numpy())
  135. def error_shape_linear_interpolate():
  136. inp = tensor(np.arange(1, 5, dtype=np.float32).reshape(1, 1, 2, 2))
  137. with pytest.raises(ValueError):
  138. F.vision.interpolate(inp, scale_factor=2.0, mode="linear")
  139. def inappropriate_scale_linear_interpolate():
  140. inp = tensor(np.arange(1, 3, dtype=np.float32).reshape(1, 1, 2))
  141. with pytest.raises(ValueError):
  142. F.vision.interpolate(inp, scale_factor=[2.0, 3.0], mode="linear")
  143. linear_interpolate()
  144. many_batch_interpolate()
  145. assign_corner_interpolate()
  146. error_shape_linear_interpolate()
  147. inappropriate_scale_linear_interpolate()
  148. def _save_to(self, name="grad"):
  149. def callback(grad):
  150. setattr(self, name, grad)
  151. return callback
  152. def _gen_roi_inp():
  153. inp_feat = np.random.randn(2, 32, 256, 256)
  154. rois = np.zeros((4, 5))
  155. rois[:, 0] = [0, 0, 1, 1]
  156. rois[:, 1:3] = np.random.rand(4, 2) * 100
  157. rois[:, 3:] = np.random.rand(4, 2) * 100 + 150
  158. inp_feat = tensor(inp_feat)
  159. rois = tensor(rois)
  160. return inp_feat, rois
  161. def test_roi_align():
  162. inp_feat, rois = _gen_roi_inp()
  163. grad = Grad().wrt(inp_feat, callback=_save_to(inp_feat))
  164. output_shape = (7, 7)
  165. out_feat = F.vision.roi_align(
  166. inp_feat,
  167. rois,
  168. output_shape=output_shape,
  169. mode="average",
  170. spatial_scale=1.0 / 4,
  171. sample_points=2,
  172. aligned=True,
  173. )
  174. assert make_shape_tuple(out_feat.shape) == (
  175. rois.shape[0],
  176. inp_feat.shape[1],
  177. *output_shape,
  178. )
  179. grad(out_feat, tensor(F.ones_like(out_feat)))
  180. assert make_shape_tuple(inp_feat.grad.shape) == make_shape_tuple(inp_feat.shape)
  181. def _gen_correlation(random=True, constant=1, image_shape=(2, 1, 160, 160)):
  182. if random:
  183. inp_feat1 = np.random.randn(
  184. image_shape[0], image_shape[1], image_shape[2], image_shape[3]
  185. )
  186. inp_feat2 = np.random.randn(
  187. image_shape[0], image_shape[1], image_shape[2], image_shape[3]
  188. )
  189. else:
  190. inp_feat1 = np.ones(image_shape) * constant
  191. inp_feat2 = np.ones(image_shape) * constant
  192. return tensor(inp_feat1), tensor(inp_feat2)
  193. def test_correlation():
  194. ##test case 0 check the grad shape
  195. data1, data2 = _gen_correlation()
  196. grad = Grad().wrt(data1, callback=_save_to(data1))
  197. out_feat = F.vision.correlation(
  198. data1,
  199. data2,
  200. kernel_size=5,
  201. max_displacement=4,
  202. stride1=2,
  203. stride2=2,
  204. pad_size=2,
  205. is_multiply=True,
  206. )
  207. grad(out_feat, tensor(F.ones_like(out_feat)))
  208. assert make_shape_tuple(data1.grad.shape) == make_shape_tuple(data1.shape)
  209. ##test case 1 from https://github.com/NVIDIA/flownet2-pytorch/issues/194
  210. data1, data2 = _gen_correlation(random=False, image_shape=(1, 1, 3, 3))
  211. out_feat = F.vision.correlation(
  212. data1,
  213. data2,
  214. kernel_size=3,
  215. max_displacement=0,
  216. stride1=1,
  217. stride2=1,
  218. pad_size=0,
  219. is_multiply=True,
  220. )
  221. assert abs(out_feat.sum() - 1) < 1e-9
  222. ##test case 2 check same image subduction
  223. data1, data2 = _gen_correlation(random=False, image_shape=(1, 1, 3, 3))
  224. out_feat = F.vision.correlation(
  225. data1,
  226. data2,
  227. kernel_size=3,
  228. max_displacement=0,
  229. stride1=1,
  230. stride2=1,
  231. pad_size=0,
  232. is_multiply=False,
  233. )
  234. assert out_feat.sum() < 1e-9
  235. ##test case 3 check same image subduction
  236. data1, data2 = _gen_correlation(random=False, image_shape=(1, 1, 3, 3))
  237. out_feat = F.vision.correlation(
  238. data1,
  239. data2,
  240. kernel_size=3,
  241. max_displacement=0,
  242. stride1=1,
  243. stride2=1,
  244. pad_size=0,
  245. is_multiply=False,
  246. )
  247. assert out_feat.sum() < 1e-9
  248. ##test case 4 check correlation
  249. data1, _ = _gen_correlation(
  250. random=False, image_shape=(1, 1, 220, 220), constant=2.0
  251. )
  252. _, data2 = _gen_correlation(
  253. random=False, image_shape=(1, 1, 220, 220), constant=1.0
  254. )
  255. out_feat = F.vision.correlation(
  256. data1,
  257. data2,
  258. kernel_size=3,
  259. max_displacement=2,
  260. stride1=1,
  261. stride2=2,
  262. pad_size=0,
  263. is_multiply=False,
  264. )
  265. assert abs(out_feat.mean() - 1) < 1e-9
  266. def test_roi_pooling():
  267. inp_feat, rois = _gen_roi_inp()
  268. grad = Grad().wrt(inp_feat, callback=_save_to(inp_feat))
  269. output_shape = (7, 7)
  270. out_feat = F.vision.roi_pooling(
  271. inp_feat, rois, output_shape=output_shape, mode="max", scale=1.0 / 4,
  272. )
  273. assert make_shape_tuple(out_feat.shape) == (
  274. rois.shape[0],
  275. inp_feat.shape[1],
  276. *output_shape,
  277. )
  278. grad(out_feat, tensor(F.ones_like(out_feat)))
  279. assert make_shape_tuple(inp_feat.grad.shape) == make_shape_tuple(inp_feat.shape)
  280. def test_adaptive_avg_pool2d():
  281. inp = tensor(np.arange(0, 16, dtype=np.float32).reshape(1, 1, 4, 4))
  282. oshp = (2, 2)
  283. grad = Grad().wrt(inp, callback=_save_to(inp))
  284. outp = F.adaptive_avg_pool2d(inp, oshp,)
  285. assert make_shape_tuple(outp.shape) == (inp.shape[0], inp.shape[1], *oshp,)
  286. np.testing.assert_equal(
  287. outp.numpy(), np.array([[[[2.5, 4.5], [10.5, 12.5]]]], dtype=np.float32)
  288. )
  289. grad(outp, tensor(F.ones_like(outp)))
  290. assert make_shape_tuple(inp.grad.shape) == make_shape_tuple(inp.shape)
  291. np.testing.assert_equal(
  292. inp.grad.numpy(),
  293. np.array(
  294. [
  295. [
  296. [
  297. [0.25, 0.25, 0.25, 0.25],
  298. [0.25, 0.25, 0.25, 0.25],
  299. [0.25, 0.25, 0.25, 0.25],
  300. [0.25, 0.25, 0.25, 0.25],
  301. ]
  302. ]
  303. ],
  304. dtype=np.float32,
  305. ),
  306. )
  307. def test_adaptive_max_pool2d():
  308. inp = tensor(np.arange(0, 16, dtype=np.float32).reshape(1, 1, 4, 4))
  309. oshp = (2, 2)
  310. grad = Grad().wrt(inp, callback=_save_to(inp))
  311. outp = F.adaptive_max_pool2d(inp, oshp,)
  312. assert make_shape_tuple(outp.shape) == (inp.shape[0], inp.shape[1], *oshp,)
  313. np.testing.assert_equal(
  314. outp.numpy(), np.array([[[[5, 7], [13, 15]]]], dtype=np.float32)
  315. )
  316. grad(outp, tensor(F.ones_like(outp)))
  317. assert make_shape_tuple(inp.grad.shape) == make_shape_tuple(inp.shape)
  318. np.testing.assert_equal(
  319. inp.grad.numpy(),
  320. np.array(
  321. [
  322. [
  323. [
  324. [0.0, 0.0, 0.0, 0.0],
  325. [0.0, 1.0, 0.0, 1.0],
  326. [0.0, 0.0, 0.0, 0.0],
  327. [0.0, 1.0, 0.0, 1.0],
  328. ]
  329. ]
  330. ],
  331. dtype=np.float32,
  332. ),
  333. )
  334. def test_one_hot():
  335. def onehot_low_dimension():
  336. inp = tensor(np.arange(1, 4, dtype=np.int32))
  337. out = F.one_hot(inp, num_classes=4)
  338. np.testing.assert_allclose(
  339. out.numpy(), np.eye(4, dtype=np.int32)[np.arange(1, 4, dtype=np.int32)]
  340. )
  341. def onehot_high_dimension():
  342. arr = np.array(
  343. [[3, 2, 4, 4, 2, 4, 0, 4, 4, 1], [4, 1, 1, 3, 2, 2, 4, 2, 4, 3]],
  344. dtype=np.int32,
  345. )
  346. inp = tensor(arr)
  347. out = F.one_hot(inp, 10)
  348. np.testing.assert_allclose(out.numpy(), np.eye(10, dtype=np.int32)[arr])
  349. onehot_low_dimension()
  350. onehot_high_dimension()
  351. def test_interpolate_fastpath():
  352. # check shape
  353. test_cases = [
  354. [(1, 1, 10, 10), (5, 5)],
  355. [(1, 3, 10, 10), (20, 20)],
  356. [(10, 1, 10, 10), (1, 1)],
  357. # [(10, 10, 1, 1), (10, 10)], # FIXME, it causes random CI failure
  358. ]
  359. for inp_shape, target_shape in test_cases:
  360. x = tensor(np.random.randn(*inp_shape), dtype=np.float32)
  361. out = F.vision.interpolate(x, target_shape, mode="bilinear")
  362. assert out.shape[0] == x.shape[0] and out.shape[1] == x.shape[1]
  363. assert out.shape[2] == target_shape[0] and out.shape[3] == target_shape[1]
  364. # check value
  365. x = tensor(np.ones((3, 3, 10, 10)), dtype=np.float32)
  366. out = F.vision.interpolate(x, (15, 5), mode="bilinear")
  367. np.testing.assert_equal(out.numpy(), np.ones((3, 3, 15, 5)).astype(np.float32))
  368. np_x = np.arange(32)
  369. x = tensor(np_x).astype(np.float32).reshape(1, 1, 32, 1)
  370. out = F.vision.interpolate(x, (1, 1), mode="bilinear")
  371. np.testing.assert_equal(out.item(), np_x.mean())
  372. def test_warp_perspective():
  373. inp_shape = (1, 1, 4, 4)
  374. x = tensor(np.arange(16, dtype=np.float32).reshape(inp_shape))
  375. M_shape = (1, 3, 3)
  376. # M defines a translation: dst(1, 1, h, w) = rst(1, 1, h+1, w+1)
  377. M = tensor(
  378. np.array(
  379. [[1.0, 0.0, 1.0], [0.0, 1.0, 1.0], [0.0, 0.0, 1.0]], dtype=np.float32
  380. ).reshape(M_shape)
  381. )
  382. outp = F.vision.warp_perspective(x, M, (2, 2))
  383. np.testing.assert_equal(
  384. outp.numpy(), np.array([[[[5.0, 6.0], [9.0, 10.0]]]], dtype=np.float32)
  385. )
  386. def test_warp_perspective_mat_idx():
  387. inp_shape = (2, 1, 4, 4)
  388. x = tensor(np.arange(32, dtype=np.float32).reshape(inp_shape))
  389. M_shape = (1, 3, 3)
  390. # M defines a translation: dst(1, 1, h, w) = rst(1, 1, h+1, w+1)
  391. M = tensor(
  392. np.array(
  393. [[1.0, 0.0, 1.0], [0.0, 1.0, 1.0], [0.0, 0.0, 1.0]], dtype=np.float32
  394. ).reshape(M_shape)
  395. )
  396. M = F.concat([M,] * 4, 0)
  397. outp = F.vision.warp_perspective(x, M, (2, 2), mat_idx=[0, 1, 1, 0])
  398. np.testing.assert_equal(
  399. outp.numpy(),
  400. np.array(
  401. [
  402. [[[5.0, 6.0], [9.0, 10.0]]],
  403. [[[21.0, 22.0], [25.0, 26.0]]],
  404. [[[21.0, 22.0], [25.0, 26.0]]],
  405. [[[5.0, 6.0], [9.0, 10.0]]],
  406. ],
  407. dtype=np.float32,
  408. ),
  409. )
  410. def test_warp_affine():
  411. inp_shape = (1, 3, 3, 3)
  412. x = tensor(np.arange(27, dtype=np.float32).reshape(inp_shape))
  413. weightv = [[[1.26666667, 0.6, -83.33333333], [-0.33333333, 1, 66.66666667]]]
  414. outp = F.vision.warp_affine(x, tensor(weightv), (2, 2), border_mode="wrap")
  415. res = np.array(
  416. [
  417. [
  418. [[7.875, 8.875, 9.875], [8.90625, 9.90625, 10.90625]],
  419. [[18.75, 19.75, 20.75], [14.90625, 15.90625, 16.90625]],
  420. ]
  421. ],
  422. dtype=np.float32,
  423. )
  424. if not is_cuda_available():
  425. np.testing.assert_almost_equal(outp.numpy(), res, 5)
  426. def test_remap():
  427. inp_shape = (1, 1, 4, 4)
  428. inp = tensor(np.arange(16, dtype=np.float32).reshape(inp_shape))
  429. map_xy_shape = (1, 2, 2, 2)
  430. map_xy = tensor(
  431. np.array(
  432. [[[1.0, 0.0], [0.0, 1.0]], [[0.0, 1.0], [0.0, 1.0]]], dtype=np.float32
  433. ).reshape(map_xy_shape)
  434. )
  435. outp = F.vision.remap(inp, map_xy)
  436. np.testing.assert_equal(
  437. outp.numpy(), np.array([[[[1.0, 4.0], [4.0, 4.0]]]], dtype=np.float32)
  438. )
  439. def test_binary_cross_entropy():
  440. data1_shape = (2, 2)
  441. label1_shape = (2, 2)
  442. data2_shape = (2, 3)
  443. label2_shape = (2, 3)
  444. def sigmoid(x):
  445. return 1 / (1 + np.exp(-x))
  446. def compare_fn(x, y):
  447. np.testing.assert_allclose(x.numpy(), y, atol=5e-4)
  448. np.random.seed(123)
  449. data1 = np.random.uniform(size=data1_shape).astype(np.float32)
  450. label1 = np.random.uniform(size=label1_shape).astype(np.float32)
  451. expect1 = np.array([0.6361], dtype=np.float32)
  452. np.random.seed(123)
  453. data2 = np.random.uniform(size=data2_shape).astype(np.float32)
  454. label2 = np.random.uniform(size=label2_shape).astype(np.float32)
  455. expect2 = np.array([0.6750], dtype=np.float32)
  456. cases = [
  457. {"input": [data1, label1], "output": expect1,},
  458. {"input": [data2, label2], "output": expect2,},
  459. ]
  460. opr_test(cases, F.nn.binary_cross_entropy, compare_fn=compare_fn)
  461. cases = [
  462. {"input": [sigmoid(data1), label1], "output": expect1,},
  463. {"input": [sigmoid(data2), label2], "output": expect2,},
  464. ]
  465. opr_test(
  466. cases,
  467. partial(F.nn.binary_cross_entropy, with_logits=False),
  468. compare_fn=compare_fn,
  469. )
  470. def test_hinge_loss():
  471. np.random.seed(123)
  472. # case with L1 norm
  473. cases = []
  474. for shape in [(2, 2), (2, 3)]:
  475. data = np.random.uniform(size=shape).astype(np.float32)
  476. label = 2 * np.random.randint(0, 1, size=shape).astype(np.float32) - 1
  477. expect = np.clip(0, np.inf, 1 - data * label).sum(axis=1).mean()
  478. cases.append({"input": [data, label], "output": expect})
  479. opr_test(cases, F.nn.hinge_loss)
  480. # cases with L2 norm
  481. cases = []
  482. for shape in [(2, 2), (2, 3)]:
  483. data = np.random.uniform(size=shape).astype(np.float32)
  484. label = 2 * np.random.randint(0, 1, size=shape).astype(np.float32) - 1
  485. expect = ((np.clip(0, np.inf, 1 - data * label) ** 2).sum(axis=1)).mean()
  486. cases.append({"input": [data, label], "output": expect})
  487. def hinge_loss_with_l2_norm(pred, label):
  488. return F.nn.hinge_loss(pred, label, "L2")
  489. opr_test(cases, hinge_loss_with_l2_norm)
  490. def test_nms():
  491. x = np.array(
  492. [
  493. [0, 0, 100, 100],
  494. [10, 10, 100, 100],
  495. [50, 50, 100, 100],
  496. [100, 100, 150, 150],
  497. ],
  498. dtype=np.float32,
  499. )
  500. inp = tensor(x)
  501. scores = tensor([0.5, 0.8, 0.9, 0.6], dtype=np.float32)
  502. result = F.vision.nms(inp, scores=scores, iou_thresh=0.5)
  503. np.testing.assert_equal(result.numpy(), np.array([2, 1, 3], dtype=np.int32))
  504. @pytest.mark.skipif(
  505. get_device_count_by_fork("gpu") > 0, reason="cuda does not support nchw int8"
  506. )
  507. def test_conv_bias():
  508. inp_scale = 1.5
  509. w_scale = 2.5
  510. outp_scale = 1.5
  511. inp_dtype = dtype.qint8(inp_scale)
  512. w_dtype = dtype.qint8(w_scale)
  513. b_dtype = dtype.qint32(inp_scale * w_scale)
  514. out_dtype = dtype.qint8(outp_scale)
  515. def run(
  516. N,
  517. IC,
  518. OC,
  519. IH,
  520. IW,
  521. KH,
  522. KW,
  523. PH,
  524. PW,
  525. SH,
  526. SW,
  527. has_bias=True,
  528. nonlinear_mode="identity",
  529. ):
  530. inp_v = np.random.normal(size=(N, IC, IH, IW))
  531. w_v = np.random.normal(size=(OC, IC, KH, KW))
  532. b_v = np.random.normal(size=(1, OC, 1, 1))
  533. inp_scale = dtype.get_scale(inp_dtype)
  534. w_scale = dtype.get_scale(w_dtype)
  535. b_scale = dtype.get_scale(b_dtype)
  536. inpv = dtype.convert_to_qint8(inp_v * inp_scale, inp_dtype)
  537. wv = dtype.convert_to_qint8(w_v * w_scale, w_dtype)
  538. bv = dtype.convert_to_qint32(b_v * b_scale, b_dtype)
  539. inp_int8 = tensor(inpv, dtype=inp_dtype)
  540. w_int8 = Parameter(wv, dtype=w_dtype)
  541. b_int32 = Parameter(bv, dtype=b_dtype)
  542. inp_fp32 = inp_int8.astype("float32")
  543. w_fp32 = w_int8.astype("float32")
  544. b_fp32 = b_int32.astype("float32")
  545. def convert_to_nchw4(var):
  546. var = F.reshape(
  547. var, (var.shape[0], var.shape[1] // 4, 4, var.shape[2], var.shape[3])
  548. )
  549. var = F.transpose(var, (0, 1, 3, 4, 2))
  550. return var
  551. def run_conv2d(inp, w, b):
  552. O = F.conv2d(
  553. inp, w, b if has_bias else None, stride=(SH, SW), padding=(PH, PW),
  554. )
  555. if nonlinear_mode == "relu":
  556. return F.relu(O)
  557. else:
  558. return O
  559. def run_conv_bias(inp, w, b, format="NCHW"):
  560. b = b if has_bias else Parameter(np.zeros_like(b.numpy()))
  561. if format == "NCHW4":
  562. inp = convert_to_nchw4(inp)
  563. w = convert_to_nchw4(w)
  564. b = convert_to_nchw4(b)
  565. return F.quantized.conv_bias_activation(
  566. inp,
  567. w,
  568. b,
  569. stride=(SH, SW),
  570. padding=(PH, PW),
  571. dtype=out_dtype,
  572. nonlinear_mode=nonlinear_mode,
  573. )
  574. format = "NCHW4" if is_cuda_available() else "NCHW"
  575. expected = run_conv2d(inp_fp32, w_fp32, b_fp32)
  576. expected = expected.astype(out_dtype).astype("float32")
  577. result = run_conv_bias(inp_int8, w_int8, b_int32, format=format).astype(
  578. "float32"
  579. )
  580. if format == "NCHW4":
  581. result = F.transpose(result, (0, 1, 4, 2, 3))
  582. expected = F.flatten(expected)
  583. result = F.flatten(result)
  584. np.testing.assert_allclose(result.numpy(), expected.numpy(), atol=outp_scale)
  585. run(1, 4, 4, 24, 33, 1, 1, 2, 3, 1, 1, False)
  586. run(10, 12, 24, 46, 46, 1, 1, 2, 1, 3, 1, False)
  587. run(10, 36, 8, 46, 26, 2, 2, 2, 1, 1, 2, False)
  588. run(1, 4, 4, 24, 33, 1, 1, 2, 3, 1, 1)
  589. run(10, 12, 24, 46, 46, 1, 1, 2, 1, 3, 1)
  590. run(10, 36, 8, 46, 26, 2, 2, 2, 1, 1, 2)
  591. run(10, 36, 8, 46, 26, 2, 2, 2, 1, 1, 2, False, "relu")
  592. run(10, 36, 8, 46, 26, 2, 2, 2, 1, 1, 2, True, "relu")
  593. @pytest.mark.skipif(
  594. get_device_count_by_fork("gpu") > 0, reason="no int8 algorithm on cuda"
  595. )
  596. def test_batch_conv_bias():
  597. inp_scale = 1.5
  598. w_scale = 2.5
  599. outp_scale = 1.5
  600. inp_dtype = dtype.qint8(inp_scale)
  601. w_dtype = dtype.qint8(w_scale)
  602. b_dtype = dtype.qint32(inp_scale * w_scale)
  603. out_dtype = dtype.qint8(outp_scale)
  604. def run(
  605. N, IC, OC, IH, IW, KH, KW, PH, PW, SH, SW, has_bias=True,
  606. ):
  607. inp_v = np.random.normal(size=(N, IC, IH, IW))
  608. w_v = np.random.normal(size=(N, OC, IC, KH, KW))
  609. b_v = np.random.normal(size=(1, OC, 1, 1))
  610. inp_scale = dtype.get_scale(inp_dtype)
  611. w_scale = dtype.get_scale(w_dtype)
  612. b_scale = dtype.get_scale(b_dtype)
  613. inpv = dtype.convert_to_qint8(inp_v * inp_scale, inp_dtype)
  614. wv = dtype.convert_to_qint8(w_v * w_scale, w_dtype)
  615. bv = dtype.convert_to_qint32(b_v * b_scale, b_dtype)
  616. inp_int8 = tensor(inpv, dtype=inp_dtype)
  617. w_int8 = Parameter(wv, dtype=w_dtype)
  618. b_int32 = Parameter(bv, dtype=b_dtype)
  619. inp_fp32 = inp_int8.astype("float32")
  620. w_fp32 = w_int8.astype("float32")
  621. b_fp32 = b_int32.astype("float32")
  622. def run_batch_conv_bias(inp, w, b):
  623. b = b if has_bias else Parameter(np.zeros_like(b.numpy()))
  624. result = F.quantized.batch_conv_bias_activation(
  625. inp, w, b, stride=(SH, SW), padding=(PH, PW), dtype=out_dtype,
  626. )
  627. return result.astype("float32")
  628. expected = F.conv2d(inp_fp32, w_fp32[0], b_fp32 if has_bias else None)[0]
  629. expected = expected.astype(out_dtype).astype("float32")
  630. expected = F.flatten(expected)
  631. result = run_batch_conv_bias(inp_int8, w_int8, b_int32)
  632. result = F.flatten(result)
  633. np.testing.assert_allclose(result.numpy(), expected.numpy(), atol=outp_scale)
  634. run(1, 4, 4, 5, 5, 3, 3, 0, 0, 1, 1, True)
  635. def test_conv2d_zero_stride_numpy_array():
  636. inp = np.random.randn(3, 224, 224).astype(np.float32)
  637. inp = inp[np.newaxis, :]
  638. inp = tensor(inp, dtype=np.float32)
  639. weight = tensor(np.random.randn(16, 3, 3, 3), dtype=np.float32)
  640. out = F.conv2d(inp, weight, None, (2, 2), (3, 3), (1, 1), 1)
  641. def test_conv3d_zero_stride_numpy_array():
  642. inp = np.random.randn(3, 224, 224, 224).astype(np.float32)
  643. inp = inp[np.newaxis, :]
  644. inp = tensor(inp, dtype=np.float32)
  645. weight = tensor(np.random.randn(16, 3, 3, 3, 3), dtype=np.float32)
  646. out = F.conv3d(inp, weight, None, (2, 2, 2), (3, 3, 3), (1, 1, 1), 1)
  647. out.numpy()
  648. def test_conv1d():
  649. inp = tensor(np.ones((16,), dtype=np.float32).reshape(2, 2, 4))
  650. weight = tensor(np.ones((12,), dtype=np.float32).reshape(3, 2, 2))
  651. out = F.conv1d(inp, weight, None, 2, 0, 1, 1)
  652. np.testing.assert_equal(
  653. out.numpy(),
  654. np.array(
  655. [[[4, 4], [4, 4], [4, 4]], [[4, 4], [4, 4], [4, 4]]], dtype=np.float32
  656. ),
  657. )
  658. def test_conv3d():
  659. inp = tensor(np.ones((256,), dtype=np.float32).reshape(2, 2, 4, 4, 4))
  660. weight = tensor(np.ones((48,), dtype=np.float32).reshape(3, 2, 2, 2, 2))
  661. out = F.conv3d(inp, weight, None, 2, 0, 1, 1)
  662. print(out.numpy().shape)
  663. np.testing.assert_equal(
  664. out.numpy(), np.ones((2, 3, 2, 2, 2), dtype=np.float32) * 16
  665. )
  666. def test_condtake():
  667. x = np.array([[1, 2, 3], [4, 5, 6]])
  668. y = np.array([[True, False, True], [False, True, True]])
  669. xx = tensor(x)
  670. yy = tensor(y)
  671. val, idx = F.cond_take(yy, xx)
  672. np.testing.assert_equal(val.numpy(), x[y])
  673. np.testing.assert_equal(idx.numpy(), np.where(y.reshape(-1))[0])
  674. def test_condtake_is_same():
  675. op1 = builtin.CondTake()
  676. op2 = builtin.CondTake()
  677. assert op1 == op2
  678. def test_nms_is_same():
  679. op1 = builtin.NMSKeep(0.7, 100)
  680. op2 = builtin.NMSKeep(0.7, 100)
  681. op3 = builtin.NMSKeep(0.8, 100)
  682. op4 = builtin.NMSKeep(0.7, 200)
  683. assert op1 == op2
  684. assert op1 != op3
  685. assert op1 != op4
  686. assert op3 != op4
  687. def test_argmxx_on_inf():
  688. def run_argmax():
  689. x = F.zeros((100, 100))
  690. x[:] = -float("inf")
  691. idxs = F.argmax(x, axis=0)
  692. return idxs
  693. def run_argmin():
  694. x = F.zeros((100, 100))
  695. x[:] = float("inf")
  696. idxs = F.argmin(x, axis=0)
  697. return idxs
  698. assert all(run_argmax() >= 0)
  699. assert all(run_argmin() >= 0)
  700. def test_deformable_psroi_pooling():
  701. inp = np.random.random((1, 256, 64, 64)).astype("float32")
  702. rois = np.random.random((1, 5)).astype("float32")
  703. trans = np.random.random((24, 2, 7, 7)).astype("float32")
  704. pooled_h = 7
  705. pooled_w = 7
  706. sample_per_part = 4
  707. no_trans = False
  708. part_size = 7
  709. spatial_scale = 1.0 / 64
  710. trans_std = 0.1
  711. y = F.deformable_psroi_pooling(
  712. tensor(inp),
  713. tensor(rois),
  714. tensor(trans),
  715. no_trans,
  716. part_size,
  717. pooled_h,
  718. pooled_w,
  719. sample_per_part,
  720. spatial_scale,
  721. trans_std,
  722. )
  723. def test_cvt_color():
  724. def rgb2gray(rgb):
  725. return np.dot(rgb[..., :3], [0.299, 0.587, 0.114])
  726. inp = np.random.randn(3, 3, 3, 3).astype(np.float32)
  727. out = np.expand_dims(rgb2gray(inp), 3).astype(np.float32)
  728. x = tensor(inp)
  729. y = F.vision.cvt_color(x, mode="RGB2GRAY")
  730. np.testing.assert_allclose(y.numpy(), out, atol=1e-5)
  731. @pytest.mark.parametrize("val", [2, [2,], [2, 3]])
  732. def test_ones(val):
  733. shp = tensor(val)
  734. np_shp = np.array(val)
  735. np.testing.assert_equal(F.ones(shp), np.ones(np_shp))
  736. def test_assert_equal():
  737. shape = (2, 3, 4, 5)
  738. x = F.ones(shape, dtype=np.float32)
  739. y = F.zeros(shape, dtype=np.float32) + 1.00001
  740. z = F.utils._assert_equal(x, y)
  741. def test_assert_not_equal():
  742. shape = (2, 3, 4, 5)
  743. x = F.ones(shape, dtype=np.float32)
  744. y = F.zeros(shape, dtype=np.float32) + 1.1
  745. with pytest.raises(RuntimeError):
  746. z = F.utils._assert_equal(x, y)
  747. def test_neg_axis():
  748. x = tensor(np.random.normal(0, 1, (32, 5)))
  749. y = F.argmax(x, axis=-1)
  750. yy = F.argmax(x, axis=1)
  751. np.testing.assert_equal(y.numpy(), yy.numpy())
  752. y = F.argmax(x, axis=(-1, -2))
  753. yy = F.argmax(x, axis=(0, 1))
  754. np.testing.assert_equal(y.numpy(), yy.numpy())
  755. y = F.argmin(x, axis=(-1, -2))
  756. yy = F.argmin(x, axis=(0, 1))
  757. np.testing.assert_equal(y.numpy(), yy.numpy())

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台