You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_tensor.py 24 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862
  1. # -*- coding: utf-8 -*-
  2. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  3. #
  4. # Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  5. #
  6. # Unless required by applicable law or agreed to in writing,
  7. # software distributed under the License is distributed on an
  8. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  9. import os
  10. import platform
  11. import numpy as np
  12. import pytest
  13. from utils import get_var_value, make_tensor, opr_test
  14. import megengine.functional as F
  15. from megengine import tensor
  16. from megengine.core._trace_option import use_symbolic_shape
  17. from megengine.core.tensor import megbrain_graph as G
  18. from megengine.core.tensor.utils import astensor1d
  19. from megengine.jit import trace
  20. from megengine.utils.network import Network, set_symbolic_shape
  21. from megengine.utils.network_node import VarNode
  22. def test_eye():
  23. dtype = np.float32
  24. cases = [{"input": [10, 20]}, {"input": [30]}]
  25. for case in cases:
  26. np.testing.assert_allclose(
  27. F.eye(case["input"], dtype=dtype).numpy(),
  28. np.eye(*case["input"]).astype(dtype),
  29. )
  30. np.testing.assert_allclose(
  31. F.eye(*case["input"], dtype=dtype).numpy(),
  32. np.eye(*case["input"]).astype(dtype),
  33. )
  34. np.testing.assert_allclose(
  35. F.eye(tensor(case["input"]), dtype=dtype).numpy(),
  36. np.eye(*case["input"]).astype(dtype),
  37. )
  38. def test_full():
  39. shape = (2, 3)
  40. values = [True, 4, 5.0]
  41. for value in values:
  42. np.testing.assert_allclose(F.full(shape, value).numpy(), np.full(shape, value))
  43. assert F.full(shape, value).dtype == tensor(value).dtype
  44. @pytest.mark.parametrize("is_varnode", [True, False])
  45. def test_concat(is_varnode):
  46. if is_varnode:
  47. network = Network()
  48. else:
  49. network = None
  50. def get_data_shape(length: int):
  51. return (length, 2, 3)
  52. data1 = np.random.random(get_data_shape(5)).astype("float32")
  53. data2 = np.random.random(get_data_shape(6)).astype("float32")
  54. data3 = np.random.random(get_data_shape(7)).astype("float32")
  55. def run(data1, data2):
  56. return F.concat([data1, data2])
  57. cases = [{"input": [data1, data2]}, {"input": [data1, data3]}]
  58. opr_test(cases, run, ref_fn=lambda x, y: np.concatenate([x, y]), network=network)
  59. @pytest.mark.parametrize("is_varnode", [True, False])
  60. def test_condtake(is_varnode):
  61. if is_varnode:
  62. network = Network()
  63. else:
  64. network = None
  65. x = np.array([[1, 2, 3], [4, 5, 6]]).astype("float32")
  66. y = np.array([[True, False, True], [False, True, True]])
  67. xx = make_tensor(x, network)
  68. yy = make_tensor(y, network)
  69. val, idx = F.cond_take(yy, xx)
  70. if is_varnode:
  71. np.testing.assert_equal(get_var_value(val), x[y])
  72. np.testing.assert_equal(get_var_value(idx), np.where(y.reshape(-1))[0])
  73. else:
  74. np.testing.assert_equal(val.numpy(), x[y])
  75. np.testing.assert_equal(idx.numpy(), np.where(y.reshape(-1))[0])
  76. @pytest.mark.parametrize("is_varnode", [True, False])
  77. def test_concat_device(is_varnode):
  78. if is_varnode:
  79. network = Network()
  80. else:
  81. network = None
  82. data1 = make_tensor(np.random.random((3, 2, 2)).astype("float32"), network, "cpu0")
  83. data2 = make_tensor(np.random.random((2, 2, 2)).astype("float32"), network, "cpu1")
  84. out = F.concat([data1, data2], device="cpu0")
  85. assert str(out.device).split(":")[0] == "cpu0"
  86. @pytest.mark.parametrize("is_varnode", [True, False])
  87. def test_stack(is_varnode):
  88. if is_varnode:
  89. network = Network()
  90. else:
  91. network = None
  92. data1 = np.random.random((3, 2, 2)).astype("float32")
  93. data2 = np.random.random((3, 2, 2)).astype("float32")
  94. data3 = np.random.random((3, 2, 2)).astype("float32")
  95. cases = [{"input": [data1, data2]}, {"input": [data1, data3]}]
  96. for ai in range(3):
  97. def run(data1, data2):
  98. return F.stack([data1, data2], axis=ai)
  99. opr_test(
  100. cases, run, ref_fn=lambda x, y: np.stack([x, y], axis=ai), network=network
  101. )
  102. @pytest.mark.parametrize("is_varnode", [True, False])
  103. def test_split_basic(is_varnode):
  104. if is_varnode:
  105. network = Network()
  106. saved_symbolic_shape = set_symbolic_shape(False)
  107. else:
  108. network = None
  109. data = np.random.random((2, 3, 4, 5)).astype(np.float32)
  110. inp = make_tensor(data, network)
  111. mge_out0 = F.split(inp, 2, axis=3)
  112. mge_out1 = F.split(inp, [3], axis=3)
  113. np_out = np.split(data, [3, 5], axis=3)
  114. assert len(mge_out0) == 2
  115. assert len(mge_out1) == 2
  116. np.testing.assert_equal(mge_out0[0].numpy(), np_out[0])
  117. np.testing.assert_equal(mge_out1[0].numpy(), np_out[0])
  118. np.testing.assert_equal(mge_out0[1].numpy(), np_out[1])
  119. np.testing.assert_equal(mge_out1[1].numpy(), np_out[1])
  120. try:
  121. F.split(inp, 4)
  122. assert False
  123. except ValueError as e:
  124. pass
  125. try:
  126. F.split(inp, [3, 2, 5], axis=3)
  127. assert False
  128. except ValueError as e:
  129. assert str(e) == "Invalid nsplits_or_secions: [3, 2, 5]"
  130. if is_varnode:
  131. set_symbolic_shape(saved_symbolic_shape)
  132. @pytest.mark.parametrize("symbolic", [None, False, True])
  133. def test_split(symbolic):
  134. inp1 = np.random.random((3, 4, 5, 6)).astype(np.float32)
  135. inp2 = np.random.random((0, 4, 5, 6)).astype(np.float32)
  136. def ref(inp, nsplits_or_sections, axis):
  137. return np.split(inp, nsplits_or_sections, axis)
  138. def func(inp, nsplits_or_sections, axis):
  139. return F.split(inp, nsplits_or_sections, axis)
  140. cases = [
  141. (inp1, 2, 3),
  142. (inp1, [3], 3),
  143. (inp1, [3, 3, 5], 3),
  144. (inp2, 2, 3),
  145. (inp2, [3], 3),
  146. (inp2, [3, 3, 5], 3),
  147. ]
  148. for case in cases:
  149. if symbolic is None:
  150. fn = func
  151. else:
  152. fn = trace(symbolic=symbolic)(func)
  153. for i in range(3 if symbolic is not None else 1):
  154. ref_out = ref(*case)
  155. out = fn(tensor(case[0]), case[1], case[2])
  156. assert len(ref_out) == len(out)
  157. for idx in range(len(ref_out)):
  158. np.testing.assert_equal(ref_out[idx], out[idx].numpy())
  159. @pytest.mark.parametrize("is_varnode", [True, False])
  160. def test_reshape(is_varnode):
  161. if is_varnode:
  162. network = Network()
  163. else:
  164. network = None
  165. x = np.arange(6, dtype="float32")
  166. xx = make_tensor(x, network)
  167. y = x.reshape(1, 2, 3)
  168. for shape in [
  169. (1, 2, 3),
  170. (1, -1, 3),
  171. (1, make_tensor(-1, network), 3),
  172. np.array([1, -1, 3], dtype="int32"),
  173. make_tensor([1, -1, 3], network),
  174. ]:
  175. yy = F.reshape(xx, shape)
  176. np.testing.assert_equal(yy.numpy(), y)
  177. @pytest.mark.parametrize("is_trace", [True, False])
  178. def test_reshape_on_empty_tensor(is_trace):
  179. input1_shape = (100, 0, 1)
  180. output1_shape = (100, 0, 10)
  181. data1 = tensor(np.random.random(input1_shape).astype(np.float32))
  182. input2_shape = (10, 0)
  183. output2_shape = (0,)
  184. data2 = tensor(np.random.random(input2_shape).astype(np.float32))
  185. input3_shape = (10, 0, 10)
  186. output3_shape = (0, 1, 2, 3)
  187. data3 = tensor(np.random.random(input3_shape).astype(np.float32))
  188. def comp(out, target_shp):
  189. assert out._tuple_shape == target_shp
  190. def func(x, shp):
  191. return F.reshape(x, shp)
  192. cases = [
  193. [data1, output1_shape],
  194. [data2, output2_shape],
  195. [data3, output3_shape],
  196. ]
  197. def test(func, inp, comp, target_shp):
  198. out = func(inp, target_shp)
  199. comp(out, target_shp)
  200. if is_trace:
  201. for symbolic in [False, True]:
  202. for inp, target_shp in cases:
  203. func_traced = trace(symbolic=symbolic)(func)
  204. test(func_traced, inp, comp, target_shp)
  205. test(func_traced, inp, comp, target_shp)
  206. test(func_traced, inp, comp, target_shp)
  207. else:
  208. for inp, target_shp in cases:
  209. test(func, inp, comp, target_shp)
  210. @pytest.mark.parametrize("is_varnode", [True, False])
  211. def test_reshape_shape_inference(is_varnode):
  212. if is_varnode:
  213. network = Network()
  214. saved_symbolic_shape = set_symbolic_shape(False)
  215. else:
  216. network = None
  217. x_shape_known = make_tensor([1, 2, 3, 4], network)
  218. x_shape_unknown = F.broadcast_to(
  219. make_tensor([1.0], network), shape=make_tensor([1, 1, 1, 1], network).sum()
  220. )
  221. tshp_unknown = astensor1d(
  222. (make_tensor([2], network), make_tensor([2], network)), x_shape_known
  223. )
  224. tshp_known = astensor1d((2, 2), x_shape_known)
  225. tshp_known_unspec = astensor1d((2, -1), x_shape_known)
  226. def check_shape(output, target):
  227. source = output.shape
  228. if isinstance(source, tensor):
  229. source = source.numpy()
  230. np.testing.assert_equal(source, target)
  231. def func(x, target_shape):
  232. return x.reshape(target_shape)
  233. cases = [
  234. {"input": [x_shape_known, tshp_unknown], "output": [(2, 2),]},
  235. {"input": [x_shape_unknown, tshp_unknown], "output": [(2, 2),]},
  236. {"input": [x_shape_known, tshp_known], "output": [(2, 2),]},
  237. {"input": [x_shape_known, tshp_known_unspec], "output": [(2, 2),]},
  238. {"input": [x_shape_unknown, tshp_known], "output": [(2, 2),]},
  239. {"input": [x_shape_unknown, tshp_known_unspec], "output": [(2, 2),]},
  240. ]
  241. opr_test(cases, func, compare_fn=check_shape, test_trace=True, network=network)
  242. if is_varnode:
  243. set_symbolic_shape(saved_symbolic_shape)
  244. @pytest.mark.parametrize("is_varnode", [True, False])
  245. def test_squeeze(is_varnode):
  246. if is_varnode:
  247. network = Network()
  248. saved_symbolic_shape = set_symbolic_shape(False)
  249. else:
  250. network = None
  251. x = np.arange(6, dtype="float32").reshape(1, 2, 3, 1)
  252. xx = make_tensor(x, network)
  253. for axis in [None, 3, -4, (3, -4)]:
  254. y = np.squeeze(x, axis)
  255. yy = F.squeeze(xx, axis)
  256. np.testing.assert_equal(y, yy.numpy())
  257. if is_varnode:
  258. set_symbolic_shape(saved_symbolic_shape)
  259. @pytest.mark.parametrize("is_varnode", [True, False])
  260. def test_expand_dims(is_varnode):
  261. if is_varnode:
  262. network = Network()
  263. else:
  264. network = None
  265. x = np.arange(6, dtype="float32").reshape(2, 3)
  266. xx = make_tensor(x, network)
  267. for axis in [2, -3, (3, -4), (1, -4)]:
  268. y = np.expand_dims(x, axis)
  269. yy = F.expand_dims(xx, axis)
  270. np.testing.assert_equal(y, yy.numpy())
  271. def test_expand_dims_for_scalar():
  272. x = np.array(1, dtype="float32")
  273. xx = make_tensor(x, None)
  274. for axis in [0, -1, (0, 1), (-1, -2), (0, -1)]:
  275. y = np.expand_dims(x, axis)
  276. yy = F.expand_dims(xx, axis)
  277. np.testing.assert_equal(y, yy.numpy())
  278. for axis in [1, -2, (1, 2), (-2, -3)]:
  279. np.testing.assert_raises(np.AxisError, np.expand_dims, x, axis)
  280. np.testing.assert_raises(AssertionError, F.expand_dims, xx, axis)
  281. @pytest.mark.parametrize("is_varnode", [True, False])
  282. def test_elemwise_dtype_promotion(is_varnode):
  283. if is_varnode:
  284. network = Network()
  285. else:
  286. network = None
  287. x = np.random.rand(2, 3).astype("float32")
  288. y = np.random.rand(1, 3).astype("float16")
  289. xx = make_tensor(x, network)
  290. yy = make_tensor(y, network)
  291. z = xx * yy
  292. np.testing.assert_equal(z.numpy(), x * y)
  293. z = xx + y
  294. np.testing.assert_equal(z.numpy(), x + y)
  295. z = x - yy
  296. np.testing.assert_equal(z.numpy(), x - y)
  297. @pytest.mark.parametrize("is_varnode", [True, False])
  298. def test_linspace(is_varnode):
  299. if is_varnode:
  300. network = Network()
  301. else:
  302. network = None
  303. cases = [
  304. {"input": [1, 9, 9]},
  305. {"input": [3, 10, 8]},
  306. ]
  307. opr_test(
  308. cases,
  309. F.linspace,
  310. ref_fn=lambda start, end, step: np.linspace(start, end, step, dtype=np.float32),
  311. network=network,
  312. )
  313. cases = [
  314. {"input": [9, 1, 9]},
  315. {"input": [10, 3, 8]},
  316. ]
  317. opr_test(
  318. cases,
  319. F.linspace,
  320. ref_fn=lambda start, end, step: np.linspace(start, end, step, dtype=np.float32),
  321. network=network,
  322. )
  323. cases = [
  324. {"input": [1, make_tensor(9, network), 9]},
  325. {"input": [make_tensor(1, network), 9, make_tensor(9, network)]},
  326. ]
  327. opr_test(
  328. cases,
  329. F.linspace,
  330. ref_fn=lambda start, end, step: np.linspace(1, 9, 9, dtype=np.float32),
  331. network=network,
  332. )
  333. @pytest.mark.parametrize("is_varnode", [True, False])
  334. def test_arange(is_varnode):
  335. if is_varnode:
  336. network = Network()
  337. else:
  338. network = None
  339. cases = [
  340. {"input": [1, 9, 1]},
  341. {"input": [2, 10, 2]},
  342. ]
  343. opr_test(
  344. cases,
  345. F.arange,
  346. ref_fn=lambda start, end, step: np.arange(start, end, step, dtype=np.float32),
  347. network=network,
  348. )
  349. cases = [
  350. {"input": [9, 1, -1]},
  351. {"input": [10, 2, -2]},
  352. ]
  353. opr_test(
  354. cases,
  355. F.arange,
  356. ref_fn=lambda start, end, step: np.arange(start, end, step, dtype=np.float32),
  357. network=network,
  358. )
  359. cases = [
  360. {"input": [9.3, 1.2, -0.5]},
  361. {"input": [10.3, 2.1, -1.7]},
  362. ]
  363. opr_test(
  364. cases,
  365. F.arange,
  366. ref_fn=lambda start, end, step: np.arange(start, end, step, dtype=np.float32),
  367. network=network,
  368. )
  369. @pytest.mark.parametrize("is_varnode", [True, False])
  370. def test_round(is_varnode):
  371. if is_varnode:
  372. network = Network()
  373. else:
  374. network = None
  375. data1_shape = (15,)
  376. data2_shape = (25,)
  377. data1 = np.random.random(data1_shape).astype(np.float32)
  378. data2 = np.random.random(data2_shape).astype(np.float32)
  379. cases = [{"input": data1}, {"input": data2}]
  380. opr_test(cases, F.round, ref_fn=np.round, network=network)
  381. @pytest.mark.parametrize("is_varnode", [True, False])
  382. def test_flatten(is_varnode):
  383. if is_varnode:
  384. network = Network()
  385. else:
  386. network = None
  387. data0_shape = (2, 3, 4, 5)
  388. data1_shape = (4, 5, 6, 7)
  389. data0 = np.random.random(data0_shape).astype(np.float32)
  390. data1 = np.random.random(data1_shape).astype(np.float32)
  391. def compare_fn(x, y):
  392. assert x._tuple_shape[0] == y
  393. output0 = (2 * 3 * 4 * 5,)
  394. output1 = (4 * 5 * 6 * 7,)
  395. cases = [
  396. {"input": data0, "output": output0},
  397. {"input": data1, "output": output1},
  398. ]
  399. opr_test(cases, F.flatten, compare_fn=compare_fn, network=network)
  400. output0 = (2, 3 * 4 * 5)
  401. output1 = (4, 5 * 6 * 7)
  402. cases = [
  403. {"input": data0, "output": output0},
  404. {"input": data1, "output": output1},
  405. ]
  406. opr_test(cases, F.flatten, compare_fn=compare_fn, start_axis=1, network=network)
  407. output0 = (2, 3, 4 * 5)
  408. output1 = (4, 5, 6 * 7)
  409. cases = [
  410. {"input": data0, "output": output0},
  411. {"input": data1, "output": output1},
  412. ]
  413. opr_test(cases, F.flatten, compare_fn=compare_fn, start_axis=2, network=network)
  414. output0 = (2, 3 * 4, 5)
  415. output1 = (4, 5 * 6, 7)
  416. cases = [
  417. {"input": data0, "output": output0},
  418. {"input": data1, "output": output1},
  419. ]
  420. opr_test(
  421. cases,
  422. F.flatten,
  423. compare_fn=compare_fn,
  424. start_axis=1,
  425. end_axis=2,
  426. network=network,
  427. )
  428. @pytest.mark.parametrize("is_varnode", [True, False])
  429. def test_broadcast(is_varnode):
  430. if is_varnode:
  431. network = Network()
  432. else:
  433. network = None
  434. input1_shape = (20, 30)
  435. output1_shape = (30, 20, 30)
  436. data1 = np.random.random(input1_shape).astype(np.float32)
  437. input2_shape = (10, 1)
  438. output2_shape = (20, 10, 20)
  439. data2 = np.random.random(input2_shape).astype(np.float32)
  440. input3_shape = (10, 10)
  441. output3_shape = (10, 10)
  442. data3 = np.random.random(input3_shape).astype(np.float32)
  443. def compare_fn(x, y):
  444. assert x._tuple_shape[0] == y
  445. cases = [
  446. {"input": [data1, output1_shape], "output": output1_shape},
  447. {"input": [data2, output2_shape], "output": output2_shape},
  448. {"input": [data3, output3_shape], "output": output3_shape},
  449. ]
  450. opr_test(cases, F.broadcast_to, compare_fn=compare_fn, network=network)
  451. x = F.ones((2, 1, 3))
  452. with pytest.raises(RuntimeError):
  453. F.broadcast_to(x, (2, 3, 4))
  454. with pytest.raises(RuntimeError):
  455. F.broadcast_to(x, (4, 1, 3))
  456. with pytest.raises(RuntimeError):
  457. F.broadcast_to(x, (1, 3))
  458. @pytest.mark.parametrize("is_trace", [True, False])
  459. def test_broadcast_on_empty_tensor(is_trace):
  460. input1_shape = (100, 0, 1)
  461. output1_shape = (100, 0, 10)
  462. data1 = tensor(np.random.random(input1_shape).astype(np.float32))
  463. input2_shape = (10, 0)
  464. output2_shape = (10, 10, 0)
  465. data2 = tensor(np.random.random(input2_shape).astype(np.float32))
  466. input3_shape = (0, 0, 1, 10)
  467. output3_shape = (10, 0, 0, 10, 10)
  468. data3 = tensor(np.random.random(input3_shape).astype(np.float32))
  469. def comp(out, target_shp):
  470. assert out._tuple_shape == target_shp
  471. def func(x, shp):
  472. return F.broadcast_to(x, shp)
  473. cases = [
  474. [data1, output1_shape],
  475. [data2, output2_shape],
  476. [data3, output3_shape],
  477. ]
  478. def test(func, inp, comp, target_shp):
  479. out = func(inp, target_shp)
  480. comp(out, target_shp)
  481. if is_trace:
  482. for symbolic in [False, True]:
  483. for inp, target_shp in cases:
  484. func_traced = trace(symbolic=symbolic)(func)
  485. test(func_traced, inp, comp, target_shp)
  486. test(func_traced, inp, comp, target_shp)
  487. test(func_traced, inp, comp, target_shp)
  488. else:
  489. for inp, target_shp in cases:
  490. test(func, inp, comp, target_shp)
  491. @pytest.mark.parametrize("is_varnode", [True, False])
  492. def test_utils_astensor1d(is_varnode):
  493. if is_varnode:
  494. network = Network()
  495. else:
  496. network = None
  497. reference = make_tensor(0, network)
  498. # literal
  499. x = [1, 2, 3]
  500. for dtype in [None, "float32"]:
  501. xx = astensor1d(x, reference, dtype=dtype)
  502. assert isinstance(xx, type(reference))
  503. np.testing.assert_equal(xx.numpy(), x)
  504. # numpy array
  505. x = np.asarray([1, 2, 3], dtype="int32")
  506. for dtype in [None, "float32"]:
  507. xx = astensor1d(x, reference, dtype=dtype)
  508. assert isinstance(xx, type(reference))
  509. np.testing.assert_equal(xx.numpy(), x.astype(dtype) if dtype else x)
  510. # tensor
  511. x = make_tensor([1, 2, 3], network)
  512. for dtype in [None, "float32"]:
  513. xx = astensor1d(x, reference, dtype=dtype)
  514. assert isinstance(xx, type(reference))
  515. np.testing.assert_equal(xx.numpy(), x.numpy())
  516. # mixed
  517. x = [1, make_tensor(2, network), 3]
  518. for dtype in [None, "float32"]:
  519. xx = astensor1d(x, reference, dtype=dtype)
  520. assert isinstance(xx, type(reference))
  521. np.testing.assert_equal(xx.numpy(), [1, 2, 3])
  522. def test_device():
  523. x = tensor([1, 2, 3], dtype="float32")
  524. y1 = F.eye(x.shape, dtype="float32")
  525. y2 = F.eye(x.shape, dtype="float32", device=None)
  526. np.testing.assert_almost_equal(y1.numpy(), y2.numpy())
  527. y3 = F.eye(x.shape, dtype="float32", device="xpux")
  528. y4 = F.eye(x.shape, dtype="float32", device=x.device)
  529. np.testing.assert_almost_equal(y3.numpy(), y4.numpy())
  530. y5 = F.full((3, 2), 4, device=x.device)
  531. y6 = F.full((3, 2), 4, device="xpux")
  532. np.testing.assert_almost_equal(y5.numpy(), y6.numpy())
  533. @pytest.mark.parametrize("is_varnode", [True, False])
  534. def test_identity(is_varnode):
  535. if is_varnode:
  536. network = Network()
  537. else:
  538. network = None
  539. x = make_tensor(np.random.random((5, 10)).astype(np.float32), network)
  540. y = F.copy(x)
  541. np.testing.assert_equal(y.numpy(), x)
  542. def copy_test(dst, src, network):
  543. data = np.random.random((2, 3)).astype(np.float32)
  544. x = make_tensor(data, device=src, network=network)
  545. y = F.copy(x, dst)
  546. assert np.allclose(data, y.numpy())
  547. if network is None:
  548. z = x.to(dst)
  549. assert np.allclose(data, z.numpy())
  550. @pytest.mark.require_ngpu(1)
  551. @pytest.mark.parametrize("is_varnode", [True, False])
  552. def test_copy_h2d(is_varnode):
  553. if is_varnode:
  554. network = Network()
  555. else:
  556. network = None
  557. copy_test("cpu0", "gpu0", network=network)
  558. @pytest.mark.require_ngpu(1)
  559. @pytest.mark.parametrize("is_varnode", [True, False])
  560. def test_copy_d2h(is_varnode):
  561. if is_varnode:
  562. network = Network()
  563. else:
  564. network = None
  565. copy_test("gpu0", "cpu0", network=network)
  566. @pytest.mark.require_ngpu(2)
  567. @pytest.mark.parametrize("is_varnode", [True, False])
  568. def test_copy_d2d(is_varnode):
  569. if is_varnode:
  570. network = Network()
  571. else:
  572. network = None
  573. copy_test("gpu0", "gpu1", network=network)
  574. copy_test("gpu0:0", "gpu0:1", network=network)
  575. @pytest.mark.require_ngpu(2)
  576. @pytest.mark.parametrize(
  577. "shape, device_src, device_dst",
  578. [
  579. ((0,), "cpu0", "cpu0"),
  580. ((10, 0), "cpu0", "cpu1"),
  581. ((2, 0, 3), "cpu0", "gpu0"),
  582. ((1, 0, 1, 0), "gpu0", "cpu0"),
  583. ((2, 3, 4, 5, 0), "gpu0", "gpu1"),
  584. ],
  585. )
  586. @pytest.mark.parametrize("is_symbolic", [None, True, False])
  587. def test_copy_empty(shape, device_src, device_dst, is_symbolic):
  588. inp = tensor(np.random.randn(*shape).astype("float32"), device=device_src)
  589. def func(inp):
  590. return F.copy(inp, device_dst)
  591. if is_symbolic is not None:
  592. func = trace(symbolic=is_symbolic)(func)
  593. for _ in range(3):
  594. out = func(inp)
  595. assert out.numpy().shape == shape
  596. assert out.device == device_dst
  597. if is_symbolic is None:
  598. break
  599. @pytest.mark.parametrize(
  600. "shape, repeats, axis",
  601. [
  602. ((2,), 2, 0),
  603. ((2, 3, 4, 5), 3, 0),
  604. ((2, 3, 4, 5), 4, 3),
  605. ((2,), 2, None),
  606. ((2, 3, 4, 5), 3, None),
  607. ((), 1, None),
  608. ((), 10, None),
  609. ],
  610. )
  611. @pytest.mark.parametrize("is_varnode", [True, False])
  612. def test_repeat(shape, repeats, axis, is_varnode):
  613. if is_varnode:
  614. network = Network()
  615. else:
  616. network = None
  617. def repeat_func(inp):
  618. return F.repeat(inp=inp, repeats=repeats, axis=axis)
  619. if shape != ():
  620. cases = [
  621. {"input": np.random.randn(*shape).astype("float32")},
  622. ]
  623. else:
  624. cases = [{"input": np.array(1.23)}]
  625. opr_test(
  626. cases,
  627. repeat_func,
  628. ref_fn=lambda inp: np.repeat(inp, repeats, axis),
  629. network=network,
  630. )
  631. @pytest.mark.parametrize(
  632. "shape, reps",
  633. [
  634. ((2,), (2,)),
  635. ((2, 3, 4, 5), (1, 1, 1, 1)),
  636. ((2, 3, 4, 5), (1, 2, 3, 4)),
  637. ((2, 3, 4, 5), (2, 2, 2, 2, 2, 2, 2)),
  638. ],
  639. )
  640. @pytest.mark.parametrize("is_varnode", [True])
  641. def test_tile(shape, reps, is_varnode):
  642. if is_varnode:
  643. network = Network()
  644. else:
  645. network = None
  646. def tile_func(inp):
  647. return F.tile(inp=inp, reps=reps)
  648. cases = [{"input": np.random.randn(*shape).astype("float32")}]
  649. opr_test(cases, tile_func, ref_fn=lambda inp: np.tile(inp, reps), network=network)
  650. @pytest.mark.parametrize(
  651. "shape, shifts, axis",
  652. [
  653. ((2, 3), 0, None),
  654. ((2, 3), 1, 0),
  655. ((2, 3), 100, 0),
  656. ((2, 3), -100, 0),
  657. ((2, 3, 4, 5), (-1, 1), (0, 1)),
  658. ((2, 3, 4, 5), (-2, 1, 2), (1, 2, 3)),
  659. ],
  660. )
  661. @pytest.mark.parametrize("is_varnode", [True, False])
  662. def test_roll(shape, shifts, axis, is_varnode):
  663. if is_varnode:
  664. network = Network()
  665. else:
  666. network = None
  667. inp = np.random.randn(*shape).astype("float32")
  668. def func(inp):
  669. return F.roll(inp, shifts, axis)
  670. cases = [
  671. {"input": inp},
  672. ]
  673. opr_test(
  674. cases, func, ref_fn=lambda inp: np.roll(inp, shifts, axis), network=network
  675. )
  676. @pytest.mark.parametrize(
  677. "shape, shifts, axis", [((10, 0), 5, 1), ((10, 0), -10, 1),],
  678. )
  679. @pytest.mark.parametrize("is_symbolic", [None, True, False])
  680. def test_roll_empty_tensor(shape, shifts, axis, is_symbolic):
  681. inp = tensor(np.random.randn(*shape).astype("float32"))
  682. def func(inp):
  683. return F.roll(inp, shifts, axis)
  684. if is_symbolic is not None:
  685. func = trace(symbolic=is_symbolic)(func)
  686. out_ref = np.roll(inp.numpy(), shifts, axis)
  687. for _ in range(3):
  688. out = F.roll(inp, shifts, axis)
  689. np.testing.assert_equal(out.numpy(), out_ref)
  690. if is_symbolic is None:
  691. break

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台