You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_functional.py 32 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035
  1. # -*- coding: utf-8 -*-
  2. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  3. #
  4. # Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  5. #
  6. # Unless required by applicable law or agreed to in writing,
  7. # software distributed under the License is distributed on an
  8. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  9. import itertools
  10. import platform
  11. from functools import partial
  12. import numpy as np
  13. import pytest
  14. from utils import opr_test
  15. import megengine.amp as amp
  16. import megengine.core.ops.builtin as builtin
  17. import megengine.core.tensor.dtype as dtype
  18. import megengine.functional as F
  19. from megengine import Parameter, Tensor, is_cuda_available, tensor
  20. from megengine.core._trace_option import use_symbolic_shape
  21. from megengine.core.autodiff.grad import Grad
  22. from megengine.core.tensor.utils import make_shape_tuple
  23. from megengine.distributed.helper import get_device_count_by_fork
  24. from megengine.jit import trace
  25. def test_where():
  26. maskv0 = np.array([[1, 0], [0, 1]], dtype=np.bool_)
  27. xv0 = np.array([[1, np.inf], [np.nan, 4]], dtype=np.float32)
  28. yv0 = np.array([[5, 6], [7, 8]], dtype=np.float32)
  29. maskv1 = np.array([[1, 0, 1], [1, 0, 0], [1, 1, 0]], dtype=np.bool_)
  30. xv1 = np.array([[1, np.inf, 2], [0, np.nan, 4], [1, 5, 7]], dtype=np.float32)
  31. yv1 = np.array([[5, 6, 9], [2, 7, 8], [2, 1, 9]], dtype=np.float32)
  32. cases = [
  33. {"input": [maskv0, xv0, yv0]},
  34. {"input": [maskv1, xv1, yv1]},
  35. ]
  36. opr_test(cases, F.where, ref_fn=np.where, test_trace=False)
  37. maskv2 = np.array([1, 1, 1], dtype=np.bool_)
  38. xv2 = np.array([1, 3, 2], dtype=np.float32)
  39. yv2 = np.array([5, 6, 9], dtype=np.float32)
  40. maskv3 = np.array([0, 0, 0], dtype=np.bool_)
  41. xv3 = np.array([1, 3, 2], dtype=np.float32)
  42. yv3 = np.array([5, 6, 9], dtype=np.float32)
  43. cases = [
  44. {"input": [maskv2, xv2, yv2]},
  45. {"input": [maskv3, xv3, yv3]},
  46. ]
  47. opr_test(cases, F.where, ref_fn=np.where, test_trace=False)
  48. def test_dropout():
  49. data = tensor(np.ones(10, dtype=np.float32))
  50. out = F.dropout(data, 1.0 / 3.0, training=False)
  51. assert out.numpy().sum() >= 0.0
  52. def test_matinv():
  53. shape1 = (5, 5)
  54. shape2 = (3, 9, 9)
  55. data1 = np.random.random(shape1).astype("float32")
  56. data2 = np.random.random(shape2).astype("float32")
  57. # make matrix diagonally dominant for numerical stability
  58. data1 += (np.eye(shape1[0]) * shape1[0]).astype("float32")
  59. data2 += np.broadcast_to((np.eye(shape2[1]) * shape2[1]).astype("float32"), shape2)
  60. cases = [
  61. {"input": data1},
  62. {"input": data2},
  63. ]
  64. opr_test(
  65. cases,
  66. F.matinv,
  67. compare_fn=lambda x, y: np.testing.assert_allclose(x.numpy(), y, rtol=1e-5),
  68. ref_fn=np.linalg.inv,
  69. )
  70. def test_matmul():
  71. shape1 = 3
  72. shape2 = 3
  73. shape3 = (3, 5)
  74. shape4 = (5, 6)
  75. data1 = np.random.random(shape1).astype("float32")
  76. data2 = np.random.random(shape2).astype("float32")
  77. data3 = np.random.random(shape3).astype("float32")
  78. data4 = np.random.random(shape4).astype("float32")
  79. cases = [
  80. {"input": [data1, data2]},
  81. {"input": [data2, data3]},
  82. {"input": [data3, data4]},
  83. ]
  84. opr_test(cases, F.matmul, ref_fn=np.matmul)
  85. batch_size = 10
  86. shape1 = (2,)
  87. shape2 = (batch_size, 2, 3)
  88. shape3 = (batch_size, 3, 4)
  89. shape4 = (batch_size, 10, 4, 2)
  90. shape5 = (batch_size, 10, 2, 4)
  91. data1 = np.random.random(shape1).astype("float32")
  92. data2 = np.random.random(shape2).astype("float32")
  93. data3 = np.random.random(shape3).astype("float32")
  94. data4 = np.random.random(shape4).astype("float32")
  95. data5 = np.random.random(shape5).astype("float32")
  96. cases = [
  97. {"input": [data1, data2]},
  98. {"input": [data2, data3]},
  99. {"input": [data3, data4]},
  100. {"input": [data4, data5]},
  101. ]
  102. opr_test(cases, F.matmul, ref_fn=np.matmul)
  103. opr_test(
  104. [{"input": [data1, data4]}],
  105. F.matmul,
  106. ref_fn=lambda x, y: np.matmul(x, y.transpose(0, 1, 3, 2)),
  107. transpose_b=True,
  108. )
  109. opr_test(
  110. [{"input": [data3, data2]}],
  111. F.matmul,
  112. ref_fn=lambda x, y: np.matmul(x.transpose(0, 2, 1), y.transpose(0, 2, 1)),
  113. transpose_a=True,
  114. transpose_b=True,
  115. )
  116. def test_interpolate():
  117. def linear_interpolate():
  118. inp = tensor(np.arange(1, 3, dtype=np.float32).reshape(1, 1, 2))
  119. out = F.vision.interpolate(inp, scale_factor=2.0, mode="linear")
  120. out2 = F.vision.interpolate(inp, 4, mode="linear")
  121. np.testing.assert_allclose(
  122. out.numpy(), np.array([[[1.0, 1.25, 1.75, 2.0]]], dtype=np.float32)
  123. )
  124. np.testing.assert_allclose(
  125. out2.numpy(), np.array([[[1.0, 1.25, 1.75, 2.0]]], dtype=np.float32)
  126. )
  127. def many_batch_interpolate():
  128. inp = tensor(np.arange(1, 9, dtype=np.float32).reshape(2, 1, 2, 2))
  129. out = F.vision.interpolate(inp, [4, 4])
  130. out2 = F.vision.interpolate(inp, scale_factor=2.0)
  131. np.testing.assert_allclose(out.numpy(), out2.numpy())
  132. def assign_corner_interpolate():
  133. inp = tensor(np.arange(1, 5, dtype=np.float32).reshape(1, 1, 2, 2))
  134. out = F.vision.interpolate(inp, [4, 4], align_corners=True)
  135. out2 = F.vision.interpolate(inp, scale_factor=2.0, align_corners=True)
  136. np.testing.assert_allclose(out.numpy(), out2.numpy())
  137. def error_shape_linear_interpolate():
  138. inp = tensor(np.arange(1, 5, dtype=np.float32).reshape(1, 1, 2, 2))
  139. with pytest.raises(ValueError):
  140. F.vision.interpolate(inp, scale_factor=2.0, mode="linear")
  141. def inappropriate_scale_linear_interpolate():
  142. inp = tensor(np.arange(1, 3, dtype=np.float32).reshape(1, 1, 2))
  143. with pytest.raises(ValueError):
  144. F.vision.interpolate(inp, scale_factor=[2.0, 3.0], mode="linear")
  145. linear_interpolate()
  146. many_batch_interpolate()
  147. assign_corner_interpolate()
  148. error_shape_linear_interpolate()
  149. inappropriate_scale_linear_interpolate()
  150. def _save_to(self, name="grad"):
  151. def callback(grad):
  152. setattr(self, name, grad)
  153. return callback
  154. def _gen_roi_inp():
  155. inp_feat = np.random.randn(2, 32, 256, 256)
  156. rois = np.zeros((4, 5))
  157. rois[:, 0] = [0, 0, 1, 1]
  158. rois[:, 1:3] = np.random.rand(4, 2) * 100
  159. rois[:, 3:] = np.random.rand(4, 2) * 100 + 150
  160. inp_feat = tensor(inp_feat)
  161. rois = tensor(rois)
  162. return inp_feat, rois
  163. def test_roi_align():
  164. inp_feat, rois = _gen_roi_inp()
  165. grad = Grad().wrt(inp_feat, callback=_save_to(inp_feat))
  166. output_shape = (7, 7)
  167. out_feat = F.vision.roi_align(
  168. inp_feat,
  169. rois,
  170. output_shape=output_shape,
  171. mode="average",
  172. spatial_scale=1.0 / 4,
  173. sample_points=2,
  174. aligned=True,
  175. )
  176. assert make_shape_tuple(out_feat.shape) == (
  177. rois.shape[0],
  178. inp_feat.shape[1],
  179. *output_shape,
  180. )
  181. grad(out_feat, tensor(F.ones_like(out_feat)))
  182. assert make_shape_tuple(inp_feat.grad.shape) == make_shape_tuple(inp_feat.shape)
  183. def _gen_correlation(random=True, constant=1, image_shape=(2, 1, 160, 160)):
  184. if random:
  185. inp_feat1 = np.random.randn(
  186. image_shape[0], image_shape[1], image_shape[2], image_shape[3]
  187. )
  188. inp_feat2 = np.random.randn(
  189. image_shape[0], image_shape[1], image_shape[2], image_shape[3]
  190. )
  191. else:
  192. inp_feat1 = np.ones(image_shape) * constant
  193. inp_feat2 = np.ones(image_shape) * constant
  194. return tensor(inp_feat1), tensor(inp_feat2)
  195. def test_correlation():
  196. ##test case 0 check the grad shape
  197. data1, data2 = _gen_correlation()
  198. grad = Grad().wrt(data1, callback=_save_to(data1))
  199. out_feat = F.vision.correlation(
  200. data1,
  201. data2,
  202. kernel_size=5,
  203. max_displacement=4,
  204. stride1=2,
  205. stride2=2,
  206. pad_size=2,
  207. is_multiply=True,
  208. )
  209. grad(out_feat, tensor(F.ones_like(out_feat)))
  210. assert make_shape_tuple(data1.grad.shape) == make_shape_tuple(data1.shape)
  211. ##test case 1 from https://github.com/NVIDIA/flownet2-pytorch/issues/194
  212. data1, data2 = _gen_correlation(random=False, image_shape=(1, 1, 3, 3))
  213. out_feat = F.vision.correlation(
  214. data1,
  215. data2,
  216. kernel_size=3,
  217. max_displacement=0,
  218. stride1=1,
  219. stride2=1,
  220. pad_size=0,
  221. is_multiply=True,
  222. )
  223. assert abs(out_feat.sum() - 1) < 1e-9
  224. ##test case 2 check same image subduction
  225. data1, data2 = _gen_correlation(random=False, image_shape=(1, 1, 3, 3))
  226. out_feat = F.vision.correlation(
  227. data1,
  228. data2,
  229. kernel_size=3,
  230. max_displacement=0,
  231. stride1=1,
  232. stride2=1,
  233. pad_size=0,
  234. is_multiply=False,
  235. )
  236. assert out_feat.sum() < 1e-9
  237. ##test case 3 check same image subduction
  238. data1, data2 = _gen_correlation(random=False, image_shape=(1, 1, 3, 3))
  239. out_feat = F.vision.correlation(
  240. data1,
  241. data2,
  242. kernel_size=3,
  243. max_displacement=0,
  244. stride1=1,
  245. stride2=1,
  246. pad_size=0,
  247. is_multiply=False,
  248. )
  249. assert out_feat.sum() < 1e-9
  250. ##test case 4 check correlation
  251. data1, _ = _gen_correlation(
  252. random=False, image_shape=(1, 1, 220, 220), constant=2.0
  253. )
  254. _, data2 = _gen_correlation(
  255. random=False, image_shape=(1, 1, 220, 220), constant=1.0
  256. )
  257. out_feat = F.vision.correlation(
  258. data1,
  259. data2,
  260. kernel_size=3,
  261. max_displacement=2,
  262. stride1=1,
  263. stride2=2,
  264. pad_size=0,
  265. is_multiply=False,
  266. )
  267. assert abs(out_feat.mean() - 1) < 1e-9
  268. def test_roi_pooling():
  269. inp_feat, rois = _gen_roi_inp()
  270. grad = Grad().wrt(inp_feat, callback=_save_to(inp_feat))
  271. output_shape = (7, 7)
  272. out_feat = F.vision.roi_pooling(
  273. inp_feat, rois, output_shape=output_shape, mode="max", scale=1.0 / 4,
  274. )
  275. assert make_shape_tuple(out_feat.shape) == (
  276. rois.shape[0],
  277. inp_feat.shape[1],
  278. *output_shape,
  279. )
  280. grad(out_feat, tensor(F.ones_like(out_feat)))
  281. assert make_shape_tuple(inp_feat.grad.shape) == make_shape_tuple(inp_feat.shape)
  282. def test_adaptive_avg_pool2d():
  283. inp = tensor(np.arange(0, 16, dtype=np.float32).reshape(1, 1, 4, 4))
  284. oshp = (2, 2)
  285. grad = Grad().wrt(inp, callback=_save_to(inp))
  286. outp = F.adaptive_avg_pool2d(inp, oshp,)
  287. assert make_shape_tuple(outp.shape) == (inp.shape[0], inp.shape[1], *oshp,)
  288. np.testing.assert_equal(
  289. outp.numpy(), np.array([[[[2.5, 4.5], [10.5, 12.5]]]], dtype=np.float32)
  290. )
  291. grad(outp, tensor(F.ones_like(outp)))
  292. assert make_shape_tuple(inp.grad.shape) == make_shape_tuple(inp.shape)
  293. np.testing.assert_equal(
  294. inp.grad.numpy(),
  295. np.array(
  296. [
  297. [
  298. [
  299. [0.25, 0.25, 0.25, 0.25],
  300. [0.25, 0.25, 0.25, 0.25],
  301. [0.25, 0.25, 0.25, 0.25],
  302. [0.25, 0.25, 0.25, 0.25],
  303. ]
  304. ]
  305. ],
  306. dtype=np.float32,
  307. ),
  308. )
  309. def test_adaptive_max_pool2d():
  310. inp = tensor(np.arange(0, 16, dtype=np.float32).reshape(1, 1, 4, 4))
  311. oshp = (2, 2)
  312. grad = Grad().wrt(inp, callback=_save_to(inp))
  313. outp = F.adaptive_max_pool2d(inp, oshp,)
  314. assert make_shape_tuple(outp.shape) == (inp.shape[0], inp.shape[1], *oshp,)
  315. np.testing.assert_equal(
  316. outp.numpy(), np.array([[[[5, 7], [13, 15]]]], dtype=np.float32)
  317. )
  318. grad(outp, tensor(F.ones_like(outp)))
  319. assert make_shape_tuple(inp.grad.shape) == make_shape_tuple(inp.shape)
  320. np.testing.assert_equal(
  321. inp.grad.numpy(),
  322. np.array(
  323. [
  324. [
  325. [
  326. [0.0, 0.0, 0.0, 0.0],
  327. [0.0, 1.0, 0.0, 1.0],
  328. [0.0, 0.0, 0.0, 0.0],
  329. [0.0, 1.0, 0.0, 1.0],
  330. ]
  331. ]
  332. ],
  333. dtype=np.float32,
  334. ),
  335. )
  336. def test_one_hot():
  337. def onehot_low_dimension():
  338. inp = tensor(np.arange(1, 4, dtype=np.int32))
  339. out = F.one_hot(inp, num_classes=4)
  340. np.testing.assert_allclose(
  341. out.numpy(), np.eye(4, dtype=np.int32)[np.arange(1, 4, dtype=np.int32)]
  342. )
  343. def onehot_high_dimension():
  344. arr = np.array(
  345. [[3, 2, 4, 4, 2, 4, 0, 4, 4, 1], [4, 1, 1, 3, 2, 2, 4, 2, 4, 3]],
  346. dtype=np.int32,
  347. )
  348. inp = tensor(arr)
  349. out = F.one_hot(inp, 10)
  350. np.testing.assert_allclose(out.numpy(), np.eye(10, dtype=np.int32)[arr])
  351. onehot_low_dimension()
  352. onehot_high_dimension()
  353. def test_interpolate_fastpath():
  354. # check shape
  355. test_cases = [
  356. [(1, 1, 10, 10), (5, 5)],
  357. [(1, 3, 10, 10), (20, 20)],
  358. [(10, 1, 10, 10), (1, 1)],
  359. # [(10, 10, 1, 1), (10, 10)], # FIXME, it causes random CI failure
  360. ]
  361. for inp_shape, target_shape in test_cases:
  362. x = tensor(np.random.randn(*inp_shape), dtype=np.float32)
  363. out = F.vision.interpolate(x, target_shape, mode="bilinear")
  364. assert out.shape[0] == x.shape[0] and out.shape[1] == x.shape[1]
  365. assert out.shape[2] == target_shape[0] and out.shape[3] == target_shape[1]
  366. # check value
  367. x = tensor(np.ones((3, 3, 10, 10)), dtype=np.float32)
  368. out = F.vision.interpolate(x, (15, 5), mode="bilinear")
  369. np.testing.assert_equal(out.numpy(), np.ones((3, 3, 15, 5)).astype(np.float32))
  370. np_x = np.arange(32)
  371. x = tensor(np_x).astype(np.float32).reshape(1, 1, 32, 1)
  372. out = F.vision.interpolate(x, (1, 1), mode="bilinear")
  373. np.testing.assert_equal(out.item(), np_x.mean())
  374. @pytest.mark.parametrize("dt", [np.float32, np.int8, np.uint8, np.float16])
  375. def test_warp_perspective(dt):
  376. inp_shape = (1, 1, 4, 4)
  377. x = tensor(np.arange(16, dtype=dt).reshape(inp_shape))
  378. M_shape = (1, 3, 3)
  379. # M defines a translation: dst(1, 1, h, w) = rst(1, 1, h+1, w+1)
  380. M = tensor(
  381. np.array(
  382. [[1.0, 0.0, 1.0], [0.0, 1.0, 1.0], [0.0, 0.0, 1.0]], dtype=np.float32
  383. ).reshape(M_shape)
  384. )
  385. outp = F.vision.warp_perspective(x, M, (2, 2))
  386. np.testing.assert_equal(outp.numpy(), np.array([[[[5, 6], [9, 10]]]], dtype=dt))
  387. @pytest.mark.parametrize("dt", [np.float32, np.int8, np.uint8, np.float16])
  388. def test_warp_perspective_mat_idx(dt):
  389. inp_shape = (2, 1, 4, 4)
  390. x = tensor(np.arange(32, dtype=dt).reshape(inp_shape))
  391. M_shape = (1, 3, 3)
  392. # M defines a translation: dst(1, 1, h, w) = rst(1, 1, h+1, w+1)
  393. M = tensor(
  394. np.array(
  395. [[1.0, 0.0, 1.0], [0.0, 1.0, 1.0], [0.0, 0.0, 1.0]], dtype=np.float32
  396. ).reshape(M_shape)
  397. )
  398. M = F.concat([M,] * 4, 0)
  399. outp = F.vision.warp_perspective(x, M, (2, 2), mat_idx=[0, 1, 1, 0])
  400. np.testing.assert_equal(
  401. outp.numpy(),
  402. np.array(
  403. [
  404. [[[5, 6], [9, 10]]],
  405. [[[21, 22], [25, 26]]],
  406. [[[21, 22], [25, 26]]],
  407. [[[5, 6], [9, 10]]],
  408. ],
  409. dtype=dt,
  410. ),
  411. )
  412. def test_warp_affine():
  413. inp_shape = (1, 3, 3, 3)
  414. x = tensor(np.arange(27, dtype=np.float32).reshape(inp_shape))
  415. weightv = [[[1.26666667, 0.6, -83.33333333], [-0.33333333, 1, 66.66666667]]]
  416. outp = F.vision.warp_affine(x, tensor(weightv), (2, 2), border_mode="wrap")
  417. res = np.array(
  418. [
  419. [
  420. [[7.875, 8.875, 9.875], [8.90625, 9.90625, 10.90625]],
  421. [[18.75, 19.75, 20.75], [14.90625, 15.90625, 16.90625]],
  422. ]
  423. ],
  424. dtype=np.float32,
  425. )
  426. if not is_cuda_available():
  427. np.testing.assert_almost_equal(outp.numpy(), res, 5)
  428. def test_remap():
  429. inp_shape = (1, 1, 4, 4)
  430. inp = tensor(np.arange(16, dtype=np.float32).reshape(inp_shape))
  431. map_xy_shape = (1, 2, 2, 2)
  432. map_xy = tensor(
  433. np.array(
  434. [[[1.0, 0.0], [0.0, 1.0]], [[0.0, 1.0], [0.0, 1.0]]], dtype=np.float32
  435. ).reshape(map_xy_shape)
  436. )
  437. outp = F.vision.remap(inp, map_xy)
  438. np.testing.assert_equal(
  439. outp.numpy(), np.array([[[[1.0, 4.0], [4.0, 4.0]]]], dtype=np.float32)
  440. )
  441. def test_binary_cross_entropy():
  442. data1_shape = (2, 2)
  443. label1_shape = (2, 2)
  444. data2_shape = (2, 3)
  445. label2_shape = (2, 3)
  446. def sigmoid(x):
  447. return 1 / (1 + np.exp(-x))
  448. def compare_fn(x, y):
  449. np.testing.assert_allclose(x.numpy(), y, atol=5e-4)
  450. np.random.seed(123)
  451. data1 = np.random.uniform(size=data1_shape).astype(np.float32)
  452. label1 = np.random.uniform(size=label1_shape).astype(np.float32)
  453. expect1 = np.array([0.6361], dtype=np.float32)
  454. np.random.seed(123)
  455. data2 = np.random.uniform(size=data2_shape).astype(np.float32)
  456. label2 = np.random.uniform(size=label2_shape).astype(np.float32)
  457. expect2 = np.array([0.6750], dtype=np.float32)
  458. cases = [
  459. {"input": [data1, label1], "output": expect1,},
  460. {"input": [data2, label2], "output": expect2,},
  461. ]
  462. opr_test(cases, F.nn.binary_cross_entropy, compare_fn=compare_fn)
  463. cases = [
  464. {"input": [sigmoid(data1), label1], "output": expect1,},
  465. {"input": [sigmoid(data2), label2], "output": expect2,},
  466. ]
  467. opr_test(
  468. cases,
  469. partial(F.nn.binary_cross_entropy, with_logits=False),
  470. compare_fn=compare_fn,
  471. )
  472. def test_hinge_loss():
  473. np.random.seed(123)
  474. # case with L1 norm
  475. cases = []
  476. for shape in [(2, 2), (2, 3)]:
  477. data = np.random.uniform(size=shape).astype(np.float32)
  478. label = 2 * np.random.randint(0, 1, size=shape).astype(np.float32) - 1
  479. expect = np.clip(0, np.inf, 1 - data * label).sum(axis=1).mean()
  480. cases.append({"input": [data, label], "output": expect})
  481. opr_test(cases, F.nn.hinge_loss)
  482. # cases with L2 norm
  483. cases = []
  484. for shape in [(2, 2), (2, 3)]:
  485. data = np.random.uniform(size=shape).astype(np.float32)
  486. label = 2 * np.random.randint(0, 1, size=shape).astype(np.float32) - 1
  487. expect = ((np.clip(0, np.inf, 1 - data * label) ** 2).sum(axis=1)).mean()
  488. cases.append({"input": [data, label], "output": expect})
  489. def hinge_loss_with_l2_norm(pred, label):
  490. return F.nn.hinge_loss(pred, label, "L2")
  491. opr_test(cases, hinge_loss_with_l2_norm)
  492. def test_nms():
  493. x = np.array(
  494. [
  495. [0, 0, 100, 100],
  496. [10, 10, 100, 100],
  497. [50, 50, 100, 100],
  498. [100, 100, 150, 150],
  499. ],
  500. dtype=np.float32,
  501. )
  502. inp = tensor(x)
  503. scores = tensor([0.5, 0.8, 0.9, 0.6], dtype=np.float32)
  504. result = F.vision.nms(inp, scores=scores, iou_thresh=0.5)
  505. np.testing.assert_equal(result.numpy(), np.array([2, 1, 3], dtype=np.int32))
  506. @pytest.mark.skipif(
  507. get_device_count_by_fork("gpu") > 0, reason="cuda does not support nchw int8"
  508. )
  509. def test_conv_bias():
  510. inp_scale = 1.5
  511. w_scale = 2.5
  512. outp_scale = 1.5
  513. inp_dtype = dtype.qint8(inp_scale)
  514. w_dtype = dtype.qint8(w_scale)
  515. b_dtype = dtype.qint32(inp_scale * w_scale)
  516. out_dtype = dtype.qint8(outp_scale)
  517. def run(
  518. N,
  519. IC,
  520. OC,
  521. IH,
  522. IW,
  523. KH,
  524. KW,
  525. PH,
  526. PW,
  527. SH,
  528. SW,
  529. has_bias=True,
  530. nonlinear_mode="identity",
  531. ):
  532. inp_v = np.random.normal(size=(N, IC, IH, IW))
  533. w_v = np.random.normal(size=(OC, IC, KH, KW))
  534. b_v = np.random.normal(size=(1, OC, 1, 1))
  535. inp_scale = dtype.get_scale(inp_dtype)
  536. w_scale = dtype.get_scale(w_dtype)
  537. b_scale = dtype.get_scale(b_dtype)
  538. inpv = dtype.convert_to_qint8(inp_v * inp_scale, inp_dtype)
  539. wv = dtype.convert_to_qint8(w_v * w_scale, w_dtype)
  540. bv = dtype.convert_to_qint32(b_v * b_scale, b_dtype)
  541. inp_int8 = tensor(inpv, dtype=inp_dtype)
  542. w_int8 = Parameter(wv, dtype=w_dtype)
  543. b_int32 = Parameter(bv, dtype=b_dtype)
  544. inp_fp32 = inp_int8.astype("float32")
  545. w_fp32 = w_int8.astype("float32")
  546. b_fp32 = b_int32.astype("float32")
  547. def convert_to_nchw4(var):
  548. var = F.reshape(
  549. var, (var.shape[0], var.shape[1] // 4, 4, var.shape[2], var.shape[3])
  550. )
  551. var = F.transpose(var, (0, 1, 3, 4, 2))
  552. return var
  553. def run_conv2d(inp, w, b):
  554. O = F.conv2d(
  555. inp, w, b if has_bias else None, stride=(SH, SW), padding=(PH, PW),
  556. )
  557. if nonlinear_mode == "relu":
  558. return F.relu(O)
  559. else:
  560. return O
  561. def run_conv_bias(inp, w, b, format="NCHW"):
  562. b = b if has_bias else Parameter(np.zeros_like(b.numpy()))
  563. if format == "NCHW4":
  564. inp = convert_to_nchw4(inp)
  565. w = convert_to_nchw4(w)
  566. b = convert_to_nchw4(b)
  567. return F.quantized.conv_bias_activation(
  568. inp,
  569. w,
  570. b,
  571. stride=(SH, SW),
  572. padding=(PH, PW),
  573. dtype=out_dtype,
  574. nonlinear_mode=nonlinear_mode,
  575. )
  576. format = "NCHW4" if is_cuda_available() else "NCHW"
  577. expected = run_conv2d(inp_fp32, w_fp32, b_fp32)
  578. expected = expected.astype(out_dtype).astype("float32")
  579. result = run_conv_bias(inp_int8, w_int8, b_int32, format=format).astype(
  580. "float32"
  581. )
  582. if format == "NCHW4":
  583. result = F.transpose(result, (0, 1, 4, 2, 3))
  584. expected = F.flatten(expected)
  585. result = F.flatten(result)
  586. np.testing.assert_allclose(result.numpy(), expected.numpy(), atol=outp_scale)
  587. run(1, 4, 4, 24, 33, 1, 1, 2, 3, 1, 1, False)
  588. run(10, 12, 24, 46, 46, 1, 1, 2, 1, 3, 1, False)
  589. run(10, 36, 8, 46, 26, 2, 2, 2, 1, 1, 2, False)
  590. run(1, 4, 4, 24, 33, 1, 1, 2, 3, 1, 1)
  591. run(10, 12, 24, 46, 46, 1, 1, 2, 1, 3, 1)
  592. run(10, 36, 8, 46, 26, 2, 2, 2, 1, 1, 2)
  593. run(10, 36, 8, 46, 26, 2, 2, 2, 1, 1, 2, False, "relu")
  594. run(10, 36, 8, 46, 26, 2, 2, 2, 1, 1, 2, True, "relu")
  595. @pytest.mark.skipif(
  596. get_device_count_by_fork("gpu") > 0, reason="no int8 algorithm on cuda"
  597. )
  598. def test_batch_conv_bias():
  599. inp_scale = 1.5
  600. w_scale = 2.5
  601. outp_scale = 1.5
  602. inp_dtype = dtype.qint8(inp_scale)
  603. w_dtype = dtype.qint8(w_scale)
  604. b_dtype = dtype.qint32(inp_scale * w_scale)
  605. out_dtype = dtype.qint8(outp_scale)
  606. def run(
  607. N, IC, OC, IH, IW, KH, KW, PH, PW, SH, SW, has_bias=True,
  608. ):
  609. inp_v = np.random.normal(size=(N, IC, IH, IW))
  610. w_v = np.random.normal(size=(N, OC, IC, KH, KW))
  611. b_v = np.random.normal(size=(1, OC, 1, 1))
  612. inp_scale = dtype.get_scale(inp_dtype)
  613. w_scale = dtype.get_scale(w_dtype)
  614. b_scale = dtype.get_scale(b_dtype)
  615. inpv = dtype.convert_to_qint8(inp_v * inp_scale, inp_dtype)
  616. wv = dtype.convert_to_qint8(w_v * w_scale, w_dtype)
  617. bv = dtype.convert_to_qint32(b_v * b_scale, b_dtype)
  618. inp_int8 = tensor(inpv, dtype=inp_dtype)
  619. w_int8 = Parameter(wv, dtype=w_dtype)
  620. b_int32 = Parameter(bv, dtype=b_dtype)
  621. inp_fp32 = inp_int8.astype("float32")
  622. w_fp32 = w_int8.astype("float32")
  623. b_fp32 = b_int32.astype("float32")
  624. def run_batch_conv_bias(inp, w, b):
  625. b = b if has_bias else Parameter(np.zeros_like(b.numpy()))
  626. result = F.quantized.batch_conv_bias_activation(
  627. inp, w, b, stride=(SH, SW), padding=(PH, PW), dtype=out_dtype,
  628. )
  629. return result.astype("float32")
  630. expected = F.conv2d(inp_fp32, w_fp32[0], b_fp32 if has_bias else None)[0]
  631. expected = expected.astype(out_dtype).astype("float32")
  632. expected = F.flatten(expected)
  633. result = run_batch_conv_bias(inp_int8, w_int8, b_int32)
  634. result = F.flatten(result)
  635. np.testing.assert_allclose(result.numpy(), expected.numpy(), atol=outp_scale)
  636. run(1, 4, 4, 5, 5, 3, 3, 0, 0, 1, 1, True)
  637. def test_conv2d_io16c32():
  638. amp.enabled = True
  639. inp = tensor(np.random.randn(1, 3, 224, 224), dtype=np.float32)
  640. weight = tensor(np.random.randn(64, 3, 7, 7), dtype=np.float32)
  641. out = F.conv2d(inp, weight, None, (2, 2), (3, 3), (1, 1), 1)
  642. amp.enabled = False
  643. expected = F.conv2d(
  644. inp.astype("float16"),
  645. weight.astype("float16"),
  646. None,
  647. (2, 2),
  648. (3, 3),
  649. (1, 1),
  650. 1,
  651. compute_mode="float32",
  652. )
  653. assert out.dtype == np.float16
  654. assert expected.dtype == np.float16
  655. np.testing.assert_allclose(out.numpy(), expected.numpy())
  656. def test_conv2d_zero_stride_numpy_array():
  657. inp = np.random.randn(3, 224, 224).astype(np.float32)
  658. inp = inp[np.newaxis, :]
  659. inp = tensor(inp, dtype=np.float32)
  660. weight = tensor(np.random.randn(16, 3, 3, 3), dtype=np.float32)
  661. out = F.conv2d(inp, weight, None, (2, 2), (3, 3), (1, 1), 1)
  662. def test_conv3d_zero_stride_numpy_array():
  663. inp = np.random.randn(3, 224, 224, 224).astype(np.float32)
  664. inp = inp[np.newaxis, :]
  665. inp = tensor(inp, dtype=np.float32)
  666. weight = tensor(np.random.randn(16, 3, 3, 3, 3), dtype=np.float32)
  667. out = F.conv3d(inp, weight, None, (2, 2, 2), (3, 3, 3), (1, 1, 1), 1)
  668. out.numpy()
  669. def test_conv1d():
  670. inp = tensor(np.ones((2, 2, 4), dtype=np.float32))
  671. weight = tensor(np.ones((3, 2, 2), dtype=np.float32))
  672. out = F.conv1d(inp, weight, None, 2, 0, 1, 1)
  673. np.testing.assert_equal(
  674. out.numpy(),
  675. np.array(
  676. [[[4, 4], [4, 4], [4, 4]], [[4, 4], [4, 4], [4, 4]]], dtype=np.float32
  677. ),
  678. )
  679. def test_batchnorm2d_io16c32():
  680. amp.enabled = True
  681. inp = tensor(np.random.randn(1, 3, 224, 224), dtype=np.float32)
  682. weight = tensor(np.ones((1, 3, 1, 1)), dtype=np.float32)
  683. bias = tensor(np.zeros((1, 3, 1, 1)), dtype=np.float32)
  684. out = F.batch_norm(inp, weight=weight, bias=bias, training=True, inplace=False)
  685. amp.enabled = False
  686. expected = F.batch_norm(
  687. inp.astype("float16"),
  688. weight=weight,
  689. bias=bias,
  690. training=True,
  691. inplace=False,
  692. compute_mode="float32",
  693. )
  694. assert out.dtype == np.float16
  695. assert expected.dtype == np.float16
  696. np.testing.assert_allclose(out.numpy(), expected.numpy())
  697. def test_conv3d():
  698. inp = tensor(np.ones((2, 2, 4, 4, 4), dtype=np.float32))
  699. weight = tensor(np.ones((3, 2, 2, 2, 2), dtype=np.float32))
  700. out = F.conv3d(inp, weight, None, 2, 0, 1, 1)
  701. print(out.numpy().shape)
  702. np.testing.assert_equal(
  703. out.numpy(), np.ones((2, 3, 2, 2, 2), dtype=np.float32) * 16
  704. )
  705. def test_condtake():
  706. x = np.array([[1, 2, 3], [4, 5, 6]])
  707. y = np.array([[True, False, True], [False, True, True]])
  708. xx = tensor(x)
  709. yy = tensor(y)
  710. val, idx = F.cond_take(yy, xx)
  711. np.testing.assert_equal(val.numpy(), x[y])
  712. np.testing.assert_equal(idx.numpy(), np.where(y.reshape(-1))[0])
  713. def test_condtake_is_same():
  714. op1 = builtin.CondTake()
  715. op2 = builtin.CondTake()
  716. assert op1 == op2
  717. def test_nms_is_same():
  718. op1 = builtin.NMSKeep(0.7, 100)
  719. op2 = builtin.NMSKeep(0.7, 100)
  720. op3 = builtin.NMSKeep(0.8, 100)
  721. op4 = builtin.NMSKeep(0.7, 200)
  722. assert op1 == op2
  723. assert op1 != op3
  724. assert op1 != op4
  725. assert op3 != op4
  726. def test_argmxx_on_inf():
  727. def run_argmax():
  728. x = F.zeros((100, 100))
  729. x[:] = -float("inf")
  730. idxs = F.argmax(x, axis=0)
  731. return idxs
  732. def run_argmin():
  733. x = F.zeros((100, 100))
  734. x[:] = float("inf")
  735. idxs = F.argmin(x, axis=0)
  736. return idxs
  737. assert all(run_argmax() >= 0)
  738. assert all(run_argmin() >= 0)
  739. def test_deformable_psroi_pooling():
  740. inp = np.random.random((1, 256, 64, 64)).astype("float32")
  741. rois = np.random.random((1, 5)).astype("float32")
  742. trans = np.random.random((24, 2, 7, 7)).astype("float32")
  743. pooled_h = 7
  744. pooled_w = 7
  745. sample_per_part = 4
  746. no_trans = False
  747. part_size = 7
  748. spatial_scale = 1.0 / 64
  749. trans_std = 0.1
  750. y = F.deformable_psroi_pooling(
  751. tensor(inp),
  752. tensor(rois),
  753. tensor(trans),
  754. no_trans,
  755. part_size,
  756. pooled_h,
  757. pooled_w,
  758. sample_per_part,
  759. spatial_scale,
  760. trans_std,
  761. )
  762. def test_cvt_color():
  763. def rgb2gray(rgb):
  764. return np.dot(rgb[..., :3], [0.299, 0.587, 0.114])
  765. inp = np.random.randn(3, 3, 3, 3).astype(np.float32)
  766. out = np.expand_dims(rgb2gray(inp), 3).astype(np.float32)
  767. x = tensor(inp)
  768. y = F.vision.cvt_color(x, mode="RGB2GRAY")
  769. np.testing.assert_allclose(y.numpy(), out, atol=1e-5)
  770. @pytest.mark.parametrize("val", [2, [2,], [2, 3]])
  771. def test_ones(val):
  772. shp = tensor(val)
  773. np_shp = np.array(val)
  774. np.testing.assert_equal(F.ones(shp), np.ones(np_shp))
  775. def test_assert_equal():
  776. shape = (2, 3, 4, 5)
  777. x = F.ones(shape, dtype=np.float32)
  778. y = F.zeros(shape, dtype=np.float32) + 1.00001
  779. z = F.utils._assert_equal(x, y)
  780. def test_assert_not_equal():
  781. shape = (2, 3, 4, 5)
  782. x = F.ones(shape, dtype=np.float32)
  783. y = F.zeros(shape, dtype=np.float32) + 1.1
  784. with pytest.raises(RuntimeError):
  785. z = F.utils._assert_equal(x, y)
  786. def test_neg_axis():
  787. x = tensor(np.random.normal(0, 1, (32, 5)))
  788. y = F.argmax(x, axis=-1)
  789. yy = F.argmax(x, axis=1)
  790. np.testing.assert_equal(y.numpy(), yy.numpy())
  791. y = F.argmax(x, axis=(-1, -2))
  792. yy = F.argmax(x, axis=(0, 1))
  793. np.testing.assert_equal(y.numpy(), yy.numpy())
  794. y = F.argmin(x, axis=(-1, -2))
  795. yy = F.argmin(x, axis=(0, 1))
  796. np.testing.assert_equal(y.numpy(), yy.numpy())
  797. def test_sliding_window():
  798. N, C, H, W = 2, 3, 7, 8
  799. inp = np.random.normal(size=(N, C, H, W))
  800. ph, pw = 1, 2
  801. sh, sw = 2, 1
  802. wh, ww = 3, 2
  803. dh, dw = 1, 3
  804. s = lambda i, p, s, d, w: (i + p * 2 - (w - 1) * d - 1) // s + 1
  805. inp_pad = np.zeros((N, C, H + ph * 2, W + pw * 2))
  806. inp_pad[:, :, ph : H + ph, pw : W + pw] = inp
  807. gt_out = np.empty(
  808. (N, C, s(H, ph, sh, dh, wh), s(W, pw, sw, dw, ww), wh, ww), dtype=np.float32
  809. )
  810. for n, c, oh, ow in itertools.product(*map(range, gt_out.shape[:4])):
  811. ih, iw = oh * sh, ow * sw
  812. gt_out[n, c, oh, ow, :] = inp_pad[
  813. n, c, ih : ih + (wh - 1) * dh + 1 : dh, iw : iw + (ww - 1) * dw + 1 : dw
  814. ]
  815. out = F.sliding_window(
  816. tensor(inp), (wh, ww), padding=(ph, pw), stride=(sh, sw), dilation=(dh, dw)
  817. )
  818. np.testing.assert_equal(gt_out, out.numpy())
  819. def test_sliding_window_transpose():
  820. N, C, H, W = 2, 3, 7, 8
  821. ph, pw = 1, 2
  822. sh, sw = 2, 1
  823. wh, ww = 3, 2
  824. dh, dw = 1, 3
  825. s = lambda i, p, s, d, w: (i + p * 2 - (w - 1) * d - 1) // s + 1
  826. inp = np.random.normal(
  827. size=(N, C, s(H, ph, sh, dh, wh), s(W, pw, sw, dw, ww), wh, ww)
  828. ).astype(np.float32)
  829. gt_out = np.zeros((N, C, H, W), dtype=np.float32)
  830. for n, c in itertools.product(*map(range, inp.shape[:2])):
  831. oh = 0
  832. for ih in range(-ph, H + ph - dh * (wh - 1), sh):
  833. ow = 0
  834. for iw in range(-pw, W + pw - dw * (ww - 1), sw):
  835. for kh, kw in itertools.product(*map(range, inp.shape[-2:])):
  836. ih2 = ih + dh * kh
  837. iw2 = iw + dw * kw
  838. if ih2 >= 0 and ih2 < H and iw2 >= 0 and iw2 < W:
  839. gt_out[n, c, ih2, iw2] += inp[n, c, oh, ow, kh, kw]
  840. ow += 1
  841. oh += 1
  842. out = F.sliding_window_transpose(
  843. tensor(inp),
  844. (H, W),
  845. (wh, ww),
  846. padding=(ph, pw),
  847. stride=(sh, sw),
  848. dilation=(dh, dw),
  849. )
  850. np.testing.assert_equal(gt_out, out.numpy())

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台