You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_network_node.py 22 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759
  1. import io
  2. import os
  3. import platform
  4. import numpy as np
  5. import pytest
  6. import megengine as mge
  7. import megengine.core.tensor.dtype as dtype
  8. import megengine.core.tensor.megbrain_graph as G
  9. import megengine.functional as F
  10. import megengine.module as M
  11. import megengine.random as rand
  12. from megengine.core._imperative_rt.core2 import apply
  13. from megengine.core._wrap import Device
  14. from megengine.core.ops import builtin
  15. from megengine.device import (
  16. get_cuda_compute_capability,
  17. get_device_count,
  18. is_cuda_available,
  19. )
  20. from megengine.functional.external import tensorrt_runtime_opr
  21. from megengine.jit.tracing import trace
  22. from megengine.tensor import Tensor
  23. from megengine.utils.comp_graph_tools import GraphInference
  24. from megengine.utils.network import Network as Net
  25. def check_pygraph_dump(trace_func, inp_data, expect_results, max_err=None):
  26. orig_model = io.BytesIO()
  27. inp_size = len(inp_data)
  28. out_size = len(expect_results)
  29. arg_names = ["arg_{}".format(i) for i in range(inp_size)]
  30. output_names = ["out_{}".format(i) for i in range(out_size)]
  31. trace_func.dump(
  32. orig_model,
  33. arg_names=arg_names,
  34. output_names=output_names,
  35. optimize_for_inference=False,
  36. )
  37. orig_model.seek(0)
  38. net = Net.load(orig_model)
  39. file = io.BytesIO()
  40. net.dump(file, optimize_for_inference=False)
  41. file.seek(0)
  42. graph = GraphInference(file)
  43. inp_dict = dict([(arg_names[i], inp_data[i].numpy()) for i in range(inp_size)])
  44. results = graph.run(inp_dict=inp_dict)
  45. for ind, tensor in enumerate(expect_results):
  46. if max_err:
  47. np.testing.assert_almost_equal(
  48. tensor.numpy(), results[output_names[ind]], max_err
  49. )
  50. else:
  51. np.testing.assert_equal(tensor.numpy(), results[output_names[ind]])
  52. assert tensor.dtype == results[output_names[ind]].dtype
  53. def test_elemwise():
  54. @trace(symbolic=True, capture_as_const=True)
  55. def fwd(x, y):
  56. z1 = x * y
  57. z2 = x + y
  58. z3 = z1 / z2
  59. z3 = z3 ** 3
  60. return z3
  61. x = Tensor([1.0, 2.0])
  62. y = Tensor([3.0, 5.0])
  63. result = fwd(x, y)
  64. check_pygraph_dump(fwd, [x, y], [result])
  65. def test_reduce():
  66. @trace(symbolic=True, capture_as_const=True)
  67. def fwd(data):
  68. x = data.sum(axis=2)
  69. x = x.mean(axis=1)
  70. return x
  71. data = Tensor(np.random.random((1, 32, 32)))
  72. result = fwd(data)
  73. check_pygraph_dump(fwd, [data], [result])
  74. def test_typecvt():
  75. @trace(symbolic=True, capture_as_const=True)
  76. def fwd(data):
  77. return data.astype(dtype.qint8(0.8))
  78. x = Tensor(np.random.random((2, 3)) * 255)
  79. result = fwd(x)
  80. check_pygraph_dump(fwd, [x], [result])
  81. def test_matinv():
  82. @trace(symbolic=True, capture_as_const=True)
  83. def fwd(data):
  84. return F.matinv(data)
  85. data = Tensor(np.random.random((5, 5)))
  86. result = fwd(data)
  87. check_pygraph_dump(fwd, [data], [result])
  88. @pytest.mark.parametrize(
  89. "benchmark_kernel, max_err", [(False, None), (True, 1e-5)],
  90. )
  91. def test_matmul(monkeypatch, benchmark_kernel, max_err):
  92. if get_device_count("gpu") == 0 and benchmark_kernel:
  93. return
  94. monkeypatch.setenv("MGE_FASTRUN_CACHE_TYPE", "MEMORY")
  95. old1, old2 = (
  96. mge.config.benchmark_kernel,
  97. mge.config.deterministic_kernel,
  98. )
  99. mge.config.benchmark_kernel = benchmark_kernel
  100. mge.config.deterministic_kernel = True
  101. @trace(symbolic=True, capture_as_const=True)
  102. def fwd(data1, data2):
  103. return F.matmul(data1, data2)
  104. data1 = Tensor(np.random.random((32, 64)))
  105. data2 = Tensor(np.random.random((64, 16)))
  106. result = fwd(data1, data2)
  107. check_pygraph_dump(fwd, [data1, data2], [result], max_err=max_err)
  108. mge.config.benchmark_kernel = old1
  109. mge.config.deterministic_kernel = old2
  110. monkeypatch.delenv("MGE_FASTRUN_CACHE_TYPE", raising=False)
  111. def test_batchmatmul():
  112. @trace(symbolic=True, capture_as_const=True)
  113. def fwd(x, y):
  114. return F.matmul(x, y)
  115. x = Tensor(np.random.random((3, 3, 5)))
  116. y = Tensor(np.random.random((3, 5, 3)))
  117. result = fwd(x, y)
  118. check_pygraph_dump(fwd, [x, y], [result])
  119. def test_dot():
  120. @trace(symbolic=True, capture_as_const=True)
  121. def fwd(x, y):
  122. return F.dot(x, y)
  123. x = Tensor([1.0, 2.0, 3.0])
  124. y = Tensor([3.0, 4.0, 5.0])
  125. result = fwd(x, y)
  126. check_pygraph_dump(fwd, [x, y], [result])
  127. def test_svd():
  128. @trace(symbolic=True, capture_as_const=True)
  129. def fwd(data):
  130. _, out, _ = F.svd(data)
  131. return out
  132. input = Tensor(np.random.random((1, 1, 3, 3)))
  133. result = fwd(input)
  134. check_pygraph_dump(fwd, [input], [result])
  135. def test_conv():
  136. conv = M.Conv2d(3, 32, 3)
  137. @trace(symbolic=True, capture_as_const=True)
  138. def fwd(data):
  139. return conv(data)
  140. data = Tensor(np.random.random((1, 3, 32, 32)))
  141. result = fwd(data)
  142. check_pygraph_dump(fwd, [data], [result])
  143. def test_deformable_conv():
  144. if not is_cuda_available():
  145. return
  146. conv = M.DeformableConv2d(3, 32, 3)
  147. @trace(symbolic=True, capture_as_const=True)
  148. def fwd(data, offset, mask):
  149. return conv(data, offset, mask)
  150. data = Tensor(np.random.random((1, 3, 32, 32)))
  151. offset = Tensor(np.ones((32, 3 * 3 * 2, 30, 30)).astype("int32") * 5)
  152. mask = Tensor(np.ones((32, 3 * 3, 30, 30)).astype("int32"))
  153. out = fwd(data, offset, mask)
  154. check_pygraph_dump(fwd, [data, offset, mask], [out])
  155. def test_convtranspose():
  156. deconv = M.ConvTranspose2d(32, 32, 3)
  157. @trace(symbolic=True, capture_as_const=True)
  158. def fwd(data):
  159. return deconv(data)
  160. data = Tensor(np.random.random((1, 32, 32, 32)))
  161. result = fwd(data)
  162. # cu111 has 1e-7 diff
  163. check_pygraph_dump(fwd, [data], [result], 5)
  164. @pytest.mark.skip(reason="pytest aborted")
  165. def test_grouplocal():
  166. n = M.LocalConv2d(3, 32, 32, 32, 3)
  167. @trace(symbolic=True, capture_as_const=True)
  168. def fwd(data):
  169. return n(data)
  170. input = Tensor(np.random.random((1, 3, 32, 32)))
  171. result = fwd(input)
  172. check_pygraph_dump(fwd, [input], [result])
  173. def test_pooling():
  174. @trace(symbolic=True, capture_as_const=True)
  175. def fwd(data):
  176. out = F.max_pool2d(data, 2, 2)
  177. out = F.avg_pool2d(out, 2, 2)
  178. return out
  179. data = Tensor(np.random.random((1, 3, 64, 64)))
  180. result = fwd(data)
  181. check_pygraph_dump(fwd, [data], [result])
  182. def test_adaptivepooling():
  183. pool1 = M.AdaptiveMaxPool2d((2, 2))
  184. pool2 = M.AdaptiveAvgPool2d((2, 2))
  185. @trace(symbolic=True, capture_as_const=True)
  186. def fwd(data):
  187. out = pool1(data)
  188. out = pool2(out)
  189. return out
  190. input = Tensor(np.random.random((1, 3, 32, 32)))
  191. result = fwd(input)
  192. check_pygraph_dump(fwd, [input], [result])
  193. def test_roipooling():
  194. inp = Tensor(np.random.random((1, 1, 128, 128)))
  195. rois = Tensor(np.random.random((4, 5)))
  196. @trace(symbolic=True, capture_as_const=True)
  197. def fwd(inp, rois):
  198. return F.vision.roi_pooling(inp, rois, (2, 2), scale=2.0)
  199. output = fwd(inp, rois)
  200. check_pygraph_dump(fwd, [inp, rois], [output])
  201. def test_deformable_ps_roi_pooling():
  202. inp = Tensor(np.random.random((1, 256, 64, 64)).astype("float32"))
  203. rois = Tensor(np.random.random((1, 5)).astype("float32"))
  204. trans = Tensor(np.random.random((24, 2, 7, 7)).astype("float32"))
  205. pooled_h = 7
  206. pooled_w = 7
  207. sample_per_part = 4
  208. no_trans = False
  209. part_size = 7
  210. spatial_scale = 1.0 / 64
  211. trans_std = 0.1
  212. @trace(symbolic=True, capture_as_const=True)
  213. def fwd(inp, rois, trans):
  214. y = F.deformable_psroi_pooling(
  215. inp,
  216. rois,
  217. trans,
  218. no_trans,
  219. part_size,
  220. pooled_h,
  221. pooled_w,
  222. sample_per_part,
  223. spatial_scale,
  224. trans_std,
  225. )
  226. return y
  227. result = fwd(inp, rois, trans)
  228. check_pygraph_dump(fwd, [inp, rois, trans], [result])
  229. @pytest.mark.skipif(
  230. get_device_count("gpu") > 0 and get_cuda_compute_capability(0) < 61,
  231. reason="does not support int8 when gpu compute capability less than 6.1",
  232. )
  233. def test_convbias():
  234. @trace(symbolic=True, capture_as_const=True)
  235. def fwd(inp, weight, bias):
  236. return F.quantized.conv_bias_activation(
  237. inp, weight, bias, dtype=dtype.qint8(scale=1.0), nonlinear_mode="relu"
  238. )
  239. inp = Tensor(np.random.random((1, 3, 64, 64)), dtype=dtype.qint8(scale=1.0))
  240. weight = Tensor(np.random.random((32, 3, 3, 3)), dtype=dtype.qint8(scale=1.0))
  241. bias = Tensor(np.random.random((1, 32, 1, 1)), dtype=dtype.qint32(scale=1.0))
  242. result = fwd(inp, weight, bias)
  243. check_pygraph_dump(fwd, [inp, weight, bias], [result])
  244. @pytest.mark.skip(reason="does not support int4 when cuda version is lower than 10.2")
  245. def test_conv_bias_int4():
  246. @trace(symbolic=True, capture_as_const=True)
  247. def fwd(inp, weight, bias):
  248. return F.quantized.conv_bias_activation(
  249. inp,
  250. weight,
  251. bias,
  252. dtype=dtype.quint4(scale=1.0, zero_point=0),
  253. nonlinear_mode="relu",
  254. )
  255. inp = Tensor(
  256. np.random.random((1, 3, 64, 64)), dtype=dtype.quint4(scale=1.0, zero_point=0)
  257. )
  258. weight = Tensor(np.random.random((32, 3, 3, 3)), dtype=dtype.qint4(scale=1.0))
  259. bias = Tensor(np.random.random((1, 32, 1, 1)), dtype=dtype.qint32(scale=1.0))
  260. result = fwd(inp, weight, bias)
  261. check_pygraph_dump(fwd, [inp, weight, bias], [result])
  262. def test_batch_convbias():
  263. if is_cuda_available():
  264. return
  265. @trace(symbolic=True, capture_as_const=True)
  266. def fwd(inp, weight, bias):
  267. return F.quantized.batch_conv_bias_activation(
  268. inp, weight, bias, dtype=dtype.qint8(scale=1.0), nonlinear_mode="relu"
  269. )
  270. inp = Tensor(np.random.random((1, 3, 64, 64)), dtype=dtype.qint8(scale=1.0))
  271. weight = Tensor(np.random.random((1, 32, 3, 3, 3)), dtype=dtype.qint8(scale=1.0))
  272. bias = Tensor(np.random.random((1, 32, 1, 1)), dtype=dtype.qint32(scale=1.0))
  273. result = fwd(inp, weight, bias)
  274. check_pygraph_dump(fwd, [inp, weight, bias], [result])
  275. def test_batchnorm():
  276. bn = M.BatchNorm2d(32)
  277. bn.eval()
  278. @trace(symbolic=True, capture_as_const=True)
  279. def fwd(data):
  280. return bn(data)
  281. data = Tensor(np.random.random((1, 32, 32, 32)))
  282. result = fwd(data)
  283. check_pygraph_dump(fwd, [data], [result])
  284. def test_roialign():
  285. inp = Tensor(np.random.randn(1, 1, 128, 128))
  286. rois = Tensor(np.random.random((4, 5)))
  287. @trace(symbolic=True, capture_as_const=True)
  288. def fwd(inp, rois):
  289. return F.vision.roi_align(inp, rois, (2, 2))
  290. output = fwd(inp, rois)
  291. check_pygraph_dump(fwd, [inp, rois], [output])
  292. def test_warpperspective():
  293. inp_shape = (1, 1, 4, 4)
  294. x = Tensor(np.arange(16, dtype=np.float32).reshape(inp_shape))
  295. M_shape = (1, 3, 3)
  296. # M defines a translation: dst(1, 1, h, w) = rst(1, 1, h+1, w+1)
  297. M = Tensor(
  298. np.array(
  299. [[1.0, 0.0, 1.0], [0.0, 1.0, 1.0], [0.0, 0.0, 1.0]], dtype=np.float32
  300. ).reshape(M_shape)
  301. )
  302. @trace(symbolic=True, capture_as_const=True)
  303. def fwd(x, M):
  304. return F.vision.warp_perspective(x, M, (2, 2))
  305. result = fwd(x, M)
  306. check_pygraph_dump(fwd, [x, M], [result])
  307. def test_warpaffine():
  308. inp_shape = (1, 3, 3, 3)
  309. x = Tensor(np.arange(27, dtype=np.float32).reshape(inp_shape))
  310. weightv = Tensor([[[1.26666667, 0.6, -83.33333333], [-0.33333333, 1, 66.66666667]]])
  311. @trace(symbolic=True, capture_as_const=True)
  312. def fwd(x, weightv):
  313. return F.vision.warp_affine(x, weightv, (2, 2), border_mode="wrap")
  314. outp = fwd(x, weightv)
  315. check_pygraph_dump(fwd, [x, weightv], [outp])
  316. def test_remap():
  317. inp_shape = (1, 1, 4, 4)
  318. inp = Tensor(np.arange(16, dtype=np.float32).reshape(inp_shape))
  319. map_xy_shape = (1, 2, 2, 2)
  320. map_xy = Tensor(
  321. np.array(
  322. [[[1.0, 0.0], [0.0, 1.0]], [[0.0, 1.0], [0.0, 1.0]]], dtype=np.float32
  323. ).reshape(map_xy_shape)
  324. )
  325. @trace(symbolic=True, capture_as_const=True)
  326. def fwd(inp, map_xy):
  327. return F.vision.remap(inp, map_xy)
  328. out = fwd(inp, map_xy)
  329. check_pygraph_dump(fwd, [inp, map_xy], [out])
  330. def test_resize():
  331. x = Tensor(np.random.randn(10, 3, 32, 32))
  332. @trace(symbolic=True, capture_as_const=True)
  333. def fwd(x):
  334. return F.vision.interpolate(x, size=(16, 16), mode="bilinear")
  335. out = fwd(x)
  336. check_pygraph_dump(fwd, [x], [out])
  337. def test_index_onehot():
  338. src = Tensor([[1.0, 2.0]])
  339. index = Tensor([0])
  340. @trace(symbolic=True, capture_as_const=True)
  341. def fwd(src, index):
  342. return F.indexing_one_hot(src, index)
  343. out = fwd(src, index)
  344. check_pygraph_dump(fwd, [src, index], [out])
  345. def test_set_onehot():
  346. x = Tensor(np.arange(1, 4, dtype=np.int32))
  347. @trace(symbolic=True, capture_as_const=True)
  348. def fwd(x):
  349. return F.one_hot(x, num_classes=4)
  350. out = fwd(x)
  351. check_pygraph_dump(fwd, [x], [out])
  352. def test_copy():
  353. x = Tensor([1, 2, 3])
  354. @trace(symbolic=True, capture_as_const=True)
  355. def fwd(x):
  356. return x.to("cpu0:0")
  357. o = fwd(x)
  358. check_pygraph_dump(fwd, [x], [o])
  359. def test_argsort():
  360. @trace(symbolic=True, capture_as_const=True)
  361. def fwd(data):
  362. return F.argsort(data, True)
  363. data = Tensor([1.0, 2.0, 3.0, 5.0])
  364. result = fwd(data)
  365. check_pygraph_dump(fwd, [data], [result])
  366. def test_argmax_min():
  367. @trace(symbolic=True, capture_as_const=True)
  368. def fwd(data):
  369. return F.argmax(data), F.argmin(data)
  370. data = Tensor(np.random.random((10, 10)))
  371. result = fwd(data)
  372. check_pygraph_dump(fwd, [data], result)
  373. def test_condtake():
  374. mask = Tensor(np.array([[True, False], [False, True]], dtype=np.bool_))
  375. x = Tensor(np.array([[1, np.inf], [np.nan, 4]], dtype=np.float32))
  376. @trace(symbolic=True, capture_as_const=True)
  377. def fwd(mask, x):
  378. v, index = F.cond_take(mask, x)
  379. return v, index
  380. v, index = fwd(mask, x)
  381. check_pygraph_dump(fwd, [mask, x], [v, index])
  382. def test_topk():
  383. x = Tensor(np.array([2, 4, 6, 8, 7, 5, 3, 1], dtype=np.float32))
  384. @trace(symbolic=True, capture_as_const=True)
  385. def fwd(x):
  386. top, indices = F.topk(x, 5)
  387. return top, indices
  388. top, indices = fwd(x)
  389. check_pygraph_dump(fwd, [x], [top, indices])
  390. def test_random():
  391. @trace(symbolic=True, capture_as_const=True)
  392. def fwd():
  393. x = rand.uniform(size=(2, 2))
  394. y = rand.normal(size=(1, 3, 3, 3))
  395. return x, y
  396. x, y = fwd()
  397. check_pygraph_dump(fwd, [], [x, y])
  398. def test_tensor_gen():
  399. @trace(symbolic=True, capture_as_const=True)
  400. def fwd():
  401. a = F.linspace(3, 10, 3, device=Device("xpux").to_c())
  402. b = F.eye(3, device=Device("xpux").to_c())
  403. return a, b
  404. a, b = fwd()
  405. check_pygraph_dump(fwd, [], [a, b])
  406. def test_getvarshape():
  407. op = builtin.GetVarShape(axis=1)
  408. @trace(symbolic=True, capture_as_const=True)
  409. def fwd(data):
  410. return apply(op, data)[0]
  411. data = Tensor(np.random.random((1, 2, 3, 4)))
  412. result = fwd(data)
  413. check_pygraph_dump(fwd, [data], [result])
  414. def test_concat():
  415. @trace(symbolic=True, capture_as_const=True)
  416. def fwd(data1, data2):
  417. return F.concat([data1, data2], axis=1)
  418. x = Tensor(np.random.random((2, 3)))
  419. y = Tensor(np.random.random((2, 5)))
  420. result = fwd(x, y)
  421. check_pygraph_dump(fwd, [x, y], [result])
  422. def test_broadcast():
  423. inp = Tensor([[1], [2], [3], [4]])
  424. @trace(symbolic=True, capture_as_const=True)
  425. def fwd(inp):
  426. return F.broadcast_to(inp, (4, 4))
  427. out = fwd(inp)
  428. check_pygraph_dump(fwd, [inp], [out])
  429. def test_identity():
  430. @trace(symbolic=True, capture_as_const=True)
  431. def fwd(data):
  432. return F.copy(data)
  433. data = Tensor([1.0, 2.0])
  434. result = fwd(data)
  435. check_pygraph_dump(fwd, [data], [result])
  436. @pytest.mark.skip(reason="advance indexing trace error")
  437. def test_nms():
  438. x = np.zeros((100, 4))
  439. np.random.seed(42)
  440. x[:, :2] = np.random.rand(100, 2) * 20
  441. x[:, 2:] = np.random.rand(100, 2) * 20 + 100
  442. scores = Tensor(np.random.rand(100))
  443. inp = Tensor(x)
  444. @trace(symbolic=True, capture_as_const=True)
  445. def fwd(inp, scores):
  446. return F.nn.nms(inp, scores, iou_thresh=0.7, max_output=3)
  447. result = fwd(inp, scores)
  448. check_pygraph_dump(fwd, [inp, scores], [result])
  449. def test_dimshuffle():
  450. inp = Tensor([1, 2, 3, 4])
  451. @trace(symbolic=True, capture_as_const=True)
  452. def fwd(inp):
  453. return inp.T
  454. out = fwd(inp)
  455. check_pygraph_dump(fwd, [inp], [out])
  456. def test_reshape():
  457. @trace(symbolic=True, capture_as_const=True)
  458. def fwd(data):
  459. return data.reshape((1, 8))
  460. data = Tensor(np.random.random((1, 2, 2, 2)))
  461. result = fwd(data)
  462. check_pygraph_dump(fwd, [data], [result])
  463. def test_add_remove_axis():
  464. @trace(symbolic=True, capture_as_const=True)
  465. def fwd(data):
  466. x = F.expand_dims(data, [0, 0])
  467. y = F.squeeze(x, 0)
  468. return y
  469. data = Tensor([1.0, 2.0])
  470. result = fwd(data)
  471. check_pygraph_dump(fwd, [data], [result])
  472. @pytest.mark.parametrize("mode", ["get", "set", "inc"])
  473. def test_subtensor(mode):
  474. items = [[0, True, True, True, False], [1, False, False, False, True]]
  475. data = [Tensor(np.random.random((5, 5))), Tensor(np.random.random(2))]
  476. if mode == "get":
  477. op = builtin.Subtensor(items)
  478. data = data[:1]
  479. if mode == "set":
  480. op = builtin.SetSubtensor(items)
  481. if mode == "inc":
  482. op = builtin.IncrSubtensor(items)
  483. tensors = [Tensor(0), Tensor(4), Tensor(2), Tensor(3)]
  484. @trace(symbolic=True, capture_as_const=True)
  485. def fwd(*tensors):
  486. return apply(op, *tensors)[0]
  487. result = fwd(*data, *tensors)
  488. check_pygraph_dump(fwd, data + tensors, [result])
  489. @pytest.mark.parametrize("mode", ["get", "set", "inc"])
  490. def test_advance_indexing(mode):
  491. items = [[0, False, False, False, True]]
  492. tensors = [Tensor([0, 4, 2])]
  493. data = [Tensor(np.random.random((5, 5))), Tensor(np.random.random((3, 5)))]
  494. if mode == "get":
  495. op = builtin.IndexingMultiAxisVec(items)
  496. data = data[:1]
  497. if mode == "set":
  498. op = builtin.IndexingSetMultiAxisVec(items)
  499. if mode == "inc":
  500. op = builtin.IndexingIncrMultiAxisVec(items)
  501. @trace(symbolic=True, capture_as_const=True)
  502. def fwd(*tensors):
  503. return apply(op, *tensors)[0]
  504. result = fwd(*data, *tensors)
  505. check_pygraph_dump(fwd, data + tensors, [result])
  506. @pytest.mark.parametrize("mode", ["get", "set", "inc"])
  507. def test_mesh_indexing(mode):
  508. items = [[0, True, True, True, False], [1, False, False, False, True]]
  509. tensors = [Tensor(0), Tensor(5), Tensor(2), Tensor([1, 3])]
  510. data = [Tensor(np.random.random((5, 5))), Tensor(np.random.random((3, 2)))]
  511. if mode == "get":
  512. op = builtin.IndexingMultiAxisVec(items)
  513. data = data[:1]
  514. if mode == "set":
  515. op = builtin.IndexingSetMultiAxisVec(items)
  516. if mode == "inc":
  517. op = builtin.IndexingIncrMultiAxisVec(items)
  518. @trace(symbolic=True, capture_as_const=True)
  519. def fwd(*tensors):
  520. return apply(op, *tensors)[0]
  521. result = fwd(*data, *tensors)
  522. check_pygraph_dump(fwd, data + tensors, [result])
  523. @pytest.mark.parametrize("mode", ["get", "set", "inc"])
  524. def test_batch_mesh_indexing(mode):
  525. items = [[1, False, False, False, True], [2, False, False, False, True]]
  526. tensors = [Tensor([[0, 2], [0, 2]]), Tensor([[0, 1, 2], [1, 2, 3]])]
  527. data = [Tensor(np.random.random((2, 3, 4))), Tensor(np.random.random((2, 2, 3)))]
  528. if mode == "get":
  529. op = builtin.BatchedMeshIndexing(items)
  530. data = data[:1]
  531. if mode == "set":
  532. op = builtin.BatchedSetMeshIndexing(items)
  533. if mode == "inc":
  534. op = builtin.BatchedIncrMeshIndexing(items)
  535. @trace(symbolic=True, capture_as_const=True)
  536. def fwd(*tensors):
  537. return apply(op, *tensors)[0]
  538. result = fwd(*data, *tensors)
  539. check_pygraph_dump(fwd, data + tensors, [result])
  540. @pytest.mark.skip(reason="tmp skip")
  541. def test_assert_equal():
  542. g = G.Graph()
  543. inp1 = g.make_h2d(dtype=np.float32, device="xpux")
  544. inp2 = g.make_h2d(dtype=np.float32, device="xpux")
  545. op = builtin.AssertEqual(maxerr=1e-5)
  546. out = G.apply_normal_varnode(op, inp1._node, inp2._node)[0]
  547. g.compile(out)
  548. file = io.BytesIO()
  549. out_model = G.dump_graph([out])
  550. file.write(out_model[0])
  551. file.seek(0)
  552. net = Net.load(file)
  553. dump_file = io.BytesIO()
  554. net.dump(dump_file)
  555. dump_file.seek(0)
  556. g = GraphInference(dump_file)
  557. g.run(np.array([1.0, 2.0]), np.array([1.0, 2.0]))
  558. def test_elemwise_multitype():
  559. op = builtin.ElemwiseMultiType(mode="qadd", dtype=dtype.qint32(2.0))
  560. @trace(symbolic=True, capture_as_const=True)
  561. def fwd(x, y):
  562. return apply(op, x, y)[0]
  563. x = Tensor(np.random.random(10) * 10, dtype=dtype.qint8(2.0))
  564. y = Tensor(np.random.random(10) * 10, dtype=dtype.qint8(2.0))
  565. result = fwd(x, y)
  566. check_pygraph_dump(fwd, [x, y], [result])
  567. def test_cvtcolor():
  568. inp = np.random.randn(3, 3, 3, 3).astype(np.float32)
  569. x = Tensor(inp)
  570. @trace(symbolic=True, capture_as_const=True)
  571. def fwd(inp):
  572. return F.vision.cvt_color(inp, mode="RGB2GRAY")
  573. result = fwd(x)
  574. check_pygraph_dump(fwd, [x], [result])