You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_tensor.py 22 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770
  1. # -*- coding: utf-8 -*-
  2. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  3. #
  4. # Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  5. #
  6. # Unless required by applicable law or agreed to in writing,
  7. # software distributed under the License is distributed on an
  8. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  9. import os
  10. import platform
  11. import numpy as np
  12. import pytest
  13. from utils import get_var_value, make_tensor, opr_test
  14. import megengine.functional as F
  15. from megengine import tensor
  16. from megengine.core._trace_option import use_symbolic_shape
  17. from megengine.core.tensor import megbrain_graph as G
  18. from megengine.core.tensor.utils import astensor1d
  19. from megengine.distributed.helper import get_device_count_by_fork
  20. from megengine.jit import trace
  21. from megengine.utils.network import Network, set_symbolic_shape
  22. from megengine.utils.network_node import VarNode
  23. def test_eye():
  24. dtype = np.float32
  25. cases = [{"input": [10, 20]}, {"input": [30]}]
  26. for case in cases:
  27. np.testing.assert_allclose(
  28. F.eye(case["input"], dtype=dtype).numpy(),
  29. np.eye(*case["input"]).astype(dtype),
  30. )
  31. np.testing.assert_allclose(
  32. F.eye(*case["input"], dtype=dtype).numpy(),
  33. np.eye(*case["input"]).astype(dtype),
  34. )
  35. np.testing.assert_allclose(
  36. F.eye(tensor(case["input"]), dtype=dtype).numpy(),
  37. np.eye(*case["input"]).astype(dtype),
  38. )
  39. @pytest.mark.parametrize("is_varnode", [True, False])
  40. def test_concat(is_varnode):
  41. if is_varnode:
  42. network = Network()
  43. else:
  44. network = None
  45. def get_data_shape(length: int):
  46. return (length, 2, 3)
  47. data1 = np.random.random(get_data_shape(5)).astype("float32")
  48. data2 = np.random.random(get_data_shape(6)).astype("float32")
  49. data3 = np.random.random(get_data_shape(7)).astype("float32")
  50. def run(data1, data2):
  51. return F.concat([data1, data2])
  52. cases = [{"input": [data1, data2]}, {"input": [data1, data3]}]
  53. opr_test(cases, run, ref_fn=lambda x, y: np.concatenate([x, y]), network=network)
  54. @pytest.mark.parametrize("is_varnode", [True, False])
  55. def test_condtake(is_varnode):
  56. if is_varnode:
  57. network = Network()
  58. else:
  59. network = None
  60. x = np.array([[1, 2, 3], [4, 5, 6]]).astype("float32")
  61. y = np.array([[True, False, True], [False, True, True]])
  62. xx = make_tensor(x, network)
  63. yy = make_tensor(y, network)
  64. val, idx = F.cond_take(yy, xx)
  65. if is_varnode:
  66. np.testing.assert_equal(get_var_value(val), x[y])
  67. np.testing.assert_equal(get_var_value(idx), np.where(y.reshape(-1))[0])
  68. else:
  69. np.testing.assert_equal(val.numpy(), x[y])
  70. np.testing.assert_equal(idx.numpy(), np.where(y.reshape(-1))[0])
  71. @pytest.mark.parametrize("is_varnode", [True, False])
  72. def test_concat_device(is_varnode):
  73. if is_varnode:
  74. network = Network()
  75. else:
  76. network = None
  77. data1 = make_tensor(np.random.random((3, 2, 2)).astype("float32"), network, "cpu0")
  78. data2 = make_tensor(np.random.random((2, 2, 2)).astype("float32"), network, "cpu1")
  79. out = F.concat([data1, data2], device="cpu0")
  80. assert str(out.device).split(":")[0] == "cpu0"
  81. @pytest.mark.parametrize("is_varnode", [True, False])
  82. def test_stack(is_varnode):
  83. if is_varnode:
  84. network = Network()
  85. else:
  86. network = None
  87. data1 = np.random.random((3, 2, 2)).astype("float32")
  88. data2 = np.random.random((3, 2, 2)).astype("float32")
  89. data3 = np.random.random((3, 2, 2)).astype("float32")
  90. cases = [{"input": [data1, data2]}, {"input": [data1, data3]}]
  91. for ai in range(3):
  92. def run(data1, data2):
  93. return F.stack([data1, data2], axis=ai)
  94. opr_test(
  95. cases, run, ref_fn=lambda x, y: np.stack([x, y], axis=ai), network=network
  96. )
  97. @pytest.mark.parametrize("is_varnode", [True, False])
  98. def test_split(is_varnode):
  99. if is_varnode:
  100. network = Network()
  101. saved_symbolic_shape = set_symbolic_shape(False)
  102. else:
  103. network = None
  104. data = np.random.random((2, 3, 4, 5)).astype(np.float32)
  105. inp = make_tensor(data, network)
  106. mge_out0 = F.split(inp, 2, axis=3)
  107. mge_out1 = F.split(inp, [3], axis=3)
  108. np_out = np.split(data, [3, 5], axis=3)
  109. assert len(mge_out0) == 2
  110. assert len(mge_out1) == 2
  111. np.testing.assert_equal(mge_out0[0].numpy(), np_out[0])
  112. np.testing.assert_equal(mge_out1[0].numpy(), np_out[0])
  113. np.testing.assert_equal(mge_out0[1].numpy(), np_out[1])
  114. np.testing.assert_equal(mge_out1[1].numpy(), np_out[1])
  115. try:
  116. F.split(inp, 4)
  117. assert False
  118. except ValueError as e:
  119. pass
  120. try:
  121. F.split(inp, [3, 3, 5], axis=3)
  122. assert False
  123. except ValueError as e:
  124. assert str(e) == "Invalid nsplits_or_secions: [3, 3, 5]"
  125. if is_varnode:
  126. set_symbolic_shape(saved_symbolic_shape)
  127. @pytest.mark.parametrize("is_varnode", [True, False])
  128. def test_reshape(is_varnode):
  129. if is_varnode:
  130. network = Network()
  131. else:
  132. network = None
  133. x = np.arange(6, dtype="float32")
  134. xx = make_tensor(x, network)
  135. y = x.reshape(1, 2, 3)
  136. for shape in [
  137. (1, 2, 3),
  138. (1, -1, 3),
  139. (1, make_tensor(-1, network), 3),
  140. np.array([1, -1, 3], dtype="int32"),
  141. make_tensor([1, -1, 3], network),
  142. ]:
  143. yy = F.reshape(xx, shape)
  144. np.testing.assert_equal(yy.numpy(), y)
  145. @pytest.mark.parametrize("is_trace", [True, False])
  146. def test_reshape_on_empty_tensor(is_trace):
  147. input1_shape = (100, 0, 1)
  148. output1_shape = (100, 0, 10)
  149. data1 = tensor(np.random.random(input1_shape).astype(np.float32))
  150. input2_shape = (10, 0)
  151. output2_shape = (0,)
  152. data2 = tensor(np.random.random(input2_shape).astype(np.float32))
  153. input3_shape = (10, 0, 10)
  154. output3_shape = (0, 1, 2, 3)
  155. data3 = tensor(np.random.random(input3_shape).astype(np.float32))
  156. def comp(out, target_shp):
  157. assert out._tuple_shape == target_shp
  158. def func(x, shp):
  159. return F.reshape(x, shp)
  160. cases = [
  161. [data1, output1_shape],
  162. [data2, output2_shape],
  163. [data3, output3_shape],
  164. ]
  165. def test(func, inp, comp, target_shp):
  166. out = func(inp, target_shp)
  167. comp(out, target_shp)
  168. if is_trace:
  169. for symbolic in [False, True]:
  170. for inp, target_shp in cases:
  171. func_traced = trace(symbolic=symbolic)(func)
  172. test(func_traced, inp, comp, target_shp)
  173. test(func_traced, inp, comp, target_shp)
  174. test(func_traced, inp, comp, target_shp)
  175. else:
  176. for inp, target_shp in cases:
  177. test(func, inp, comp, target_shp)
  178. @pytest.mark.parametrize("is_varnode", [True, False])
  179. def test_reshape_shape_inference(is_varnode):
  180. if is_varnode:
  181. network = Network()
  182. saved_symbolic_shape = set_symbolic_shape(False)
  183. else:
  184. network = None
  185. x_shape_known = make_tensor([1, 2, 3, 4], network)
  186. x_shape_unknown = F.broadcast_to(
  187. make_tensor([1.0], network), shape=make_tensor([1, 1, 1, 1], network).sum()
  188. )
  189. tshp_unknown = astensor1d(
  190. (make_tensor([2], network), make_tensor([2], network)), x_shape_known
  191. )
  192. tshp_known = astensor1d((2, 2), x_shape_known)
  193. tshp_known_unspec = astensor1d((2, -1), x_shape_known)
  194. def check_shape(output, target):
  195. source = output.shape
  196. if isinstance(source, tensor):
  197. source = source.numpy()
  198. np.testing.assert_equal(source, target)
  199. def func(x, target_shape):
  200. return x.reshape(target_shape)
  201. cases = [
  202. {"input": [x_shape_known, tshp_unknown], "output": [(2, 2),]},
  203. {"input": [x_shape_unknown, tshp_unknown], "output": [(2, 2),]},
  204. {"input": [x_shape_known, tshp_known], "output": [(2, 2),]},
  205. {"input": [x_shape_known, tshp_known_unspec], "output": [(2, 2),]},
  206. {"input": [x_shape_unknown, tshp_known], "output": [(2, 2),]},
  207. {"input": [x_shape_unknown, tshp_known_unspec], "output": [(2, 2),]},
  208. ]
  209. opr_test(cases, func, compare_fn=check_shape, test_trace=True, network=network)
  210. if is_varnode:
  211. set_symbolic_shape(saved_symbolic_shape)
  212. @pytest.mark.parametrize("is_varnode", [True, False])
  213. def test_squeeze(is_varnode):
  214. if is_varnode:
  215. network = Network()
  216. saved_symbolic_shape = set_symbolic_shape(False)
  217. else:
  218. network = None
  219. x = np.arange(6, dtype="float32").reshape(1, 2, 3, 1)
  220. xx = make_tensor(x, network)
  221. for axis in [None, 3, -4, (3, -4)]:
  222. y = np.squeeze(x, axis)
  223. yy = F.squeeze(xx, axis)
  224. np.testing.assert_equal(y, yy.numpy())
  225. if is_varnode:
  226. set_symbolic_shape(saved_symbolic_shape)
  227. @pytest.mark.parametrize("is_varnode", [True, False])
  228. def test_expand_dims(is_varnode):
  229. if is_varnode:
  230. network = Network()
  231. else:
  232. network = None
  233. x = np.arange(6, dtype="float32").reshape(2, 3)
  234. xx = make_tensor(x, network)
  235. for axis in [2, -3, (3, -4), (1, -4)]:
  236. y = np.expand_dims(x, axis)
  237. yy = F.expand_dims(xx, axis)
  238. np.testing.assert_equal(y, yy.numpy())
  239. def test_expand_dims_for_scalar():
  240. x = np.array(1, dtype="float32")
  241. xx = make_tensor(x, None)
  242. for axis in [0, -1, (0, 1), (-1, -2), (0, -1)]:
  243. y = np.expand_dims(x, axis)
  244. yy = F.expand_dims(xx, axis)
  245. np.testing.assert_equal(y, yy.numpy())
  246. for axis in [1, -2, (1, 2), (-2, -3)]:
  247. np.testing.assert_raises(np.AxisError, np.expand_dims, x, axis)
  248. np.testing.assert_raises(AssertionError, F.expand_dims, xx, axis)
  249. @pytest.mark.parametrize("is_varnode", [True, False])
  250. def test_elemwise_dtype_promotion(is_varnode):
  251. if is_varnode:
  252. network = Network()
  253. else:
  254. network = None
  255. x = np.random.rand(2, 3).astype("float32")
  256. y = np.random.rand(1, 3).astype("float16")
  257. xx = make_tensor(x, network)
  258. yy = make_tensor(y, network)
  259. z = xx * yy
  260. np.testing.assert_equal(z.numpy(), x * y)
  261. z = xx + y
  262. np.testing.assert_equal(z.numpy(), x + y)
  263. z = x - yy
  264. np.testing.assert_equal(z.numpy(), x - y)
  265. @pytest.mark.parametrize("is_varnode", [True, False])
  266. def test_linspace(is_varnode):
  267. if is_varnode:
  268. network = Network()
  269. else:
  270. network = None
  271. cases = [
  272. {"input": [1, 9, 9]},
  273. {"input": [3, 10, 8]},
  274. ]
  275. opr_test(
  276. cases,
  277. F.linspace,
  278. ref_fn=lambda start, end, step: np.linspace(start, end, step, dtype=np.float32),
  279. network=network,
  280. )
  281. cases = [
  282. {"input": [9, 1, 9]},
  283. {"input": [10, 3, 8]},
  284. ]
  285. opr_test(
  286. cases,
  287. F.linspace,
  288. ref_fn=lambda start, end, step: np.linspace(start, end, step, dtype=np.float32),
  289. network=network,
  290. )
  291. cases = [
  292. {"input": [1, make_tensor(9, network), 9]},
  293. {"input": [make_tensor(1, network), 9, make_tensor(9, network)]},
  294. ]
  295. opr_test(
  296. cases,
  297. F.linspace,
  298. ref_fn=lambda start, end, step: np.linspace(1, 9, 9, dtype=np.float32),
  299. network=network,
  300. )
  301. @pytest.mark.parametrize("is_varnode", [True, False])
  302. def test_arange(is_varnode):
  303. if is_varnode:
  304. network = Network()
  305. else:
  306. network = None
  307. cases = [
  308. {"input": [1, 9, 1]},
  309. {"input": [2, 10, 2]},
  310. ]
  311. opr_test(
  312. cases,
  313. F.arange,
  314. ref_fn=lambda start, end, step: np.arange(start, end, step, dtype=np.float32),
  315. network=network,
  316. )
  317. cases = [
  318. {"input": [9, 1, -1]},
  319. {"input": [10, 2, -2]},
  320. ]
  321. opr_test(
  322. cases,
  323. F.arange,
  324. ref_fn=lambda start, end, step: np.arange(start, end, step, dtype=np.float32),
  325. network=network,
  326. )
  327. cases = [
  328. {"input": [9.3, 1.2, -0.5]},
  329. {"input": [10.3, 2.1, -1.7]},
  330. ]
  331. opr_test(
  332. cases,
  333. F.arange,
  334. ref_fn=lambda start, end, step: np.arange(start, end, step, dtype=np.float32),
  335. network=network,
  336. )
  337. @pytest.mark.parametrize("is_varnode", [True, False])
  338. def test_round(is_varnode):
  339. if is_varnode:
  340. network = Network()
  341. else:
  342. network = None
  343. data1_shape = (15,)
  344. data2_shape = (25,)
  345. data1 = np.random.random(data1_shape).astype(np.float32)
  346. data2 = np.random.random(data2_shape).astype(np.float32)
  347. cases = [{"input": data1}, {"input": data2}]
  348. opr_test(cases, F.round, ref_fn=np.round, network=network)
  349. @pytest.mark.parametrize("is_varnode", [True, False])
  350. def test_flatten(is_varnode):
  351. if is_varnode:
  352. network = Network()
  353. else:
  354. network = None
  355. data0_shape = (2, 3, 4, 5)
  356. data1_shape = (4, 5, 6, 7)
  357. data0 = np.random.random(data0_shape).astype(np.float32)
  358. data1 = np.random.random(data1_shape).astype(np.float32)
  359. def compare_fn(x, y):
  360. assert x._tuple_shape[0] == y
  361. output0 = (2 * 3 * 4 * 5,)
  362. output1 = (4 * 5 * 6 * 7,)
  363. cases = [
  364. {"input": data0, "output": output0},
  365. {"input": data1, "output": output1},
  366. ]
  367. opr_test(cases, F.flatten, compare_fn=compare_fn, network=network)
  368. output0 = (2, 3 * 4 * 5)
  369. output1 = (4, 5 * 6 * 7)
  370. cases = [
  371. {"input": data0, "output": output0},
  372. {"input": data1, "output": output1},
  373. ]
  374. opr_test(cases, F.flatten, compare_fn=compare_fn, start_axis=1, network=network)
  375. output0 = (2, 3, 4 * 5)
  376. output1 = (4, 5, 6 * 7)
  377. cases = [
  378. {"input": data0, "output": output0},
  379. {"input": data1, "output": output1},
  380. ]
  381. opr_test(cases, F.flatten, compare_fn=compare_fn, start_axis=2, network=network)
  382. output0 = (2, 3 * 4, 5)
  383. output1 = (4, 5 * 6, 7)
  384. cases = [
  385. {"input": data0, "output": output0},
  386. {"input": data1, "output": output1},
  387. ]
  388. opr_test(
  389. cases,
  390. F.flatten,
  391. compare_fn=compare_fn,
  392. start_axis=1,
  393. end_axis=2,
  394. network=network,
  395. )
  396. @pytest.mark.parametrize("is_varnode", [True, False])
  397. def test_broadcast(is_varnode):
  398. if is_varnode:
  399. network = Network()
  400. else:
  401. network = None
  402. input1_shape = (20, 30)
  403. output1_shape = (30, 20, 30)
  404. data1 = np.random.random(input1_shape).astype(np.float32)
  405. input2_shape = (10, 1)
  406. output2_shape = (20, 10, 20)
  407. data2 = np.random.random(input2_shape).astype(np.float32)
  408. input3_shape = (10, 10)
  409. output3_shape = (10, 10)
  410. data3 = np.random.random(input3_shape).astype(np.float32)
  411. def compare_fn(x, y):
  412. assert x._tuple_shape[0] == y
  413. cases = [
  414. {"input": [data1, output1_shape], "output": output1_shape},
  415. {"input": [data2, output2_shape], "output": output2_shape},
  416. {"input": [data3, output3_shape], "output": output3_shape},
  417. ]
  418. opr_test(cases, F.broadcast_to, compare_fn=compare_fn, network=network)
  419. x = F.ones((2, 1, 3))
  420. with pytest.raises(RuntimeError):
  421. F.broadcast_to(x, (2, 3, 4))
  422. with pytest.raises(RuntimeError):
  423. F.broadcast_to(x, (4, 1, 3))
  424. with pytest.raises(RuntimeError):
  425. F.broadcast_to(x, (1, 3))
  426. @pytest.mark.parametrize("is_trace", [True, False])
  427. def test_broadcast_on_empty_tensor(is_trace):
  428. input1_shape = (100, 0, 1)
  429. output1_shape = (100, 0, 10)
  430. data1 = tensor(np.random.random(input1_shape).astype(np.float32))
  431. input2_shape = (10, 0)
  432. output2_shape = (10, 10, 0)
  433. data2 = tensor(np.random.random(input2_shape).astype(np.float32))
  434. input3_shape = (0, 0, 1, 10)
  435. output3_shape = (10, 0, 0, 10, 10)
  436. data3 = tensor(np.random.random(input3_shape).astype(np.float32))
  437. def comp(out, target_shp):
  438. assert out._tuple_shape == target_shp
  439. def func(x, shp):
  440. return F.broadcast_to(x, shp)
  441. cases = [
  442. [data1, output1_shape],
  443. [data2, output2_shape],
  444. [data3, output3_shape],
  445. ]
  446. def test(func, inp, comp, target_shp):
  447. out = func(inp, target_shp)
  448. comp(out, target_shp)
  449. if is_trace:
  450. for symbolic in [False, True]:
  451. for inp, target_shp in cases:
  452. func_traced = trace(symbolic=symbolic)(func)
  453. test(func_traced, inp, comp, target_shp)
  454. test(func_traced, inp, comp, target_shp)
  455. test(func_traced, inp, comp, target_shp)
  456. else:
  457. for inp, target_shp in cases:
  458. test(func, inp, comp, target_shp)
  459. @pytest.mark.parametrize("is_varnode", [True, False])
  460. def test_utils_astensor1d(is_varnode):
  461. if is_varnode:
  462. network = Network()
  463. else:
  464. network = None
  465. reference = make_tensor(0, network)
  466. # literal
  467. x = [1, 2, 3]
  468. for dtype in [None, "float32"]:
  469. xx = astensor1d(x, reference, dtype=dtype)
  470. assert isinstance(xx, type(reference))
  471. np.testing.assert_equal(xx.numpy(), x)
  472. # numpy array
  473. x = np.asarray([1, 2, 3], dtype="int32")
  474. for dtype in [None, "float32"]:
  475. xx = astensor1d(x, reference, dtype=dtype)
  476. assert isinstance(xx, type(reference))
  477. np.testing.assert_equal(xx.numpy(), x.astype(dtype) if dtype else x)
  478. # tensor
  479. x = make_tensor([1, 2, 3], network)
  480. for dtype in [None, "float32"]:
  481. xx = astensor1d(x, reference, dtype=dtype)
  482. assert isinstance(xx, type(reference))
  483. np.testing.assert_equal(xx.numpy(), x.numpy())
  484. # mixed
  485. x = [1, make_tensor(2, network), 3]
  486. for dtype in [None, "float32"]:
  487. xx = astensor1d(x, reference, dtype=dtype)
  488. assert isinstance(xx, type(reference))
  489. np.testing.assert_equal(xx.numpy(), [1, 2, 3])
  490. def test_device():
  491. x = tensor([1, 2, 3], dtype="float32")
  492. y1 = F.eye(x.shape, dtype="float32")
  493. y2 = F.eye(x.shape, dtype="float32", device=None)
  494. np.testing.assert_almost_equal(y1.numpy(), y2.numpy())
  495. y3 = F.eye(x.shape, dtype="float32", device="xpux")
  496. y4 = F.eye(x.shape, dtype="float32", device=x.device)
  497. np.testing.assert_almost_equal(y3.numpy(), y4.numpy())
  498. y5 = F.full((3, 2), 4, device=x.device)
  499. y6 = F.full((3, 2), 4, device="xpux")
  500. np.testing.assert_almost_equal(y5.numpy(), y6.numpy())
  501. @pytest.mark.parametrize("is_varnode", [True, False])
  502. def test_identity(is_varnode):
  503. if is_varnode:
  504. network = Network()
  505. else:
  506. network = None
  507. x = make_tensor(np.random.random((5, 10)).astype(np.float32), network)
  508. y = F.copy(x)
  509. np.testing.assert_equal(y.numpy(), x)
  510. def copy_test(dst, src, network):
  511. data = np.random.random((2, 3)).astype(np.float32)
  512. x = make_tensor(data, device=src, network=network)
  513. y = F.copy(x, dst)
  514. assert np.allclose(data, y.numpy())
  515. if network is None:
  516. z = x.to(dst)
  517. assert np.allclose(data, z.numpy())
  518. @pytest.mark.require_ngpu(1)
  519. @pytest.mark.parametrize("is_varnode", [True, False])
  520. def test_copy_h2d(is_varnode):
  521. if is_varnode:
  522. network = Network()
  523. else:
  524. network = None
  525. copy_test("cpu0", "gpu0", network=network)
  526. @pytest.mark.require_ngpu(1)
  527. @pytest.mark.parametrize("is_varnode", [True, False])
  528. def test_copy_d2h(is_varnode):
  529. if is_varnode:
  530. network = Network()
  531. else:
  532. network = None
  533. copy_test("gpu0", "cpu0", network=network)
  534. @pytest.mark.require_ngpu(2)
  535. @pytest.mark.parametrize("is_varnode", [True, False])
  536. def test_copy_d2d(is_varnode):
  537. if is_varnode:
  538. network = Network()
  539. else:
  540. network = None
  541. copy_test("gpu0", "gpu1", network=network)
  542. copy_test("gpu0:0", "gpu0:1", network=network)
  543. @pytest.mark.parametrize(
  544. "shape, repeats, axis",
  545. [
  546. ((2,), 2, 0),
  547. ((2, 3, 4, 5), 3, 0),
  548. ((2, 3, 4, 5), 4, 3),
  549. ((2,), 2, None),
  550. ((2, 3, 4, 5), 3, None),
  551. ((), 1, None),
  552. ((), 10, None),
  553. ],
  554. )
  555. @pytest.mark.parametrize("is_varnode", [True, False])
  556. def test_repeat(shape, repeats, axis, is_varnode):
  557. if is_varnode:
  558. network = Network()
  559. else:
  560. network = None
  561. def repeat_func(inp):
  562. return F.repeat(inp=inp, repeats=repeats, axis=axis)
  563. if shape != ():
  564. cases = [
  565. {"input": np.random.randn(*shape).astype("float32")},
  566. ]
  567. else:
  568. cases = [{"input": np.array(1.23)}]
  569. opr_test(
  570. cases,
  571. repeat_func,
  572. ref_fn=lambda inp: np.repeat(inp, repeats, axis),
  573. network=network,
  574. )
  575. @pytest.mark.parametrize(
  576. "shape, reps",
  577. [
  578. ((2,), (2,)),
  579. ((2, 3, 4, 5), (1, 1, 1, 1)),
  580. ((2, 3, 4, 5), (1, 2, 3, 4)),
  581. ((2, 3, 4, 5), (2, 2, 2, 2, 2, 2, 2)),
  582. ],
  583. )
  584. @pytest.mark.parametrize("is_varnode", [True])
  585. def test_tile(shape, reps, is_varnode):
  586. if is_varnode:
  587. network = Network()
  588. else:
  589. network = None
  590. def tile_func(inp):
  591. return F.tile(inp=inp, reps=reps)
  592. cases = [{"input": np.random.randn(*shape).astype("float32")}]
  593. opr_test(cases, tile_func, ref_fn=lambda inp: np.tile(inp, reps), network=network)
  594. @pytest.mark.parametrize(
  595. "shape, shifts, axis",
  596. [
  597. ((2, 3), 0, None),
  598. ((2, 3), 1, 0),
  599. ((2, 3, 4, 5), (-1, 1), (0, 1)),
  600. ((2, 3, 4, 5), (-2, 1, 2), (1, 2, 3)),
  601. ],
  602. )
  603. @pytest.mark.parametrize("is_varnode", [True, False])
  604. def test_roll(shape, shifts, axis, is_varnode):
  605. if is_varnode:
  606. network = Network()
  607. else:
  608. network = None
  609. inp = np.random.randn(*shape).astype("float32")
  610. def func(inp):
  611. return F.roll(inp, shifts, axis)
  612. cases = [
  613. {"input": inp},
  614. ]
  615. opr_test(
  616. cases, func, ref_fn=lambda inp: np.roll(inp, shifts, axis), network=network
  617. )

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台