You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_tensor.py 18 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638
  1. # -*- coding: utf-8 -*-
  2. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  3. #
  4. # Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  5. #
  6. # Unless required by applicable law or agreed to in writing,
  7. # software distributed under the License is distributed on an
  8. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  9. import os
  10. import platform
  11. import numpy as np
  12. import pytest
  13. from utils import make_tensor, opr_test
  14. import megengine.functional as F
  15. from megengine import tensor
  16. from megengine.core._trace_option import use_symbolic_shape
  17. from megengine.core.tensor import megbrain_graph as G
  18. from megengine.core.tensor.utils import astensor1d
  19. from megengine.distributed.helper import get_device_count_by_fork
  20. from megengine.utils.network import Network, set_symbolic_shape
  21. from megengine.utils.network_node import VarNode
  22. def test_eye():
  23. dtype = np.float32
  24. cases = [{"input": [10, 20]}, {"input": [30]}]
  25. for case in cases:
  26. np.testing.assert_allclose(
  27. F.eye(case["input"], dtype=dtype).numpy(),
  28. np.eye(*case["input"]).astype(dtype),
  29. )
  30. np.testing.assert_allclose(
  31. F.eye(*case["input"], dtype=dtype).numpy(),
  32. np.eye(*case["input"]).astype(dtype),
  33. )
  34. np.testing.assert_allclose(
  35. F.eye(tensor(case["input"]), dtype=dtype).numpy(),
  36. np.eye(*case["input"]).astype(dtype),
  37. )
  38. @pytest.mark.parametrize("is_varnode", [True, False])
  39. def test_concat(is_varnode):
  40. if is_varnode:
  41. network = Network()
  42. else:
  43. network = None
  44. def get_data_shape(length: int):
  45. return (length, 2, 3)
  46. data1 = np.random.random(get_data_shape(5)).astype("float32")
  47. data2 = np.random.random(get_data_shape(6)).astype("float32")
  48. data3 = np.random.random(get_data_shape(7)).astype("float32")
  49. def run(data1, data2):
  50. return F.concat([data1, data2])
  51. cases = [{"input": [data1, data2]}, {"input": [data1, data3]}]
  52. opr_test(cases, run, ref_fn=lambda x, y: np.concatenate([x, y]), network=network)
  53. @pytest.mark.parametrize("is_varnode", [True, False])
  54. def test_condtake(is_varnode):
  55. if is_varnode:
  56. network = Network()
  57. else:
  58. network = None
  59. x = np.array([[1, 2, 3], [4, 5, 6]]).astype("float32")
  60. y = np.array([[True, False, True], [False, True, True]])
  61. xx = make_tensor(x, network)
  62. yy = make_tensor(y, network)
  63. val, idx = F.cond_take(yy, xx)
  64. np.testing.assert_equal(val.numpy(), x[y])
  65. np.testing.assert_equal(idx.numpy(), np.where(y.reshape(-1))[0])
  66. @pytest.mark.parametrize("is_varnode", [True, False])
  67. def test_concat_device(is_varnode):
  68. if is_varnode:
  69. network = Network()
  70. else:
  71. network = None
  72. data1 = make_tensor(np.random.random((3, 2, 2)).astype("float32"), network, "cpu0")
  73. data2 = make_tensor(np.random.random((2, 2, 2)).astype("float32"), network, "cpu1")
  74. out = F.concat([data1, data2], device="cpu0")
  75. assert str(out.device).split(":")[0] == "cpu0"
  76. @pytest.mark.parametrize("is_varnode", [True, False])
  77. def test_stack(is_varnode):
  78. if is_varnode:
  79. network = Network()
  80. else:
  81. network = None
  82. data1 = np.random.random((3, 2, 2)).astype("float32")
  83. data2 = np.random.random((3, 2, 2)).astype("float32")
  84. data3 = np.random.random((3, 2, 2)).astype("float32")
  85. cases = [{"input": [data1, data2]}, {"input": [data1, data3]}]
  86. for ai in range(3):
  87. def run(data1, data2):
  88. return F.stack([data1, data2], axis=ai)
  89. opr_test(
  90. cases, run, ref_fn=lambda x, y: np.stack([x, y], axis=ai), network=network
  91. )
  92. @pytest.mark.parametrize("is_varnode", [True, False])
  93. def test_split(is_varnode):
  94. if is_varnode:
  95. network = Network()
  96. saved_symbolic_shape = set_symbolic_shape(False)
  97. else:
  98. network = None
  99. data = np.random.random((2, 3, 4, 5)).astype(np.float32)
  100. inp = make_tensor(data, network)
  101. mge_out0 = F.split(inp, 2, axis=3)
  102. mge_out1 = F.split(inp, [3], axis=3)
  103. np_out = np.split(data, [3, 5], axis=3)
  104. assert len(mge_out0) == 2
  105. assert len(mge_out1) == 2
  106. np.testing.assert_equal(mge_out0[0].numpy(), np_out[0])
  107. np.testing.assert_equal(mge_out1[0].numpy(), np_out[0])
  108. np.testing.assert_equal(mge_out0[1].numpy(), np_out[1])
  109. np.testing.assert_equal(mge_out1[1].numpy(), np_out[1])
  110. try:
  111. F.split(inp, 4)
  112. assert False
  113. except ValueError as e:
  114. pass
  115. try:
  116. F.split(inp, [3, 3, 5], axis=3)
  117. assert False
  118. except ValueError as e:
  119. assert str(e) == "Invalid nsplits_or_secions: [3, 3, 5]"
  120. if is_varnode:
  121. set_symbolic_shape(saved_symbolic_shape)
  122. @pytest.mark.parametrize("is_varnode", [True, False])
  123. def test_reshape(is_varnode):
  124. if is_varnode:
  125. network = Network()
  126. else:
  127. network = None
  128. x = np.arange(6, dtype="float32")
  129. xx = make_tensor(x, network)
  130. y = x.reshape(1, 2, 3)
  131. for shape in [
  132. (1, 2, 3),
  133. (1, -1, 3),
  134. (1, make_tensor(-1, network), 3),
  135. np.array([1, -1, 3], dtype="int32"),
  136. make_tensor([1, -1, 3], network),
  137. ]:
  138. yy = F.reshape(xx, shape)
  139. np.testing.assert_equal(yy.numpy(), y)
  140. @pytest.mark.parametrize("is_varnode", [True, False])
  141. def test_reshape_shape_inference(is_varnode):
  142. if is_varnode:
  143. network = Network()
  144. saved_symbolic_shape = set_symbolic_shape(False)
  145. else:
  146. network = None
  147. x_shape_known = make_tensor([1, 2, 3, 4], network)
  148. x_shape_unknown = F.broadcast_to(
  149. make_tensor([1.0], network), shape=make_tensor([1, 1, 1, 1], network).sum()
  150. )
  151. tshp_unknown = astensor1d(
  152. (make_tensor([2], network), make_tensor([2], network)), x_shape_known
  153. )
  154. tshp_known = astensor1d((2, 2), x_shape_known)
  155. tshp_known_unspec = astensor1d((2, -1), x_shape_known)
  156. def check_shape(output, target):
  157. source = output.shape
  158. if isinstance(source, tensor):
  159. source = source.numpy()
  160. np.testing.assert_equal(source, target)
  161. def func(x, target_shape):
  162. return x.reshape(target_shape)
  163. cases = [
  164. {"input": [x_shape_known, tshp_unknown], "output": [(2, 2),]},
  165. {"input": [x_shape_unknown, tshp_unknown], "output": [(2, 2),]},
  166. {"input": [x_shape_known, tshp_known], "output": [(2, 2),]},
  167. {"input": [x_shape_known, tshp_known_unspec], "output": [(2, 2),]},
  168. {"input": [x_shape_unknown, tshp_known], "output": [(2, 2),]},
  169. {"input": [x_shape_unknown, tshp_known_unspec], "output": [(2, 2),]},
  170. ]
  171. opr_test(cases, func, compare_fn=check_shape, test_trace=True, network=network)
  172. if is_varnode:
  173. set_symbolic_shape(saved_symbolic_shape)
  174. @pytest.mark.parametrize("is_varnode", [True, False])
  175. def test_squeeze(is_varnode):
  176. if is_varnode:
  177. network = Network()
  178. saved_symbolic_shape = set_symbolic_shape(False)
  179. else:
  180. network = None
  181. x = np.arange(6, dtype="float32").reshape(1, 2, 3, 1)
  182. xx = make_tensor(x, network)
  183. for axis in [None, 3, -4, (3, -4)]:
  184. y = np.squeeze(x, axis)
  185. yy = F.squeeze(xx, axis)
  186. np.testing.assert_equal(y, yy.numpy())
  187. if is_varnode:
  188. set_symbolic_shape(saved_symbolic_shape)
  189. @pytest.mark.parametrize("is_varnode", [True, False])
  190. def test_expand_dims(is_varnode):
  191. if is_varnode:
  192. network = Network()
  193. else:
  194. network = None
  195. x = np.arange(6, dtype="float32").reshape(2, 3)
  196. xx = make_tensor(x, network)
  197. for axis in [2, -3, (3, -4), (1, -4)]:
  198. y = np.expand_dims(x, axis)
  199. yy = F.expand_dims(xx, axis)
  200. np.testing.assert_equal(y, yy.numpy())
  201. @pytest.mark.parametrize("is_varnode", [True, False])
  202. def test_elemwise_dtype_promotion(is_varnode):
  203. if is_varnode:
  204. network = Network()
  205. else:
  206. network = None
  207. x = np.random.rand(2, 3).astype("float32")
  208. y = np.random.rand(1, 3).astype("float16")
  209. xx = make_tensor(x, network)
  210. yy = make_tensor(y, network)
  211. z = xx * yy
  212. np.testing.assert_equal(z.numpy(), x * y)
  213. z = xx + y
  214. np.testing.assert_equal(z.numpy(), x + y)
  215. z = x - yy
  216. np.testing.assert_equal(z.numpy(), x - y)
  217. @pytest.mark.parametrize("is_varnode", [True, False])
  218. def test_linspace(is_varnode):
  219. if is_varnode:
  220. network = Network()
  221. else:
  222. network = None
  223. cases = [
  224. {"input": [1, 9, 9]},
  225. {"input": [3, 10, 8]},
  226. ]
  227. opr_test(
  228. cases,
  229. F.linspace,
  230. ref_fn=lambda start, end, step: np.linspace(start, end, step, dtype=np.float32),
  231. network=network,
  232. )
  233. cases = [
  234. {"input": [9, 1, 9]},
  235. {"input": [10, 3, 8]},
  236. ]
  237. opr_test(
  238. cases,
  239. F.linspace,
  240. ref_fn=lambda start, end, step: np.linspace(start, end, step, dtype=np.float32),
  241. network=network,
  242. )
  243. cases = [
  244. {"input": [1, make_tensor(9, network), 9]},
  245. {"input": [make_tensor(1, network), 9, make_tensor(9, network)]},
  246. ]
  247. opr_test(
  248. cases,
  249. F.linspace,
  250. ref_fn=lambda start, end, step: np.linspace(1, 9, 9, dtype=np.float32),
  251. network=network,
  252. )
  253. @pytest.mark.parametrize("is_varnode", [True, False])
  254. def test_arange(is_varnode):
  255. if is_varnode:
  256. network = Network()
  257. else:
  258. network = None
  259. cases = [
  260. {"input": [1, 9, 1]},
  261. {"input": [2, 10, 2]},
  262. ]
  263. opr_test(
  264. cases,
  265. F.arange,
  266. ref_fn=lambda start, end, step: np.arange(start, end, step, dtype=np.float32),
  267. network=network,
  268. )
  269. cases = [
  270. {"input": [9, 1, -1]},
  271. {"input": [10, 2, -2]},
  272. ]
  273. opr_test(
  274. cases,
  275. F.arange,
  276. ref_fn=lambda start, end, step: np.arange(start, end, step, dtype=np.float32),
  277. network=network,
  278. )
  279. cases = [
  280. {"input": [9.3, 1.2, -0.5]},
  281. {"input": [10.3, 2.1, -1.7]},
  282. ]
  283. opr_test(
  284. cases,
  285. F.arange,
  286. ref_fn=lambda start, end, step: np.arange(start, end, step, dtype=np.float32),
  287. network=network,
  288. )
  289. @pytest.mark.parametrize("is_varnode", [True, False])
  290. def test_round(is_varnode):
  291. if is_varnode:
  292. network = Network()
  293. else:
  294. network = None
  295. data1_shape = (15,)
  296. data2_shape = (25,)
  297. data1 = np.random.random(data1_shape).astype(np.float32)
  298. data2 = np.random.random(data2_shape).astype(np.float32)
  299. cases = [{"input": data1}, {"input": data2}]
  300. opr_test(cases, F.round, ref_fn=np.round, network=network)
  301. @pytest.mark.parametrize("is_varnode", [True, False])
  302. def test_flatten(is_varnode):
  303. if is_varnode:
  304. network = Network()
  305. else:
  306. network = None
  307. data0_shape = (2, 3, 4, 5)
  308. data1_shape = (4, 5, 6, 7)
  309. data0 = np.random.random(data0_shape).astype(np.float32)
  310. data1 = np.random.random(data1_shape).astype(np.float32)
  311. def compare_fn(x, y):
  312. assert x._tuple_shape[0] == y
  313. output0 = (2 * 3 * 4 * 5,)
  314. output1 = (4 * 5 * 6 * 7,)
  315. cases = [
  316. {"input": data0, "output": output0},
  317. {"input": data1, "output": output1},
  318. ]
  319. opr_test(cases, F.flatten, compare_fn=compare_fn, network=network)
  320. output0 = (2, 3 * 4 * 5)
  321. output1 = (4, 5 * 6 * 7)
  322. cases = [
  323. {"input": data0, "output": output0},
  324. {"input": data1, "output": output1},
  325. ]
  326. opr_test(cases, F.flatten, compare_fn=compare_fn, start_axis=1, network=network)
  327. output0 = (2, 3, 4 * 5)
  328. output1 = (4, 5, 6 * 7)
  329. cases = [
  330. {"input": data0, "output": output0},
  331. {"input": data1, "output": output1},
  332. ]
  333. opr_test(cases, F.flatten, compare_fn=compare_fn, start_axis=2, network=network)
  334. output0 = (2, 3 * 4, 5)
  335. output1 = (4, 5 * 6, 7)
  336. cases = [
  337. {"input": data0, "output": output0},
  338. {"input": data1, "output": output1},
  339. ]
  340. opr_test(
  341. cases,
  342. F.flatten,
  343. compare_fn=compare_fn,
  344. start_axis=1,
  345. end_axis=2,
  346. network=network,
  347. )
  348. @pytest.mark.parametrize("is_varnode", [True, False])
  349. def test_broadcast(is_varnode):
  350. if is_varnode:
  351. network = Network()
  352. else:
  353. network = None
  354. input1_shape = (20, 30)
  355. output1_shape = (30, 20, 30)
  356. data1 = np.random.random(input1_shape).astype(np.float32)
  357. input2_shape = (10, 1)
  358. output2_shape = (20, 10, 20)
  359. data2 = np.random.random(input2_shape).astype(np.float32)
  360. input3_shape = (10, 10)
  361. output3_shape = (10, 10)
  362. data3 = np.random.random(input3_shape).astype(np.float32)
  363. def compare_fn(x, y):
  364. assert x._tuple_shape[0] == y
  365. cases = [
  366. {"input": [data1, output1_shape], "output": output1_shape},
  367. {"input": [data2, output2_shape], "output": output2_shape},
  368. {"input": [data3, output3_shape], "output": output3_shape},
  369. ]
  370. opr_test(cases, F.broadcast_to, compare_fn=compare_fn, network=network)
  371. x = F.ones((2, 1, 3))
  372. with pytest.raises(RuntimeError):
  373. F.broadcast_to(x, (2, 3, 4))
  374. with pytest.raises(RuntimeError):
  375. F.broadcast_to(x, (4, 1, 3))
  376. with pytest.raises(RuntimeError):
  377. F.broadcast_to(x, (1, 3))
  378. @pytest.mark.parametrize("is_varnode", [True, False])
  379. def test_utils_astensor1d(is_varnode):
  380. if is_varnode:
  381. network = Network()
  382. else:
  383. network = None
  384. reference = make_tensor(0, network)
  385. # literal
  386. x = [1, 2, 3]
  387. for dtype in [None, "float32"]:
  388. xx = astensor1d(x, reference, dtype=dtype)
  389. assert isinstance(xx, type(reference))
  390. np.testing.assert_equal(xx.numpy(), x)
  391. # numpy array
  392. x = np.asarray([1, 2, 3], dtype="int32")
  393. for dtype in [None, "float32"]:
  394. xx = astensor1d(x, reference, dtype=dtype)
  395. assert isinstance(xx, type(reference))
  396. np.testing.assert_equal(xx.numpy(), x.astype(dtype) if dtype else x)
  397. # tensor
  398. x = make_tensor([1, 2, 3], network)
  399. for dtype in [None, "float32"]:
  400. xx = astensor1d(x, reference, dtype=dtype)
  401. assert isinstance(xx, type(reference))
  402. np.testing.assert_equal(xx.numpy(), x.numpy())
  403. # mixed
  404. x = [1, make_tensor(2, network), 3]
  405. for dtype in [None, "float32"]:
  406. xx = astensor1d(x, reference, dtype=dtype)
  407. assert isinstance(xx, type(reference))
  408. np.testing.assert_equal(xx.numpy(), [1, 2, 3])
  409. def test_device():
  410. x = tensor([1, 2, 3], dtype="float32")
  411. y1 = F.eye(x.shape, dtype="float32")
  412. y2 = F.eye(x.shape, dtype="float32", device=None)
  413. np.testing.assert_almost_equal(y1.numpy(), y2.numpy())
  414. y3 = F.eye(x.shape, dtype="float32", device="xpux")
  415. y4 = F.eye(x.shape, dtype="float32", device=x.device)
  416. np.testing.assert_almost_equal(y3.numpy(), y4.numpy())
  417. y5 = F.full((3, 2), 4, device=x.device)
  418. y6 = F.full((3, 2), 4, device="xpux")
  419. np.testing.assert_almost_equal(y5.numpy(), y6.numpy())
  420. @pytest.mark.parametrize("is_varnode", [True, False])
  421. def test_identity(is_varnode):
  422. if is_varnode:
  423. network = Network()
  424. else:
  425. network = None
  426. x = make_tensor(np.random.random((5, 10)).astype(np.float32), network)
  427. y = F.copy(x)
  428. np.testing.assert_equal(y.numpy(), x)
  429. def copy_test(dst, src, network):
  430. data = np.random.random((2, 3)).astype(np.float32)
  431. x = make_tensor(data, device=src, network=network)
  432. y = F.copy(x, dst)
  433. assert np.allclose(data, y.numpy())
  434. if network is None:
  435. z = x.to(dst)
  436. assert np.allclose(data, z.numpy())
  437. @pytest.mark.require_ngpu(1)
  438. @pytest.mark.parametrize("is_varnode", [True, False])
  439. def test_copy_h2d(is_varnode):
  440. if is_varnode:
  441. network = Network()
  442. else:
  443. network = None
  444. copy_test("cpu0", "gpu0", network=network)
  445. @pytest.mark.require_ngpu(1)
  446. @pytest.mark.parametrize("is_varnode", [True, False])
  447. def test_copy_d2h(is_varnode):
  448. if is_varnode:
  449. network = Network()
  450. else:
  451. network = None
  452. copy_test("gpu0", "cpu0", network=network)
  453. @pytest.mark.require_ngpu(2)
  454. @pytest.mark.parametrize("is_varnode", [True, False])
  455. def test_copy_d2d(is_varnode):
  456. if is_varnode:
  457. network = Network()
  458. else:
  459. network = None
  460. copy_test("gpu0", "gpu1", network=network)
  461. copy_test("gpu0:0", "gpu0:1", network=network)
  462. @pytest.mark.parametrize(
  463. "shape, repeats, axis",
  464. [
  465. ((2,), 2, 0),
  466. ((2, 3, 4, 5), 3, 0),
  467. ((2, 3, 4, 5), 4, 3),
  468. ((2,), 2, None),
  469. ((2, 3, 4, 5), 3, None),
  470. ((), 1, None),
  471. ((), 10, None),
  472. ],
  473. )
  474. @pytest.mark.parametrize("is_varnode", [True, False])
  475. def test_repeat(shape, repeats, axis, is_varnode):
  476. if is_varnode:
  477. network = Network()
  478. else:
  479. network = None
  480. def repeat_func(inp):
  481. return F.repeat(inp=inp, repeats=repeats, axis=axis)
  482. if shape != ():
  483. cases = [
  484. {"input": np.random.randn(*shape).astype("float32")},
  485. ]
  486. else:
  487. cases = [{"input": np.array(1.23)}]
  488. opr_test(
  489. cases,
  490. repeat_func,
  491. ref_fn=lambda inp: np.repeat(inp, repeats, axis),
  492. network=network,
  493. )
  494. @pytest.mark.parametrize(
  495. "shape, reps",
  496. [
  497. ((2,), (2,)),
  498. ((2, 3, 4, 5), (1, 1, 1, 1)),
  499. ((2, 3, 4, 5), (1, 2, 3, 4)),
  500. ((2, 3, 4, 5), (2, 2, 2, 2, 2, 2, 2)),
  501. ],
  502. )
  503. @pytest.mark.parametrize("is_varnode", [True])
  504. def test_tile(shape, reps, is_varnode):
  505. if is_varnode:
  506. network = Network()
  507. else:
  508. network = None
  509. def tile_func(inp):
  510. return F.tile(inp=inp, reps=reps)
  511. cases = [{"input": np.random.randn(*shape).astype("float32")}]
  512. opr_test(cases, tile_func, ref_fn=lambda inp: np.tile(inp, reps), network=network)

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台