You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_tensor.py 24 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854
  1. # -*- coding: utf-8 -*-
  2. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  3. #
  4. # Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  5. #
  6. # Unless required by applicable law or agreed to in writing,
  7. # software distributed under the License is distributed on an
  8. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  9. import os
  10. import platform
  11. import numpy as np
  12. import pytest
  13. from utils import get_var_value, make_tensor, opr_test
  14. import megengine.functional as F
  15. from megengine import tensor
  16. from megengine.core._trace_option import use_symbolic_shape
  17. from megengine.core.tensor import megbrain_graph as G
  18. from megengine.core.tensor.utils import astensor1d
  19. from megengine.jit import trace
  20. from megengine.utils.network import Network, set_symbolic_shape
  21. from megengine.utils.network_node import VarNode
  22. def test_eye():
  23. dtype = np.float32
  24. cases = [{"input": [10, 20]}, {"input": [30]}]
  25. for case in cases:
  26. np.testing.assert_allclose(
  27. F.eye(case["input"], dtype=dtype).numpy(),
  28. np.eye(*case["input"]).astype(dtype),
  29. )
  30. np.testing.assert_allclose(
  31. F.eye(*case["input"], dtype=dtype).numpy(),
  32. np.eye(*case["input"]).astype(dtype),
  33. )
  34. np.testing.assert_allclose(
  35. F.eye(tensor(case["input"]), dtype=dtype).numpy(),
  36. np.eye(*case["input"]).astype(dtype),
  37. )
  38. @pytest.mark.parametrize("is_varnode", [True, False])
  39. def test_concat(is_varnode):
  40. if is_varnode:
  41. network = Network()
  42. else:
  43. network = None
  44. def get_data_shape(length: int):
  45. return (length, 2, 3)
  46. data1 = np.random.random(get_data_shape(5)).astype("float32")
  47. data2 = np.random.random(get_data_shape(6)).astype("float32")
  48. data3 = np.random.random(get_data_shape(7)).astype("float32")
  49. def run(data1, data2):
  50. return F.concat([data1, data2])
  51. cases = [{"input": [data1, data2]}, {"input": [data1, data3]}]
  52. opr_test(cases, run, ref_fn=lambda x, y: np.concatenate([x, y]), network=network)
  53. @pytest.mark.parametrize("is_varnode", [True, False])
  54. def test_condtake(is_varnode):
  55. if is_varnode:
  56. network = Network()
  57. else:
  58. network = None
  59. x = np.array([[1, 2, 3], [4, 5, 6]]).astype("float32")
  60. y = np.array([[True, False, True], [False, True, True]])
  61. xx = make_tensor(x, network)
  62. yy = make_tensor(y, network)
  63. val, idx = F.cond_take(yy, xx)
  64. if is_varnode:
  65. np.testing.assert_equal(get_var_value(val), x[y])
  66. np.testing.assert_equal(get_var_value(idx), np.where(y.reshape(-1))[0])
  67. else:
  68. np.testing.assert_equal(val.numpy(), x[y])
  69. np.testing.assert_equal(idx.numpy(), np.where(y.reshape(-1))[0])
  70. @pytest.mark.parametrize("is_varnode", [True, False])
  71. def test_concat_device(is_varnode):
  72. if is_varnode:
  73. network = Network()
  74. else:
  75. network = None
  76. data1 = make_tensor(np.random.random((3, 2, 2)).astype("float32"), network, "cpu0")
  77. data2 = make_tensor(np.random.random((2, 2, 2)).astype("float32"), network, "cpu1")
  78. out = F.concat([data1, data2], device="cpu0")
  79. assert str(out.device).split(":")[0] == "cpu0"
  80. @pytest.mark.parametrize("is_varnode", [True, False])
  81. def test_stack(is_varnode):
  82. if is_varnode:
  83. network = Network()
  84. else:
  85. network = None
  86. data1 = np.random.random((3, 2, 2)).astype("float32")
  87. data2 = np.random.random((3, 2, 2)).astype("float32")
  88. data3 = np.random.random((3, 2, 2)).astype("float32")
  89. cases = [{"input": [data1, data2]}, {"input": [data1, data3]}]
  90. for ai in range(3):
  91. def run(data1, data2):
  92. return F.stack([data1, data2], axis=ai)
  93. opr_test(
  94. cases, run, ref_fn=lambda x, y: np.stack([x, y], axis=ai), network=network
  95. )
  96. @pytest.mark.parametrize("is_varnode", [True, False])
  97. def test_split_basic(is_varnode):
  98. if is_varnode:
  99. network = Network()
  100. saved_symbolic_shape = set_symbolic_shape(False)
  101. else:
  102. network = None
  103. data = np.random.random((2, 3, 4, 5)).astype(np.float32)
  104. inp = make_tensor(data, network)
  105. mge_out0 = F.split(inp, 2, axis=3)
  106. mge_out1 = F.split(inp, [3], axis=3)
  107. np_out = np.split(data, [3, 5], axis=3)
  108. assert len(mge_out0) == 2
  109. assert len(mge_out1) == 2
  110. np.testing.assert_equal(mge_out0[0].numpy(), np_out[0])
  111. np.testing.assert_equal(mge_out1[0].numpy(), np_out[0])
  112. np.testing.assert_equal(mge_out0[1].numpy(), np_out[1])
  113. np.testing.assert_equal(mge_out1[1].numpy(), np_out[1])
  114. try:
  115. F.split(inp, 4)
  116. assert False
  117. except ValueError as e:
  118. pass
  119. try:
  120. F.split(inp, [3, 2, 5], axis=3)
  121. assert False
  122. except ValueError as e:
  123. assert str(e) == "Invalid nsplits_or_secions: [3, 2, 5]"
  124. if is_varnode:
  125. set_symbolic_shape(saved_symbolic_shape)
  126. @pytest.mark.parametrize("symbolic", [None, False, True])
  127. def test_split(symbolic):
  128. inp1 = np.random.random((3, 4, 5, 6)).astype(np.float32)
  129. inp2 = np.random.random((0, 4, 5, 6)).astype(np.float32)
  130. def ref(inp, nsplits_or_sections, axis):
  131. return np.split(inp, nsplits_or_sections, axis)
  132. def func(inp, nsplits_or_sections, axis):
  133. return F.split(inp, nsplits_or_sections, axis)
  134. cases = [
  135. (inp1, 2, 3),
  136. (inp1, [3], 3),
  137. (inp1, [3, 3, 5], 3),
  138. (inp2, 2, 3),
  139. (inp2, [3], 3),
  140. (inp2, [3, 3, 5], 3),
  141. ]
  142. for case in cases:
  143. if symbolic is None:
  144. fn = func
  145. else:
  146. fn = trace(symbolic=symbolic)(func)
  147. for i in range(3 if symbolic is not None else 1):
  148. ref_out = ref(*case)
  149. out = fn(tensor(case[0]), case[1], case[2])
  150. assert len(ref_out) == len(out)
  151. for idx in range(len(ref_out)):
  152. np.testing.assert_equal(ref_out[idx], out[idx].numpy())
  153. @pytest.mark.parametrize("is_varnode", [True, False])
  154. def test_reshape(is_varnode):
  155. if is_varnode:
  156. network = Network()
  157. else:
  158. network = None
  159. x = np.arange(6, dtype="float32")
  160. xx = make_tensor(x, network)
  161. y = x.reshape(1, 2, 3)
  162. for shape in [
  163. (1, 2, 3),
  164. (1, -1, 3),
  165. (1, make_tensor(-1, network), 3),
  166. np.array([1, -1, 3], dtype="int32"),
  167. make_tensor([1, -1, 3], network),
  168. ]:
  169. yy = F.reshape(xx, shape)
  170. np.testing.assert_equal(yy.numpy(), y)
  171. @pytest.mark.parametrize("is_trace", [True, False])
  172. def test_reshape_on_empty_tensor(is_trace):
  173. input1_shape = (100, 0, 1)
  174. output1_shape = (100, 0, 10)
  175. data1 = tensor(np.random.random(input1_shape).astype(np.float32))
  176. input2_shape = (10, 0)
  177. output2_shape = (0,)
  178. data2 = tensor(np.random.random(input2_shape).astype(np.float32))
  179. input3_shape = (10, 0, 10)
  180. output3_shape = (0, 1, 2, 3)
  181. data3 = tensor(np.random.random(input3_shape).astype(np.float32))
  182. def comp(out, target_shp):
  183. assert out._tuple_shape == target_shp
  184. def func(x, shp):
  185. return F.reshape(x, shp)
  186. cases = [
  187. [data1, output1_shape],
  188. [data2, output2_shape],
  189. [data3, output3_shape],
  190. ]
  191. def test(func, inp, comp, target_shp):
  192. out = func(inp, target_shp)
  193. comp(out, target_shp)
  194. if is_trace:
  195. for symbolic in [False, True]:
  196. for inp, target_shp in cases:
  197. func_traced = trace(symbolic=symbolic)(func)
  198. test(func_traced, inp, comp, target_shp)
  199. test(func_traced, inp, comp, target_shp)
  200. test(func_traced, inp, comp, target_shp)
  201. else:
  202. for inp, target_shp in cases:
  203. test(func, inp, comp, target_shp)
  204. @pytest.mark.parametrize("is_varnode", [True, False])
  205. def test_reshape_shape_inference(is_varnode):
  206. if is_varnode:
  207. network = Network()
  208. saved_symbolic_shape = set_symbolic_shape(False)
  209. else:
  210. network = None
  211. x_shape_known = make_tensor([1, 2, 3, 4], network)
  212. x_shape_unknown = F.broadcast_to(
  213. make_tensor([1.0], network), shape=make_tensor([1, 1, 1, 1], network).sum()
  214. )
  215. tshp_unknown = astensor1d(
  216. (make_tensor([2], network), make_tensor([2], network)), x_shape_known
  217. )
  218. tshp_known = astensor1d((2, 2), x_shape_known)
  219. tshp_known_unspec = astensor1d((2, -1), x_shape_known)
  220. def check_shape(output, target):
  221. source = output.shape
  222. if isinstance(source, tensor):
  223. source = source.numpy()
  224. np.testing.assert_equal(source, target)
  225. def func(x, target_shape):
  226. return x.reshape(target_shape)
  227. cases = [
  228. {"input": [x_shape_known, tshp_unknown], "output": [(2, 2),]},
  229. {"input": [x_shape_unknown, tshp_unknown], "output": [(2, 2),]},
  230. {"input": [x_shape_known, tshp_known], "output": [(2, 2),]},
  231. {"input": [x_shape_known, tshp_known_unspec], "output": [(2, 2),]},
  232. {"input": [x_shape_unknown, tshp_known], "output": [(2, 2),]},
  233. {"input": [x_shape_unknown, tshp_known_unspec], "output": [(2, 2),]},
  234. ]
  235. opr_test(cases, func, compare_fn=check_shape, test_trace=True, network=network)
  236. if is_varnode:
  237. set_symbolic_shape(saved_symbolic_shape)
  238. @pytest.mark.parametrize("is_varnode", [True, False])
  239. def test_squeeze(is_varnode):
  240. if is_varnode:
  241. network = Network()
  242. saved_symbolic_shape = set_symbolic_shape(False)
  243. else:
  244. network = None
  245. x = np.arange(6, dtype="float32").reshape(1, 2, 3, 1)
  246. xx = make_tensor(x, network)
  247. for axis in [None, 3, -4, (3, -4)]:
  248. y = np.squeeze(x, axis)
  249. yy = F.squeeze(xx, axis)
  250. np.testing.assert_equal(y, yy.numpy())
  251. if is_varnode:
  252. set_symbolic_shape(saved_symbolic_shape)
  253. @pytest.mark.parametrize("is_varnode", [True, False])
  254. def test_expand_dims(is_varnode):
  255. if is_varnode:
  256. network = Network()
  257. else:
  258. network = None
  259. x = np.arange(6, dtype="float32").reshape(2, 3)
  260. xx = make_tensor(x, network)
  261. for axis in [2, -3, (3, -4), (1, -4)]:
  262. y = np.expand_dims(x, axis)
  263. yy = F.expand_dims(xx, axis)
  264. np.testing.assert_equal(y, yy.numpy())
  265. def test_expand_dims_for_scalar():
  266. x = np.array(1, dtype="float32")
  267. xx = make_tensor(x, None)
  268. for axis in [0, -1, (0, 1), (-1, -2), (0, -1)]:
  269. y = np.expand_dims(x, axis)
  270. yy = F.expand_dims(xx, axis)
  271. np.testing.assert_equal(y, yy.numpy())
  272. for axis in [1, -2, (1, 2), (-2, -3)]:
  273. np.testing.assert_raises(np.AxisError, np.expand_dims, x, axis)
  274. np.testing.assert_raises(AssertionError, F.expand_dims, xx, axis)
  275. @pytest.mark.parametrize("is_varnode", [True, False])
  276. def test_elemwise_dtype_promotion(is_varnode):
  277. if is_varnode:
  278. network = Network()
  279. else:
  280. network = None
  281. x = np.random.rand(2, 3).astype("float32")
  282. y = np.random.rand(1, 3).astype("float16")
  283. xx = make_tensor(x, network)
  284. yy = make_tensor(y, network)
  285. z = xx * yy
  286. np.testing.assert_equal(z.numpy(), x * y)
  287. z = xx + y
  288. np.testing.assert_equal(z.numpy(), x + y)
  289. z = x - yy
  290. np.testing.assert_equal(z.numpy(), x - y)
  291. @pytest.mark.parametrize("is_varnode", [True, False])
  292. def test_linspace(is_varnode):
  293. if is_varnode:
  294. network = Network()
  295. else:
  296. network = None
  297. cases = [
  298. {"input": [1, 9, 9]},
  299. {"input": [3, 10, 8]},
  300. ]
  301. opr_test(
  302. cases,
  303. F.linspace,
  304. ref_fn=lambda start, end, step: np.linspace(start, end, step, dtype=np.float32),
  305. network=network,
  306. )
  307. cases = [
  308. {"input": [9, 1, 9]},
  309. {"input": [10, 3, 8]},
  310. ]
  311. opr_test(
  312. cases,
  313. F.linspace,
  314. ref_fn=lambda start, end, step: np.linspace(start, end, step, dtype=np.float32),
  315. network=network,
  316. )
  317. cases = [
  318. {"input": [1, make_tensor(9, network), 9]},
  319. {"input": [make_tensor(1, network), 9, make_tensor(9, network)]},
  320. ]
  321. opr_test(
  322. cases,
  323. F.linspace,
  324. ref_fn=lambda start, end, step: np.linspace(1, 9, 9, dtype=np.float32),
  325. network=network,
  326. )
  327. @pytest.mark.parametrize("is_varnode", [True, False])
  328. def test_arange(is_varnode):
  329. if is_varnode:
  330. network = Network()
  331. else:
  332. network = None
  333. cases = [
  334. {"input": [1, 9, 1]},
  335. {"input": [2, 10, 2]},
  336. ]
  337. opr_test(
  338. cases,
  339. F.arange,
  340. ref_fn=lambda start, end, step: np.arange(start, end, step, dtype=np.float32),
  341. network=network,
  342. )
  343. cases = [
  344. {"input": [9, 1, -1]},
  345. {"input": [10, 2, -2]},
  346. ]
  347. opr_test(
  348. cases,
  349. F.arange,
  350. ref_fn=lambda start, end, step: np.arange(start, end, step, dtype=np.float32),
  351. network=network,
  352. )
  353. cases = [
  354. {"input": [9.3, 1.2, -0.5]},
  355. {"input": [10.3, 2.1, -1.7]},
  356. ]
  357. opr_test(
  358. cases,
  359. F.arange,
  360. ref_fn=lambda start, end, step: np.arange(start, end, step, dtype=np.float32),
  361. network=network,
  362. )
  363. @pytest.mark.parametrize("is_varnode", [True, False])
  364. def test_round(is_varnode):
  365. if is_varnode:
  366. network = Network()
  367. else:
  368. network = None
  369. data1_shape = (15,)
  370. data2_shape = (25,)
  371. data1 = np.random.random(data1_shape).astype(np.float32)
  372. data2 = np.random.random(data2_shape).astype(np.float32)
  373. cases = [{"input": data1}, {"input": data2}]
  374. opr_test(cases, F.round, ref_fn=np.round, network=network)
  375. @pytest.mark.parametrize("is_varnode", [True, False])
  376. def test_flatten(is_varnode):
  377. if is_varnode:
  378. network = Network()
  379. else:
  380. network = None
  381. data0_shape = (2, 3, 4, 5)
  382. data1_shape = (4, 5, 6, 7)
  383. data0 = np.random.random(data0_shape).astype(np.float32)
  384. data1 = np.random.random(data1_shape).astype(np.float32)
  385. def compare_fn(x, y):
  386. assert x._tuple_shape[0] == y
  387. output0 = (2 * 3 * 4 * 5,)
  388. output1 = (4 * 5 * 6 * 7,)
  389. cases = [
  390. {"input": data0, "output": output0},
  391. {"input": data1, "output": output1},
  392. ]
  393. opr_test(cases, F.flatten, compare_fn=compare_fn, network=network)
  394. output0 = (2, 3 * 4 * 5)
  395. output1 = (4, 5 * 6 * 7)
  396. cases = [
  397. {"input": data0, "output": output0},
  398. {"input": data1, "output": output1},
  399. ]
  400. opr_test(cases, F.flatten, compare_fn=compare_fn, start_axis=1, network=network)
  401. output0 = (2, 3, 4 * 5)
  402. output1 = (4, 5, 6 * 7)
  403. cases = [
  404. {"input": data0, "output": output0},
  405. {"input": data1, "output": output1},
  406. ]
  407. opr_test(cases, F.flatten, compare_fn=compare_fn, start_axis=2, network=network)
  408. output0 = (2, 3 * 4, 5)
  409. output1 = (4, 5 * 6, 7)
  410. cases = [
  411. {"input": data0, "output": output0},
  412. {"input": data1, "output": output1},
  413. ]
  414. opr_test(
  415. cases,
  416. F.flatten,
  417. compare_fn=compare_fn,
  418. start_axis=1,
  419. end_axis=2,
  420. network=network,
  421. )
  422. @pytest.mark.parametrize("is_varnode", [True, False])
  423. def test_broadcast(is_varnode):
  424. if is_varnode:
  425. network = Network()
  426. else:
  427. network = None
  428. input1_shape = (20, 30)
  429. output1_shape = (30, 20, 30)
  430. data1 = np.random.random(input1_shape).astype(np.float32)
  431. input2_shape = (10, 1)
  432. output2_shape = (20, 10, 20)
  433. data2 = np.random.random(input2_shape).astype(np.float32)
  434. input3_shape = (10, 10)
  435. output3_shape = (10, 10)
  436. data3 = np.random.random(input3_shape).astype(np.float32)
  437. def compare_fn(x, y):
  438. assert x._tuple_shape[0] == y
  439. cases = [
  440. {"input": [data1, output1_shape], "output": output1_shape},
  441. {"input": [data2, output2_shape], "output": output2_shape},
  442. {"input": [data3, output3_shape], "output": output3_shape},
  443. ]
  444. opr_test(cases, F.broadcast_to, compare_fn=compare_fn, network=network)
  445. x = F.ones((2, 1, 3))
  446. with pytest.raises(RuntimeError):
  447. F.broadcast_to(x, (2, 3, 4))
  448. with pytest.raises(RuntimeError):
  449. F.broadcast_to(x, (4, 1, 3))
  450. with pytest.raises(RuntimeError):
  451. F.broadcast_to(x, (1, 3))
  452. @pytest.mark.parametrize("is_trace", [True, False])
  453. def test_broadcast_on_empty_tensor(is_trace):
  454. input1_shape = (100, 0, 1)
  455. output1_shape = (100, 0, 10)
  456. data1 = tensor(np.random.random(input1_shape).astype(np.float32))
  457. input2_shape = (10, 0)
  458. output2_shape = (10, 10, 0)
  459. data2 = tensor(np.random.random(input2_shape).astype(np.float32))
  460. input3_shape = (0, 0, 1, 10)
  461. output3_shape = (10, 0, 0, 10, 10)
  462. data3 = tensor(np.random.random(input3_shape).astype(np.float32))
  463. def comp(out, target_shp):
  464. assert out._tuple_shape == target_shp
  465. def func(x, shp):
  466. return F.broadcast_to(x, shp)
  467. cases = [
  468. [data1, output1_shape],
  469. [data2, output2_shape],
  470. [data3, output3_shape],
  471. ]
  472. def test(func, inp, comp, target_shp):
  473. out = func(inp, target_shp)
  474. comp(out, target_shp)
  475. if is_trace:
  476. for symbolic in [False, True]:
  477. for inp, target_shp in cases:
  478. func_traced = trace(symbolic=symbolic)(func)
  479. test(func_traced, inp, comp, target_shp)
  480. test(func_traced, inp, comp, target_shp)
  481. test(func_traced, inp, comp, target_shp)
  482. else:
  483. for inp, target_shp in cases:
  484. test(func, inp, comp, target_shp)
  485. @pytest.mark.parametrize("is_varnode", [True, False])
  486. def test_utils_astensor1d(is_varnode):
  487. if is_varnode:
  488. network = Network()
  489. else:
  490. network = None
  491. reference = make_tensor(0, network)
  492. # literal
  493. x = [1, 2, 3]
  494. for dtype in [None, "float32"]:
  495. xx = astensor1d(x, reference, dtype=dtype)
  496. assert isinstance(xx, type(reference))
  497. np.testing.assert_equal(xx.numpy(), x)
  498. # numpy array
  499. x = np.asarray([1, 2, 3], dtype="int32")
  500. for dtype in [None, "float32"]:
  501. xx = astensor1d(x, reference, dtype=dtype)
  502. assert isinstance(xx, type(reference))
  503. np.testing.assert_equal(xx.numpy(), x.astype(dtype) if dtype else x)
  504. # tensor
  505. x = make_tensor([1, 2, 3], network)
  506. for dtype in [None, "float32"]:
  507. xx = astensor1d(x, reference, dtype=dtype)
  508. assert isinstance(xx, type(reference))
  509. np.testing.assert_equal(xx.numpy(), x.numpy())
  510. # mixed
  511. x = [1, make_tensor(2, network), 3]
  512. for dtype in [None, "float32"]:
  513. xx = astensor1d(x, reference, dtype=dtype)
  514. assert isinstance(xx, type(reference))
  515. np.testing.assert_equal(xx.numpy(), [1, 2, 3])
  516. def test_device():
  517. x = tensor([1, 2, 3], dtype="float32")
  518. y1 = F.eye(x.shape, dtype="float32")
  519. y2 = F.eye(x.shape, dtype="float32", device=None)
  520. np.testing.assert_almost_equal(y1.numpy(), y2.numpy())
  521. y3 = F.eye(x.shape, dtype="float32", device="xpux")
  522. y4 = F.eye(x.shape, dtype="float32", device=x.device)
  523. np.testing.assert_almost_equal(y3.numpy(), y4.numpy())
  524. y5 = F.full((3, 2), 4, device=x.device)
  525. y6 = F.full((3, 2), 4, device="xpux")
  526. np.testing.assert_almost_equal(y5.numpy(), y6.numpy())
  527. @pytest.mark.parametrize("is_varnode", [True, False])
  528. def test_identity(is_varnode):
  529. if is_varnode:
  530. network = Network()
  531. else:
  532. network = None
  533. x = make_tensor(np.random.random((5, 10)).astype(np.float32), network)
  534. y = F.copy(x)
  535. np.testing.assert_equal(y.numpy(), x)
  536. def copy_test(dst, src, network):
  537. data = np.random.random((2, 3)).astype(np.float32)
  538. x = make_tensor(data, device=src, network=network)
  539. y = F.copy(x, dst)
  540. assert np.allclose(data, y.numpy())
  541. if network is None:
  542. z = x.to(dst)
  543. assert np.allclose(data, z.numpy())
  544. @pytest.mark.require_ngpu(1)
  545. @pytest.mark.parametrize("is_varnode", [True, False])
  546. def test_copy_h2d(is_varnode):
  547. if is_varnode:
  548. network = Network()
  549. else:
  550. network = None
  551. copy_test("cpu0", "gpu0", network=network)
  552. @pytest.mark.require_ngpu(1)
  553. @pytest.mark.parametrize("is_varnode", [True, False])
  554. def test_copy_d2h(is_varnode):
  555. if is_varnode:
  556. network = Network()
  557. else:
  558. network = None
  559. copy_test("gpu0", "cpu0", network=network)
  560. @pytest.mark.require_ngpu(2)
  561. @pytest.mark.parametrize("is_varnode", [True, False])
  562. def test_copy_d2d(is_varnode):
  563. if is_varnode:
  564. network = Network()
  565. else:
  566. network = None
  567. copy_test("gpu0", "gpu1", network=network)
  568. copy_test("gpu0:0", "gpu0:1", network=network)
  569. @pytest.mark.require_ngpu(2)
  570. @pytest.mark.parametrize(
  571. "shape, device_src, device_dst",
  572. [
  573. ((0,), "cpu0", "cpu0"),
  574. ((10, 0), "cpu0", "cpu1"),
  575. ((2, 0, 3), "cpu0", "gpu0"),
  576. ((1, 0, 1, 0), "gpu0", "cpu0"),
  577. ((2, 3, 4, 5, 0), "gpu0", "gpu1"),
  578. ],
  579. )
  580. @pytest.mark.parametrize("is_symbolic", [None, True, False])
  581. def test_copy_empty(shape, device_src, device_dst, is_symbolic):
  582. inp = tensor(np.random.randn(*shape).astype("float32"), device=device_src)
  583. def func(inp):
  584. return F.copy(inp, device_dst)
  585. if is_symbolic is not None:
  586. func = trace(symbolic=is_symbolic)(func)
  587. for _ in range(3):
  588. out = func(inp)
  589. assert out.numpy().shape == shape
  590. assert out.device == device_dst
  591. if is_symbolic is None:
  592. break
  593. @pytest.mark.parametrize(
  594. "shape, repeats, axis",
  595. [
  596. ((2,), 2, 0),
  597. ((2, 3, 4, 5), 3, 0),
  598. ((2, 3, 4, 5), 4, 3),
  599. ((2,), 2, None),
  600. ((2, 3, 4, 5), 3, None),
  601. ((), 1, None),
  602. ((), 10, None),
  603. ],
  604. )
  605. @pytest.mark.parametrize("is_varnode", [True, False])
  606. def test_repeat(shape, repeats, axis, is_varnode):
  607. if is_varnode:
  608. network = Network()
  609. else:
  610. network = None
  611. def repeat_func(inp):
  612. return F.repeat(inp=inp, repeats=repeats, axis=axis)
  613. if shape != ():
  614. cases = [
  615. {"input": np.random.randn(*shape).astype("float32")},
  616. ]
  617. else:
  618. cases = [{"input": np.array(1.23)}]
  619. opr_test(
  620. cases,
  621. repeat_func,
  622. ref_fn=lambda inp: np.repeat(inp, repeats, axis),
  623. network=network,
  624. )
  625. @pytest.mark.parametrize(
  626. "shape, reps",
  627. [
  628. ((2,), (2,)),
  629. ((2, 3, 4, 5), (1, 1, 1, 1)),
  630. ((2, 3, 4, 5), (1, 2, 3, 4)),
  631. ((2, 3, 4, 5), (2, 2, 2, 2, 2, 2, 2)),
  632. ],
  633. )
  634. @pytest.mark.parametrize("is_varnode", [True])
  635. def test_tile(shape, reps, is_varnode):
  636. if is_varnode:
  637. network = Network()
  638. else:
  639. network = None
  640. def tile_func(inp):
  641. return F.tile(inp=inp, reps=reps)
  642. cases = [{"input": np.random.randn(*shape).astype("float32")}]
  643. opr_test(cases, tile_func, ref_fn=lambda inp: np.tile(inp, reps), network=network)
  644. @pytest.mark.parametrize(
  645. "shape, shifts, axis",
  646. [
  647. ((2, 3), 0, None),
  648. ((2, 3), 1, 0),
  649. ((2, 3), 100, 0),
  650. ((2, 3), -100, 0),
  651. ((2, 3, 4, 5), (-1, 1), (0, 1)),
  652. ((2, 3, 4, 5), (-2, 1, 2), (1, 2, 3)),
  653. ],
  654. )
  655. @pytest.mark.parametrize("is_varnode", [True, False])
  656. def test_roll(shape, shifts, axis, is_varnode):
  657. if is_varnode:
  658. network = Network()
  659. else:
  660. network = None
  661. inp = np.random.randn(*shape).astype("float32")
  662. def func(inp):
  663. return F.roll(inp, shifts, axis)
  664. cases = [
  665. {"input": inp},
  666. ]
  667. opr_test(
  668. cases, func, ref_fn=lambda inp: np.roll(inp, shifts, axis), network=network
  669. )
  670. @pytest.mark.parametrize(
  671. "shape, shifts, axis", [((10, 0), 5, 1), ((10, 0), -10, 1),],
  672. )
  673. @pytest.mark.parametrize("is_symbolic", [None, True, False])
  674. def test_roll_empty_tensor(shape, shifts, axis, is_symbolic):
  675. inp = tensor(np.random.randn(*shape).astype("float32"))
  676. def func(inp):
  677. return F.roll(inp, shifts, axis)
  678. if is_symbolic is not None:
  679. func = trace(symbolic=is_symbolic)(func)
  680. out_ref = np.roll(inp.numpy(), shifts, axis)
  681. for _ in range(3):
  682. out = F.roll(inp, shifts, axis)
  683. np.testing.assert_equal(out.numpy(), out_ref)
  684. if is_symbolic is None:
  685. break

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台