You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_network_node.py 20 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730
  1. import io
  2. import os
  3. import platform
  4. import numpy as np
  5. import pytest
  6. import megengine.core.tensor.dtype as dtype
  7. import megengine.core.tensor.megbrain_graph as G
  8. import megengine.functional as F
  9. import megengine.module as M
  10. import megengine.random as rand
  11. from megengine.core._imperative_rt.core2 import apply
  12. from megengine.core._wrap import Device
  13. from megengine.core.ops import builtin
  14. from megengine.device import is_cuda_available
  15. from megengine.distributed.helper import get_device_count_by_fork
  16. from megengine.functional.external import tensorrt_runtime_opr
  17. from megengine.jit.tracing import trace
  18. from megengine.tensor import Tensor
  19. from megengine.utils.comp_graph_tools import GraphInference
  20. from megengine.utils.network import Network as Net
  21. def check_pygraph_dump(trace_func, inp_data, expect_results):
  22. orig_model = io.BytesIO()
  23. inp_size = len(inp_data)
  24. out_size = len(expect_results)
  25. arg_names = ["arg_{}".format(i) for i in range(inp_size)]
  26. output_names = ["out_{}".format(i) for i in range(out_size)]
  27. trace_func.dump(
  28. orig_model,
  29. arg_names=arg_names,
  30. output_names=output_names,
  31. optimize_for_inference=False,
  32. )
  33. orig_model.seek(0)
  34. net = Net.load(orig_model)
  35. file = io.BytesIO()
  36. net.dump(file, optimize_for_inference=False)
  37. file.seek(0)
  38. graph = GraphInference(file)
  39. inp_dict = dict([(arg_names[i], inp_data[i].numpy()) for i in range(inp_size)])
  40. results = graph.run(inp_dict=inp_dict)
  41. for ind, tensor in enumerate(expect_results):
  42. np.testing.assert_equal(tensor.numpy(), results[output_names[ind]])
  43. assert tensor.dtype == results[output_names[ind]].dtype
  44. def test_elemwise():
  45. @trace(symbolic=True, capture_as_const=True)
  46. def fwd(x, y):
  47. z1 = x * y
  48. z2 = x + y
  49. z3 = z1 / z2
  50. z3 = z3 ** 3
  51. return z3
  52. x = Tensor([1.0, 2.0])
  53. y = Tensor([3.0, 5.0])
  54. result = fwd(x, y)
  55. check_pygraph_dump(fwd, [x, y], [result])
  56. def test_reduce():
  57. @trace(symbolic=True, capture_as_const=True)
  58. def fwd(data):
  59. x = data.sum(axis=2)
  60. x = x.mean(axis=1)
  61. return x
  62. data = Tensor(np.random.random((1, 32, 32)))
  63. result = fwd(data)
  64. check_pygraph_dump(fwd, [data], [result])
  65. def test_typecvt():
  66. @trace(symbolic=True, capture_as_const=True)
  67. def fwd(data):
  68. return data.astype(dtype.qint8(0.8))
  69. x = Tensor(np.random.random((2, 3)) * 255)
  70. result = fwd(x)
  71. check_pygraph_dump(fwd, [x], [result])
  72. def test_matinv():
  73. @trace(symbolic=True, capture_as_const=True)
  74. def fwd(data):
  75. return F.matinv(data)
  76. data = Tensor(np.random.random((5, 5)))
  77. result = fwd(data)
  78. check_pygraph_dump(fwd, [data], [result])
  79. def test_matmul():
  80. @trace(symbolic=True, capture_as_const=True)
  81. def fwd(data1, data2):
  82. return F.matmul(data1, data2)
  83. data1 = Tensor(np.random.random((32, 64)))
  84. data2 = Tensor(np.random.random((64, 16)))
  85. result = fwd(data1, data2)
  86. check_pygraph_dump(fwd, [data1, data2], [result])
  87. def test_batchmatmul():
  88. @trace(symbolic=True, capture_as_const=True)
  89. def fwd(x, y):
  90. return F.matmul(x, y)
  91. x = Tensor(np.random.random((3, 3, 5)))
  92. y = Tensor(np.random.random((3, 5, 3)))
  93. result = fwd(x, y)
  94. check_pygraph_dump(fwd, [x, y], [result])
  95. def test_dot():
  96. @trace(symbolic=True, capture_as_const=True)
  97. def fwd(x, y):
  98. return F.dot(x, y)
  99. x = Tensor([1.0, 2.0, 3.0])
  100. y = Tensor([3.0, 4.0, 5.0])
  101. result = fwd(x, y)
  102. check_pygraph_dump(fwd, [x, y], [result])
  103. def test_svd():
  104. @trace(symbolic=True, capture_as_const=True)
  105. def fwd(data):
  106. _, out, _ = F.svd(data)
  107. return out
  108. input = Tensor(np.random.random((1, 1, 3, 3)))
  109. result = fwd(input)
  110. check_pygraph_dump(fwd, [input], [result])
  111. def test_conv():
  112. conv = M.Conv2d(3, 32, 3)
  113. @trace(symbolic=True, capture_as_const=True)
  114. def fwd(data):
  115. return conv(data)
  116. data = Tensor(np.random.random((1, 3, 32, 32)))
  117. result = fwd(data)
  118. check_pygraph_dump(fwd, [data], [result])
  119. def test_deformable_conv():
  120. if not is_cuda_available():
  121. return
  122. conv = M.DeformableConv2d(3, 32, 3)
  123. @trace(symbolic=True, capture_as_const=True)
  124. def fwd(data, offset, mask):
  125. return conv(data, offset, mask)
  126. data = Tensor(np.random.random((1, 3, 32, 32)))
  127. offset = Tensor(np.ones((32, 3 * 3 * 2, 30, 30)).astype("int32") * 5)
  128. mask = Tensor(np.ones((32, 3 * 3, 30, 30)).astype("int32"))
  129. out = fwd(data, offset, mask)
  130. check_pygraph_dump(fwd, [data, offset, mask], [out])
  131. def test_convtranspose():
  132. deconv = M.ConvTranspose2d(32, 32, 3)
  133. @trace(symbolic=True, capture_as_const=True)
  134. def fwd(data):
  135. return deconv(data)
  136. data = Tensor(np.random.random((1, 32, 32, 32)))
  137. result = fwd(data)
  138. check_pygraph_dump(fwd, [data], [result])
  139. @pytest.mark.skip(reason="pytest aborted")
  140. def test_grouplocal():
  141. n = M.LocalConv2d(3, 32, 32, 32, 3)
  142. @trace(symbolic=True, capture_as_const=True)
  143. def fwd(data):
  144. return n(data)
  145. input = Tensor(np.random.random((1, 3, 32, 32)))
  146. result = fwd(input)
  147. check_pygraph_dump(fwd, [input], [result])
  148. def test_pooling():
  149. @trace(symbolic=True, capture_as_const=True)
  150. def fwd(data):
  151. out = F.max_pool2d(data, 2, 2)
  152. out = F.avg_pool2d(out, 2, 2)
  153. return out
  154. data = Tensor(np.random.random((1, 3, 64, 64)))
  155. result = fwd(data)
  156. check_pygraph_dump(fwd, [data], [result])
  157. def test_adaptivepooling():
  158. pool1 = M.AdaptiveMaxPool2d((2, 2))
  159. pool2 = M.AdaptiveAvgPool2d((2, 2))
  160. @trace(symbolic=True, capture_as_const=True)
  161. def fwd(data):
  162. out = pool1(data)
  163. out = pool2(out)
  164. return out
  165. input = Tensor(np.random.random((1, 3, 32, 32)))
  166. result = fwd(input)
  167. check_pygraph_dump(fwd, [input], [result])
  168. def test_roipooling():
  169. inp = Tensor(np.random.random((1, 1, 128, 128)))
  170. rois = Tensor(np.random.random((4, 5)))
  171. @trace(symbolic=True, capture_as_const=True)
  172. def fwd(inp, rois):
  173. return F.vision.roi_pooling(inp, rois, (2, 2), scale=2.0)
  174. output = fwd(inp, rois)
  175. check_pygraph_dump(fwd, [inp, rois], [output])
  176. def test_deformable_ps_roi_pooling():
  177. inp = Tensor(np.random.random((1, 256, 64, 64)).astype("float32"))
  178. rois = Tensor(np.random.random((1, 5)).astype("float32"))
  179. trans = Tensor(np.random.random((24, 2, 7, 7)).astype("float32"))
  180. pooled_h = 7
  181. pooled_w = 7
  182. sample_per_part = 4
  183. no_trans = False
  184. part_size = 7
  185. spatial_scale = 1.0 / 64
  186. trans_std = 0.1
  187. @trace(symbolic=True, capture_as_const=True)
  188. def fwd(inp, rois, trans):
  189. y = F.deformable_psroi_pooling(
  190. inp,
  191. rois,
  192. trans,
  193. no_trans,
  194. part_size,
  195. pooled_h,
  196. pooled_w,
  197. sample_per_part,
  198. spatial_scale,
  199. trans_std,
  200. )
  201. return y
  202. result = fwd(inp, rois, trans)
  203. check_pygraph_dump(fwd, [inp, rois, trans], [result])
  204. @pytest.mark.skipif(
  205. get_device_count_by_fork("gpu") > 0,
  206. reason="does not support int8 when gpu compute capability less than 6.1",
  207. )
  208. def test_convbias():
  209. @trace(symbolic=True, capture_as_const=True)
  210. def fwd(inp, weight, bias):
  211. return F.quantized.conv_bias_activation(
  212. inp, weight, bias, dtype=dtype.qint8(scale=1.0), nonlinear_mode="RELU"
  213. )
  214. inp = Tensor(np.random.random((1, 3, 64, 64)), dtype=dtype.qint8(scale=1.0))
  215. weight = Tensor(np.random.random((32, 3, 3, 3)), dtype=dtype.qint8(scale=1.0))
  216. bias = Tensor(np.random.random((1, 32, 1, 1)), dtype=dtype.qint32(scale=1.0))
  217. result = fwd(inp, weight, bias)
  218. check_pygraph_dump(fwd, [inp, weight, bias], [result])
  219. def test_batch_convbias():
  220. if is_cuda_available():
  221. return
  222. @trace(symbolic=True, capture_as_const=True)
  223. def fwd(inp, weight, bias):
  224. return F.quantized.batch_conv_bias_activation(
  225. inp, weight, bias, dtype=dtype.qint8(scale=1.0), nonlinear_mode="RELU"
  226. )
  227. inp = Tensor(np.random.random((1, 3, 64, 64)), dtype=dtype.qint8(scale=1.0))
  228. weight = Tensor(np.random.random((1, 32, 3, 3, 3)), dtype=dtype.qint8(scale=1.0))
  229. bias = Tensor(np.random.random((1, 32, 1, 1)), dtype=dtype.qint32(scale=1.0))
  230. result = fwd(inp, weight, bias)
  231. check_pygraph_dump(fwd, [inp, weight, bias], [result])
  232. def test_batchnorm():
  233. bn = M.BatchNorm2d(32)
  234. bn.eval()
  235. @trace(symbolic=True, capture_as_const=True)
  236. def fwd(data):
  237. return bn(data)
  238. data = Tensor(np.random.random((1, 32, 32, 32)))
  239. result = fwd(data)
  240. check_pygraph_dump(fwd, [data], [result])
  241. def test_roialign():
  242. inp = Tensor(np.random.randn(1, 1, 128, 128))
  243. rois = Tensor(np.random.random((4, 5)))
  244. @trace(symbolic=True, capture_as_const=True)
  245. def fwd(inp, rois):
  246. return F.vision.roi_align(inp, rois, (2, 2))
  247. output = fwd(inp, rois)
  248. check_pygraph_dump(fwd, [inp, rois], [output])
  249. def test_warpperspective():
  250. inp_shape = (1, 1, 4, 4)
  251. x = Tensor(np.arange(16, dtype=np.float32).reshape(inp_shape))
  252. M_shape = (1, 3, 3)
  253. # M defines a translation: dst(1, 1, h, w) = rst(1, 1, h+1, w+1)
  254. M = Tensor(
  255. np.array(
  256. [[1.0, 0.0, 1.0], [0.0, 1.0, 1.0], [0.0, 0.0, 1.0]], dtype=np.float32
  257. ).reshape(M_shape)
  258. )
  259. @trace(symbolic=True, capture_as_const=True)
  260. def fwd(x, M):
  261. return F.vision.warp_perspective(x, M, (2, 2))
  262. result = fwd(x, M)
  263. check_pygraph_dump(fwd, [x, M], [result])
  264. def test_warpaffine():
  265. inp_shape = (1, 3, 3, 3)
  266. x = Tensor(np.arange(27, dtype=np.float32).reshape(inp_shape))
  267. weightv = Tensor([[[1.26666667, 0.6, -83.33333333], [-0.33333333, 1, 66.66666667]]])
  268. @trace(symbolic=True, capture_as_const=True)
  269. def fwd(x, weightv):
  270. return F.vision.warp_affine(x, weightv, (2, 2), border_mode="WRAP")
  271. outp = fwd(x, weightv)
  272. check_pygraph_dump(fwd, [x, weightv], [outp])
  273. def test_remap():
  274. inp_shape = (1, 1, 4, 4)
  275. inp = Tensor(np.arange(16, dtype=np.float32).reshape(inp_shape))
  276. map_xy_shape = (1, 2, 2, 2)
  277. map_xy = Tensor(
  278. np.array(
  279. [[[1.0, 0.0], [0.0, 1.0]], [[0.0, 1.0], [0.0, 1.0]]], dtype=np.float32
  280. ).reshape(map_xy_shape)
  281. )
  282. @trace(symbolic=True, capture_as_const=True)
  283. def fwd(inp, map_xy):
  284. return F.vision.remap(inp, map_xy)
  285. out = fwd(inp, map_xy)
  286. check_pygraph_dump(fwd, [inp, map_xy], [out])
  287. def test_resize():
  288. x = Tensor(np.random.randn(10, 3, 32, 32))
  289. @trace(symbolic=True, capture_as_const=True)
  290. def fwd(x):
  291. return F.vision.interpolate(x, size=(16, 16), mode="BILINEAR")
  292. out = fwd(x)
  293. check_pygraph_dump(fwd, [x], [out])
  294. def test_index_onehot():
  295. src = Tensor([[1.0, 2.0]])
  296. index = Tensor([0])
  297. @trace(symbolic=True, capture_as_const=True)
  298. def fwd(src, index):
  299. return F.indexing_one_hot(src, index)
  300. out = fwd(src, index)
  301. check_pygraph_dump(fwd, [src, index], [out])
  302. def test_set_onehot():
  303. x = Tensor(np.arange(1, 4, dtype=np.int32))
  304. @trace(symbolic=True, capture_as_const=True)
  305. def fwd(x):
  306. return F.one_hot(x, num_classes=4)
  307. out = fwd(x)
  308. check_pygraph_dump(fwd, [x], [out])
  309. def test_copy():
  310. x = Tensor([1, 2, 3])
  311. @trace(symbolic=True, capture_as_const=True)
  312. def fwd(x):
  313. return x.to("cpu0:0")
  314. o = fwd(x)
  315. check_pygraph_dump(fwd, [x], [o])
  316. def test_argsort():
  317. @trace(symbolic=True, capture_as_const=True)
  318. def fwd(data):
  319. return F.argsort(data, True)
  320. data = Tensor([1.0, 2.0, 3.0, 5.0])
  321. result = fwd(data)
  322. check_pygraph_dump(fwd, [data], [result])
  323. def test_argmax_min():
  324. @trace(symbolic=True, capture_as_const=True)
  325. def fwd(data):
  326. return F.argmax(data), F.argmin(data)
  327. data = Tensor(np.random.random((10, 10)))
  328. result = fwd(data)
  329. check_pygraph_dump(fwd, [data], result)
  330. def test_condtake():
  331. mask = Tensor(np.array([[True, False], [False, True]], dtype=np.bool_))
  332. x = Tensor(np.array([[1, np.inf], [np.nan, 4]], dtype=np.float32))
  333. @trace(symbolic=True, capture_as_const=True)
  334. def fwd(mask, x):
  335. v, index = F.cond_take(mask, x)
  336. return v, index
  337. v, index = fwd(mask, x)
  338. check_pygraph_dump(fwd, [mask, x], [v, index])
  339. def test_topk():
  340. x = Tensor(np.array([2, 4, 6, 8, 7, 5, 3, 1], dtype=np.float32))
  341. @trace(symbolic=True, capture_as_const=True)
  342. def fwd(x):
  343. top, indices = F.topk(x, 5)
  344. return top, indices
  345. top, indices = fwd(x)
  346. check_pygraph_dump(fwd, [x], [top, indices])
  347. def test_nvof():
  348. if not is_cuda_available():
  349. return
  350. src_shape = (4, 5, 224, 224, 4)
  351. src = np.random.randint(0, 255, src_shape).astype("uint8")
  352. src = Tensor(src)
  353. @trace(symbolic=True, capture_as_const=True)
  354. def fwd(src):
  355. return F.nn.nvof(src, precision=1)
  356. result = fwd(src)
  357. check_pygraph_dump(fwd, [src], [result])
  358. def test_random():
  359. @trace(symbolic=True, capture_as_const=True)
  360. def fwd():
  361. x = rand.uniform(size=(2, 2))
  362. y = rand.normal(size=(1, 3, 3, 3))
  363. return x, y
  364. x, y = fwd()
  365. check_pygraph_dump(fwd, [], [x, y])
  366. def test_tensor_gen():
  367. @trace(symbolic=True, capture_as_const=True)
  368. def fwd():
  369. a = F.linspace(3, 10, 3, device=Device("xpux").to_c())
  370. b = F.eye(3, device=Device("xpux").to_c())
  371. return a, b
  372. a, b = fwd()
  373. check_pygraph_dump(fwd, [], [a, b])
  374. def test_getvarshape():
  375. op = builtin.GetVarShape(axis=1)
  376. @trace(symbolic=True, capture_as_const=True)
  377. def fwd(data):
  378. return apply(op, data)[0]
  379. data = Tensor(np.random.random((1, 2, 3, 4)))
  380. result = fwd(data)
  381. check_pygraph_dump(fwd, [data], [result])
  382. def test_concat():
  383. @trace(symbolic=True, capture_as_const=True)
  384. def fwd(data1, data2):
  385. return F.concat([data1, data2], axis=1)
  386. x = Tensor(np.random.random((2, 3)))
  387. y = Tensor(np.random.random((2, 5)))
  388. result = fwd(x, y)
  389. check_pygraph_dump(fwd, [x, y], [result])
  390. def test_broadcast():
  391. inp = Tensor([[1], [2], [3], [4]])
  392. @trace(symbolic=True, capture_as_const=True)
  393. def fwd(inp):
  394. return F.broadcast_to(inp, (4, 4))
  395. out = fwd(inp)
  396. check_pygraph_dump(fwd, [inp], [out])
  397. def test_identity():
  398. @trace(symbolic=True, capture_as_const=True)
  399. def fwd(data):
  400. return F.copy(data)
  401. data = Tensor([1.0, 2.0])
  402. result = fwd(data)
  403. check_pygraph_dump(fwd, [data], [result])
  404. @pytest.mark.skip(reason="advance indexing trace error")
  405. def test_nms():
  406. x = np.zeros((100, 4))
  407. np.random.seed(42)
  408. x[:, :2] = np.random.rand(100, 2) * 20
  409. x[:, 2:] = np.random.rand(100, 2) * 20 + 100
  410. scores = Tensor(np.random.rand(100))
  411. inp = Tensor(x)
  412. @trace(symbolic=True, capture_as_const=True)
  413. def fwd(inp, scores):
  414. return F.nn.nms(inp, scores, iou_thresh=0.7, max_output=3)
  415. result = fwd(inp, scores)
  416. check_pygraph_dump(fwd, [inp, scores], [result])
  417. def test_dimshuffle():
  418. inp = Tensor([1, 2, 3, 4])
  419. @trace(symbolic=True, capture_as_const=True)
  420. def fwd(inp):
  421. return inp.T
  422. out = fwd(inp)
  423. check_pygraph_dump(fwd, [inp], [out])
  424. def test_reshape():
  425. @trace(symbolic=True, capture_as_const=True)
  426. def fwd(data):
  427. return data.reshape((1, 8))
  428. data = Tensor(np.random.random((1, 2, 2, 2)))
  429. result = fwd(data)
  430. check_pygraph_dump(fwd, [data], [result])
  431. def test_add_remove_axis():
  432. @trace(symbolic=True, capture_as_const=True)
  433. def fwd(data):
  434. x = F.expand_dims(data, [0, 0])
  435. y = F.squeeze(x, 0)
  436. return y
  437. data = Tensor([1.0, 2.0])
  438. result = fwd(data)
  439. check_pygraph_dump(fwd, [data], [result])
  440. @pytest.mark.parametrize("mode", ["get", "set", "inc"])
  441. def test_subtensor(mode):
  442. items = [[0, True, True, True, False], [1, False, False, False, True]]
  443. data = [Tensor(np.random.random((5, 5))), Tensor(np.random.random(2))]
  444. if mode == "get":
  445. op = builtin.Subtensor(items)
  446. data = data[:1]
  447. if mode == "set":
  448. op = builtin.SetSubtensor(items)
  449. if mode == "inc":
  450. op = builtin.IncrSubtensor(items)
  451. tensors = [Tensor(0), Tensor(4), Tensor(2), Tensor(3)]
  452. @trace(symbolic=True, capture_as_const=True)
  453. def fwd(*tensors):
  454. return apply(op, *tensors)[0]
  455. result = fwd(*data, *tensors)
  456. check_pygraph_dump(fwd, data + tensors, [result])
  457. @pytest.mark.parametrize("mode", ["get", "set", "inc"])
  458. def test_advance_indexing(mode):
  459. items = [[0, False, False, False, True]]
  460. tensors = [Tensor([0, 4, 2])]
  461. data = [Tensor(np.random.random((5, 5))), Tensor(np.random.random((3, 5)))]
  462. if mode == "get":
  463. op = builtin.IndexingMultiAxisVec(items)
  464. data = data[:1]
  465. if mode == "set":
  466. op = builtin.IndexingSetMultiAxisVec(items)
  467. if mode == "inc":
  468. op = builtin.IndexingIncrMultiAxisVec(items)
  469. @trace(symbolic=True, capture_as_const=True)
  470. def fwd(*tensors):
  471. return apply(op, *tensors)[0]
  472. result = fwd(*data, *tensors)
  473. check_pygraph_dump(fwd, data + tensors, [result])
  474. @pytest.mark.parametrize("mode", ["get", "set", "inc"])
  475. def test_mesh_indexing(mode):
  476. items = [[0, True, True, True, False], [1, False, False, False, True]]
  477. tensors = [Tensor(0), Tensor(5), Tensor(2), Tensor([1, 3])]
  478. data = [Tensor(np.random.random((5, 5))), Tensor(np.random.random((3, 2)))]
  479. if mode == "get":
  480. op = builtin.IndexingMultiAxisVec(items)
  481. data = data[:1]
  482. if mode == "set":
  483. op = builtin.IndexingSetMultiAxisVec(items)
  484. if mode == "inc":
  485. op = builtin.IndexingIncrMultiAxisVec(items)
  486. @trace(symbolic=True, capture_as_const=True)
  487. def fwd(*tensors):
  488. return apply(op, *tensors)[0]
  489. result = fwd(*data, *tensors)
  490. check_pygraph_dump(fwd, data + tensors, [result])
  491. @pytest.mark.parametrize("mode", ["get", "set", "inc"])
  492. def test_batch_mesh_indexing(mode):
  493. items = [[1, False, False, False, True], [2, False, False, False, True]]
  494. tensors = [Tensor([[0, 2], [0, 2]]), Tensor([[0, 1, 2], [1, 2, 3]])]
  495. data = [Tensor(np.random.random((2, 3, 4))), Tensor(np.random.random((2, 2, 3)))]
  496. if mode == "get":
  497. op = builtin.BatchedMeshIndexing(items)
  498. data = data[:1]
  499. if mode == "set":
  500. op = builtin.BatchedSetMeshIndexing(items)
  501. if mode == "inc":
  502. op = builtin.BatchedIncrMeshIndexing(items)
  503. @trace(symbolic=True, capture_as_const=True)
  504. def fwd(*tensors):
  505. return apply(op, *tensors)[0]
  506. result = fwd(*data, *tensors)
  507. check_pygraph_dump(fwd, data + tensors, [result])
  508. @pytest.mark.skip(reason="tmp skip")
  509. def test_assert_equal():
  510. g = G.Graph()
  511. inp1 = g.make_h2d(dtype=np.float32, device="xpux")
  512. inp2 = g.make_h2d(dtype=np.float32, device="xpux")
  513. op = builtin.AssertEqual(maxerr=1e-5)
  514. out = G.apply_normal_varnode(op, inp1._node, inp2._node)[0]
  515. print(out)
  516. g.compile(out)
  517. file = io.BytesIO()
  518. out_model = G.dump_graph([out])
  519. file.write(out_model[0])
  520. file.seek(0)
  521. net = Net.load(file)
  522. dump_file = io.BytesIO()
  523. net.dump(dump_file)
  524. dump_file.seek(0)
  525. g = GraphInference(dump_file)
  526. g.run(np.array([1.0, 2.0]), np.array([1.0, 2.0]))
  527. def test_elemwise_multitype():
  528. op = builtin.ElemwiseMultiType(mode="QADD", dtype=dtype.qint32(2.0))
  529. @trace(symbolic=True, capture_as_const=True)
  530. def fwd(x, y):
  531. return apply(op, x, y)[0]
  532. x = Tensor(np.random.random(10) * 10, dtype=dtype.qint8(2.0))
  533. y = Tensor(np.random.random(10) * 10, dtype=dtype.qint8(2.0))
  534. result = fwd(x, y)
  535. check_pygraph_dump(fwd, [x, y], [result])
  536. def test_cvtcolor():
  537. inp = np.random.randn(3, 3, 3, 3).astype(np.float32)
  538. x = Tensor(inp)
  539. @trace(symbolic=True, capture_as_const=True)
  540. def fwd(inp):
  541. return F.vision.cvt_color(inp, mode="RGB2GRAY")
  542. result = fwd(x)
  543. check_pygraph_dump(fwd, [x], [result])

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台