You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_functional.py 46 kB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411
  1. # -*- coding: utf-8 -*-
  2. import itertools
  3. import platform
  4. from functools import partial
  5. import numpy as np
  6. import pytest
  7. from utils import opr_test
  8. import megengine.amp as amp
  9. import megengine.config as config
  10. import megengine.core.ops.builtin as builtin
  11. import megengine.core.tensor.dtype as dtype
  12. import megengine.functional as F
  13. import megengine.jit as jit
  14. from megengine import Parameter, Tensor, is_cuda_available, tensor
  15. from megengine.core._trace_option import use_symbolic_shape
  16. from megengine.core.autodiff.grad import Grad
  17. from megengine.core.tensor.utils import make_shape_tuple
  18. from megengine.device import get_device_count
  19. from megengine.jit.tracing import trace
  20. from megengine.module import ConvTranspose2d, ConvTranspose3d, LayerNorm
  21. _assert_allclose = partial(np.testing.assert_allclose, atol=5e-6, rtol=5e-6)
  22. def test_where():
  23. maskv0 = np.array([[1, 0], [0, 1]], dtype=np.bool_)
  24. xv0 = np.array([[1, np.inf], [np.nan, 4]], dtype=np.float32)
  25. yv0 = np.array([[5, 6], [7, 8]], dtype=np.float32)
  26. maskv1 = np.array([[1, 0, 1], [1, 0, 0], [1, 1, 0]], dtype=np.bool_)
  27. xv1 = np.array([[1, np.inf, 2], [0, np.nan, 4], [1, 5, 7]], dtype=np.float32)
  28. yv1 = np.array([[5, 6, 9], [2, 7, 8], [2, 1, 9]], dtype=np.float32)
  29. maskv2 = np.array([1, 1, 1], dtype=np.bool_)
  30. xv2 = np.array([1, 3, 2], dtype=np.float32)
  31. yv2 = np.array([5, 6, 9], dtype=np.float32)
  32. maskv3 = np.array([0, 0, 0], dtype=np.bool_)
  33. xv3 = np.array([1, 3, 2], dtype=np.float32)
  34. yv3 = np.array([5, 6, 9], dtype=np.float32)
  35. maskv4 = np.array(1, dtype=np.bool_)
  36. xv4 = np.array(1, dtype=np.float32)
  37. yv4 = np.array(0, dtype=np.float32)
  38. cases = [
  39. {"input": [maskv0, xv0, yv0]},
  40. {"input": [maskv1, xv1, yv1]},
  41. {"input": [maskv2, xv2, yv2]},
  42. {"input": [maskv3, xv3, yv3]},
  43. {"input": [maskv4, xv4, yv4]},
  44. ]
  45. opr_test(cases, F.where, ref_fn=np.where, test_trace=True)
  46. def test_dropout():
  47. from megengine.autodiff import GradManager
  48. from megengine.core._imperative_rt.ops import set_global_rng_seed
  49. def test_dropout_with_shape(shape, rate):
  50. data = tensor(np.ones(shape, dtype=np.float32))
  51. gm = GradManager().attach([data])
  52. with gm:
  53. out = F.nn.dropout(data, rate, training=True)
  54. gm.backward(out, tensor(np.ones(shape, dtype=np.float32)))
  55. if len(shape) != 0:
  56. assert not out.numpy().all()
  57. np.testing.assert_allclose(out.numpy(), data.grad.numpy(), 1e-7, 1e-7)
  58. def test_multiple_dropout(shape, rate):
  59. data = tensor(np.ones(shape, dtype=np.float32))
  60. gm = GradManager().attach([data])
  61. with gm:
  62. out1 = F.nn.dropout(data, rate, training=True)
  63. out2 = F.nn.dropout(out1, rate, training=True)
  64. out3 = F.nn.dropout(out2, rate, training=True)
  65. gm.backward(out3, tensor(np.ones(shape, dtype=np.float32)))
  66. np.testing.assert_allclose(out3.numpy(), data.grad.numpy(), 1e-7, 1e-7)
  67. def test_dropout_seed(shape, rate):
  68. data = tensor(np.random.randn(*shape), dtype="float32")
  69. set_global_rng_seed(111)
  70. out1 = F.nn.dropout(data, rate, training=True)
  71. out2 = F.nn.dropout(data, rate, training=True)
  72. assert not (out1.numpy() == out2.numpy()).all()
  73. set_global_rng_seed(111)
  74. out3 = F.nn.dropout(data, rate, training=True)
  75. assert (out1.numpy() == out3.numpy()).all()
  76. set_global_rng_seed(222)
  77. out4 = F.nn.dropout(data, rate, training=True)
  78. assert not (out1.numpy() == out4.numpy()).all()
  79. test_dropout_with_shape([], 0.4)
  80. test_dropout_with_shape([13, 17, 63, 21], 0.4)
  81. test_dropout_with_shape([16, 32, 64], 0.3)
  82. test_multiple_dropout([1024], 0.2)
  83. test_dropout_seed([16, 32], 0.2)
  84. def test_matinv():
  85. shape1 = (5, 5)
  86. shape2 = (3, 9, 9)
  87. data1 = np.random.random(shape1).astype("float32")
  88. data2 = np.random.random(shape2).astype("float32")
  89. # make matrix diagonally dominant for numerical stability
  90. data1 += (np.eye(shape1[0]) * shape1[0]).astype("float32")
  91. data2 += np.broadcast_to((np.eye(shape2[1]) * shape2[1]).astype("float32"), shape2)
  92. cases = [
  93. {"input": data1},
  94. {"input": data2},
  95. ]
  96. opr_test(
  97. cases,
  98. F.matinv,
  99. compare_fn=lambda x, y: np.testing.assert_allclose(x.numpy(), y, rtol=1e-4),
  100. ref_fn=np.linalg.inv,
  101. )
  102. def test_matmul():
  103. shape1 = 3
  104. shape2 = 3
  105. shape3 = (3, 5)
  106. shape4 = (5, 6)
  107. data1 = np.random.random(shape1).astype("float32")
  108. data2 = np.random.random(shape2).astype("float32")
  109. data3 = np.random.random(shape3).astype("float32")
  110. data4 = np.random.random(shape4).astype("float32")
  111. cases = [
  112. {"input": [data1, data2]},
  113. {"input": [data2, data3]},
  114. {"input": [data3, data4]},
  115. ]
  116. opr_test(cases, F.matmul, ref_fn=np.matmul)
  117. batch_size = 10
  118. shape1 = (2,)
  119. shape2 = (batch_size, 2, 3)
  120. shape3 = (batch_size, 3, 4)
  121. shape4 = (batch_size, 10, 4, 2)
  122. shape5 = (batch_size, 10, 2, 4)
  123. data1 = np.random.random(shape1).astype("float32")
  124. data2 = np.random.random(shape2).astype("float32")
  125. data3 = np.random.random(shape3).astype("float32")
  126. data4 = np.random.random(shape4).astype("float32")
  127. data5 = np.random.random(shape5).astype("float32")
  128. cases = [
  129. {"input": [data1, data2]},
  130. {"input": [data2, data3]},
  131. {"input": [data3, data4]},
  132. {"input": [data4, data5]},
  133. ]
  134. opr_test(cases, F.matmul, ref_fn=np.matmul)
  135. opr_test(
  136. [{"input": [data1, data4]}],
  137. F.matmul,
  138. ref_fn=lambda x, y: np.matmul(x, y.transpose(0, 1, 3, 2)),
  139. transpose_b=True,
  140. )
  141. opr_test(
  142. [{"input": [data3, data2]}],
  143. F.matmul,
  144. ref_fn=lambda x, y: np.matmul(x.transpose(0, 2, 1), y.transpose(0, 2, 1)),
  145. transpose_a=True,
  146. transpose_b=True,
  147. )
  148. @pytest.mark.parametrize(
  149. "shape_a, shape_b", [((0,), (0,)), ((10, 0), (0, 10)), ((3, 10, 0), (3, 0, 10)),],
  150. )
  151. @pytest.mark.parametrize("is_symbolic", [None, True, False])
  152. def test_matmul_empty_tensor(shape_a, shape_b, is_symbolic):
  153. def func(a, b):
  154. return F.matmul(a, b)
  155. if is_symbolic is not None:
  156. func = jit.trace(symbolic=is_symbolic)(func)
  157. a = tensor(np.random.randn(*shape_a))
  158. b = tensor(np.random.randn(*shape_b))
  159. for _ in range(3):
  160. out = func(a, b)
  161. assert np.all(out.numpy() == 0)
  162. if is_symbolic is None:
  163. break
  164. def test_interpolate():
  165. def linear_interpolate():
  166. inp = tensor(np.arange(1, 3, dtype=np.float32).reshape(1, 1, 2))
  167. test_func = lambda inp: F.vision.interpolate(
  168. inp, scale_factor=2.0, mode="linear"
  169. )
  170. ref_func = lambda inp: F.vision.interpolate(inp, 4, mode="linear").numpy()
  171. cases = [{"input": inp}]
  172. opr_test(cases, test_func, ref_fn=ref_func, test_trace=True)
  173. def many_batch_interpolate():
  174. inp = tensor(np.arange(1, 9, dtype=np.float32).reshape(2, 1, 2, 2))
  175. test_func = lambda inp: F.vision.interpolate(inp, scale_factor=2.0)
  176. ref_func = lambda inp: F.vision.interpolate(inp, [4, 4]).numpy()
  177. cases = [{"input": inp}]
  178. opr_test(cases, test_func, ref_fn=ref_func, test_trace=True)
  179. def assign_corner_interpolate():
  180. inp = tensor(np.arange(1, 5, dtype=np.float32).reshape(1, 1, 2, 2))
  181. test_func = lambda inp: F.vision.interpolate(inp, [4, 4])
  182. ref_func = lambda inp: F.vision.interpolate(inp, scale_factor=2.0).numpy()
  183. cases = [{"input": inp}]
  184. opr_test(cases, test_func, ref_fn=ref_func, test_trace=True)
  185. def error_shape_linear_interpolate():
  186. inp = tensor(np.arange(1, 5, dtype=np.float32).reshape(1, 1, 2, 2))
  187. with pytest.raises(ValueError):
  188. F.vision.interpolate(inp, scale_factor=2.0, mode="linear")
  189. def inappropriate_scale_linear_interpolate():
  190. inp = tensor(np.arange(1, 3, dtype=np.float32).reshape(1, 1, 2))
  191. with pytest.raises(ValueError):
  192. F.vision.interpolate(inp, scale_factor=[2.0, 3.0], mode="linear")
  193. linear_interpolate()
  194. many_batch_interpolate()
  195. assign_corner_interpolate()
  196. error_shape_linear_interpolate()
  197. # inappropriate_scale_linear_interpolate()
  198. def _save_to(self, name="grad"):
  199. def callback(grad):
  200. setattr(self, name, grad)
  201. return callback
  202. def _gen_roi_inp():
  203. inp_feat = np.random.randn(2, 32, 256, 256)
  204. rois = np.zeros((4, 5))
  205. rois[:, 0] = [0, 0, 1, 1]
  206. rois[:, 1:3] = np.random.rand(4, 2) * 100
  207. rois[:, 3:] = np.random.rand(4, 2) * 100 + 150
  208. inp_feat = tensor(inp_feat)
  209. rois = tensor(rois)
  210. return inp_feat, rois
  211. def test_roi_align():
  212. inp_feat, rois = _gen_roi_inp()
  213. with Grad() as grad:
  214. grad.wrt(inp_feat, callback=_save_to(inp_feat))
  215. output_shape = (7, 7)
  216. out_feat = F.vision.roi_align(
  217. inp_feat,
  218. rois,
  219. output_shape=output_shape,
  220. mode="average",
  221. spatial_scale=1.0 / 4,
  222. sample_points=2,
  223. aligned=True,
  224. )
  225. assert make_shape_tuple(out_feat.shape) == (
  226. rois.shape[0],
  227. inp_feat.shape[1],
  228. *output_shape,
  229. )
  230. grad(out_feat, tensor(F.ones_like(out_feat)))
  231. assert make_shape_tuple(inp_feat.grad.shape) == make_shape_tuple(inp_feat.shape)
  232. @pytest.mark.parametrize("shapes", [((2, 0, 26, 26), (4, 5)), ((2, 3, 26, 26), (0, 5))])
  233. @pytest.mark.parametrize("is_tracing", [False, True])
  234. def test_roi_align_empty(shapes, is_tracing):
  235. inp_feat = tensor(np.random.randn(*(shapes[0])))
  236. rois = tensor(np.random.random(shapes[1]))
  237. output_shape = (7, 7)
  238. def func(inp, rois):
  239. out_feat = F.vision.roi_align(
  240. inp_feat,
  241. rois,
  242. output_shape=output_shape,
  243. mode="average",
  244. spatial_scale=1.0 / 4,
  245. sample_points=2,
  246. aligned=True,
  247. )
  248. return out_feat
  249. if is_tracing:
  250. func = jit.trace(func)
  251. for _ in range(3):
  252. out_feat = func(inp_feat, rois)
  253. assert make_shape_tuple(out_feat.shape) == (
  254. rois.shape[0],
  255. inp_feat.shape[1],
  256. *output_shape,
  257. )
  258. def _gen_correlation(random=True, constant=1, image_shape=(2, 1, 160, 160)):
  259. if random:
  260. inp_feat1 = np.random.randn(
  261. image_shape[0], image_shape[1], image_shape[2], image_shape[3]
  262. )
  263. inp_feat2 = np.random.randn(
  264. image_shape[0], image_shape[1], image_shape[2], image_shape[3]
  265. )
  266. else:
  267. inp_feat1 = np.ones(image_shape) * constant
  268. inp_feat2 = np.ones(image_shape) * constant
  269. return tensor(inp_feat1), tensor(inp_feat2)
  270. def test_correlation():
  271. ##test case 0 check the grad shape
  272. data1, data2 = _gen_correlation()
  273. with Grad() as grad:
  274. grad.wrt(data1, callback=_save_to(data1))
  275. out_feat = F.vision.correlation(
  276. data1,
  277. data2,
  278. kernel_size=5,
  279. max_displacement=4,
  280. stride1=2,
  281. stride2=2,
  282. pad_size=2,
  283. is_multiply=True,
  284. )
  285. grad(out_feat, tensor(F.ones_like(out_feat)))
  286. assert make_shape_tuple(data1.grad.shape) == make_shape_tuple(data1.shape)
  287. ##test case 1 from https://github.com/NVIDIA/flownet2-pytorch/issues/194
  288. data1, data2 = _gen_correlation(random=False, image_shape=(1, 1, 3, 3))
  289. out_feat = F.vision.correlation(
  290. data1,
  291. data2,
  292. kernel_size=3,
  293. max_displacement=0,
  294. stride1=1,
  295. stride2=1,
  296. pad_size=0,
  297. is_multiply=True,
  298. )
  299. assert abs(out_feat.sum() - 1) < 1e-9
  300. ##test case 2 check same image subduction
  301. data1, data2 = _gen_correlation(random=False, image_shape=(1, 1, 3, 3))
  302. out_feat = F.vision.correlation(
  303. data1,
  304. data2,
  305. kernel_size=3,
  306. max_displacement=0,
  307. stride1=1,
  308. stride2=1,
  309. pad_size=0,
  310. is_multiply=False,
  311. )
  312. assert out_feat.sum() < 1e-9
  313. ##test case 3 check same image subduction
  314. data1, data2 = _gen_correlation(random=False, image_shape=(1, 1, 3, 3))
  315. out_feat = F.vision.correlation(
  316. data1,
  317. data2,
  318. kernel_size=3,
  319. max_displacement=0,
  320. stride1=1,
  321. stride2=1,
  322. pad_size=0,
  323. is_multiply=False,
  324. )
  325. assert out_feat.sum() < 1e-9
  326. ##test case 4 check correlation
  327. data1, _ = _gen_correlation(
  328. random=False, image_shape=(1, 1, 220, 220), constant=2.0
  329. )
  330. _, data2 = _gen_correlation(
  331. random=False, image_shape=(1, 1, 220, 220), constant=1.0
  332. )
  333. out_feat = F.vision.correlation(
  334. data1,
  335. data2,
  336. kernel_size=3,
  337. max_displacement=2,
  338. stride1=1,
  339. stride2=2,
  340. pad_size=0,
  341. is_multiply=False,
  342. )
  343. assert abs(out_feat.mean() - 1) < 1e-9
  344. def test_roi_pooling():
  345. inp_feat, rois = _gen_roi_inp()
  346. with Grad() as grad:
  347. grad.wrt(inp_feat, callback=_save_to(inp_feat))
  348. output_shape = (7, 7)
  349. out_feat = F.vision.roi_pooling(
  350. inp_feat, rois, output_shape=output_shape, mode="max", scale=1.0 / 4,
  351. )
  352. assert make_shape_tuple(out_feat.shape) == (
  353. rois.shape[0],
  354. inp_feat.shape[1],
  355. *output_shape,
  356. )
  357. grad(out_feat, tensor(F.ones_like(out_feat)))
  358. assert make_shape_tuple(inp_feat.grad.shape) == make_shape_tuple(inp_feat.shape)
  359. def test_adaptive_avg_pool2d():
  360. inp = tensor(np.arange(0, 16, dtype=np.float32).reshape(1, 1, 4, 4))
  361. oshp = (2, 2)
  362. with Grad() as grad:
  363. grad.wrt(inp, callback=_save_to(inp))
  364. outp = F.adaptive_avg_pool2d(inp, oshp,)
  365. assert make_shape_tuple(outp.shape) == (inp.shape[0], inp.shape[1], *oshp,)
  366. np.testing.assert_equal(
  367. outp.numpy(), np.array([[[[2.5, 4.5], [10.5, 12.5]]]], dtype=np.float32)
  368. )
  369. grad(outp, tensor(F.ones_like(outp)))
  370. assert make_shape_tuple(inp.grad.shape) == make_shape_tuple(inp.shape)
  371. np.testing.assert_equal(
  372. inp.grad.numpy(),
  373. np.array(
  374. [
  375. [
  376. [
  377. [0.25, 0.25, 0.25, 0.25],
  378. [0.25, 0.25, 0.25, 0.25],
  379. [0.25, 0.25, 0.25, 0.25],
  380. [0.25, 0.25, 0.25, 0.25],
  381. ]
  382. ]
  383. ],
  384. dtype=np.float32,
  385. ),
  386. )
  387. def test_adaptive_max_pool2d():
  388. inp = tensor(np.arange(0, 16, dtype=np.float32).reshape(1, 1, 4, 4))
  389. oshp = (2, 2)
  390. with Grad() as grad:
  391. grad.wrt(inp, callback=_save_to(inp))
  392. outp = F.adaptive_max_pool2d(inp, oshp,)
  393. assert make_shape_tuple(outp.shape) == (inp.shape[0], inp.shape[1], *oshp,)
  394. np.testing.assert_equal(
  395. outp.numpy(), np.array([[[[5, 7], [13, 15]]]], dtype=np.float32)
  396. )
  397. grad(outp, tensor(F.ones_like(outp)))
  398. assert make_shape_tuple(inp.grad.shape) == make_shape_tuple(inp.shape)
  399. np.testing.assert_equal(
  400. inp.grad.numpy(),
  401. np.array(
  402. [
  403. [
  404. [
  405. [0.0, 0.0, 0.0, 0.0],
  406. [0.0, 1.0, 0.0, 1.0],
  407. [0.0, 0.0, 0.0, 0.0],
  408. [0.0, 1.0, 0.0, 1.0],
  409. ]
  410. ]
  411. ],
  412. dtype=np.float32,
  413. ),
  414. )
  415. def test_one_hot():
  416. def onehot_low_dimension():
  417. inp = tensor(np.arange(1, 4, dtype=np.int32))
  418. out = F.one_hot(inp, num_classes=4)
  419. np.testing.assert_allclose(
  420. out.numpy(), np.eye(4, dtype=np.int32)[np.arange(1, 4, dtype=np.int32)]
  421. )
  422. def onehot_high_dimension():
  423. arr = np.array(
  424. [[3, 2, 4, 4, 2, 4, 0, 4, 4, 1], [4, 1, 1, 3, 2, 2, 4, 2, 4, 3]],
  425. dtype=np.int32,
  426. )
  427. inp = tensor(arr)
  428. out = F.one_hot(inp, 10)
  429. np.testing.assert_allclose(out.numpy(), np.eye(10, dtype=np.int32)[arr])
  430. onehot_low_dimension()
  431. onehot_high_dimension()
  432. def test_interpolate_fastpath():
  433. # check shape
  434. test_cases = [
  435. [(1, 1, 10, 10), (5, 5)],
  436. [(1, 3, 10, 10), (20, 20)],
  437. [(10, 1, 10, 10), (1, 1)],
  438. [(10, 10, 1, 1), (10, 10)],
  439. ]
  440. for inp_shape, target_shape in test_cases:
  441. x = tensor(np.random.randn(*inp_shape), dtype=np.float32)
  442. out = F.vision.interpolate(x, target_shape, mode="bilinear")
  443. assert out.shape[0] == x.shape[0] and out.shape[1] == x.shape[1]
  444. assert out.shape[2] == target_shape[0] and out.shape[3] == target_shape[1]
  445. # check value
  446. x = tensor(np.ones((3, 3, 10, 10)), dtype=np.float32)
  447. out = F.vision.interpolate(x, (15, 5), mode="bilinear")
  448. np.testing.assert_equal(out.numpy(), np.ones((3, 3, 15, 5)).astype(np.float32))
  449. np_x = np.arange(32)
  450. x = tensor(np_x).astype(np.float32).reshape(1, 1, 32, 1)
  451. out = F.vision.interpolate(x, (1, 1), mode="bilinear")
  452. np.testing.assert_equal(out.item(), np_x.mean())
  453. @pytest.mark.parametrize("dt", [np.float32, np.int8, np.uint8, np.float16])
  454. def test_warp_perspective(dt):
  455. inp_shape = (1, 1, 4, 4)
  456. x = tensor(np.arange(16, dtype=dt).reshape(inp_shape))
  457. M_shape = (1, 3, 3)
  458. # M defines a translation: dst(1, 1, h, w) = rst(1, 1, h+1, w+1)
  459. M = tensor(
  460. np.array(
  461. [[1.0, 0.0, 1.0], [0.0, 1.0, 1.0], [0.0, 0.0, 1.0]], dtype=np.float32
  462. ).reshape(M_shape)
  463. )
  464. outp = F.vision.warp_perspective(x, M, (2, 2))
  465. np.testing.assert_equal(outp.numpy(), np.array([[[[5, 6], [9, 10]]]], dtype=dt))
  466. @pytest.mark.parametrize("dt", [np.float32, np.int8, np.uint8, np.float16])
  467. def test_warp_perspective_mat_idx(dt):
  468. inp_shape = (2, 1, 4, 4)
  469. x = tensor(np.arange(32, dtype=dt).reshape(inp_shape))
  470. M_shape = (1, 3, 3)
  471. # M defines a translation: dst(1, 1, h, w) = rst(1, 1, h+1, w+1)
  472. M = tensor(
  473. np.array(
  474. [[1.0, 0.0, 1.0], [0.0, 1.0, 1.0], [0.0, 0.0, 1.0]], dtype=np.float32
  475. ).reshape(M_shape)
  476. )
  477. M = F.concat([M,] * 4, 0)
  478. outp = F.vision.warp_perspective(x, M, (2, 2), mat_idx=[0, 1, 1, 0])
  479. np.testing.assert_equal(
  480. outp.numpy(),
  481. np.array(
  482. [
  483. [[[5, 6], [9, 10]]],
  484. [[[21, 22], [25, 26]]],
  485. [[[21, 22], [25, 26]]],
  486. [[[5, 6], [9, 10]]],
  487. ],
  488. dtype=dt,
  489. ),
  490. )
  491. def test_warp_affine():
  492. inp_shape = (1, 3, 3, 3)
  493. x = tensor(np.arange(27, dtype=np.float32).reshape(inp_shape))
  494. weightv = [[[1.26666667, 0.6, -83.33333333], [-0.33333333, 1, 66.66666667]]]
  495. outp = F.vision.warp_affine(x, tensor(weightv), (2, 2), border_mode="wrap")
  496. res = np.array(
  497. [
  498. [
  499. [[7.875, 8.875, 9.875], [8.90625, 9.90625, 10.90625]],
  500. [[18.75, 19.75, 20.75], [14.90625, 15.90625, 16.90625]],
  501. ]
  502. ],
  503. dtype=np.float32,
  504. )
  505. if not is_cuda_available():
  506. np.testing.assert_almost_equal(outp.numpy(), res, 5)
  507. def test_remap():
  508. inp_shape = (1, 1, 4, 4)
  509. inp = tensor(np.arange(16, dtype=np.float32).reshape(inp_shape))
  510. map_xy_shape = (1, 2, 2, 2)
  511. map_xy = tensor(
  512. np.array(
  513. [[[1.0, 0.0], [0.0, 1.0]], [[0.0, 1.0], [0.0, 1.0]]], dtype=np.float32
  514. ).reshape(map_xy_shape)
  515. )
  516. outp = F.vision.remap(inp, map_xy)
  517. np.testing.assert_equal(
  518. outp.numpy(), np.array([[[[1.0, 4.0], [4.0, 4.0]]]], dtype=np.float32)
  519. )
  520. def test_binary_cross_entropy():
  521. data1_shape = (2, 2)
  522. label1_shape = (2, 2)
  523. data2_shape = (2, 3)
  524. label2_shape = (2, 3)
  525. def sigmoid(x):
  526. return 1 / (1 + np.exp(-x))
  527. def compare_fn(x, y):
  528. np.testing.assert_allclose(x.numpy(), y, atol=5e-4)
  529. np.random.seed(123)
  530. data1 = np.random.uniform(size=data1_shape).astype(np.float32)
  531. label1 = np.random.uniform(size=label1_shape).astype(np.float32)
  532. expect1 = np.array(0.6361, dtype=np.float32)
  533. np.random.seed(123)
  534. data2 = np.random.uniform(size=data2_shape).astype(np.float32)
  535. label2 = np.random.uniform(size=label2_shape).astype(np.float32)
  536. expect2 = np.array(0.6750, dtype=np.float32)
  537. cases = [
  538. {"input": [data1, label1], "output": expect1,},
  539. {"input": [data2, label2], "output": expect2,},
  540. ]
  541. opr_test(cases, F.nn.binary_cross_entropy, compare_fn=compare_fn)
  542. cases = [
  543. {"input": [sigmoid(data1), label1], "output": expect1,},
  544. {"input": [sigmoid(data2), label2], "output": expect2,},
  545. ]
  546. opr_test(
  547. cases,
  548. partial(F.nn.binary_cross_entropy, with_logits=False),
  549. compare_fn=compare_fn,
  550. )
  551. def test_hinge_loss():
  552. np.random.seed(123)
  553. # case with L1 norm
  554. cases = []
  555. for shape in [(2, 2), (2, 3)]:
  556. data = np.random.uniform(size=shape).astype(np.float32)
  557. label = 2 * np.random.randint(0, 1, size=shape).astype(np.float32) - 1
  558. expect = np.clip(0, np.inf, 1 - data * label).sum(axis=1).mean()
  559. cases.append({"input": [data, label], "output": expect})
  560. opr_test(cases, F.nn.hinge_loss)
  561. # cases with L2 norm
  562. cases = []
  563. for shape in [(2, 2), (2, 3)]:
  564. data = np.random.uniform(size=shape).astype(np.float32)
  565. label = 2 * np.random.randint(0, 1, size=shape).astype(np.float32) - 1
  566. expect = ((np.clip(0, np.inf, 1 - data * label) ** 2).sum(axis=1)).mean()
  567. cases.append({"input": [data, label], "output": expect})
  568. def hinge_loss_with_l2_norm(pred, label):
  569. return F.nn.hinge_loss(pred, label, "L2")
  570. opr_test(cases, hinge_loss_with_l2_norm)
  571. @pytest.mark.parametrize("is_symbolic", [None, False, True])
  572. def test_nms(is_symbolic):
  573. def fn(inp, scores):
  574. return F.vision.nms(
  575. inp,
  576. scores=scores,
  577. iou_thresh=0.5,
  578. max_output=None if is_symbolic is None else 4,
  579. )
  580. if is_symbolic is not None:
  581. fn = jit.trace(symbolic=is_symbolic)(fn)
  582. x = np.array(
  583. [
  584. [0, 0, 100, 100],
  585. [10, 10, 100, 100],
  586. [50, 50, 100, 100],
  587. [100, 100, 150, 150],
  588. ],
  589. dtype=np.float32,
  590. )
  591. inp = tensor(x)
  592. scores = tensor([0.5, 0.8, 0.9, 0.6], dtype=np.float32)
  593. for _ in range(3):
  594. result = fn(inp, scores=scores)
  595. np.testing.assert_equal(result.numpy(), np.array([2, 1, 3], dtype=np.int32))
  596. x = np.array([], dtype=np.float32,).reshape(0, 4)
  597. inp = tensor(x)
  598. scores = tensor([], dtype=np.float32)
  599. for _ in range(3):
  600. result = fn(inp, scores=scores)
  601. np.testing.assert_equal(result.numpy(), np.array([], dtype=np.int32))
  602. @pytest.mark.skipif(
  603. get_device_count("gpu") > 0, reason="cuda does not support nchw int8"
  604. )
  605. def test_conv_bias():
  606. inp_scale = 1.5
  607. w_scale = 2.5
  608. outp_scale = 1.5
  609. inp_dtype = dtype.qint8(inp_scale)
  610. w_dtype = dtype.qint8(w_scale)
  611. b_dtype = dtype.qint32(inp_scale * w_scale)
  612. out_dtype = dtype.qint8(outp_scale)
  613. def run(
  614. N,
  615. IC,
  616. OC,
  617. IH,
  618. IW,
  619. KH,
  620. KW,
  621. PH,
  622. PW,
  623. SH,
  624. SW,
  625. has_bias=True,
  626. nonlinear_mode="identity",
  627. ):
  628. inp_v = np.random.normal(size=(N, IC, IH, IW))
  629. w_v = np.random.normal(size=(OC, IC, KH, KW))
  630. b_v = np.random.normal(size=(1, OC, 1, 1))
  631. inp_scale = dtype.get_scale(inp_dtype)
  632. w_scale = dtype.get_scale(w_dtype)
  633. b_scale = dtype.get_scale(b_dtype)
  634. inpv = dtype.convert_to_qint8(inp_v * inp_scale, inp_dtype)
  635. wv = dtype.convert_to_qint8(w_v * w_scale, w_dtype)
  636. bv = dtype.convert_to_qint32(b_v * b_scale, b_dtype)
  637. inp_int8 = tensor(inpv, dtype=inp_dtype)
  638. w_int8 = Parameter(wv, dtype=w_dtype)
  639. b_int32 = Parameter(bv, dtype=b_dtype)
  640. inp_fp32 = inp_int8.astype("float32")
  641. w_fp32 = w_int8.astype("float32")
  642. b_fp32 = b_int32.astype("float32")
  643. def convert_to_nchw4(var):
  644. var = F.reshape(
  645. var, (var.shape[0], var.shape[1] // 4, 4, var.shape[2], var.shape[3])
  646. )
  647. var = F.transpose(var, (0, 1, 3, 4, 2))
  648. return var
  649. def run_conv2d(inp, w, b):
  650. O = F.conv2d(
  651. inp, w, b if has_bias else None, stride=(SH, SW), padding=(PH, PW),
  652. )
  653. if nonlinear_mode == "relu":
  654. return F.relu(O)
  655. else:
  656. return O
  657. def run_conv_bias(inp, w, b, format="NCHW"):
  658. b = b if has_bias else Parameter(np.zeros_like(b.numpy()))
  659. if format == "NCHW4":
  660. inp = convert_to_nchw4(inp)
  661. w = convert_to_nchw4(w)
  662. b = convert_to_nchw4(b)
  663. return F.quantized.conv_bias_activation(
  664. inp,
  665. w,
  666. b,
  667. stride=(SH, SW),
  668. padding=(PH, PW),
  669. dtype=out_dtype,
  670. nonlinear_mode=nonlinear_mode,
  671. )
  672. format = "NCHW4" if is_cuda_available() else "NCHW"
  673. expected = run_conv2d(inp_fp32, w_fp32, b_fp32)
  674. expected = expected.astype(out_dtype).astype("float32")
  675. result = run_conv_bias(inp_int8, w_int8, b_int32, format=format).astype(
  676. "float32"
  677. )
  678. if format == "NCHW4":
  679. result = F.transpose(result, (0, 1, 4, 2, 3))
  680. expected = F.flatten(expected)
  681. result = F.flatten(result)
  682. np.testing.assert_allclose(result.numpy(), expected.numpy(), atol=outp_scale)
  683. run(1, 4, 4, 24, 33, 1, 1, 2, 3, 1, 1, False)
  684. run(10, 12, 24, 46, 46, 1, 1, 2, 1, 3, 1, False)
  685. run(10, 36, 8, 46, 26, 2, 2, 2, 1, 1, 2, False)
  686. run(1, 4, 4, 24, 33, 1, 1, 2, 3, 1, 1)
  687. run(10, 12, 24, 46, 46, 1, 1, 2, 1, 3, 1)
  688. run(10, 36, 8, 46, 26, 2, 2, 2, 1, 1, 2)
  689. run(10, 36, 8, 46, 26, 2, 2, 2, 1, 1, 2, False, "relu")
  690. run(10, 36, 8, 46, 26, 2, 2, 2, 1, 1, 2, True, "relu")
  691. @pytest.mark.skipif(get_device_count("gpu") > 0, reason="no int8 algorithm on cuda")
  692. def test_batch_conv_bias():
  693. inp_scale = 1.5
  694. w_scale = 2.5
  695. outp_scale = 1.5
  696. inp_dtype = dtype.qint8(inp_scale)
  697. w_dtype = dtype.qint8(w_scale)
  698. b_dtype = dtype.qint32(inp_scale * w_scale)
  699. out_dtype = dtype.qint8(outp_scale)
  700. def run(
  701. N, IC, OC, IH, IW, KH, KW, PH, PW, SH, SW, has_bias=True,
  702. ):
  703. inp_v = np.random.normal(size=(N, IC, IH, IW))
  704. w_v = np.random.normal(size=(N, OC, IC, KH, KW))
  705. b_v = np.random.normal(size=(1, OC, 1, 1))
  706. inp_scale = dtype.get_scale(inp_dtype)
  707. w_scale = dtype.get_scale(w_dtype)
  708. b_scale = dtype.get_scale(b_dtype)
  709. inpv = dtype.convert_to_qint8(inp_v * inp_scale, inp_dtype)
  710. wv = dtype.convert_to_qint8(w_v * w_scale, w_dtype)
  711. bv = dtype.convert_to_qint32(b_v * b_scale, b_dtype)
  712. inp_int8 = tensor(inpv, dtype=inp_dtype)
  713. w_int8 = Parameter(wv, dtype=w_dtype)
  714. b_int32 = Parameter(bv, dtype=b_dtype)
  715. inp_fp32 = inp_int8.astype("float32")
  716. w_fp32 = w_int8.astype("float32")
  717. b_fp32 = b_int32.astype("float32")
  718. def run_batch_conv_bias(inp, w, b):
  719. b = b if has_bias else Parameter(np.zeros_like(b.numpy()))
  720. result = F.quantized.batch_conv_bias_activation(
  721. inp, w, b, stride=(SH, SW), padding=(PH, PW), dtype=out_dtype,
  722. )
  723. return result.astype("float32")
  724. expected = F.conv2d(inp_fp32, w_fp32[0], b_fp32 if has_bias else None)[0]
  725. expected = expected.astype(out_dtype).astype("float32")
  726. expected = F.flatten(expected)
  727. result = run_batch_conv_bias(inp_int8, w_int8, b_int32)
  728. result = F.flatten(result)
  729. np.testing.assert_allclose(result.numpy(), expected.numpy(), atol=outp_scale)
  730. run(1, 4, 4, 5, 5, 3, 3, 0, 0, 1, 1, True)
  731. def test_conv2d_autocast():
  732. """check amp's result is equal to manually converted result"""
  733. amp.enabled = True
  734. inp = tensor(np.random.randn(1, 3, 224, 224), dtype=np.float32)
  735. weight = tensor(np.random.randn(64, 3, 7, 7), dtype=np.float32)
  736. out = F.conv2d(inp, weight, None, (2, 2), (3, 3), (1, 1), 1)
  737. amp.enabled = False
  738. expected = F.conv2d(
  739. inp.astype("float16"),
  740. weight.astype("float16"),
  741. None,
  742. (2, 2),
  743. (3, 3),
  744. (1, 1),
  745. 1,
  746. compute_mode="float32",
  747. )
  748. assert out.dtype == np.float16
  749. assert expected.dtype == np.float16
  750. np.testing.assert_allclose(out.numpy(), expected.numpy())
  751. def test_conv2d_zero_stride_numpy_array():
  752. inp = np.random.randn(3, 224, 224).astype(np.float32)
  753. inp = inp[np.newaxis, :]
  754. inp = tensor(inp, dtype=np.float32)
  755. weight = tensor(np.random.randn(16, 3, 3, 3), dtype=np.float32)
  756. out = F.conv2d(inp, weight, None, (2, 2), (3, 3), (1, 1), 1)
  757. def test_conv3d_zero_stride_numpy_array():
  758. inp = np.random.randn(3, 224, 224, 224).astype(np.float32)
  759. inp = inp[np.newaxis, :]
  760. inp = tensor(inp, dtype=np.float32)
  761. weight = tensor(np.random.randn(16, 3, 3, 3, 3), dtype=np.float32)
  762. out = F.conv3d(inp, weight, None, (2, 2, 2), (3, 3, 3), (1, 1, 1), 1)
  763. out.numpy()
  764. @pytest.mark.parametrize("bias", [True, False])
  765. def test_conv1d(bias):
  766. inp = tensor(np.ones((2, 2, 4), dtype=np.float32))
  767. weight = tensor(np.ones((3, 2, 2), dtype=np.float32))
  768. bias = tensor(np.ones((1, 3, 1), dtype=np.float32)) if bias else None
  769. out = F.conv1d(inp, weight, bias, 2, 0, 1, 1)
  770. np.testing.assert_equal(
  771. out.numpy(),
  772. np.array([[[5, 5], [5, 5], [5, 5]], [[5, 5], [5, 5], [5, 5]]], dtype=np.float32)
  773. if bias is not None
  774. else np.array(
  775. [[[4, 4], [4, 4], [4, 4]], [[4, 4], [4, 4], [4, 4]]], dtype=np.float32
  776. ),
  777. )
  778. def test_batchnorm2d_autocast():
  779. """check amp's result is equal to manually converted result"""
  780. amp.enabled = True
  781. tshape = (1, 3, 224, 224)
  782. pshape = (1, 3, 1, 1)
  783. inp = tensor(np.random.randn(*tshape), dtype=np.float32)
  784. weight = tensor(np.ones(pshape, dtype=np.float32))
  785. bias = tensor(np.zeros(pshape, dtype=np.float32))
  786. out = F.batch_norm(inp, weight=weight, bias=bias, training=True, inplace=False)
  787. amp.enabled = False
  788. expected = F.batch_norm(
  789. inp.astype("float16"), weight=weight, bias=bias, training=True, inplace=False,
  790. )
  791. assert out.dtype == np.float16
  792. assert expected.dtype == np.float16
  793. np.testing.assert_allclose(out.numpy(), expected.numpy())
  794. @pytest.mark.parametrize("bias", [True, False])
  795. def test_conv3d(bias):
  796. inp = tensor(np.ones((2, 2, 4, 4, 4), dtype=np.float32))
  797. weight = tensor(np.ones((3, 2, 2, 2, 2), dtype=np.float32))
  798. bias = tensor(np.ones((1, 3, 1, 1, 1), dtype=np.float32)) if bias else None
  799. out = F.conv3d(inp, weight, bias, 2, 0, 1, 1)
  800. target = np.ones((2, 3, 2, 2, 2), dtype=np.float32) * 16
  801. target = target + 1 if bias is not None else target
  802. np.testing.assert_equal(out.numpy(), target)
  803. def test_condtake():
  804. x = np.array([[1, 2, 3], [4, 5, 6]])
  805. y = np.array([[True, False, True], [False, True, True]])
  806. xx = tensor(x)
  807. yy = tensor(y)
  808. val, idx = F.cond_take(yy, xx)
  809. np.testing.assert_equal(val.numpy(), x[y])
  810. np.testing.assert_equal(idx.numpy(), np.where(y.reshape(-1))[0])
  811. @pytest.mark.parametrize("is_symbolic", [None, False, True])
  812. def test_condtake(is_symbolic):
  813. shapes = [
  814. (3, 3, 3),
  815. (0,),
  816. (3, 0, 3),
  817. ]
  818. def fn(mask, data):
  819. return F.cond_take(mask, data)
  820. if is_symbolic is not None:
  821. fn = jit.trace(symbolic=is_symbolic)(fn)
  822. for shp in shapes:
  823. x_np = np.random.randn(*shp).astype("float32")
  824. mask_np = x_np > 0
  825. x = tensor(x_np)
  826. mask = tensor(mask_np)
  827. ref_out = x_np[mask_np]
  828. ref_idx = mask_np.flatten().nonzero()[0]
  829. for i in range(3):
  830. out, idx = fn(mask, x)
  831. np.testing.assert_equal(out.numpy(), ref_out)
  832. np.testing.assert_equal(idx.numpy(), ref_idx)
  833. if is_symbolic is None:
  834. break
  835. def test_condtake_is_same():
  836. op1 = builtin.CondTake()
  837. op2 = builtin.CondTake()
  838. assert op1 == op2
  839. def test_nms_is_same():
  840. op1 = builtin.NMSKeep(0.7, 100)
  841. op2 = builtin.NMSKeep(0.7, 100)
  842. op3 = builtin.NMSKeep(0.8, 100)
  843. op4 = builtin.NMSKeep(0.7, 200)
  844. assert op1 == op2
  845. assert op1 != op3
  846. assert op1 != op4
  847. assert op3 != op4
  848. def test_argmxx_on_inf():
  849. def run_argmax():
  850. x = F.zeros((100, 100))
  851. x[:] = -float("inf")
  852. idxs = F.argmax(x, axis=0)
  853. return idxs
  854. def run_argmin():
  855. x = F.zeros((100, 100))
  856. x[:] = float("inf")
  857. idxs = F.argmin(x, axis=0)
  858. return idxs
  859. assert all(run_argmax() >= 0)
  860. assert all(run_argmin() >= 0)
  861. def test_deformable_psroi_pooling():
  862. inp = np.random.random((1, 256, 64, 64)).astype("float32")
  863. rois = np.random.random((1, 5)).astype("float32")
  864. trans = np.random.random((24, 2, 7, 7)).astype("float32")
  865. pooled_h = 7
  866. pooled_w = 7
  867. sample_per_part = 4
  868. no_trans = False
  869. part_size = 7
  870. spatial_scale = 1.0 / 64
  871. trans_std = 0.1
  872. y = F.deformable_psroi_pooling(
  873. tensor(inp),
  874. tensor(rois),
  875. tensor(trans),
  876. no_trans,
  877. part_size,
  878. pooled_h,
  879. pooled_w,
  880. sample_per_part,
  881. spatial_scale,
  882. trans_std,
  883. )
  884. def test_cvt_color():
  885. def rgb2gray(rgb):
  886. return np.dot(rgb[..., :3], [0.299, 0.587, 0.114])
  887. def bgr2gray(bgr):
  888. return np.dot(bgr[..., :3], [0.114, 0.587, 0.299])
  889. inp = np.random.randn(3, 3, 3, 3).astype(np.float32)
  890. out = np.expand_dims(rgb2gray(inp), 3).astype(np.float32)
  891. x = tensor(inp)
  892. y = F.vision.cvt_color(x, mode="RGB2GRAY")
  893. np.testing.assert_allclose(y.numpy(), out, atol=1e-5)
  894. out1 = np.expand_dims(bgr2gray(inp), 3).astype(np.float32)
  895. y1 = F.vision.cvt_color(x, mode="BGR2GRAY")
  896. np.testing.assert_allclose(y1.numpy(), out1, atol=1e-5)
  897. @pytest.mark.parametrize("val", [2, [2,], [2, 3]])
  898. def test_ones(val):
  899. shp = tensor(val)
  900. np_shp = np.array(val)
  901. np.testing.assert_equal(F.ones(shp), np.ones(np_shp))
  902. def test_assert_equal():
  903. shape = (2, 3, 4, 5)
  904. x = F.ones(shape, dtype=np.float32)
  905. y = F.zeros(shape, dtype=np.float32) + 1.00001
  906. z = F.utils._assert_equal(x, y)
  907. def test_assert_not_equal():
  908. shape = (2, 3, 4, 5)
  909. x = F.ones(shape, dtype=np.float32)
  910. y = F.zeros(shape, dtype=np.float32) + 1.1
  911. with pytest.raises(RuntimeError):
  912. z = F.utils._assert_equal(x, y)
  913. def test_neg_axis():
  914. x = tensor(np.random.normal(0, 1, (32, 5)))
  915. y = F.argmax(x, axis=-1)
  916. yy = F.argmax(x, axis=1)
  917. np.testing.assert_equal(y.numpy(), yy.numpy())
  918. y = F.argmax(x, axis=(-1, -2))
  919. yy = F.argmax(x, axis=(0, 1))
  920. np.testing.assert_equal(y.numpy(), yy.numpy())
  921. y = F.argmin(x, axis=(-1, -2))
  922. yy = F.argmin(x, axis=(0, 1))
  923. np.testing.assert_equal(y.numpy(), yy.numpy())
  924. def test_sliding_window():
  925. N, C, H, W = 2, 3, 7, 8
  926. inp = np.random.normal(size=(N, C, H, W))
  927. ph, pw = 1, 2
  928. sh, sw = 2, 1
  929. wh, ww = 3, 2
  930. dh, dw = 1, 3
  931. s = lambda i, p, s, d, w: (i + p * 2 - (w - 1) * d - 1) // s + 1
  932. inp_pad = np.zeros((N, C, H + ph * 2, W + pw * 2))
  933. inp_pad[:, :, ph : H + ph, pw : W + pw] = inp
  934. gt_out = np.empty(
  935. (N, C, s(H, ph, sh, dh, wh), s(W, pw, sw, dw, ww), wh, ww), dtype=np.float32
  936. )
  937. for n, c, oh, ow in itertools.product(*map(range, gt_out.shape[:4])):
  938. ih, iw = oh * sh, ow * sw
  939. gt_out[n, c, oh, ow, :] = inp_pad[
  940. n, c, ih : ih + (wh - 1) * dh + 1 : dh, iw : iw + (ww - 1) * dw + 1 : dw
  941. ]
  942. out = F.sliding_window(
  943. tensor(inp), (wh, ww), padding=(ph, pw), stride=(sh, sw), dilation=(dh, dw)
  944. )
  945. np.testing.assert_equal(gt_out, out.numpy())
  946. def test_sliding_window_transpose():
  947. N, C, H, W = 2, 3, 7, 8
  948. ph, pw = 1, 2
  949. sh, sw = 2, 1
  950. wh, ww = 3, 2
  951. dh, dw = 1, 3
  952. s = lambda i, p, s, d, w: (i + p * 2 - (w - 1) * d - 1) // s + 1
  953. inp = np.random.normal(
  954. size=(N, C, s(H, ph, sh, dh, wh), s(W, pw, sw, dw, ww), wh, ww)
  955. ).astype(np.float32)
  956. gt_out = np.zeros((N, C, H, W), dtype=np.float32)
  957. for n, c in itertools.product(*map(range, inp.shape[:2])):
  958. oh = 0
  959. for ih in range(-ph, H + ph - dh * (wh - 1), sh):
  960. ow = 0
  961. for iw in range(-pw, W + pw - dw * (ww - 1), sw):
  962. for kh, kw in itertools.product(*map(range, inp.shape[-2:])):
  963. ih2 = ih + dh * kh
  964. iw2 = iw + dw * kw
  965. if ih2 >= 0 and ih2 < H and iw2 >= 0 and iw2 < W:
  966. gt_out[n, c, ih2, iw2] += inp[n, c, oh, ow, kh, kw]
  967. ow += 1
  968. oh += 1
  969. out = F.sliding_window_transpose(
  970. tensor(inp),
  971. (H, W),
  972. (wh, ww),
  973. padding=(ph, pw),
  974. stride=(sh, sw),
  975. dilation=(dh, dw),
  976. )
  977. np.testing.assert_equal(gt_out, out.numpy())
  978. def test_pad():
  979. src = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]], dtype=np.float32)
  980. dst = np.pad(src, ((2, 2), (2, 2)), "constant")
  981. res = F.nn.pad(tensor(src), ((2, 2), (2, 2)), "CONSTANT")
  982. np.testing.assert_allclose(res, dst, atol=1e-5)
  983. dst = np.pad(src, ((2, 2), (2, 2)), "constant", constant_values=3)
  984. res = F.nn.pad(tensor(src), ((2, 2), (2, 2)), "CONSTANT", constant_value=3)
  985. np.testing.assert_allclose(res, dst, atol=1e-5)
  986. dst = np.pad(src, ((2, 2), (2, 2)), "edge")
  987. res = F.nn.pad(tensor(src), ((2, 2), (2, 2)), "EDGE")
  988. np.testing.assert_allclose(res, dst, atol=1e-5)
  989. dst = np.pad(src, ((2, 2), (2, 2)), "reflect")
  990. res = F.nn.pad(tensor(src), ((2, 2), (2, 2)), "REFLECT")
  991. np.testing.assert_allclose(res, dst, atol=1e-5)
  992. def pixel_shuffle(data, r):
  993. high_dim = data.shape[:-3]
  994. data = data.reshape(-1, data.shape[-3], data.shape[-2], data.shape[-1])
  995. inn, ic, ih, iw = data.shape
  996. res = np.zeros((inn, int(ic / (r * r)), ih * r, iw * r))
  997. for n in range(inn):
  998. for c in range(ic):
  999. for h in range(ih):
  1000. for w in range(iw):
  1001. res[
  1002. n,
  1003. int(c / r / r),
  1004. h * r + int((c % (r * r)) / r),
  1005. w * r + c % r,
  1006. ] = data[n, c, h, w]
  1007. if len(high_dim) > 0:
  1008. res = res.reshape((*high_dim, int(ic / r / r), ih * r, iw * r))
  1009. else:
  1010. res = res[0]
  1011. return res
  1012. def test_pixel_shuffle():
  1013. # ndim = 3
  1014. inp = np.arange(16 * 3 * 3).reshape(16, 3, 3)
  1015. out = F.pixel_shuffle(tensor(inp), upscale_factor=4)
  1016. golden = pixel_shuffle(inp, 4)
  1017. np.testing.assert_equal(out.numpy(), golden)
  1018. inp_float = np.float32(inp)
  1019. out = F.pixel_shuffle(tensor(inp_float), upscale_factor=2)
  1020. golden = pixel_shuffle(inp_float, 2)
  1021. np.testing.assert_equal(out.numpy(), golden)
  1022. # ndim = 4
  1023. inp = np.arange(3 * 18 * 3 * 3).reshape(3, 18, 3, 3)
  1024. out = F.pixel_shuffle(tensor(inp), upscale_factor=3)
  1025. golden = pixel_shuffle(inp, 3)
  1026. np.testing.assert_equal(out.numpy(), golden)
  1027. inp_float = np.float32(inp)
  1028. out = F.pixel_shuffle(tensor(inp_float), upscale_factor=3)
  1029. golden = pixel_shuffle(inp_float, 3)
  1030. np.testing.assert_equal(out.numpy(), golden)
  1031. # ndim = 5
  1032. inp = np.arange(5 * 3 * 20 * 3 * 4).reshape(5, 3, 20, 3, 4)
  1033. out = F.pixel_shuffle(tensor(inp), upscale_factor=2)
  1034. golden = pixel_shuffle(inp, 2)
  1035. np.testing.assert_equal(out.numpy(), golden)
  1036. inp_float = np.float32(inp)
  1037. out = F.pixel_shuffle(tensor(inp_float), upscale_factor=2)
  1038. golden = pixel_shuffle(inp_float, 2)
  1039. np.testing.assert_equal(out.numpy(), golden)
  1040. # ndim = 6
  1041. inp = np.arange(6 * 5 * 3 * 25 * 3 * 4).reshape(6, 5, 3, 25, 3, 4)
  1042. out = F.pixel_shuffle(tensor(inp), upscale_factor=5)
  1043. golden = pixel_shuffle(inp, 5)
  1044. np.testing.assert_equal(out.numpy(), golden)
  1045. inp_float = np.float32(inp)
  1046. out = F.pixel_shuffle(tensor(inp_float), upscale_factor=5)
  1047. golden = pixel_shuffle(inp_float, 5)
  1048. np.testing.assert_equal(out.numpy(), golden)
  1049. # ndim = 7
  1050. inp = np.arange(2 * 3 * 5 * 3 * 20 * 3 * 4).reshape(2, 3, 5, 3, 20, 3, 4)
  1051. out = F.pixel_shuffle(tensor(inp), upscale_factor=2)
  1052. golden = pixel_shuffle(inp, 2)
  1053. np.testing.assert_equal(out.numpy(), golden)
  1054. inp_float = np.float32(inp)
  1055. out = F.pixel_shuffle(tensor(inp_float), upscale_factor=2)
  1056. golden = pixel_shuffle(inp_float, 2)
  1057. np.testing.assert_equal(out.numpy(), golden)
  1058. @pytest.mark.parametrize("type", ["int32", "float32"])
  1059. @pytest.mark.parametrize("is_symbolic", [False, True])
  1060. def test_pixel_shuffle_symbolic(is_symbolic, type):
  1061. def fn(inp, upscale_factor):
  1062. return F.pixel_shuffle(inp, upscale_factor=upscale_factor)
  1063. if is_symbolic is not None:
  1064. fn = jit.trace(symbolic=is_symbolic)(fn)
  1065. inp = tensor(np.arange(3 * 4 * 5 * 5).reshape(3, 4, 5, 5).astype(type))
  1066. golden = pixel_shuffle(inp, 2)
  1067. for _ in range(3):
  1068. out = fn(inp, 2)
  1069. np.testing.assert_equal(out.numpy(), golden)
  1070. if is_symbolic is None:
  1071. break
  1072. def test_set_conv2d_config():
  1073. """check setting config by contextmanager is equal to manually converted result"""
  1074. config._compute_mode = "float32"
  1075. inp = tensor(np.random.randn(1, 3, 224, 224), dtype=np.float16)
  1076. weight = tensor(np.random.randn(64, 3, 7, 7), dtype=np.float16)
  1077. config_out = F.conv2d(inp, weight, None, (2, 2), (3, 3), (1, 1), 1)
  1078. config._compute_mode = "default"
  1079. with config._override(compute_mode="float32"):
  1080. context_out = F.conv2d(inp, weight, None, (2, 2), (3, 3), (1, 1), 1)
  1081. expected = F.conv2d(
  1082. inp, weight, None, (2, 2), (3, 3), (1, 1), 1, compute_mode="float32",
  1083. )
  1084. np.testing.assert_allclose(config_out.numpy(), expected.numpy())
  1085. np.testing.assert_allclose(context_out.numpy(), expected.numpy())
  1086. @pytest.mark.parametrize("stride", [(1, 1)])
  1087. @pytest.mark.parametrize("padding", [(1, 1)])
  1088. @pytest.mark.parametrize("dilation", [(1, 1)])
  1089. @pytest.mark.parametrize("ksize", [(3, 3)])
  1090. @pytest.mark.parametrize("groups", [1, 2])
  1091. def test_local_conv2d(stride, padding, dilation, ksize, groups):
  1092. batch_size, in_channels, out_channels = 2, 4, 8
  1093. input_height, input_width = 10, 10
  1094. output_height = (input_height + padding[0] * 2 - ksize[0]) // stride[0] + 1
  1095. output_width = (input_width + padding[1] * 2 - ksize[1]) // stride[1] + 1
  1096. def local_conv2d_np(data, weight, stride, padding, dialtion):
  1097. # naive calculation use numpy
  1098. # only test output_height == input_height, output_width == input_width
  1099. data = np.pad(data, ((0, 0), (0, 0), (1, 1), (1, 1)))
  1100. expected = np.zeros(
  1101. (batch_size, out_channels, output_height, output_width), dtype=np.float32,
  1102. )
  1103. ic_group_size = in_channels // groups
  1104. oc_group_size = out_channels // groups
  1105. for n, oc, oh, ow in itertools.product(
  1106. *map(range, [batch_size, out_channels, output_height, output_width])
  1107. ):
  1108. ih, iw = oh * stride[0], ow * stride[1]
  1109. g_id = oc // oc_group_size
  1110. expected[n, oc, ih, iw] = np.sum(
  1111. data[
  1112. n,
  1113. g_id * ic_group_size : (g_id + 1) * ic_group_size,
  1114. ih : ih + ksize[0],
  1115. iw : iw + ksize[1],
  1116. ]
  1117. * weight[g_id, oh, ow, :, :, :, oc % oc_group_size]
  1118. )
  1119. return expected
  1120. data = np.random.rand(batch_size, in_channels, input_height, input_width).astype(
  1121. "float32"
  1122. )
  1123. weight = np.random.rand(
  1124. groups,
  1125. output_height,
  1126. output_width,
  1127. in_channels // groups,
  1128. *ksize,
  1129. out_channels // groups,
  1130. ).astype("float32")
  1131. output = F.local_conv2d(
  1132. tensor(data),
  1133. tensor(weight),
  1134. None,
  1135. stride=stride,
  1136. padding=padding,
  1137. dilation=dilation,
  1138. )
  1139. ref = local_conv2d_np(data, weight, stride, padding, dilation)
  1140. np.testing.assert_almost_equal(output.numpy(), ref, 5)
  1141. def test_conv_transpose2d():
  1142. m = ConvTranspose2d(
  1143. 16, 33, (3, 5), output_padding=(1, 2), stride=(2, 3), padding=(4, 2)
  1144. )
  1145. @trace(symbolic=True)
  1146. def fwd(inp: Tensor):
  1147. return m(inp)
  1148. input = Tensor(np.random.rand(20, 16, 50, 100))
  1149. output = fwd(input)
  1150. output_shape = Tensor(output.shape)
  1151. np.testing.assert_equal(
  1152. output_shape.numpy(), np.array([20, 33, 94, 300], dtype=np.int32)
  1153. )
  1154. def test_conv_transpose3d():
  1155. m = ConvTranspose3d(
  1156. 16, 33, (3, 5, 2), output_padding=(2, 1, 1), stride=(3, 2, 2), padding=(0, 4, 2)
  1157. )
  1158. @trace(symbolic=True)
  1159. def fwd(inp: Tensor):
  1160. return m(inp)
  1161. input = Tensor(np.random.rand(20, 16, 10, 50, 100))
  1162. output = fwd(input)
  1163. output_shape = Tensor(output.shape)
  1164. np.testing.assert_equal(
  1165. output_shape.numpy(), np.array([20, 33, 32, 96, 197], dtype=np.int32)
  1166. )