You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_tensor.py 26 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911
  1. # -*- coding: utf-8 -*-
  2. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  3. #
  4. # Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  5. #
  6. # Unless required by applicable law or agreed to in writing,
  7. # software distributed under the License is distributed on an
  8. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  9. import os
  10. import platform
  11. import numpy as np
  12. import pytest
  13. from utils import get_var_value, make_tensor, opr_test
  14. import megengine.functional as F
  15. from megengine import tensor
  16. from megengine.core._trace_option import use_symbolic_shape
  17. from megengine.core.tensor import megbrain_graph as G
  18. from megengine.core.tensor.utils import astensor1d
  19. from megengine.jit import trace
  20. from megengine.utils.network import Network, set_symbolic_shape
  21. from megengine.utils.network_node import VarNode
  22. def test_eye():
  23. dtypes = [np.float32, np.bool]
  24. cases = [{"input": [10, 20]}, {"input": [30]}]
  25. for dtype in dtypes:
  26. for case in cases:
  27. np.testing.assert_allclose(
  28. F.eye(case["input"], dtype=dtype).numpy(),
  29. np.eye(*case["input"]).astype(dtype),
  30. )
  31. np.testing.assert_allclose(
  32. F.eye(*case["input"], dtype=dtype).numpy(),
  33. np.eye(*case["input"]).astype(dtype),
  34. )
  35. np.testing.assert_allclose(
  36. F.eye(tensor(case["input"]), dtype=dtype).numpy(),
  37. np.eye(*case["input"]).astype(dtype),
  38. )
  39. @pytest.mark.parametrize("is_varnode", [False, True])
  40. def test_diag(is_varnode):
  41. if is_varnode:
  42. network = Network()
  43. else:
  44. network = None
  45. shapes = [(10, 10), (6, 9), (8, 7), (8,)]
  46. cases = []
  47. for shp in shapes:
  48. cases.append({"input": [np.random.random(shp).astype("float32")]})
  49. for axis in range(-2, 3):
  50. def run(data):
  51. return F.diag(data, k=axis)
  52. opr_test(cases, run, ref_fn=lambda x: np.diag(x, axis), network=network)
  53. def test_full():
  54. shape = (2, 3)
  55. values = [True, 4, 5.0]
  56. for value in values:
  57. np.testing.assert_allclose(F.full(shape, value).numpy(), np.full(shape, value))
  58. assert F.full(shape, value).dtype == tensor(value).dtype
  59. @pytest.mark.parametrize("is_varnode", [True, False])
  60. def test_concat(is_varnode):
  61. if is_varnode:
  62. network = Network()
  63. else:
  64. network = None
  65. def get_data_shape(length: int):
  66. return (length, 2, 3)
  67. data1 = np.random.random(get_data_shape(5)).astype("float32")
  68. data2 = np.random.random(get_data_shape(6)).astype("float32")
  69. data3 = np.random.random(get_data_shape(7)).astype("float32")
  70. def run(data1, data2):
  71. return F.concat([data1, data2])
  72. cases = [{"input": [data1, data2]}, {"input": [data1, data3]}]
  73. opr_test(cases, run, ref_fn=lambda x, y: np.concatenate([x, y]), network=network)
  74. @pytest.mark.parametrize("is_varnode", [True, False])
  75. def test_condtake(is_varnode):
  76. if is_varnode:
  77. network = Network()
  78. else:
  79. network = None
  80. x = np.array([[1, 2, 3], [4, 5, 6]]).astype("float32")
  81. y = np.array([[True, False, True], [False, True, True]])
  82. xx = make_tensor(x, network)
  83. yy = make_tensor(y, network)
  84. val, idx = F.cond_take(yy, xx)
  85. if is_varnode:
  86. np.testing.assert_equal(get_var_value(val), x[y])
  87. np.testing.assert_equal(get_var_value(idx), np.where(y.reshape(-1))[0])
  88. else:
  89. np.testing.assert_equal(val.numpy(), x[y])
  90. np.testing.assert_equal(idx.numpy(), np.where(y.reshape(-1))[0])
  91. @pytest.mark.parametrize("is_varnode", [True, False])
  92. def test_concat_device(is_varnode):
  93. if is_varnode:
  94. network = Network()
  95. else:
  96. network = None
  97. data1 = make_tensor(np.random.random((3, 2, 2)).astype("float32"), network, "cpu0")
  98. data2 = make_tensor(np.random.random((2, 2, 2)).astype("float32"), network, "cpu1")
  99. out = F.concat([data1, data2], device="cpu0")
  100. assert str(out.device).split(":")[0] == "cpu0"
  101. @pytest.mark.parametrize("is_varnode", [True, False])
  102. def test_stack(is_varnode):
  103. if is_varnode:
  104. network = Network()
  105. else:
  106. network = None
  107. data1 = np.random.random((3, 2, 2)).astype("float32")
  108. data2 = np.random.random((3, 2, 2)).astype("float32")
  109. data3 = np.random.random((3, 2, 2)).astype("float32")
  110. cases = [{"input": [data1, data2]}, {"input": [data1, data3]}]
  111. for ai in range(3):
  112. def run(data1, data2):
  113. return F.stack([data1, data2], axis=ai)
  114. opr_test(
  115. cases, run, ref_fn=lambda x, y: np.stack([x, y], axis=ai), network=network
  116. )
  117. @pytest.mark.parametrize("is_varnode", [True, False])
  118. def test_split_basic(is_varnode):
  119. if is_varnode:
  120. network = Network()
  121. saved_symbolic_shape = set_symbolic_shape(False)
  122. else:
  123. network = None
  124. data = np.random.random((2, 3, 4, 5)).astype(np.float32)
  125. inp = make_tensor(data, network)
  126. mge_out0 = F.split(inp, 2, axis=3)
  127. mge_out1 = F.split(inp, [3], axis=3)
  128. np_out = np.split(data, [3, 5], axis=3)
  129. assert len(mge_out0) == 2
  130. assert len(mge_out1) == 2
  131. np.testing.assert_equal(mge_out0[0].numpy(), np_out[0])
  132. np.testing.assert_equal(mge_out1[0].numpy(), np_out[0])
  133. np.testing.assert_equal(mge_out0[1].numpy(), np_out[1])
  134. np.testing.assert_equal(mge_out1[1].numpy(), np_out[1])
  135. try:
  136. F.split(inp, 4)
  137. assert False
  138. except ValueError as e:
  139. pass
  140. try:
  141. F.split(inp, [3, 2, 5], axis=3)
  142. assert False
  143. except ValueError as e:
  144. assert str(e) == "Invalid nsplits_or_secions: [3, 2, 5]"
  145. if is_varnode:
  146. set_symbolic_shape(saved_symbolic_shape)
  147. @pytest.mark.parametrize("symbolic", [None, False, True])
  148. def test_split(symbolic):
  149. inp1 = np.random.random((3, 4, 5, 6)).astype(np.float32)
  150. inp2 = np.random.random((0, 4, 5, 6)).astype(np.float32)
  151. def ref(inp, nsplits_or_sections, axis):
  152. return np.split(inp, nsplits_or_sections, axis)
  153. def func(inp, nsplits_or_sections, axis):
  154. return F.split(inp, nsplits_or_sections, axis)
  155. cases = [
  156. (inp1, 2, 3),
  157. (inp1, [3], 3),
  158. (inp1, [3, 3, 5], 3),
  159. (inp2, 2, 3),
  160. (inp2, [3], 3),
  161. (inp2, [3, 3, 5], 3),
  162. ]
  163. for case in cases:
  164. if symbolic is None:
  165. fn = func
  166. else:
  167. fn = trace(symbolic=symbolic)(func)
  168. for i in range(3 if symbolic is not None else 1):
  169. ref_out = ref(*case)
  170. out = fn(tensor(case[0]), case[1], case[2])
  171. assert len(ref_out) == len(out)
  172. for idx in range(len(ref_out)):
  173. np.testing.assert_equal(ref_out[idx], out[idx].numpy())
  174. @pytest.mark.parametrize("is_varnode", [True, False])
  175. def test_reshape(is_varnode):
  176. if is_varnode:
  177. network = Network()
  178. else:
  179. network = None
  180. x = np.arange(6, dtype="float32")
  181. xx = make_tensor(x, network)
  182. y = x.reshape(1, 2, 3)
  183. for shape in [
  184. (1, 2, 3),
  185. (1, -1, 3),
  186. (1, make_tensor(-1, network), 3),
  187. np.array([1, -1, 3], dtype="int32"),
  188. make_tensor([1, -1, 3], network),
  189. ]:
  190. yy = F.reshape(xx, shape)
  191. np.testing.assert_equal(yy.numpy(), y)
  192. @pytest.mark.parametrize("is_varnode", [True, False])
  193. def test_broadcast_auto_infer(is_varnode):
  194. if is_varnode:
  195. network = Network()
  196. else:
  197. network = None
  198. x = np.random.random((1, 2, 3)).astype(np.float32)
  199. xx = make_tensor(x, network)
  200. for shape in [
  201. (1, 2, 3),
  202. (1, None, 3),
  203. ]:
  204. yy = F.broadcast_to(xx, shape)
  205. np.testing.assert_equal(yy.numpy(), x)
  206. with pytest.raises(ValueError):
  207. F.broadcast_to(xx, (1, -1, 3))
  208. with pytest.raises(ValueError):
  209. F.broadcast_to(xx, (None, 1, 2, 3))
  210. F.broadcast_to(xx, (1, None, 2, 3))
  211. t = tensor(2, dtype=np.int32)
  212. F.broadcast_to(xx, (t, None, 2, 3))
  213. @pytest.mark.parametrize("is_trace", [True, False])
  214. def test_reshape_on_empty_tensor(is_trace):
  215. input1_shape = (100, 0, 1)
  216. output1_shape = (100, 0, 10)
  217. data1 = tensor(np.random.random(input1_shape).astype(np.float32))
  218. input2_shape = (10, 0)
  219. output2_shape = (0,)
  220. data2 = tensor(np.random.random(input2_shape).astype(np.float32))
  221. input3_shape = (10, 0, 10)
  222. output3_shape = (0, 1, 2, 3)
  223. data3 = tensor(np.random.random(input3_shape).astype(np.float32))
  224. def comp(out, target_shp):
  225. assert out._tuple_shape == target_shp
  226. def func(x, shp):
  227. return F.reshape(x, shp)
  228. cases = [
  229. [data1, output1_shape],
  230. [data2, output2_shape],
  231. [data3, output3_shape],
  232. ]
  233. def test(func, inp, comp, target_shp):
  234. out = func(inp, target_shp)
  235. comp(out, target_shp)
  236. if is_trace:
  237. for symbolic in [False, True]:
  238. for inp, target_shp in cases:
  239. func_traced = trace(symbolic=symbolic)(func)
  240. test(func_traced, inp, comp, target_shp)
  241. test(func_traced, inp, comp, target_shp)
  242. test(func_traced, inp, comp, target_shp)
  243. else:
  244. for inp, target_shp in cases:
  245. test(func, inp, comp, target_shp)
  246. @pytest.mark.parametrize("is_varnode", [True, False])
  247. def test_reshape_shape_inference(is_varnode):
  248. if is_varnode:
  249. network = Network()
  250. saved_symbolic_shape = set_symbolic_shape(False)
  251. else:
  252. network = None
  253. x_shape_known = make_tensor([1, 2, 3, 4], network)
  254. x_shape_unknown = F.broadcast_to(
  255. make_tensor([1.0], network), shape=make_tensor([1, 1, 1, 1], network).sum()
  256. )
  257. tshp_unknown = astensor1d(
  258. (make_tensor([2], network), make_tensor([2], network)), x_shape_known
  259. )
  260. tshp_known = astensor1d((2, 2), x_shape_known)
  261. tshp_known_unspec = astensor1d((2, -1), x_shape_known)
  262. def check_shape(output, target):
  263. source = output.shape
  264. if isinstance(source, tensor):
  265. source = source.numpy()
  266. np.testing.assert_equal(source, target)
  267. def func(x, target_shape):
  268. return x.reshape(target_shape)
  269. cases = [
  270. {"input": [x_shape_known, tshp_unknown], "output": [(2, 2),]},
  271. {"input": [x_shape_unknown, tshp_unknown], "output": [(2, 2),]},
  272. {"input": [x_shape_known, tshp_known], "output": [(2, 2),]},
  273. {"input": [x_shape_known, tshp_known_unspec], "output": [(2, 2),]},
  274. {"input": [x_shape_unknown, tshp_known], "output": [(2, 2),]},
  275. {"input": [x_shape_unknown, tshp_known_unspec], "output": [(2, 2),]},
  276. ]
  277. opr_test(cases, func, compare_fn=check_shape, test_trace=True, network=network)
  278. if is_varnode:
  279. set_symbolic_shape(saved_symbolic_shape)
  280. @pytest.mark.parametrize("is_varnode", [True, False])
  281. def test_squeeze(is_varnode):
  282. if is_varnode:
  283. network = Network()
  284. saved_symbolic_shape = set_symbolic_shape(False)
  285. else:
  286. network = None
  287. x = np.arange(6, dtype="float32").reshape(1, 2, 3, 1)
  288. xx = make_tensor(x, network)
  289. for axis in [None, 3, -4, (3, -4)]:
  290. y = np.squeeze(x, axis)
  291. yy = F.squeeze(xx, axis)
  292. np.testing.assert_equal(y, yy.numpy())
  293. if is_varnode:
  294. set_symbolic_shape(saved_symbolic_shape)
  295. @pytest.mark.parametrize("is_varnode", [True, False])
  296. def test_expand_dims(is_varnode):
  297. if is_varnode:
  298. network = Network()
  299. else:
  300. network = None
  301. x = np.arange(6, dtype="float32").reshape(2, 3)
  302. xx = make_tensor(x, network)
  303. for axis in [2, -3, (3, -4), (1, -4)]:
  304. y = np.expand_dims(x, axis)
  305. yy = F.expand_dims(xx, axis)
  306. np.testing.assert_equal(y, yy.numpy())
  307. def test_expand_dims_for_scalar():
  308. x = np.array(1, dtype="float32")
  309. xx = make_tensor(x, None)
  310. for axis in [0, -1, (0, 1), (-1, -2), (0, -1)]:
  311. y = np.expand_dims(x, axis)
  312. yy = F.expand_dims(xx, axis)
  313. np.testing.assert_equal(y, yy.numpy())
  314. for axis in [1, -2, (1, 2), (-2, -3)]:
  315. np.testing.assert_raises(np.AxisError, np.expand_dims, x, axis)
  316. np.testing.assert_raises(RuntimeError, F.expand_dims, xx, axis)
  317. @pytest.mark.parametrize("is_varnode", [True, False])
  318. def test_elemwise_dtype_promotion(is_varnode):
  319. if is_varnode:
  320. network = Network()
  321. else:
  322. network = None
  323. x = np.random.rand(2, 3).astype("float32")
  324. y = np.random.rand(1, 3).astype("float16")
  325. xx = make_tensor(x, network)
  326. yy = make_tensor(y, network)
  327. z = xx * yy
  328. np.testing.assert_equal(z.numpy(), x * y)
  329. z = xx + y
  330. np.testing.assert_equal(z.numpy(), x + y)
  331. z = x - yy
  332. np.testing.assert_equal(z.numpy(), x - y)
  333. @pytest.mark.parametrize("is_varnode", [True, False])
  334. def test_linspace(is_varnode):
  335. if is_varnode:
  336. network = Network()
  337. else:
  338. network = None
  339. cases = [
  340. {"input": [1, 9, 9]},
  341. {"input": [3, 10, 8]},
  342. ]
  343. opr_test(
  344. cases,
  345. F.linspace,
  346. ref_fn=lambda start, end, step: np.linspace(start, end, step, dtype=np.float32),
  347. network=network,
  348. )
  349. cases = [
  350. {"input": [9, 1, 9]},
  351. {"input": [10, 3, 8]},
  352. ]
  353. opr_test(
  354. cases,
  355. F.linspace,
  356. ref_fn=lambda start, end, step: np.linspace(start, end, step, dtype=np.float32),
  357. network=network,
  358. )
  359. cases = [
  360. {"input": [1, make_tensor(9, network), 9]},
  361. {"input": [make_tensor(1, network), 9, make_tensor(9, network)]},
  362. ]
  363. opr_test(
  364. cases,
  365. F.linspace,
  366. ref_fn=lambda start, end, step: np.linspace(1, 9, 9, dtype=np.float32),
  367. network=network,
  368. )
  369. @pytest.mark.parametrize("is_varnode", [True, False])
  370. def test_arange(is_varnode):
  371. if is_varnode:
  372. network = Network()
  373. else:
  374. network = None
  375. cases = [
  376. {"input": [1, 9, 1]},
  377. {"input": [2, 10, 2]},
  378. ]
  379. opr_test(
  380. cases,
  381. F.arange,
  382. ref_fn=lambda start, end, step: np.arange(start, end, step, dtype=np.float32),
  383. network=network,
  384. )
  385. cases = [
  386. {"input": [9, 1, -1]},
  387. {"input": [10, 2, -2]},
  388. ]
  389. opr_test(
  390. cases,
  391. F.arange,
  392. ref_fn=lambda start, end, step: np.arange(start, end, step, dtype=np.float32),
  393. network=network,
  394. )
  395. cases = [
  396. {"input": [9.3, 1.2, -0.5]},
  397. {"input": [10.3, 2.1, -1.7]},
  398. ]
  399. opr_test(
  400. cases,
  401. F.arange,
  402. ref_fn=lambda start, end, step: np.arange(start, end, step, dtype=np.float32),
  403. network=network,
  404. )
  405. @pytest.mark.parametrize("is_varnode", [True, False])
  406. def test_round(is_varnode):
  407. if is_varnode:
  408. network = Network()
  409. else:
  410. network = None
  411. data1_shape = (15,)
  412. data2_shape = (25,)
  413. data1 = np.random.random(data1_shape).astype(np.float32)
  414. data2 = np.random.random(data2_shape).astype(np.float32)
  415. cases = [{"input": data1}, {"input": data2}]
  416. opr_test(cases, F.round, ref_fn=np.round, network=network)
  417. @pytest.mark.parametrize("is_varnode", [True, False])
  418. def test_flatten(is_varnode):
  419. if is_varnode:
  420. network = Network()
  421. else:
  422. network = None
  423. data0_shape = (2, 3, 4, 5)
  424. data1_shape = (4, 5, 6, 7)
  425. data0 = np.random.random(data0_shape).astype(np.float32)
  426. data1 = np.random.random(data1_shape).astype(np.float32)
  427. def compare_fn(x, y):
  428. assert x._tuple_shape[0] == y
  429. output0 = (2 * 3 * 4 * 5,)
  430. output1 = (4 * 5 * 6 * 7,)
  431. cases = [
  432. {"input": data0, "output": output0},
  433. {"input": data1, "output": output1},
  434. ]
  435. opr_test(cases, F.flatten, compare_fn=compare_fn, network=network)
  436. output0 = (2, 3 * 4 * 5)
  437. output1 = (4, 5 * 6 * 7)
  438. cases = [
  439. {"input": data0, "output": output0},
  440. {"input": data1, "output": output1},
  441. ]
  442. opr_test(cases, F.flatten, compare_fn=compare_fn, start_axis=1, network=network)
  443. output0 = (2, 3, 4 * 5)
  444. output1 = (4, 5, 6 * 7)
  445. cases = [
  446. {"input": data0, "output": output0},
  447. {"input": data1, "output": output1},
  448. ]
  449. opr_test(cases, F.flatten, compare_fn=compare_fn, start_axis=2, network=network)
  450. output0 = (2, 3 * 4, 5)
  451. output1 = (4, 5 * 6, 7)
  452. cases = [
  453. {"input": data0, "output": output0},
  454. {"input": data1, "output": output1},
  455. ]
  456. opr_test(
  457. cases,
  458. F.flatten,
  459. compare_fn=compare_fn,
  460. start_axis=1,
  461. end_axis=2,
  462. network=network,
  463. )
  464. @pytest.mark.parametrize("is_varnode", [True, False])
  465. def test_broadcast(is_varnode):
  466. if is_varnode:
  467. network = Network()
  468. else:
  469. network = None
  470. input1_shape = (20, 30)
  471. output1_shape = (30, 20, 30)
  472. data1 = np.random.random(input1_shape).astype(np.float32)
  473. input2_shape = (10, 1)
  474. output2_shape = (20, 10, 20)
  475. data2 = np.random.random(input2_shape).astype(np.float32)
  476. input3_shape = (10, 10)
  477. output3_shape = (10, 10)
  478. data3 = np.random.random(input3_shape).astype(np.float32)
  479. def compare_fn(x, y):
  480. assert x._tuple_shape[0] == y
  481. cases = [
  482. {"input": [data1, output1_shape], "output": output1_shape},
  483. {"input": [data2, output2_shape], "output": output2_shape},
  484. {"input": [data3, output3_shape], "output": output3_shape},
  485. ]
  486. opr_test(cases, F.broadcast_to, compare_fn=compare_fn, network=network)
  487. x = F.ones((2, 1, 3))
  488. with pytest.raises(RuntimeError):
  489. F.broadcast_to(x, (2, 3, 4))
  490. with pytest.raises(RuntimeError):
  491. F.broadcast_to(x, (4, 1, 3))
  492. with pytest.raises(RuntimeError):
  493. F.broadcast_to(x, (1, 3))
  494. @pytest.mark.parametrize("is_trace", [True, False])
  495. def test_broadcast_on_empty_tensor(is_trace):
  496. input1_shape = (100, 0, 1)
  497. output1_shape = (100, 0, 10)
  498. data1 = tensor(np.random.random(input1_shape).astype(np.float32))
  499. input2_shape = (10, 0)
  500. output2_shape = (10, 10, 0)
  501. data2 = tensor(np.random.random(input2_shape).astype(np.float32))
  502. input3_shape = (0, 0, 1, 10)
  503. output3_shape = (10, 0, 0, 10, 10)
  504. data3 = tensor(np.random.random(input3_shape).astype(np.float32))
  505. def comp(out, target_shp):
  506. assert out._tuple_shape == target_shp
  507. def func(x, shp):
  508. return F.broadcast_to(x, shp)
  509. cases = [
  510. [data1, output1_shape],
  511. [data2, output2_shape],
  512. [data3, output3_shape],
  513. ]
  514. def test(func, inp, comp, target_shp):
  515. out = func(inp, target_shp)
  516. comp(out, target_shp)
  517. if is_trace:
  518. for symbolic in [False, True]:
  519. for inp, target_shp in cases:
  520. func_traced = trace(symbolic=symbolic)(func)
  521. test(func_traced, inp, comp, target_shp)
  522. test(func_traced, inp, comp, target_shp)
  523. test(func_traced, inp, comp, target_shp)
  524. else:
  525. for inp, target_shp in cases:
  526. test(func, inp, comp, target_shp)
  527. @pytest.mark.parametrize("is_varnode", [True, False])
  528. def test_utils_astensor1d(is_varnode):
  529. if is_varnode:
  530. network = Network()
  531. else:
  532. network = None
  533. reference = make_tensor(0, network)
  534. # literal
  535. x = [1, 2, 3]
  536. for dtype in [None, "float32"]:
  537. xx = astensor1d(x, reference, dtype=dtype)
  538. assert isinstance(xx, type(reference))
  539. np.testing.assert_equal(xx.numpy(), x)
  540. # numpy array
  541. x = np.asarray([1, 2, 3], dtype="int32")
  542. for dtype in [None, "float32"]:
  543. xx = astensor1d(x, reference, dtype=dtype)
  544. assert isinstance(xx, type(reference))
  545. np.testing.assert_equal(xx.numpy(), x.astype(dtype) if dtype else x)
  546. # tensor
  547. x = make_tensor([1, 2, 3], network)
  548. for dtype in [None, "float32"]:
  549. xx = astensor1d(x, reference, dtype=dtype)
  550. assert isinstance(xx, type(reference))
  551. np.testing.assert_equal(xx.numpy(), x.numpy())
  552. # mixed
  553. x = [1, make_tensor(2, network), 3]
  554. for dtype in [None, "float32"]:
  555. xx = astensor1d(x, reference, dtype=dtype)
  556. assert isinstance(xx, type(reference))
  557. np.testing.assert_equal(xx.numpy(), [1, 2, 3])
  558. def test_device():
  559. x = tensor([1, 2, 3], dtype="float32")
  560. y1 = F.eye(x.shape, dtype="float32")
  561. y2 = F.eye(x.shape, dtype="float32", device=None)
  562. np.testing.assert_almost_equal(y1.numpy(), y2.numpy())
  563. y3 = F.eye(x.shape, dtype="float32", device="xpux")
  564. y4 = F.eye(x.shape, dtype="float32", device=x.device)
  565. np.testing.assert_almost_equal(y3.numpy(), y4.numpy())
  566. y5 = F.full((3, 2), 4, device=x.device)
  567. y6 = F.full((3, 2), 4, device="xpux")
  568. np.testing.assert_almost_equal(y5.numpy(), y6.numpy())
  569. @pytest.mark.parametrize("is_varnode", [True, False])
  570. def test_identity(is_varnode):
  571. if is_varnode:
  572. network = Network()
  573. else:
  574. network = None
  575. x = make_tensor(np.random.random((5, 10)).astype(np.float32), network)
  576. y = F.copy(x)
  577. np.testing.assert_equal(y.numpy(), x)
  578. def copy_test(dst, src, network):
  579. data = np.random.random((2, 3)).astype(np.float32)
  580. x = make_tensor(data, device=src, network=network)
  581. y = F.copy(x, dst)
  582. assert np.allclose(data, y.numpy())
  583. if network is None:
  584. z = x.to(dst)
  585. assert np.allclose(data, z.numpy())
  586. @pytest.mark.require_ngpu(1)
  587. @pytest.mark.parametrize("is_varnode", [True, False])
  588. def test_copy_h2d(is_varnode):
  589. if is_varnode:
  590. network = Network()
  591. else:
  592. network = None
  593. copy_test("cpu0", "gpu0", network=network)
  594. @pytest.mark.require_ngpu(1)
  595. @pytest.mark.parametrize("is_varnode", [True, False])
  596. def test_copy_d2h(is_varnode):
  597. if is_varnode:
  598. network = Network()
  599. else:
  600. network = None
  601. copy_test("gpu0", "cpu0", network=network)
  602. @pytest.mark.require_ngpu(2)
  603. @pytest.mark.parametrize("is_varnode", [True, False])
  604. def test_copy_d2d(is_varnode):
  605. if is_varnode:
  606. network = Network()
  607. else:
  608. network = None
  609. copy_test("gpu0", "gpu1", network=network)
  610. copy_test("gpu0:0", "gpu0:1", network=network)
  611. @pytest.mark.require_ngpu(2)
  612. @pytest.mark.parametrize(
  613. "shape, device_src, device_dst",
  614. [
  615. ((0,), "cpu0", "cpu0"),
  616. ((10, 0), "cpu0", "cpu1"),
  617. ((2, 0, 3), "cpu0", "gpu0"),
  618. ((1, 0, 1, 0), "gpu0", "cpu0"),
  619. ((2, 3, 4, 5, 0), "gpu0", "gpu1"),
  620. ],
  621. )
  622. @pytest.mark.parametrize("is_symbolic", [None, True, False])
  623. def test_copy_empty(shape, device_src, device_dst, is_symbolic):
  624. inp = tensor(np.random.randn(*shape).astype("float32"), device=device_src)
  625. def func(inp):
  626. return F.copy(inp, device_dst)
  627. if is_symbolic is not None:
  628. func = trace(symbolic=is_symbolic)(func)
  629. for _ in range(3):
  630. out = func(inp)
  631. assert out.numpy().shape == shape
  632. assert out.device == device_dst
  633. if is_symbolic is None:
  634. break
  635. @pytest.mark.parametrize(
  636. "shape, repeats, axis",
  637. [
  638. ((2,), 2, 0),
  639. ((2, 3, 4, 5), 3, 0),
  640. ((2, 3, 4, 5), 4, 3),
  641. ((2,), 2, None),
  642. ((2, 3, 4, 5), 3, None),
  643. ((), 1, None),
  644. ((), 10, None),
  645. ],
  646. )
  647. @pytest.mark.parametrize("is_varnode", [True, False])
  648. def test_repeat(shape, repeats, axis, is_varnode):
  649. if is_varnode:
  650. network = Network()
  651. else:
  652. network = None
  653. def repeat_func(inp):
  654. return F.repeat(inp=inp, repeats=repeats, axis=axis)
  655. if shape != ():
  656. cases = [
  657. {"input": np.random.randn(*shape).astype("float32")},
  658. ]
  659. else:
  660. cases = [{"input": np.array(1.23)}]
  661. opr_test(
  662. cases,
  663. repeat_func,
  664. ref_fn=lambda inp: np.repeat(inp, repeats, axis),
  665. network=network,
  666. )
  667. @pytest.mark.parametrize(
  668. "shape, reps",
  669. [
  670. ((2,), (2,)),
  671. ((2, 3, 4, 5), (1, 1, 1, 1)),
  672. ((2, 3, 4, 5), (1, 2, 3, 4)),
  673. ((2, 3, 4, 5), (2, 2, 2, 2, 2, 2, 2)),
  674. ],
  675. )
  676. @pytest.mark.parametrize("is_varnode", [True])
  677. def test_tile(shape, reps, is_varnode):
  678. if is_varnode:
  679. network = Network()
  680. else:
  681. network = None
  682. def tile_func(inp):
  683. return F.tile(inp=inp, reps=reps)
  684. cases = [{"input": np.random.randn(*shape).astype("float32")}]
  685. opr_test(cases, tile_func, ref_fn=lambda inp: np.tile(inp, reps), network=network)
  686. @pytest.mark.parametrize(
  687. "shape, shifts, axis",
  688. [
  689. ((2, 3), 0, None),
  690. ((2, 3), 1, 0),
  691. ((2, 3), 100, 0),
  692. ((2, 3), -100, 0),
  693. ((2, 3, 4, 5), (-1, 1), (0, 1)),
  694. ((2, 3, 4, 5), (-2, 1, 2), (1, 2, 3)),
  695. ],
  696. )
  697. @pytest.mark.parametrize("is_varnode", [True, False])
  698. def test_roll(shape, shifts, axis, is_varnode):
  699. if is_varnode:
  700. network = Network()
  701. else:
  702. network = None
  703. inp = np.random.randn(*shape).astype("float32")
  704. def func(inp):
  705. return F.roll(inp, shifts, axis)
  706. cases = [
  707. {"input": inp},
  708. ]
  709. opr_test(
  710. cases, func, ref_fn=lambda inp: np.roll(inp, shifts, axis), network=network
  711. )
  712. @pytest.mark.parametrize(
  713. "shape, shifts, axis", [((10, 0), 5, 1), ((10, 0), -10, 1),],
  714. )
  715. @pytest.mark.parametrize("is_symbolic", [None, True, False])
  716. def test_roll_empty_tensor(shape, shifts, axis, is_symbolic):
  717. inp = tensor(np.random.randn(*shape).astype("float32"))
  718. def func(inp):
  719. return F.roll(inp, shifts, axis)
  720. if is_symbolic is not None:
  721. func = trace(symbolic=is_symbolic)(func)
  722. out_ref = np.roll(inp.numpy(), shifts, axis)
  723. for _ in range(3):
  724. out = F.roll(inp, shifts, axis)
  725. np.testing.assert_equal(out.numpy(), out_ref)
  726. if is_symbolic is None:
  727. break