You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_functional.py 55 kB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553155415551556155715581559156015611562156315641565156615671568156915701571157215731574157515761577157815791580158115821583158415851586158715881589159015911592159315941595159615971598159916001601160216031604160516061607160816091610161116121613161416151616161716181619162016211622162316241625162616271628162916301631163216331634163516361637163816391640164116421643164416451646164716481649165016511652165316541655165616571658165916601661166216631664166516661667166816691670
  1. # -*- coding: utf-8 -*-
  2. import itertools
  3. import platform
  4. from functools import partial
  5. import numpy as np
  6. import pytest
  7. from utils import opr_test
  8. import megengine.amp as amp
  9. import megengine.config as config
  10. import megengine.core.ops.builtin as builtin
  11. import megengine.core.tensor.dtype as dtype
  12. import megengine.functional as F
  13. import megengine.jit as jit
  14. from megengine import Parameter, Tensor, is_cuda_available, tensor
  15. from megengine.autodiff import GradManager
  16. from megengine.core._trace_option import use_symbolic_shape
  17. from megengine.core.autodiff.grad import Grad
  18. from megengine.core.tensor.utils import make_shape_tuple
  19. from megengine.device import get_device_count
  20. from megengine.jit.tracing import trace
  21. from megengine.module import ConvTranspose2d, ConvTranspose3d, LayerNorm
  22. _assert_allclose = partial(np.testing.assert_allclose, atol=5e-6, rtol=5e-6)
  23. def test_where():
  24. maskv0 = np.array([[1, 0], [0, 1]], dtype=np.bool_)
  25. xv0 = np.array([[1, np.inf], [np.nan, 4]], dtype=np.float32)
  26. yv0 = np.array([[5, 6], [7, 8]], dtype=np.float32)
  27. maskv1 = np.array([[1, 0, 1], [1, 0, 0], [1, 1, 0]], dtype=np.bool_)
  28. xv1 = np.array([[1, np.inf, 2], [0, np.nan, 4], [1, 5, 7]], dtype=np.float32)
  29. yv1 = np.array([[5, 6, 9], [2, 7, 8], [2, 1, 9]], dtype=np.float32)
  30. maskv2 = np.array([1, 1, 1], dtype=np.bool_)
  31. xv2 = np.array([1, 3, 2], dtype=np.float32)
  32. yv2 = np.array([5, 6, 9], dtype=np.float32)
  33. maskv3 = np.array([0, 0, 0], dtype=np.bool_)
  34. xv3 = np.array([1, 3, 2], dtype=np.float32)
  35. yv3 = np.array([5, 6, 9], dtype=np.float32)
  36. maskv4 = np.array(1, dtype=np.bool_)
  37. xv4 = np.array(1, dtype=np.float32)
  38. yv4 = np.array(0, dtype=np.float32)
  39. cases = [
  40. {"input": [maskv0, xv0, yv0]},
  41. {"input": [maskv1, xv1, yv1]},
  42. {"input": [maskv2, xv2, yv2]},
  43. {"input": [maskv3, xv3, yv3]},
  44. {"input": [maskv4, xv4, yv4]},
  45. ]
  46. opr_test(cases, F.where, ref_fn=np.where, test_trace=True)
  47. def test_dropout():
  48. from megengine.autodiff import GradManager
  49. from megengine.core._imperative_rt.ops import set_global_rng_seed
  50. def test_dropout_with_shape(shape, rate):
  51. data = tensor(np.ones(shape, dtype=np.float32))
  52. gm = GradManager().attach([data])
  53. with gm:
  54. out = F.nn.dropout(data, rate, training=True)
  55. gm.backward(out, tensor(np.ones(shape, dtype=np.float32)))
  56. if len(shape) != 0:
  57. assert not out.numpy().all()
  58. np.testing.assert_allclose(out.numpy(), data.grad.numpy(), 1e-7, 1e-7)
  59. def test_multiple_dropout(shape, rate):
  60. data = tensor(np.ones(shape, dtype=np.float32))
  61. gm = GradManager().attach([data])
  62. with gm:
  63. out1 = F.nn.dropout(data, rate, training=True)
  64. out2 = F.nn.dropout(out1, rate, training=True)
  65. out3 = F.nn.dropout(out2, rate, training=True)
  66. gm.backward(out3, tensor(np.ones(shape, dtype=np.float32)))
  67. np.testing.assert_allclose(out3.numpy(), data.grad.numpy(), 1e-7, 1e-7)
  68. def test_dropout_seed(shape, rate):
  69. data = tensor(np.random.randn(*shape), dtype="float32")
  70. set_global_rng_seed(111)
  71. out1 = F.nn.dropout(data, rate, training=True)
  72. out2 = F.nn.dropout(data, rate, training=True)
  73. assert not (out1.numpy() == out2.numpy()).all()
  74. set_global_rng_seed(111)
  75. out3 = F.nn.dropout(data, rate, training=True)
  76. assert (out1.numpy() == out3.numpy()).all()
  77. set_global_rng_seed(222)
  78. out4 = F.nn.dropout(data, rate, training=True)
  79. assert not (out1.numpy() == out4.numpy()).all()
  80. test_dropout_with_shape([], 0.4)
  81. test_dropout_with_shape([13, 17, 63, 21], 0.4)
  82. test_dropout_with_shape([16, 32, 64], 0.3)
  83. test_multiple_dropout([1024], 0.2)
  84. test_dropout_seed([16, 32], 0.2)
  85. def test_matinv():
  86. shape1 = (5, 5)
  87. shape2 = (3, 9, 9)
  88. data1 = np.random.random(shape1).astype("float32")
  89. data2 = np.random.random(shape2).astype("float32")
  90. # make matrix diagonally dominant for numerical stability
  91. data1 += (np.eye(shape1[0]) * shape1[0]).astype("float32")
  92. data2 += np.broadcast_to((np.eye(shape2[1]) * shape2[1]).astype("float32"), shape2)
  93. cases = [
  94. {"input": data1},
  95. {"input": data2},
  96. ]
  97. opr_test(
  98. cases,
  99. F.matinv,
  100. compare_fn=lambda x, y: np.testing.assert_allclose(x.numpy(), y, rtol=1e-4),
  101. ref_fn=np.linalg.inv,
  102. )
  103. def test_matmul():
  104. shape1 = 3
  105. shape2 = 3
  106. shape3 = (3, 5)
  107. shape4 = (5, 6)
  108. data1 = np.random.random(shape1).astype("float32")
  109. data2 = np.random.random(shape2).astype("float32")
  110. data3 = np.random.random(shape3).astype("float32")
  111. data4 = np.random.random(shape4).astype("float32")
  112. cases = [
  113. {"input": [data1, data2]},
  114. {"input": [data2, data3]},
  115. {"input": [data3, data4]},
  116. ]
  117. opr_test(cases, F.matmul, ref_fn=np.matmul)
  118. batch_size = 10
  119. shape1 = (2,)
  120. shape2 = (batch_size, 2, 3)
  121. shape3 = (batch_size, 3, 4)
  122. shape4 = (batch_size, 10, 4, 2)
  123. shape5 = (batch_size, 10, 2, 4)
  124. data1 = np.random.random(shape1).astype("float32")
  125. data2 = np.random.random(shape2).astype("float32")
  126. data3 = np.random.random(shape3).astype("float32")
  127. data4 = np.random.random(shape4).astype("float32")
  128. data5 = np.random.random(shape5).astype("float32")
  129. cases = [
  130. {"input": [data1, data2]},
  131. {"input": [data2, data3]},
  132. {"input": [data3, data4]},
  133. {"input": [data4, data5]},
  134. ]
  135. opr_test(cases, F.matmul, ref_fn=np.matmul)
  136. opr_test(
  137. [{"input": [data1, data4]}],
  138. F.matmul,
  139. ref_fn=lambda x, y: np.matmul(x, y.transpose(0, 1, 3, 2)),
  140. transpose_b=True,
  141. )
  142. opr_test(
  143. [{"input": [data3, data2]}],
  144. F.matmul,
  145. ref_fn=lambda x, y: np.matmul(x.transpose(0, 2, 1), y.transpose(0, 2, 1)),
  146. transpose_a=True,
  147. transpose_b=True,
  148. )
  149. @pytest.mark.parametrize(
  150. "shape_a, shape_b", [((0,), (0,)), ((10, 0), (0, 10)), ((3, 10, 0), (3, 0, 10)),],
  151. )
  152. @pytest.mark.parametrize("is_symbolic", [None, True, False])
  153. def test_matmul_empty_tensor(shape_a, shape_b, is_symbolic):
  154. def func(a, b):
  155. return F.matmul(a, b)
  156. if is_symbolic is not None:
  157. func = jit.trace(symbolic=is_symbolic)(func)
  158. a = tensor(np.random.randn(*shape_a))
  159. b = tensor(np.random.randn(*shape_b))
  160. for _ in range(3):
  161. out = func(a, b)
  162. assert np.all(out.numpy() == 0)
  163. if is_symbolic is None:
  164. break
  165. def test_interpolate():
  166. def linear_interpolate():
  167. inp = tensor(np.arange(1, 3, dtype=np.float32).reshape(1, 1, 2))
  168. test_func = lambda inp: F.vision.interpolate(
  169. inp, scale_factor=2.0, mode="linear"
  170. )
  171. ref_func = lambda inp: F.vision.interpolate(inp, 4, mode="linear").numpy()
  172. cases = [{"input": inp}]
  173. opr_test(cases, test_func, ref_fn=ref_func, test_trace=True)
  174. def many_batch_interpolate():
  175. inp = tensor(np.arange(1, 9, dtype=np.float32).reshape(2, 1, 2, 2))
  176. test_func = lambda inp: F.vision.interpolate(inp, scale_factor=2.0)
  177. ref_func = lambda inp: F.vision.interpolate(inp, [4, 4]).numpy()
  178. cases = [{"input": inp}]
  179. opr_test(cases, test_func, ref_fn=ref_func, test_trace=True)
  180. def assign_corner_interpolate():
  181. inp = tensor(np.arange(1, 5, dtype=np.float32).reshape(1, 1, 2, 2))
  182. test_func = lambda inp: F.vision.interpolate(inp, [4, 4])
  183. ref_func = lambda inp: F.vision.interpolate(inp, scale_factor=2.0).numpy()
  184. cases = [{"input": inp}]
  185. opr_test(cases, test_func, ref_fn=ref_func, test_trace=True)
  186. def error_shape_linear_interpolate():
  187. inp = tensor(np.arange(1, 5, dtype=np.float32).reshape(1, 1, 2, 2))
  188. with pytest.raises(ValueError):
  189. F.vision.interpolate(inp, scale_factor=2.0, mode="linear")
  190. def inappropriate_scale_linear_interpolate():
  191. inp = tensor(np.arange(1, 3, dtype=np.float32).reshape(1, 1, 2))
  192. with pytest.raises(ValueError):
  193. F.vision.interpolate(inp, scale_factor=[2.0, 3.0], mode="linear")
  194. linear_interpolate()
  195. many_batch_interpolate()
  196. assign_corner_interpolate()
  197. error_shape_linear_interpolate()
  198. # inappropriate_scale_linear_interpolate()
  199. def _save_to(self, name="grad"):
  200. def callback(grad):
  201. setattr(self, name, grad)
  202. return callback
  203. def _gen_roi_inp():
  204. inp_feat = np.random.randn(2, 32, 256, 256)
  205. rois = np.zeros((4, 5))
  206. rois[:, 0] = [0, 0, 1, 1]
  207. rois[:, 1:3] = np.random.rand(4, 2) * 100
  208. rois[:, 3:] = np.random.rand(4, 2) * 100 + 150
  209. inp_feat = tensor(inp_feat)
  210. rois = tensor(rois)
  211. return inp_feat, rois
  212. def test_roi_align():
  213. inp_feat, rois = _gen_roi_inp()
  214. with Grad() as grad:
  215. grad.wrt(inp_feat, callback=_save_to(inp_feat))
  216. output_shape = (7, 7)
  217. out_feat = F.vision.roi_align(
  218. inp_feat,
  219. rois,
  220. output_shape=output_shape,
  221. mode="average",
  222. spatial_scale=1.0 / 4,
  223. sample_points=2,
  224. aligned=True,
  225. )
  226. assert make_shape_tuple(out_feat.shape) == (
  227. rois.shape[0],
  228. inp_feat.shape[1],
  229. *output_shape,
  230. )
  231. grad(out_feat, tensor(F.ones_like(out_feat)))
  232. assert make_shape_tuple(inp_feat.grad.shape) == make_shape_tuple(inp_feat.shape)
  233. @pytest.mark.parametrize("shapes", [((2, 0, 26, 26), (4, 5)), ((2, 3, 26, 26), (0, 5))])
  234. @pytest.mark.parametrize("is_tracing", [False, True])
  235. def test_roi_align_empty(shapes, is_tracing):
  236. inp_feat = tensor(np.random.randn(*(shapes[0])))
  237. rois = tensor(np.random.random(shapes[1]))
  238. output_shape = (7, 7)
  239. def func(inp, rois):
  240. out_feat = F.vision.roi_align(
  241. inp_feat,
  242. rois,
  243. output_shape=output_shape,
  244. mode="average",
  245. spatial_scale=1.0 / 4,
  246. sample_points=2,
  247. aligned=True,
  248. )
  249. return out_feat
  250. if is_tracing:
  251. func = jit.trace(func)
  252. for _ in range(3):
  253. out_feat = func(inp_feat, rois)
  254. assert make_shape_tuple(out_feat.shape) == (
  255. rois.shape[0],
  256. inp_feat.shape[1],
  257. *output_shape,
  258. )
  259. def _gen_correlation(random=True, constant=1, image_shape=(2, 1, 160, 160)):
  260. if random:
  261. inp_feat1 = np.random.randn(
  262. image_shape[0], image_shape[1], image_shape[2], image_shape[3]
  263. )
  264. inp_feat2 = np.random.randn(
  265. image_shape[0], image_shape[1], image_shape[2], image_shape[3]
  266. )
  267. else:
  268. inp_feat1 = np.ones(image_shape) * constant
  269. inp_feat2 = np.ones(image_shape) * constant
  270. return tensor(inp_feat1), tensor(inp_feat2)
  271. def test_correlation():
  272. ##test case 0 check the grad shape
  273. data1, data2 = _gen_correlation()
  274. with Grad() as grad:
  275. grad.wrt(data1, callback=_save_to(data1))
  276. out_feat = F.vision.correlation(
  277. data1,
  278. data2,
  279. kernel_size=5,
  280. max_displacement=4,
  281. stride1=2,
  282. stride2=2,
  283. pad_size=2,
  284. is_multiply=True,
  285. )
  286. grad(out_feat, tensor(F.ones_like(out_feat)))
  287. assert make_shape_tuple(data1.grad.shape) == make_shape_tuple(data1.shape)
  288. ##test case 1 from https://github.com/NVIDIA/flownet2-pytorch/issues/194
  289. data1, data2 = _gen_correlation(random=False, image_shape=(1, 1, 3, 3))
  290. out_feat = F.vision.correlation(
  291. data1,
  292. data2,
  293. kernel_size=3,
  294. max_displacement=0,
  295. stride1=1,
  296. stride2=1,
  297. pad_size=0,
  298. is_multiply=True,
  299. )
  300. assert abs(out_feat.sum() - 1) < 1e-9
  301. ##test case 2 check same image subduction
  302. data1, data2 = _gen_correlation(random=False, image_shape=(1, 1, 3, 3))
  303. out_feat = F.vision.correlation(
  304. data1,
  305. data2,
  306. kernel_size=3,
  307. max_displacement=0,
  308. stride1=1,
  309. stride2=1,
  310. pad_size=0,
  311. is_multiply=False,
  312. )
  313. assert out_feat.sum() < 1e-9
  314. ##test case 3 check same image subduction
  315. data1, data2 = _gen_correlation(random=False, image_shape=(1, 1, 3, 3))
  316. out_feat = F.vision.correlation(
  317. data1,
  318. data2,
  319. kernel_size=3,
  320. max_displacement=0,
  321. stride1=1,
  322. stride2=1,
  323. pad_size=0,
  324. is_multiply=False,
  325. )
  326. assert out_feat.sum() < 1e-9
  327. ##test case 4 check correlation
  328. data1, _ = _gen_correlation(
  329. random=False, image_shape=(1, 1, 220, 220), constant=2.0
  330. )
  331. _, data2 = _gen_correlation(
  332. random=False, image_shape=(1, 1, 220, 220), constant=1.0
  333. )
  334. out_feat = F.vision.correlation(
  335. data1,
  336. data2,
  337. kernel_size=3,
  338. max_displacement=2,
  339. stride1=1,
  340. stride2=2,
  341. pad_size=0,
  342. is_multiply=False,
  343. )
  344. assert abs(out_feat.mean() - 1) < 1e-9
  345. def test_roi_pooling():
  346. inp_feat, rois = _gen_roi_inp()
  347. with Grad() as grad:
  348. grad.wrt(inp_feat, callback=_save_to(inp_feat))
  349. output_shape = (7, 7)
  350. out_feat = F.vision.roi_pooling(
  351. inp_feat, rois, output_shape=output_shape, mode="max", scale=1.0 / 4,
  352. )
  353. assert make_shape_tuple(out_feat.shape) == (
  354. rois.shape[0],
  355. inp_feat.shape[1],
  356. *output_shape,
  357. )
  358. grad(out_feat, tensor(F.ones_like(out_feat)))
  359. assert make_shape_tuple(inp_feat.grad.shape) == make_shape_tuple(inp_feat.shape)
  360. def test_adaptive_avg_pool2d():
  361. inp = tensor(np.arange(0, 16, dtype=np.float32).reshape(1, 1, 4, 4))
  362. oshp = (2, 2)
  363. with Grad() as grad:
  364. grad.wrt(inp, callback=_save_to(inp))
  365. outp = F.adaptive_avg_pool2d(inp, oshp,)
  366. assert make_shape_tuple(outp.shape) == (inp.shape[0], inp.shape[1], *oshp,)
  367. np.testing.assert_equal(
  368. outp.numpy(), np.array([[[[2.5, 4.5], [10.5, 12.5]]]], dtype=np.float32)
  369. )
  370. grad(outp, tensor(F.ones_like(outp)))
  371. assert make_shape_tuple(inp.grad.shape) == make_shape_tuple(inp.shape)
  372. np.testing.assert_equal(
  373. inp.grad.numpy(),
  374. np.array(
  375. [
  376. [
  377. [
  378. [0.25, 0.25, 0.25, 0.25],
  379. [0.25, 0.25, 0.25, 0.25],
  380. [0.25, 0.25, 0.25, 0.25],
  381. [0.25, 0.25, 0.25, 0.25],
  382. ]
  383. ]
  384. ],
  385. dtype=np.float32,
  386. ),
  387. )
  388. def test_adaptive_max_pool2d():
  389. inp = tensor(np.arange(0, 16, dtype=np.float32).reshape(1, 1, 4, 4))
  390. oshp = (2, 2)
  391. with Grad() as grad:
  392. grad.wrt(inp, callback=_save_to(inp))
  393. outp = F.adaptive_max_pool2d(inp, oshp,)
  394. assert make_shape_tuple(outp.shape) == (inp.shape[0], inp.shape[1], *oshp,)
  395. np.testing.assert_equal(
  396. outp.numpy(), np.array([[[[5, 7], [13, 15]]]], dtype=np.float32)
  397. )
  398. grad(outp, tensor(F.ones_like(outp)))
  399. assert make_shape_tuple(inp.grad.shape) == make_shape_tuple(inp.shape)
  400. np.testing.assert_equal(
  401. inp.grad.numpy(),
  402. np.array(
  403. [
  404. [
  405. [
  406. [0.0, 0.0, 0.0, 0.0],
  407. [0.0, 1.0, 0.0, 1.0],
  408. [0.0, 0.0, 0.0, 0.0],
  409. [0.0, 1.0, 0.0, 1.0],
  410. ]
  411. ]
  412. ],
  413. dtype=np.float32,
  414. ),
  415. )
  416. def test_one_hot():
  417. def onehot_low_dimension():
  418. inp = tensor(np.arange(1, 4, dtype=np.int32))
  419. out = F.one_hot(inp, num_classes=4)
  420. np.testing.assert_allclose(
  421. out.numpy(), np.eye(4, dtype=np.int32)[np.arange(1, 4, dtype=np.int32)]
  422. )
  423. def onehot_high_dimension():
  424. arr = np.array(
  425. [[3, 2, 4, 4, 2, 4, 0, 4, 4, 1], [4, 1, 1, 3, 2, 2, 4, 2, 4, 3]],
  426. dtype=np.int32,
  427. )
  428. inp = tensor(arr)
  429. out = F.one_hot(inp, 10)
  430. np.testing.assert_allclose(out.numpy(), np.eye(10, dtype=np.int32)[arr])
  431. onehot_low_dimension()
  432. onehot_high_dimension()
  433. def test_interpolate_fastpath():
  434. # check shape
  435. test_cases = [
  436. [(1, 1, 10, 10), (5, 5)],
  437. [(1, 3, 10, 10), (20, 20)],
  438. [(10, 1, 10, 10), (1, 1)],
  439. [(10, 10, 1, 1), (10, 10)],
  440. ]
  441. for inp_shape, target_shape in test_cases:
  442. x = tensor(np.random.randn(*inp_shape), dtype=np.float32)
  443. out = F.vision.interpolate(x, target_shape, mode="bilinear")
  444. assert out.shape[0] == x.shape[0] and out.shape[1] == x.shape[1]
  445. assert out.shape[2] == target_shape[0] and out.shape[3] == target_shape[1]
  446. # check value
  447. x = tensor(np.ones((3, 3, 10, 10)), dtype=np.float32)
  448. out = F.vision.interpolate(x, (15, 5), mode="bilinear")
  449. np.testing.assert_equal(out.numpy(), np.ones((3, 3, 15, 5)).astype(np.float32))
  450. np_x = np.arange(32)
  451. x = tensor(np_x).astype(np.float32).reshape(1, 1, 32, 1)
  452. out = F.vision.interpolate(x, (1, 1), mode="bilinear")
  453. np.testing.assert_equal(out.item(), np_x.mean())
  454. @pytest.mark.parametrize("dt", [np.float32, np.int8, np.uint8, np.float16])
  455. def test_warp_perspective(dt):
  456. inp_shape = (1, 1, 4, 4)
  457. x = tensor(np.arange(16, dtype=dt).reshape(inp_shape))
  458. M_shape = (1, 3, 3)
  459. # M defines a translation: dst(1, 1, h, w) = rst(1, 1, h+1, w+1)
  460. M = tensor(
  461. np.array(
  462. [[1.0, 0.0, 1.0], [0.0, 1.0, 1.0], [0.0, 0.0, 1.0]], dtype=np.float32
  463. ).reshape(M_shape)
  464. )
  465. outp = F.vision.warp_perspective(x, M, (2, 2))
  466. np.testing.assert_equal(outp.numpy(), np.array([[[[5, 6], [9, 10]]]], dtype=dt))
  467. def test_warp_affine_grad():
  468. dy_np = np.arange(1, 10, dtype=np.float32).reshape(1, 1, 3, 3)
  469. x_np = np.arange(1, 10, dtype=np.float32).reshape(1, 1, 3, 3)
  470. mat_np_affine = np.array([[[0.5, 0, 0], [0, 0.5, 0],]]).astype("float32")
  471. mat_np_perspective = np.array([[[0.5, 0, 0], [0, 0.5, 0], [0, 0, 1]]]).astype(
  472. "float32"
  473. )
  474. dmat_affine = Tensor(np.ones((1, 2, 3), dtype=np.float32))
  475. dy_affine = Tensor(dy_np)
  476. x_affine = Tensor(x_np)
  477. mat_affine = Tensor(mat_np_affine)
  478. target_shape_affine = x_affine.shape[2:]
  479. dmat_perspective = Tensor(np.ones((1, 3, 3), dtype=np.float32))
  480. dy_perspective = Tensor(dy_np)
  481. x_perspective = Tensor(x_np)
  482. mat_perspective = Tensor(mat_np_perspective)
  483. target_shape_perspective = x_perspective.shape[2:]
  484. gm = GradManager().attach([x_affine, mat_affine, x_perspective, mat_perspective])
  485. with gm:
  486. y_affine = F.warp_affine(
  487. x_affine, mat_affine, target_shape_affine, format="NCHW"
  488. )
  489. y_perspective = F.warp_perspective(
  490. x_perspective, mat_perspective, target_shape_perspective
  491. )
  492. gm.backward([y_affine, y_perspective], [dy_affine, dy_perspective])
  493. np.testing.assert_allclose(
  494. x_affine.grad.numpy(), x_perspective.grad.numpy(), rtol=1e-5, atol=1e-5
  495. )
  496. np.testing.assert_allclose(
  497. mat_affine.grad.numpy(),
  498. mat_perspective.grad.numpy()[0:1, 0:2, 0:3],
  499. rtol=1e-5,
  500. atol=1e-5,
  501. )
  502. @pytest.mark.parametrize("dt", [np.float32, np.int8, np.uint8, np.float16])
  503. def test_warp_perspective_mat_idx(dt):
  504. inp_shape = (2, 1, 4, 4)
  505. x = tensor(np.arange(32, dtype=dt).reshape(inp_shape))
  506. M_shape = (1, 3, 3)
  507. # M defines a translation: dst(1, 1, h, w) = rst(1, 1, h+1, w+1)
  508. M = tensor(
  509. np.array(
  510. [[1.0, 0.0, 1.0], [0.0, 1.0, 1.0], [0.0, 0.0, 1.0]], dtype=np.float32
  511. ).reshape(M_shape)
  512. )
  513. M = F.concat([M,] * 4, 0)
  514. outp = F.vision.warp_perspective(x, M, (2, 2), mat_idx=[0, 1, 1, 0])
  515. np.testing.assert_equal(
  516. outp.numpy(),
  517. np.array(
  518. [
  519. [[[5, 6], [9, 10]]],
  520. [[[21, 22], [25, 26]]],
  521. [[[21, 22], [25, 26]]],
  522. [[[5, 6], [9, 10]]],
  523. ],
  524. dtype=dt,
  525. ),
  526. )
  527. def test_warp_affine():
  528. inp_shape = (1, 3, 3, 3)
  529. x = tensor(np.arange(27, dtype=np.float32).reshape(inp_shape))
  530. weightv = [[[1.26666667, 0.6, -83.33333333], [-0.33333333, 1, 66.66666667]]]
  531. outp = F.vision.warp_affine(x, tensor(weightv), (2, 2), border_mode="wrap")
  532. res = np.array(
  533. [
  534. [
  535. [[7.875, 8.875, 9.875], [8.90625, 9.90625, 10.90625]],
  536. [[18.75, 19.75, 20.75], [14.90625, 15.90625, 16.90625]],
  537. ]
  538. ],
  539. dtype=np.float32,
  540. )
  541. if not is_cuda_available():
  542. np.testing.assert_almost_equal(outp.numpy(), res, 5)
  543. def test_remap():
  544. inp_shape = (1, 1, 4, 4)
  545. inp = tensor(np.arange(16, dtype=np.float32).reshape(inp_shape))
  546. map_xy_shape = (1, 2, 2, 2)
  547. map_xy = tensor(
  548. np.array(
  549. [[[1.0, 0.0], [0.0, 1.0]], [[0.0, 1.0], [0.0, 1.0]]], dtype=np.float32
  550. ).reshape(map_xy_shape)
  551. )
  552. outp = F.vision.remap(inp, map_xy)
  553. np.testing.assert_equal(
  554. outp.numpy(), np.array([[[[1.0, 4.0], [4.0, 4.0]]]], dtype=np.float32)
  555. )
  556. def test_binary_cross_entropy():
  557. data1_shape = (2, 2)
  558. label1_shape = (2, 2)
  559. data2_shape = (2, 3)
  560. label2_shape = (2, 3)
  561. def sigmoid(x):
  562. return 1 / (1 + np.exp(-x))
  563. def compare_fn(x, y):
  564. np.testing.assert_allclose(x.numpy(), y, atol=5e-4)
  565. np.random.seed(123)
  566. data1 = np.random.uniform(size=data1_shape).astype(np.float32)
  567. label1 = np.random.uniform(size=label1_shape).astype(np.float32)
  568. expect1 = np.array(0.6361, dtype=np.float32)
  569. np.random.seed(123)
  570. data2 = np.random.uniform(size=data2_shape).astype(np.float32)
  571. label2 = np.random.uniform(size=label2_shape).astype(np.float32)
  572. expect2 = np.array(0.6750, dtype=np.float32)
  573. cases = [
  574. {"input": [data1, label1], "output": expect1,},
  575. {"input": [data2, label2], "output": expect2,},
  576. ]
  577. opr_test(cases, F.nn.binary_cross_entropy, compare_fn=compare_fn)
  578. cases = [
  579. {"input": [sigmoid(data1), label1], "output": expect1,},
  580. {"input": [sigmoid(data2), label2], "output": expect2,},
  581. ]
  582. opr_test(
  583. cases,
  584. partial(F.nn.binary_cross_entropy, with_logits=False),
  585. compare_fn=compare_fn,
  586. )
  587. def test_hinge_loss():
  588. np.random.seed(123)
  589. # case with L1 norm
  590. cases = []
  591. for shape in [(2, 2), (2, 3)]:
  592. data = np.random.uniform(size=shape).astype(np.float32)
  593. label = 2 * np.random.randint(0, 1, size=shape).astype(np.float32) - 1
  594. expect = np.clip(0, np.inf, 1 - data * label).sum(axis=1).mean()
  595. cases.append({"input": [data, label], "output": expect})
  596. opr_test(cases, F.nn.hinge_loss)
  597. # cases with L2 norm
  598. cases = []
  599. for shape in [(2, 2), (2, 3)]:
  600. data = np.random.uniform(size=shape).astype(np.float32)
  601. label = 2 * np.random.randint(0, 1, size=shape).astype(np.float32) - 1
  602. expect = ((np.clip(0, np.inf, 1 - data * label) ** 2).sum(axis=1)).mean()
  603. cases.append({"input": [data, label], "output": expect})
  604. def hinge_loss_with_l2_norm(pred, label):
  605. return F.nn.hinge_loss(pred, label, "L2")
  606. opr_test(cases, hinge_loss_with_l2_norm)
  607. @pytest.mark.parametrize("is_symbolic", [None, False, True])
  608. def test_nms(is_symbolic):
  609. def fn(inp, scores):
  610. return F.vision.nms(
  611. inp,
  612. scores=scores,
  613. iou_thresh=0.5,
  614. max_output=None if is_symbolic is None else 4,
  615. )
  616. if is_symbolic is not None:
  617. fn = jit.trace(symbolic=is_symbolic)(fn)
  618. x = np.array(
  619. [
  620. [0, 0, 100, 100],
  621. [10, 10, 100, 100],
  622. [50, 50, 100, 100],
  623. [100, 100, 150, 150],
  624. ],
  625. dtype=np.float32,
  626. )
  627. inp = tensor(x)
  628. scores = tensor([0.5, 0.8, 0.9, 0.6], dtype=np.float32)
  629. for _ in range(3):
  630. result = fn(inp, scores=scores)
  631. np.testing.assert_equal(result.numpy(), np.array([2, 1, 3], dtype=np.int32))
  632. x = np.array([], dtype=np.float32,).reshape(0, 4)
  633. inp = tensor(x)
  634. scores = tensor([], dtype=np.float32)
  635. for _ in range(3):
  636. result = fn(inp, scores=scores)
  637. np.testing.assert_equal(result.numpy(), np.array([], dtype=np.int32))
  638. @pytest.mark.skipif(
  639. get_device_count("gpu") > 0, reason="cuda does not support nchw int8"
  640. )
  641. def test_conv_bias():
  642. inp_scale = 1.5
  643. w_scale = 2.5
  644. outp_scale = 1.5
  645. inp_dtype = dtype.qint8(inp_scale)
  646. w_dtype = dtype.qint8(w_scale)
  647. b_dtype = dtype.qint32(inp_scale * w_scale)
  648. out_dtype = dtype.qint8(outp_scale)
  649. def run(
  650. N,
  651. IC,
  652. OC,
  653. IH,
  654. IW,
  655. KH,
  656. KW,
  657. PH,
  658. PW,
  659. SH,
  660. SW,
  661. has_bias=True,
  662. nonlinear_mode="identity",
  663. ):
  664. inp_v = np.random.normal(size=(N, IC, IH, IW))
  665. w_v = np.random.normal(size=(OC, IC, KH, KW))
  666. b_v = np.random.normal(size=(1, OC, 1, 1))
  667. inp_scale = dtype.get_scale(inp_dtype)
  668. w_scale = dtype.get_scale(w_dtype)
  669. b_scale = dtype.get_scale(b_dtype)
  670. inpv = dtype.convert_to_qint8(inp_v * inp_scale, inp_dtype)
  671. wv = dtype.convert_to_qint8(w_v * w_scale, w_dtype)
  672. bv = dtype.convert_to_qint32(b_v * b_scale, b_dtype)
  673. inp_int8 = tensor(inpv, dtype=inp_dtype)
  674. w_int8 = Parameter(wv, dtype=w_dtype)
  675. b_int32 = Parameter(bv, dtype=b_dtype)
  676. inp_fp32 = inp_int8.astype("float32")
  677. w_fp32 = w_int8.astype("float32")
  678. b_fp32 = b_int32.astype("float32")
  679. def convert_to_nchw4(var):
  680. var = F.reshape(
  681. var, (var.shape[0], var.shape[1] // 4, 4, var.shape[2], var.shape[3])
  682. )
  683. var = F.transpose(var, (0, 1, 3, 4, 2))
  684. return var
  685. def run_conv2d(inp, w, b):
  686. O = F.conv2d(
  687. inp, w, b if has_bias else None, stride=(SH, SW), padding=(PH, PW),
  688. )
  689. if nonlinear_mode == "relu":
  690. return F.relu(O)
  691. else:
  692. return O
  693. def run_conv_bias(inp, w, b, format="NCHW"):
  694. b = b if has_bias else Parameter(np.zeros_like(b.numpy()))
  695. if format == "NCHW4":
  696. inp = convert_to_nchw4(inp)
  697. w = convert_to_nchw4(w)
  698. b = convert_to_nchw4(b)
  699. return F.quantized.conv_bias_activation(
  700. inp,
  701. w,
  702. b,
  703. stride=(SH, SW),
  704. padding=(PH, PW),
  705. dtype=out_dtype,
  706. nonlinear_mode=nonlinear_mode,
  707. )
  708. format = "NCHW4" if is_cuda_available() else "NCHW"
  709. expected = run_conv2d(inp_fp32, w_fp32, b_fp32)
  710. expected = expected.astype(out_dtype).astype("float32")
  711. result = run_conv_bias(inp_int8, w_int8, b_int32, format=format).astype(
  712. "float32"
  713. )
  714. if format == "NCHW4":
  715. result = F.transpose(result, (0, 1, 4, 2, 3))
  716. expected = F.flatten(expected)
  717. result = F.flatten(result)
  718. np.testing.assert_allclose(result.numpy(), expected.numpy(), atol=outp_scale)
  719. run(1, 4, 4, 24, 33, 1, 1, 2, 3, 1, 1, False)
  720. run(10, 12, 24, 46, 46, 1, 1, 2, 1, 3, 1, False)
  721. run(10, 36, 8, 46, 26, 2, 2, 2, 1, 1, 2, False)
  722. run(1, 4, 4, 24, 33, 1, 1, 2, 3, 1, 1)
  723. run(10, 12, 24, 46, 46, 1, 1, 2, 1, 3, 1)
  724. run(10, 36, 8, 46, 26, 2, 2, 2, 1, 1, 2)
  725. run(10, 36, 8, 46, 26, 2, 2, 2, 1, 1, 2, False, "relu")
  726. run(10, 36, 8, 46, 26, 2, 2, 2, 1, 1, 2, True, "relu")
  727. @pytest.mark.skipif(get_device_count("gpu") > 0, reason="no int8 algorithm on cuda")
  728. def test_batch_conv_bias():
  729. inp_scale = 1.5
  730. w_scale = 2.5
  731. outp_scale = 1.5
  732. inp_dtype = dtype.qint8(inp_scale)
  733. w_dtype = dtype.qint8(w_scale)
  734. b_dtype = dtype.qint32(inp_scale * w_scale)
  735. out_dtype = dtype.qint8(outp_scale)
  736. def run(
  737. N, IC, OC, IH, IW, KH, KW, PH, PW, SH, SW, has_bias=True,
  738. ):
  739. inp_v = np.random.normal(size=(N, IC, IH, IW))
  740. w_v = np.random.normal(size=(N, OC, IC, KH, KW))
  741. b_v = np.random.normal(size=(1, OC, 1, 1))
  742. inp_scale = dtype.get_scale(inp_dtype)
  743. w_scale = dtype.get_scale(w_dtype)
  744. b_scale = dtype.get_scale(b_dtype)
  745. inpv = dtype.convert_to_qint8(inp_v * inp_scale, inp_dtype)
  746. wv = dtype.convert_to_qint8(w_v * w_scale, w_dtype)
  747. bv = dtype.convert_to_qint32(b_v * b_scale, b_dtype)
  748. inp_int8 = tensor(inpv, dtype=inp_dtype)
  749. w_int8 = Parameter(wv, dtype=w_dtype)
  750. b_int32 = Parameter(bv, dtype=b_dtype)
  751. inp_fp32 = inp_int8.astype("float32")
  752. w_fp32 = w_int8.astype("float32")
  753. b_fp32 = b_int32.astype("float32")
  754. def run_batch_conv_bias(inp, w, b):
  755. b = b if has_bias else Parameter(np.zeros_like(b.numpy()))
  756. result = F.quantized.batch_conv_bias_activation(
  757. inp, w, b, stride=(SH, SW), padding=(PH, PW), dtype=out_dtype,
  758. )
  759. return result.astype("float32")
  760. expected = F.conv2d(inp_fp32, w_fp32[0], b_fp32 if has_bias else None)[0]
  761. expected = expected.astype(out_dtype).astype("float32")
  762. expected = F.flatten(expected)
  763. result = run_batch_conv_bias(inp_int8, w_int8, b_int32)
  764. result = F.flatten(result)
  765. np.testing.assert_allclose(result.numpy(), expected.numpy(), atol=outp_scale)
  766. run(1, 4, 4, 5, 5, 3, 3, 0, 0, 1, 1, True)
  767. @pytest.mark.parametrize("bias", [True, False])
  768. def test_region_restricted_conv_forward_backward_naive(bias):
  769. import megengine as mge
  770. import megengine.module as M
  771. from megengine.autodiff import GradManager
  772. handle = "cpu0"
  773. src_1 = np.arange(8).reshape(1, 2, 2, 2).astype(np.float32)
  774. filter_1 = np.arange(8).reshape(2, 1, 1, 2, 2).astype(np.float32)
  775. rin_1 = np.array([1, 1, 1, 1]).reshape(1, 2, 2).astype(np.int32)
  776. rout_1 = np.array([1]).reshape(1, 1, 1).astype(np.int32)
  777. cpu_src = tensor(src_1, device=handle)
  778. cpu_filter = tensor(filter_1, device=handle)
  779. gm = GradManager().attach([cpu_src, cpu_filter])
  780. cpu_bias = (
  781. tensor(np.ones((1, 2, 1, 1), dtype=np.float32), device=handle) if bias else None
  782. )
  783. with gm:
  784. cpu_out = F.region_restricted_conv(
  785. cpu_src,
  786. cpu_filter,
  787. tensor(rin_1, device=handle),
  788. tensor(rout_1, device=handle),
  789. bias=cpu_bias,
  790. groups=2,
  791. )
  792. gm.backward(cpu_out, tensor(np.ones((1, 2, 1, 1)), device=handle))
  793. if cpu_bias is not None:
  794. cpu_out = cpu_out - cpu_bias
  795. np.testing.assert_allclose(cpu_out, np.array([14, 126]).reshape(1, 2, 1, 1))
  796. np.testing.assert_allclose(
  797. cpu_src.grad, np.array([0, 1, 2, 3, 4, 5, 6, 7]).reshape(1, 2, 2, 2)
  798. )
  799. np.testing.assert_allclose(
  800. cpu_filter.grad, np.array([0, 1, 2, 3, 4, 5, 6, 7]).reshape(2, 1, 1, 2, 2)
  801. )
  802. @pytest.mark.skipif(
  803. not is_cuda_available(), reason="rrconv cuda kernel requires cuda available"
  804. )
  805. @pytest.mark.parametrize("bias, groups", [(True, 1), (True, 3), (False, 1), (False, 3)])
  806. def test_region_restricted_conv_forward_backward_cuda(bias, groups):
  807. import megengine as mge
  808. import megengine.module as M
  809. from megengine.autodiff import GradManager
  810. # params
  811. handle = "gpu0"
  812. N = 1
  813. GROUP = groups
  814. FH = FW = 2
  815. IH = IW = 2
  816. OH = OW = 1
  817. ICPG = OCPG = 1
  818. grad_shape = (N, GROUP * ICPG, IH, IW)
  819. src_shape = grad_shape
  820. filter_shape = (GROUP, OCPG, ICPG, FH, FW)
  821. diff_shape = (N, GROUP * OCPG, OH, OW)
  822. rin_shape = (N, IH, IW)
  823. rout_shape = (N, OH, OW)
  824. def reduce(shape):
  825. mul = 1
  826. for x in shape:
  827. mul *= x
  828. return mul
  829. def get_groundtruth():
  830. src = tensor(
  831. np.arange(reduce(src_shape)).reshape(src_shape).astype(np.float32),
  832. device="cpu0",
  833. )
  834. filter = tensor(np.ones(filter_shape).astype(np.float32), device="cpu0")
  835. rin = tensor(np.ones(rin_shape).astype(np.int32), device="cpu0")
  836. rout = tensor(np.ones(rout_shape).astype(np.int32), device="cpu0")
  837. bias_cpu = (
  838. tensor(np.ones((1, GROUP * OCPG, 1, 1)).astype(np.float32), device="cpu0")
  839. if bias
  840. else None
  841. )
  842. gm = GradManager().attach([src, filter])
  843. with gm:
  844. expected_out = F.region_restricted_conv(
  845. src, filter, rin, rout, bias=bias_cpu, groups=GROUP
  846. )
  847. gm.backward(
  848. expected_out,
  849. tensor(np.ones(diff_shape, dtype=np.float32), device="cpu0"),
  850. )
  851. return src, filter, expected_out
  852. expected_src, expected_filter, expected_out = get_groundtruth()
  853. src = tensor(
  854. np.arange(reduce(src_shape)).reshape(src_shape).astype(np.float32),
  855. device=handle,
  856. )
  857. filter = tensor(np.ones(filter_shape).astype(np.float32), device=handle)
  858. rin = tensor(np.ones(rin_shape).astype(np.int32), device=handle)
  859. rout = tensor(np.ones(rout_shape).astype(np.int32), device=handle)
  860. bias_gpu = (
  861. tensor(np.ones((1, GROUP * OCPG, 1, 1)).astype(np.float32), device=handle)
  862. if bias
  863. else None
  864. )
  865. gm = GradManager().attach([src, filter])
  866. with gm:
  867. gpu_out = F.region_restricted_conv(
  868. src, filter, rin, rout, bias=bias_gpu, groups=GROUP
  869. )
  870. gm.backward(gpu_out, tensor(np.ones(diff_shape), device=handle))
  871. np.testing.assert_allclose(src.grad, expected_src.grad)
  872. np.testing.assert_allclose(filter.grad, expected_filter.grad)
  873. np.testing.assert_allclose(gpu_out, expected_out)
  874. @pytest.mark.skipif(
  875. not is_cuda_available(), reason="rrconv cuda kernel requires cuda available"
  876. )
  877. @pytest.mark.parametrize("bias, groups", [(True, 1), (True, 3), (False, 1), (False, 3)])
  878. def test_region_restricted_conv_forward_backward_uint8(bias, groups):
  879. import megengine as mge
  880. import megengine.module as M
  881. from megengine.autodiff import GradManager
  882. # params
  883. handle = "gpu0"
  884. N = 1
  885. GROUP = groups
  886. FH = FW = 1
  887. IH = IW = 4
  888. OH = OW = 4
  889. ICPG = OCPG = 1
  890. grad_shape = (N, GROUP * ICPG, IH, IW)
  891. src_shape = grad_shape
  892. filter_shape = (GROUP, OCPG, ICPG, FH, FW)
  893. diff_shape = (N, GROUP * OCPG, OH, OW)
  894. rin_shape = (N, IH, IW)
  895. rout_shape = (N, OH, OW)
  896. def reduce(shape):
  897. mul = 1
  898. for x in shape:
  899. mul *= x
  900. return mul
  901. def get_groundtruth():
  902. src = tensor(
  903. np.arange(reduce(src_shape)).reshape(src_shape).astype(np.float32),
  904. device="cpu0",
  905. )
  906. filter = tensor(np.ones(filter_shape).astype(np.float32), device="cpu0")
  907. rin = tensor(np.ones(rin_shape).astype(np.int32), device="cpu0")
  908. rout = tensor(np.ones(rout_shape).astype(np.int32), device="cpu0")
  909. bias_cpu = (
  910. tensor(np.ones((1, GROUP * OCPG, 1, 1)).astype(np.float32), device="cpu0")
  911. if bias
  912. else None
  913. )
  914. gm = GradManager().attach([src, filter])
  915. with gm:
  916. expected_out = F.region_restricted_conv(
  917. src, filter, rin, rout, bias=bias_cpu, groups=GROUP
  918. )
  919. gm.backward(
  920. expected_out,
  921. tensor(np.ones(diff_shape, dtype=np.float32), device="cpu0"),
  922. )
  923. return src, filter, expected_out
  924. expected_src, expected_filter, expected_out = get_groundtruth()
  925. # forward and dgrad/wgrad
  926. src = tensor(
  927. np.arange(reduce(src_shape)).reshape(src_shape).astype(np.float32),
  928. device=handle,
  929. )
  930. filter = tensor(np.ones(filter_shape).astype(np.float32), device=handle)
  931. rin = tensor(np.ones(rin_shape).astype(np.uint8), device=handle)
  932. rout = tensor(np.ones(rout_shape).astype(np.uint8), device=handle)
  933. bias_gpu = (
  934. tensor(np.ones((1, GROUP * OCPG, 1, 1)).astype(np.float32), device=handle)
  935. if bias
  936. else None
  937. )
  938. gm = GradManager().attach([src, filter])
  939. with gm:
  940. gpu_out = F.region_restricted_conv(
  941. src, filter, rin, rout, bias=bias_gpu, groups=GROUP
  942. )
  943. gm.backward(
  944. gpu_out, tensor(np.ones(diff_shape, dtype=np.float32), device=handle)
  945. )
  946. # assert uint8 gpu result close to cpu result
  947. np.testing.assert_allclose(src.grad, expected_src.grad)
  948. np.testing.assert_allclose(filter.grad, expected_filter.grad)
  949. np.testing.assert_allclose(gpu_out, expected_out)
  950. def test_conv2d_autocast():
  951. """check amp's result is equal to manually converted result"""
  952. amp.enabled = True
  953. inp = tensor(np.random.randn(1, 3, 224, 224), dtype=np.float32)
  954. weight = tensor(np.random.randn(64, 3, 7, 7), dtype=np.float32)
  955. out = F.conv2d(inp, weight, None, (2, 2), (3, 3), (1, 1), 1)
  956. amp.enabled = False
  957. expected = F.conv2d(
  958. inp.astype("float16"),
  959. weight.astype("float16"),
  960. None,
  961. (2, 2),
  962. (3, 3),
  963. (1, 1),
  964. 1,
  965. compute_mode="float32",
  966. )
  967. assert out.dtype == np.float16
  968. assert expected.dtype == np.float16
  969. np.testing.assert_allclose(out.numpy(), expected.numpy())
  970. def test_conv2d_zero_stride_numpy_array():
  971. inp = np.random.randn(3, 224, 224).astype(np.float32)
  972. inp = inp[np.newaxis, :]
  973. inp = tensor(inp, dtype=np.float32)
  974. weight = tensor(np.random.randn(16, 3, 3, 3), dtype=np.float32)
  975. out = F.conv2d(inp, weight, None, (2, 2), (3, 3), (1, 1), 1)
  976. def test_conv3d_zero_stride_numpy_array():
  977. inp = np.random.randn(3, 224, 224, 224).astype(np.float32)
  978. inp = inp[np.newaxis, :]
  979. inp = tensor(inp, dtype=np.float32)
  980. weight = tensor(np.random.randn(16, 3, 3, 3, 3), dtype=np.float32)
  981. out = F.conv3d(inp, weight, None, (2, 2, 2), (3, 3, 3), (1, 1, 1), 1)
  982. out.numpy()
  983. @pytest.mark.parametrize("bias", [True, False])
  984. def test_conv1d(bias):
  985. inp = tensor(np.ones((2, 2, 4), dtype=np.float32))
  986. weight = tensor(np.ones((3, 2, 2), dtype=np.float32))
  987. bias = tensor(np.ones((1, 3, 1), dtype=np.float32)) if bias else None
  988. out = F.conv1d(inp, weight, bias, 2, 0, 1, 1)
  989. np.testing.assert_equal(
  990. out.numpy(),
  991. np.array([[[5, 5], [5, 5], [5, 5]], [[5, 5], [5, 5], [5, 5]]], dtype=np.float32)
  992. if bias is not None
  993. else np.array(
  994. [[[4, 4], [4, 4], [4, 4]], [[4, 4], [4, 4], [4, 4]]], dtype=np.float32
  995. ),
  996. )
  997. def test_batchnorm2d_autocast():
  998. """check amp's result is equal to manually converted result"""
  999. amp.enabled = True
  1000. tshape = (1, 3, 224, 224)
  1001. pshape = (1, 3, 1, 1)
  1002. inp = tensor(np.random.randn(*tshape), dtype=np.float32)
  1003. weight = tensor(np.ones(pshape, dtype=np.float32))
  1004. bias = tensor(np.zeros(pshape, dtype=np.float32))
  1005. out = F.batch_norm(inp, weight=weight, bias=bias, training=True, inplace=False)
  1006. amp.enabled = False
  1007. expected = F.batch_norm(
  1008. inp.astype("float16"), weight=weight, bias=bias, training=True, inplace=False,
  1009. )
  1010. assert out.dtype == np.float16
  1011. assert expected.dtype == np.float16
  1012. np.testing.assert_allclose(out.numpy(), expected.numpy())
  1013. @pytest.mark.parametrize("bias", [True, False])
  1014. def test_conv3d(bias):
  1015. inp = tensor(np.ones((2, 2, 4, 4, 4), dtype=np.float32))
  1016. weight = tensor(np.ones((3, 2, 2, 2, 2), dtype=np.float32))
  1017. bias = tensor(np.ones((1, 3, 1, 1, 1), dtype=np.float32)) if bias else None
  1018. out = F.conv3d(inp, weight, bias, 2, 0, 1, 1)
  1019. target = np.ones((2, 3, 2, 2, 2), dtype=np.float32) * 16
  1020. target = target + 1 if bias is not None else target
  1021. np.testing.assert_equal(out.numpy(), target)
  1022. def test_condtake():
  1023. x = np.array([[1, 2, 3], [4, 5, 6]])
  1024. y = np.array([[True, False, True], [False, True, True]])
  1025. xx = tensor(x)
  1026. yy = tensor(y)
  1027. val, idx = F.cond_take(yy, xx)
  1028. np.testing.assert_equal(val.numpy(), x[y])
  1029. np.testing.assert_equal(idx.numpy(), np.where(y.reshape(-1))[0])
  1030. @pytest.mark.parametrize("is_symbolic", [None, False, True])
  1031. def test_condtake(is_symbolic):
  1032. shapes = [
  1033. (3, 3, 3),
  1034. (0,),
  1035. (3, 0, 3),
  1036. ]
  1037. def fn(mask, data):
  1038. return F.cond_take(mask, data)
  1039. if is_symbolic is not None:
  1040. fn = jit.trace(symbolic=is_symbolic)(fn)
  1041. for shp in shapes:
  1042. x_np = np.random.randn(*shp).astype("float32")
  1043. mask_np = x_np > 0
  1044. x = tensor(x_np)
  1045. mask = tensor(mask_np)
  1046. ref_out = x_np[mask_np]
  1047. ref_idx = mask_np.flatten().nonzero()[0]
  1048. for i in range(3):
  1049. out, idx = fn(mask, x)
  1050. np.testing.assert_equal(out.numpy(), ref_out)
  1051. np.testing.assert_equal(idx.numpy(), ref_idx)
  1052. if is_symbolic is None:
  1053. break
  1054. def test_condtake_is_same():
  1055. op1 = builtin.CondTake()
  1056. op2 = builtin.CondTake()
  1057. assert op1 == op2
  1058. def test_nms_is_same():
  1059. op1 = builtin.NMSKeep(0.7, 100)
  1060. op2 = builtin.NMSKeep(0.7, 100)
  1061. op3 = builtin.NMSKeep(0.8, 100)
  1062. op4 = builtin.NMSKeep(0.7, 200)
  1063. assert op1 == op2
  1064. assert op1 != op3
  1065. assert op1 != op4
  1066. assert op3 != op4
  1067. def test_argmxx_on_inf():
  1068. def run_argmax():
  1069. x = F.zeros((100, 100))
  1070. x[:] = -float("inf")
  1071. idxs = F.argmax(x, axis=0)
  1072. return idxs
  1073. def run_argmin():
  1074. x = F.zeros((100, 100))
  1075. x[:] = float("inf")
  1076. idxs = F.argmin(x, axis=0)
  1077. return idxs
  1078. assert all(run_argmax() >= 0)
  1079. assert all(run_argmin() >= 0)
  1080. def test_deformable_psroi_pooling():
  1081. inp = np.random.random((1, 256, 64, 64)).astype("float32")
  1082. rois = np.random.random((1, 5)).astype("float32")
  1083. trans = np.random.random((24, 2, 7, 7)).astype("float32")
  1084. pooled_h = 7
  1085. pooled_w = 7
  1086. sample_per_part = 4
  1087. no_trans = False
  1088. part_size = 7
  1089. spatial_scale = 1.0 / 64
  1090. trans_std = 0.1
  1091. y = F.deformable_psroi_pooling(
  1092. tensor(inp),
  1093. tensor(rois),
  1094. tensor(trans),
  1095. no_trans,
  1096. part_size,
  1097. pooled_h,
  1098. pooled_w,
  1099. sample_per_part,
  1100. spatial_scale,
  1101. trans_std,
  1102. )
  1103. def test_cvt_color():
  1104. def rgb2gray(rgb):
  1105. return np.dot(rgb[..., :3], [0.299, 0.587, 0.114])
  1106. def bgr2gray(bgr):
  1107. return np.dot(bgr[..., :3], [0.114, 0.587, 0.299])
  1108. inp = np.random.randn(3, 3, 3, 3).astype(np.float32)
  1109. out = np.expand_dims(rgb2gray(inp), 3).astype(np.float32)
  1110. x = tensor(inp)
  1111. y = F.vision.cvt_color(x, mode="RGB2GRAY")
  1112. np.testing.assert_allclose(y.numpy(), out, atol=1e-5)
  1113. out1 = np.expand_dims(bgr2gray(inp), 3).astype(np.float32)
  1114. y1 = F.vision.cvt_color(x, mode="BGR2GRAY")
  1115. np.testing.assert_allclose(y1.numpy(), out1, atol=1e-5)
  1116. @pytest.mark.parametrize("val", [2, [2,], [2, 3]])
  1117. def test_ones(val):
  1118. shp = tensor(val)
  1119. np_shp = np.array(val)
  1120. np.testing.assert_equal(F.ones(shp), np.ones(np_shp))
  1121. def test_assert_equal():
  1122. shape = (2, 3, 4, 5)
  1123. x = F.ones(shape, dtype=np.float32)
  1124. y = F.zeros(shape, dtype=np.float32) + 1.00001
  1125. z = F.utils._assert_equal(x, y)
  1126. def test_assert_not_equal():
  1127. shape = (2, 3, 4, 5)
  1128. x = F.ones(shape, dtype=np.float32)
  1129. y = F.zeros(shape, dtype=np.float32) + 1.1
  1130. with pytest.raises(RuntimeError):
  1131. z = F.utils._assert_equal(x, y)
  1132. def test_neg_axis():
  1133. x = tensor(np.random.normal(0, 1, (32, 5)))
  1134. y = F.argmax(x, axis=-1)
  1135. yy = F.argmax(x, axis=1)
  1136. np.testing.assert_equal(y.numpy(), yy.numpy())
  1137. y = F.argmax(x, axis=(-1, -2))
  1138. yy = F.argmax(x, axis=(0, 1))
  1139. np.testing.assert_equal(y.numpy(), yy.numpy())
  1140. y = F.argmin(x, axis=(-1, -2))
  1141. yy = F.argmin(x, axis=(0, 1))
  1142. np.testing.assert_equal(y.numpy(), yy.numpy())
  1143. def test_sliding_window():
  1144. N, C, H, W = 2, 3, 7, 8
  1145. inp = np.random.normal(size=(N, C, H, W))
  1146. ph, pw = 1, 2
  1147. sh, sw = 2, 1
  1148. wh, ww = 3, 2
  1149. dh, dw = 1, 3
  1150. s = lambda i, p, s, d, w: (i + p * 2 - (w - 1) * d - 1) // s + 1
  1151. inp_pad = np.zeros((N, C, H + ph * 2, W + pw * 2))
  1152. inp_pad[:, :, ph : H + ph, pw : W + pw] = inp
  1153. gt_out = np.empty(
  1154. (N, C, s(H, ph, sh, dh, wh), s(W, pw, sw, dw, ww), wh, ww), dtype=np.float32
  1155. )
  1156. for n, c, oh, ow in itertools.product(*map(range, gt_out.shape[:4])):
  1157. ih, iw = oh * sh, ow * sw
  1158. gt_out[n, c, oh, ow, :] = inp_pad[
  1159. n, c, ih : ih + (wh - 1) * dh + 1 : dh, iw : iw + (ww - 1) * dw + 1 : dw
  1160. ]
  1161. out = F.sliding_window(
  1162. tensor(inp), (wh, ww), padding=(ph, pw), stride=(sh, sw), dilation=(dh, dw)
  1163. )
  1164. np.testing.assert_equal(gt_out, out.numpy())
  1165. def test_sliding_window_transpose():
  1166. N, C, H, W = 2, 3, 7, 8
  1167. ph, pw = 1, 2
  1168. sh, sw = 2, 1
  1169. wh, ww = 3, 2
  1170. dh, dw = 1, 3
  1171. s = lambda i, p, s, d, w: (i + p * 2 - (w - 1) * d - 1) // s + 1
  1172. inp = np.random.normal(
  1173. size=(N, C, s(H, ph, sh, dh, wh), s(W, pw, sw, dw, ww), wh, ww)
  1174. ).astype(np.float32)
  1175. gt_out = np.zeros((N, C, H, W), dtype=np.float32)
  1176. for n, c in itertools.product(*map(range, inp.shape[:2])):
  1177. oh = 0
  1178. for ih in range(-ph, H + ph - dh * (wh - 1), sh):
  1179. ow = 0
  1180. for iw in range(-pw, W + pw - dw * (ww - 1), sw):
  1181. for kh, kw in itertools.product(*map(range, inp.shape[-2:])):
  1182. ih2 = ih + dh * kh
  1183. iw2 = iw + dw * kw
  1184. if ih2 >= 0 and ih2 < H and iw2 >= 0 and iw2 < W:
  1185. gt_out[n, c, ih2, iw2] += inp[n, c, oh, ow, kh, kw]
  1186. ow += 1
  1187. oh += 1
  1188. out = F.sliding_window_transpose(
  1189. tensor(inp),
  1190. (H, W),
  1191. (wh, ww),
  1192. padding=(ph, pw),
  1193. stride=(sh, sw),
  1194. dilation=(dh, dw),
  1195. )
  1196. np.testing.assert_equal(gt_out, out.numpy())
  1197. def test_pad():
  1198. src = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]], dtype=np.float32)
  1199. dst = np.pad(src, ((2, 2), (2, 2)), "constant")
  1200. res = F.nn.pad(tensor(src), ((2, 2), (2, 2)), "CONSTANT")
  1201. np.testing.assert_allclose(res, dst, atol=1e-5)
  1202. dst = np.pad(src, ((2, 2), (2, 2)), "constant", constant_values=3)
  1203. res = F.nn.pad(tensor(src), ((2, 2), (2, 2)), "CONSTANT", constant_value=3)
  1204. np.testing.assert_allclose(res, dst, atol=1e-5)
  1205. dst = np.pad(src, ((2, 2), (2, 2)), "edge")
  1206. res = F.nn.pad(tensor(src), ((2, 2), (2, 2)), "EDGE")
  1207. np.testing.assert_allclose(res, dst, atol=1e-5)
  1208. dst = np.pad(src, ((2, 2), (2, 2)), "reflect")
  1209. res = F.nn.pad(tensor(src), ((2, 2), (2, 2)), "REFLECT")
  1210. np.testing.assert_allclose(res, dst, atol=1e-5)
  1211. def pixel_shuffle(data, r):
  1212. high_dim = data.shape[:-3]
  1213. data = data.reshape(-1, data.shape[-3], data.shape[-2], data.shape[-1])
  1214. inn, ic, ih, iw = data.shape
  1215. res = np.zeros((inn, int(ic / (r * r)), ih * r, iw * r))
  1216. for n in range(inn):
  1217. for c in range(ic):
  1218. for h in range(ih):
  1219. for w in range(iw):
  1220. res[
  1221. n,
  1222. int(c / r / r),
  1223. h * r + int((c % (r * r)) / r),
  1224. w * r + c % r,
  1225. ] = data[n, c, h, w]
  1226. if len(high_dim) > 0:
  1227. res = res.reshape((*high_dim, int(ic / r / r), ih * r, iw * r))
  1228. else:
  1229. res = res[0]
  1230. return res
  1231. def test_pixel_shuffle():
  1232. # ndim = 3
  1233. inp = np.arange(16 * 3 * 3).reshape(16, 3, 3)
  1234. out = F.pixel_shuffle(tensor(inp), upscale_factor=4)
  1235. golden = pixel_shuffle(inp, 4)
  1236. np.testing.assert_equal(out.numpy(), golden)
  1237. inp_float = np.float32(inp)
  1238. out = F.pixel_shuffle(tensor(inp_float), upscale_factor=2)
  1239. golden = pixel_shuffle(inp_float, 2)
  1240. np.testing.assert_equal(out.numpy(), golden)
  1241. # ndim = 4
  1242. inp = np.arange(3 * 18 * 3 * 3).reshape(3, 18, 3, 3)
  1243. out = F.pixel_shuffle(tensor(inp), upscale_factor=3)
  1244. golden = pixel_shuffle(inp, 3)
  1245. np.testing.assert_equal(out.numpy(), golden)
  1246. inp_float = np.float32(inp)
  1247. out = F.pixel_shuffle(tensor(inp_float), upscale_factor=3)
  1248. golden = pixel_shuffle(inp_float, 3)
  1249. np.testing.assert_equal(out.numpy(), golden)
  1250. # ndim = 5
  1251. inp = np.arange(5 * 3 * 20 * 3 * 4).reshape(5, 3, 20, 3, 4)
  1252. out = F.pixel_shuffle(tensor(inp), upscale_factor=2)
  1253. golden = pixel_shuffle(inp, 2)
  1254. np.testing.assert_equal(out.numpy(), golden)
  1255. inp_float = np.float32(inp)
  1256. out = F.pixel_shuffle(tensor(inp_float), upscale_factor=2)
  1257. golden = pixel_shuffle(inp_float, 2)
  1258. np.testing.assert_equal(out.numpy(), golden)
  1259. # ndim = 6
  1260. inp = np.arange(6 * 5 * 3 * 25 * 3 * 4).reshape(6, 5, 3, 25, 3, 4)
  1261. out = F.pixel_shuffle(tensor(inp), upscale_factor=5)
  1262. golden = pixel_shuffle(inp, 5)
  1263. np.testing.assert_equal(out.numpy(), golden)
  1264. inp_float = np.float32(inp)
  1265. out = F.pixel_shuffle(tensor(inp_float), upscale_factor=5)
  1266. golden = pixel_shuffle(inp_float, 5)
  1267. np.testing.assert_equal(out.numpy(), golden)
  1268. # ndim = 7
  1269. inp = np.arange(2 * 3 * 5 * 3 * 20 * 3 * 4).reshape(2, 3, 5, 3, 20, 3, 4)
  1270. out = F.pixel_shuffle(tensor(inp), upscale_factor=2)
  1271. golden = pixel_shuffle(inp, 2)
  1272. np.testing.assert_equal(out.numpy(), golden)
  1273. inp_float = np.float32(inp)
  1274. out = F.pixel_shuffle(tensor(inp_float), upscale_factor=2)
  1275. golden = pixel_shuffle(inp_float, 2)
  1276. np.testing.assert_equal(out.numpy(), golden)
  1277. @pytest.mark.parametrize("type", ["int32", "float32"])
  1278. @pytest.mark.parametrize("is_symbolic", [False, True])
  1279. def test_pixel_shuffle_symbolic(is_symbolic, type):
  1280. def fn(inp, upscale_factor):
  1281. return F.pixel_shuffle(inp, upscale_factor=upscale_factor)
  1282. if is_symbolic is not None:
  1283. fn = jit.trace(symbolic=is_symbolic)(fn)
  1284. inp = tensor(np.arange(3 * 4 * 5 * 5).reshape(3, 4, 5, 5).astype(type))
  1285. golden = pixel_shuffle(inp, 2)
  1286. for _ in range(3):
  1287. out = fn(inp, 2)
  1288. np.testing.assert_equal(out.numpy(), golden)
  1289. if is_symbolic is None:
  1290. break
  1291. def test_set_conv2d_config():
  1292. """check setting config by contextmanager is equal to manually converted result"""
  1293. config._compute_mode = "float32"
  1294. inp = tensor(np.random.randn(1, 3, 224, 224), dtype=np.float16)
  1295. weight = tensor(np.random.randn(64, 3, 7, 7), dtype=np.float16)
  1296. config_out = F.conv2d(inp, weight, None, (2, 2), (3, 3), (1, 1), 1)
  1297. config._compute_mode = "default"
  1298. with config._override(compute_mode="float32"):
  1299. context_out = F.conv2d(inp, weight, None, (2, 2), (3, 3), (1, 1), 1)
  1300. expected = F.conv2d(
  1301. inp, weight, None, (2, 2), (3, 3), (1, 1), 1, compute_mode="float32",
  1302. )
  1303. np.testing.assert_allclose(config_out.numpy(), expected.numpy())
  1304. np.testing.assert_allclose(context_out.numpy(), expected.numpy())
  1305. @pytest.mark.parametrize("stride", [(1, 1)])
  1306. @pytest.mark.parametrize("padding", [(1, 1)])
  1307. @pytest.mark.parametrize("dilation", [(1, 1)])
  1308. @pytest.mark.parametrize("ksize", [(3, 3)])
  1309. @pytest.mark.parametrize("groups", [1, 2])
  1310. def test_local_conv2d(stride, padding, dilation, ksize, groups):
  1311. batch_size, in_channels, out_channels = 2, 4, 8
  1312. input_height, input_width = 10, 10
  1313. output_height = (input_height + padding[0] * 2 - ksize[0]) // stride[0] + 1
  1314. output_width = (input_width + padding[1] * 2 - ksize[1]) // stride[1] + 1
  1315. def local_conv2d_np(data, weight, stride, padding, dialtion):
  1316. # naive calculation use numpy
  1317. # only test output_height == input_height, output_width == input_width
  1318. data = np.pad(data, ((0, 0), (0, 0), (1, 1), (1, 1)))
  1319. expected = np.zeros(
  1320. (batch_size, out_channels, output_height, output_width), dtype=np.float32,
  1321. )
  1322. ic_group_size = in_channels // groups
  1323. oc_group_size = out_channels // groups
  1324. for n, oc, oh, ow in itertools.product(
  1325. *map(range, [batch_size, out_channels, output_height, output_width])
  1326. ):
  1327. ih, iw = oh * stride[0], ow * stride[1]
  1328. g_id = oc // oc_group_size
  1329. expected[n, oc, ih, iw] = np.sum(
  1330. data[
  1331. n,
  1332. g_id * ic_group_size : (g_id + 1) * ic_group_size,
  1333. ih : ih + ksize[0],
  1334. iw : iw + ksize[1],
  1335. ]
  1336. * weight[g_id, oh, ow, :, :, :, oc % oc_group_size]
  1337. )
  1338. return expected
  1339. data = np.random.rand(batch_size, in_channels, input_height, input_width).astype(
  1340. "float32"
  1341. )
  1342. weight = np.random.rand(
  1343. groups,
  1344. output_height,
  1345. output_width,
  1346. in_channels // groups,
  1347. *ksize,
  1348. out_channels // groups,
  1349. ).astype("float32")
  1350. output = F.local_conv2d(
  1351. tensor(data),
  1352. tensor(weight),
  1353. None,
  1354. stride=stride,
  1355. padding=padding,
  1356. dilation=dilation,
  1357. )
  1358. ref = local_conv2d_np(data, weight, stride, padding, dilation)
  1359. np.testing.assert_almost_equal(output.numpy(), ref, 5)
  1360. def test_conv_transpose2d():
  1361. m = ConvTranspose2d(
  1362. 16, 33, (3, 5), output_padding=(1, 2), stride=(2, 3), padding=(4, 2)
  1363. )
  1364. @trace(symbolic=True)
  1365. def fwd(inp: Tensor):
  1366. return m(inp)
  1367. input = Tensor(np.random.rand(20, 16, 50, 100))
  1368. output = fwd(input)
  1369. output_shape = Tensor(output.shape)
  1370. np.testing.assert_equal(
  1371. output_shape.numpy(), np.array([20, 33, 94, 300], dtype=np.int32)
  1372. )
  1373. def test_conv_transpose3d():
  1374. m = ConvTranspose3d(
  1375. 16, 33, (3, 5, 2), output_padding=(2, 1, 1), stride=(3, 2, 2), padding=(0, 4, 2)
  1376. )
  1377. @trace(symbolic=True)
  1378. def fwd(inp: Tensor):
  1379. return m(inp)
  1380. input = Tensor(np.random.rand(20, 16, 10, 50, 100))
  1381. output = fwd(input)
  1382. output_shape = Tensor(output.shape)
  1383. np.testing.assert_equal(
  1384. output_shape.numpy(), np.array([20, 33, 32, 96, 197], dtype=np.int32)
  1385. )
  1386. @pytest.mark.skip(reason="pytest aborted")
  1387. def test_softmax():
  1388. def np_softmax(x):
  1389. return np.exp(x) / np.sum(np.exp(x), axis=1, keepdims=True)
  1390. data = (np.random.random(size=(1, 16, 224, 224)).astype(np.float32) - 0.5) * 100
  1391. desired = np_softmax(data[:, :3, 0, 0])
  1392. data = Tensor(data)
  1393. data = data[:, :3, 0, 0]
  1394. actual = F.softmax(data)
  1395. np.testing.assert_allclose(actual.numpy(), desired, rtol=1e-5)