You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_tensor.py 26 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920
  1. # -*- coding: utf-8 -*-
  2. import os
  3. import platform
  4. import numpy as np
  5. import pytest
  6. from utils import get_var_value, make_tensor, opr_test
  7. import megengine.functional as F
  8. from megengine import tensor
  9. from megengine.core._trace_option import use_symbolic_shape
  10. from megengine.core.tensor import megbrain_graph as G
  11. from megengine.core.tensor.utils import astensor1d
  12. from megengine.jit import trace
  13. from megengine.utils.network import Network, set_symbolic_shape
  14. from megengine.utils.network_node import VarNode
  15. def test_eye():
  16. dtypes = [np.float32, np.bool]
  17. cases = [{"input": [10, 20]}, {"input": [30]}]
  18. for dtype in dtypes:
  19. for case in cases:
  20. np.testing.assert_allclose(
  21. F.eye(case["input"], dtype=dtype).numpy(),
  22. np.eye(*case["input"]).astype(dtype),
  23. )
  24. np.testing.assert_allclose(
  25. F.eye(*case["input"], dtype=dtype).numpy(),
  26. np.eye(*case["input"]).astype(dtype),
  27. )
  28. np.testing.assert_allclose(
  29. F.eye(tensor(case["input"]), dtype=dtype).numpy(),
  30. np.eye(*case["input"]).astype(dtype),
  31. )
  32. @pytest.mark.parametrize("is_varnode", [False, True])
  33. def test_diag(is_varnode):
  34. if is_varnode:
  35. network = Network()
  36. else:
  37. network = None
  38. shapes = [(10, 10), (6, 9), (8, 7), (8,)]
  39. cases = []
  40. for shp in shapes:
  41. cases.append({"input": [np.random.random(shp).astype("float32")]})
  42. for axis in range(-2, 3):
  43. def run(data):
  44. return F.diag(data, k=axis)
  45. opr_test(cases, run, ref_fn=lambda x: np.diag(x, axis), network=network)
  46. def test_full():
  47. shape = (2, 3)
  48. values = [True, 4, 5.0]
  49. for value in values:
  50. np.testing.assert_allclose(F.full(shape, value).numpy(), np.full(shape, value))
  51. assert F.full(shape, value).dtype == tensor(value).dtype
  52. @pytest.mark.parametrize("is_varnode", [True, False])
  53. def test_concat(is_varnode):
  54. if is_varnode:
  55. network = Network()
  56. else:
  57. network = None
  58. def get_data_shape(length: int):
  59. return (length, 2, 3)
  60. data1 = np.random.random(get_data_shape(5)).astype("float32")
  61. data2 = np.random.random(get_data_shape(6)).astype("float32")
  62. data3 = np.random.random(get_data_shape(7)).astype("float32")
  63. def run(data1, data2):
  64. return F.concat([data1, data2])
  65. cases = [{"input": [data1, data2]}, {"input": [data1, data3]}]
  66. opr_test(cases, run, ref_fn=lambda x, y: np.concatenate([x, y]), network=network)
  67. @pytest.mark.parametrize("is_varnode", [True, False])
  68. def test_condtake(is_varnode):
  69. if is_varnode:
  70. network = Network()
  71. else:
  72. network = None
  73. x = np.array([[1, 2, 3], [4, 5, 6]]).astype("float32")
  74. y = np.array([[True, False, True], [False, True, True]])
  75. xx = make_tensor(x, network)
  76. yy = make_tensor(y, network)
  77. val, idx = F.cond_take(yy, xx)
  78. if is_varnode:
  79. np.testing.assert_equal(get_var_value(val), x[y])
  80. np.testing.assert_equal(get_var_value(idx), np.where(y.reshape(-1))[0])
  81. else:
  82. np.testing.assert_equal(val.numpy(), x[y])
  83. np.testing.assert_equal(idx.numpy(), np.where(y.reshape(-1))[0])
  84. @pytest.mark.parametrize("is_varnode", [True, False])
  85. def test_concat_device(is_varnode):
  86. if is_varnode:
  87. network = Network()
  88. else:
  89. network = None
  90. data1 = make_tensor(np.random.random((3, 2, 2)).astype("float32"), network, "cpu0")
  91. data2 = make_tensor(np.random.random((2, 2, 2)).astype("float32"), network, "cpu1")
  92. out = F.concat([data1, data2], device="cpu0")
  93. assert str(out.device).split(":")[0] == "cpu0"
  94. @pytest.mark.parametrize("is_varnode", [True, False])
  95. def test_stack(is_varnode):
  96. if is_varnode:
  97. network = Network()
  98. else:
  99. network = None
  100. data1 = np.random.random((3, 2, 2)).astype("float32")
  101. data2 = np.random.random((3, 2, 2)).astype("float32")
  102. data3 = np.random.random((3, 2, 2)).astype("float32")
  103. cases = [{"input": [data1, data2]}, {"input": [data1, data3]}]
  104. for ai in range(3):
  105. def run(data1, data2):
  106. return F.stack([data1, data2], axis=ai)
  107. opr_test(
  108. cases, run, ref_fn=lambda x, y: np.stack([x, y], axis=ai), network=network
  109. )
  110. @pytest.mark.parametrize("is_varnode", [True, False])
  111. def test_split_basic(is_varnode):
  112. if is_varnode:
  113. network = Network()
  114. saved_symbolic_shape = set_symbolic_shape(False)
  115. else:
  116. network = None
  117. data = np.random.random((2, 3, 4, 5)).astype(np.float32)
  118. inp = make_tensor(data, network)
  119. mge_out0 = F.split(inp, 2, axis=3)
  120. mge_out1 = F.split(inp, [3], axis=3)
  121. np_out = np.split(data, [3, 5], axis=3)
  122. assert len(mge_out0) == 2
  123. assert len(mge_out1) == 2
  124. np.testing.assert_equal(mge_out0[0].numpy(), np_out[0])
  125. np.testing.assert_equal(mge_out1[0].numpy(), np_out[0])
  126. np.testing.assert_equal(mge_out0[1].numpy(), np_out[1])
  127. np.testing.assert_equal(mge_out1[1].numpy(), np_out[1])
  128. try:
  129. F.split(inp, 4)
  130. assert False
  131. except ValueError as e:
  132. pass
  133. try:
  134. F.split(inp, [3, 2, 5], axis=3)
  135. assert False
  136. except ValueError as e:
  137. assert str(e) == "Invalid nsplits_or_secions: [3, 2, 5]"
  138. if is_varnode:
  139. set_symbolic_shape(saved_symbolic_shape)
  140. @pytest.mark.parametrize("symbolic", [None, False, True])
  141. def test_split(symbolic):
  142. inp1 = np.random.random((3, 4, 5, 6)).astype(np.float32)
  143. inp2 = np.random.random((0, 4, 5, 6)).astype(np.float32)
  144. def ref(inp, nsplits_or_sections, axis):
  145. return np.split(inp, nsplits_or_sections, axis)
  146. def func(inp, nsplits_or_sections, axis):
  147. return F.split(inp, nsplits_or_sections, axis)
  148. cases = [
  149. (inp1, 2, 3),
  150. (inp1, [3], 3),
  151. (inp1, [3, 3, 5], 3),
  152. (inp2, 2, 3),
  153. (inp2, [3], 3),
  154. (inp2, [3, 3, 5], 3),
  155. ]
  156. for case in cases:
  157. if symbolic is None:
  158. fn = func
  159. else:
  160. fn = trace(symbolic=symbolic)(func)
  161. for i in range(3 if symbolic is not None else 1):
  162. ref_out = ref(*case)
  163. out = fn(tensor(case[0]), case[1], case[2])
  164. assert len(ref_out) == len(out)
  165. for idx in range(len(ref_out)):
  166. np.testing.assert_equal(ref_out[idx], out[idx].numpy())
  167. @pytest.mark.parametrize("is_varnode", [True, False])
  168. def test_swapaxes(is_varnode):
  169. if is_varnode:
  170. network = Network()
  171. else:
  172. network = None
  173. x = tensor(np.array([[1, 2, 3]], dtype=np.int32))
  174. y = F.swapaxes(x, 0, 1)
  175. np.testing.assert_equal(y.numpy(), np.array([[1], [2], [3]]).astype(np.int32))
  176. @pytest.mark.parametrize("is_varnode", [True, False])
  177. def test_reshape(is_varnode):
  178. if is_varnode:
  179. network = Network()
  180. else:
  181. network = None
  182. x = np.arange(6, dtype="float32")
  183. xx = make_tensor(x, network)
  184. y = x.reshape(1, 2, 3)
  185. for shape in [
  186. (1, 2, 3),
  187. (1, -1, 3),
  188. (1, make_tensor(-1, network), 3),
  189. np.array([1, -1, 3], dtype="int32"),
  190. make_tensor([1, -1, 3], network),
  191. ]:
  192. yy = F.reshape(xx, shape)
  193. np.testing.assert_equal(yy.numpy(), y)
  194. @pytest.mark.parametrize("is_varnode", [True, False])
  195. def test_broadcast_auto_infer(is_varnode):
  196. if is_varnode:
  197. network = Network()
  198. else:
  199. network = None
  200. x = np.random.random((1, 2, 3)).astype(np.float32)
  201. xx = make_tensor(x, network)
  202. for shape in [
  203. (1, 2, 3),
  204. (1, None, 3),
  205. ]:
  206. yy = F.broadcast_to(xx, shape)
  207. np.testing.assert_equal(yy.numpy(), x)
  208. with pytest.raises(ValueError):
  209. F.broadcast_to(xx, (1, -1, 3))
  210. with pytest.raises(ValueError):
  211. F.broadcast_to(xx, (None, 1, 2, 3))
  212. F.broadcast_to(xx, (1, None, 2, 3))
  213. t = make_tensor(2, network)
  214. F.broadcast_to(xx, (t, None, 2, 3))
  215. @pytest.mark.parametrize("is_trace", [True, False])
  216. def test_reshape_on_empty_tensor(is_trace):
  217. input1_shape = (100, 0, 1)
  218. output1_shape = (100, 0, 10)
  219. data1 = tensor(np.random.random(input1_shape).astype(np.float32))
  220. input2_shape = (10, 0)
  221. output2_shape = (0,)
  222. data2 = tensor(np.random.random(input2_shape).astype(np.float32))
  223. input3_shape = (10, 0, 10)
  224. output3_shape = (0, 1, 2, 3)
  225. data3 = tensor(np.random.random(input3_shape).astype(np.float32))
  226. def comp(out, target_shp):
  227. assert out._tuple_shape == target_shp
  228. def func(x, shp):
  229. return F.reshape(x, shp)
  230. cases = [
  231. [data1, output1_shape],
  232. [data2, output2_shape],
  233. [data3, output3_shape],
  234. ]
  235. def test(func, inp, comp, target_shp):
  236. out = func(inp, target_shp)
  237. comp(out, target_shp)
  238. if is_trace:
  239. for symbolic in [False, True]:
  240. for inp, target_shp in cases:
  241. func_traced = trace(symbolic=symbolic)(func)
  242. test(func_traced, inp, comp, target_shp)
  243. test(func_traced, inp, comp, target_shp)
  244. test(func_traced, inp, comp, target_shp)
  245. else:
  246. for inp, target_shp in cases:
  247. test(func, inp, comp, target_shp)
  248. @pytest.mark.parametrize("is_varnode", [True, False])
  249. def test_reshape_shape_inference(is_varnode):
  250. if is_varnode:
  251. network = Network()
  252. saved_symbolic_shape = set_symbolic_shape(False)
  253. else:
  254. network = None
  255. x_shape_known = make_tensor([1, 2, 3, 4], network)
  256. x_shape_unknown = F.broadcast_to(
  257. make_tensor([1.0], network), shape=make_tensor([1, 1, 1, 1], network).sum()
  258. )
  259. tshp_unknown = astensor1d(
  260. (make_tensor([2], network), make_tensor([2], network)), x_shape_known
  261. )
  262. tshp_known = astensor1d((2, 2), x_shape_known)
  263. tshp_known_unspec = astensor1d((2, -1), x_shape_known)
  264. def check_shape(output, target):
  265. source = output.shape
  266. if isinstance(source, tensor):
  267. source = source.numpy()
  268. np.testing.assert_equal(source, target.shape)
  269. def func(x, target_shape):
  270. return x.reshape(target_shape)
  271. cases = [
  272. {"input": [x_shape_known, tshp_unknown], "output": [np.zeros((2, 2)),]},
  273. {"input": [x_shape_unknown, tshp_unknown], "output": [np.zeros((2, 2)),]},
  274. {"input": [x_shape_known, tshp_known], "output": [np.zeros((2, 2)),]},
  275. {"input": [x_shape_known, tshp_known_unspec], "output": [np.zeros((2, 2)),]},
  276. {"input": [x_shape_unknown, tshp_known], "output": [np.zeros((2, 2)),]},
  277. {"input": [x_shape_unknown, tshp_known_unspec], "output": [np.zeros((2, 2)),]},
  278. ]
  279. opr_test(cases, func, compare_fn=check_shape, test_trace=True, network=network)
  280. if is_varnode:
  281. set_symbolic_shape(saved_symbolic_shape)
  282. @pytest.mark.parametrize("is_varnode", [True, False])
  283. def test_squeeze(is_varnode):
  284. if is_varnode:
  285. network = Network()
  286. saved_symbolic_shape = set_symbolic_shape(False)
  287. else:
  288. network = None
  289. x = np.arange(6, dtype="float32").reshape(1, 2, 3, 1)
  290. xx = make_tensor(x, network)
  291. for axis in [None, 3, -4, (3, -4)]:
  292. y = np.squeeze(x, axis)
  293. yy = F.squeeze(xx, axis)
  294. np.testing.assert_equal(y, yy.numpy())
  295. if is_varnode:
  296. set_symbolic_shape(saved_symbolic_shape)
  297. @pytest.mark.parametrize("is_varnode", [True, False])
  298. def test_expand_dims(is_varnode):
  299. if is_varnode:
  300. network = Network()
  301. else:
  302. network = None
  303. x = np.arange(6, dtype="float32").reshape(2, 3)
  304. xx = make_tensor(x, network)
  305. for axis in [2, -3, (3, -4), (1, -4)]:
  306. y = np.expand_dims(x, axis)
  307. yy = F.expand_dims(xx, axis)
  308. np.testing.assert_equal(y, yy.numpy())
  309. def test_expand_dims_for_scalar():
  310. x = np.array(1, dtype="float32")
  311. xx = make_tensor(x, None)
  312. for axis in [0, -1, (0, 1), (-1, -2), (0, -1)]:
  313. y = np.expand_dims(x, axis)
  314. yy = F.expand_dims(xx, axis)
  315. np.testing.assert_equal(y, yy.numpy())
  316. for axis in [1, -2, (1, 2), (-2, -3)]:
  317. np.testing.assert_raises(np.AxisError, np.expand_dims, x, axis)
  318. np.testing.assert_raises(RuntimeError, F.expand_dims, xx, axis)
  319. @pytest.mark.parametrize("is_varnode", [True, False])
  320. def test_elemwise_dtype_promotion(is_varnode):
  321. if is_varnode:
  322. network = Network()
  323. else:
  324. network = None
  325. x = np.random.rand(2, 3).astype("float32")
  326. y = np.random.rand(1, 3).astype("float16")
  327. xx = make_tensor(x, network)
  328. yy = make_tensor(y, network)
  329. z = xx * yy
  330. np.testing.assert_equal(z.numpy(), x * y)
  331. z = xx + y
  332. np.testing.assert_equal(z.numpy(), x + y)
  333. z = x - yy
  334. np.testing.assert_equal(z.numpy(), x - y)
  335. @pytest.mark.parametrize("is_varnode", [True, False])
  336. def test_linspace(is_varnode):
  337. if is_varnode:
  338. network = Network()
  339. else:
  340. network = None
  341. cases = [
  342. {"input": [1, 9, 9]},
  343. {"input": [3, 10, 8]},
  344. ]
  345. opr_test(
  346. cases,
  347. F.linspace,
  348. ref_fn=lambda start, end, step: np.linspace(start, end, step, dtype=np.float32),
  349. network=network,
  350. )
  351. cases = [
  352. {"input": [9, 1, 9]},
  353. {"input": [10, 3, 8]},
  354. ]
  355. opr_test(
  356. cases,
  357. F.linspace,
  358. ref_fn=lambda start, end, step: np.linspace(start, end, step, dtype=np.float32),
  359. network=network,
  360. )
  361. cases = [
  362. {"input": [1, make_tensor(9, network), 9]},
  363. {"input": [make_tensor(1, network), 9, make_tensor(9, network)]},
  364. ]
  365. opr_test(
  366. cases,
  367. F.linspace,
  368. ref_fn=lambda start, end, step: np.linspace(1, 9, 9, dtype=np.float32),
  369. network=network,
  370. )
  371. @pytest.mark.parametrize("is_varnode", [True, False])
  372. def test_arange(is_varnode):
  373. if is_varnode:
  374. network = Network()
  375. else:
  376. network = None
  377. cases = [
  378. {"input": [1, 9, 1]},
  379. {"input": [2, 10, 2]},
  380. ]
  381. opr_test(
  382. cases,
  383. F.arange,
  384. ref_fn=lambda start, end, step: np.arange(start, end, step, dtype=np.float32),
  385. network=network,
  386. )
  387. cases = [
  388. {"input": [9, 1, -1]},
  389. {"input": [10, 2, -2]},
  390. ]
  391. opr_test(
  392. cases,
  393. F.arange,
  394. ref_fn=lambda start, end, step: np.arange(start, end, step, dtype=np.float32),
  395. network=network,
  396. )
  397. cases = [
  398. {"input": [9.3, 1.2, -0.5]},
  399. {"input": [10.3, 2.1, -1.7]},
  400. ]
  401. opr_test(
  402. cases,
  403. F.arange,
  404. ref_fn=lambda start, end, step: np.arange(start, end, step, dtype=np.float32),
  405. network=network,
  406. )
  407. @pytest.mark.parametrize("is_varnode", [True, False])
  408. def test_round(is_varnode):
  409. if is_varnode:
  410. network = Network()
  411. else:
  412. network = None
  413. data1_shape = (15,)
  414. data2_shape = (25,)
  415. data1 = np.random.random(data1_shape).astype(np.float32)
  416. data2 = np.random.random(data2_shape).astype(np.float32)
  417. cases = [{"input": data1}, {"input": data2}]
  418. opr_test(cases, F.round, ref_fn=np.round, network=network)
  419. @pytest.mark.parametrize("is_varnode", [True, False])
  420. def test_flatten(is_varnode):
  421. if is_varnode:
  422. network = Network()
  423. else:
  424. network = None
  425. data0_shape = (2, 3, 4, 5)
  426. data1_shape = (4, 5, 6, 7)
  427. data0 = np.random.random(data0_shape).astype(np.float32)
  428. data1 = np.random.random(data1_shape).astype(np.float32)
  429. cases = [
  430. {"input": data0, "output": data0.flatten()},
  431. {"input": data1, "output": data1.flatten()},
  432. ]
  433. opr_test(cases, F.flatten, network=network)
  434. cases = [
  435. {"input": data0, "output": data0.reshape(2, -1)},
  436. {"input": data1, "output": data1.reshape(4, -1)},
  437. ]
  438. opr_test(cases, F.flatten, start_axis=1, network=network)
  439. cases = [
  440. {"input": data0, "output": data0.reshape(2, 3, -1)},
  441. {"input": data1, "output": data1.reshape(4, 5, -1)},
  442. ]
  443. opr_test(cases, F.flatten, start_axis=2, network=network)
  444. cases = [
  445. {"input": data0, "output": data0.reshape(2, -1, 5)},
  446. {"input": data1, "output": data1.reshape(4, -1, 7)},
  447. ]
  448. opr_test(
  449. cases, F.flatten, start_axis=1, end_axis=2, network=network,
  450. )
  451. @pytest.mark.parametrize("is_varnode", [True, False])
  452. def test_broadcast(is_varnode):
  453. if is_varnode:
  454. network = Network()
  455. else:
  456. network = None
  457. input1_shape = (20, 30)
  458. output1_shape = (30, 20, 30)
  459. data1 = np.random.random(input1_shape).astype(np.float32)
  460. input2_shape = (10, 1)
  461. output2_shape = (20, 10, 20)
  462. data2 = np.random.random(input2_shape).astype(np.float32)
  463. input3_shape = (10, 10)
  464. output3_shape = (10, 10)
  465. data3 = np.random.random(input3_shape).astype(np.float32)
  466. cases = [
  467. {
  468. "input": [data1, output1_shape],
  469. "output": np.broadcast_to(data1, output1_shape),
  470. },
  471. {
  472. "input": [data2, output2_shape],
  473. "output": np.broadcast_to(data2, output2_shape),
  474. },
  475. {
  476. "input": [data3, output3_shape],
  477. "output": np.broadcast_to(data3, output3_shape),
  478. },
  479. ]
  480. opr_test(cases, F.broadcast_to, network=network)
  481. x = F.ones((2, 1, 3))
  482. with pytest.raises(RuntimeError):
  483. F.broadcast_to(x, (2, 3, 4))
  484. with pytest.raises(RuntimeError):
  485. F.broadcast_to(x, (4, 1, 3))
  486. with pytest.raises(RuntimeError):
  487. F.broadcast_to(x, (1, 3))
  488. @pytest.mark.parametrize("is_trace", [True, False])
  489. def test_broadcast_on_empty_tensor(is_trace):
  490. input1_shape = (100, 0, 1)
  491. output1_shape = (100, 0, 10)
  492. data1 = tensor(np.random.random(input1_shape).astype(np.float32))
  493. input2_shape = (10, 0)
  494. output2_shape = (10, 10, 0)
  495. data2 = tensor(np.random.random(input2_shape).astype(np.float32))
  496. input3_shape = (0, 0, 1, 10)
  497. output3_shape = (10, 0, 0, 10, 10)
  498. data3 = tensor(np.random.random(input3_shape).astype(np.float32))
  499. def comp(out, target_shp):
  500. assert out._tuple_shape == target_shp
  501. def func(x, shp):
  502. return F.broadcast_to(x, shp)
  503. cases = [
  504. [data1, output1_shape],
  505. [data2, output2_shape],
  506. [data3, output3_shape],
  507. ]
  508. def test(func, inp, comp, target_shp):
  509. out = func(inp, target_shp)
  510. comp(out, target_shp)
  511. if is_trace:
  512. for symbolic in [False, True]:
  513. for inp, target_shp in cases:
  514. func_traced = trace(symbolic=symbolic)(func)
  515. test(func_traced, inp, comp, target_shp)
  516. test(func_traced, inp, comp, target_shp)
  517. test(func_traced, inp, comp, target_shp)
  518. else:
  519. for inp, target_shp in cases:
  520. test(func, inp, comp, target_shp)
  521. @pytest.mark.parametrize("is_varnode", [True, False])
  522. def test_utils_astensor1d(is_varnode):
  523. if is_varnode:
  524. network = Network()
  525. else:
  526. network = None
  527. reference = make_tensor(0, network)
  528. # literal
  529. x = [1, 2, 3]
  530. for dtype in [None, "float32"]:
  531. xx = astensor1d(x, reference, dtype=dtype)
  532. assert isinstance(xx, type(reference))
  533. np.testing.assert_equal(xx.numpy(), x)
  534. # numpy array
  535. x = np.asarray([1, 2, 3], dtype="int32")
  536. for dtype in [None, "float32"]:
  537. xx = astensor1d(x, reference, dtype=dtype)
  538. assert isinstance(xx, type(reference))
  539. np.testing.assert_equal(xx.numpy(), x.astype(dtype) if dtype else x)
  540. # tensor
  541. x = make_tensor([1, 2, 3], network)
  542. for dtype in [None, "float32"]:
  543. xx = astensor1d(x, reference, dtype=dtype)
  544. assert isinstance(xx, type(reference))
  545. np.testing.assert_equal(xx.numpy(), x.numpy())
  546. # mixed
  547. x = [1, make_tensor(2, network), 3]
  548. for dtype in [None, "float32"]:
  549. xx = astensor1d(x, reference, dtype=dtype)
  550. assert isinstance(xx, type(reference))
  551. np.testing.assert_equal(xx.numpy(), [1, 2, 3])
  552. # varnode
  553. if is_varnode:
  554. a = np.array([[1, 2, 3], [4, 5, 6]]).astype("float32")
  555. b = np.array([[True, False, True], [False, True, True]])
  556. aa = make_tensor(a, network)
  557. bb = make_tensor(b, network)
  558. x, y = F.cond_take(bb, aa)
  559. for dtype in [None, "float32"]:
  560. xx = astensor1d(x, reference, dtype=dtype)
  561. assert isinstance(xx, type(reference))
  562. np.testing.assert_equal(get_var_value(xx), get_var_value(x))
  563. def test_device():
  564. x = tensor([1, 2, 3], dtype="float32")
  565. y1 = F.eye(x.shape, dtype="float32")
  566. y2 = F.eye(x.shape, dtype="float32", device=None)
  567. np.testing.assert_almost_equal(y1.numpy(), y2.numpy())
  568. y3 = F.eye(x.shape, dtype="float32", device="xpux")
  569. y4 = F.eye(x.shape, dtype="float32", device=x.device)
  570. np.testing.assert_almost_equal(y3.numpy(), y4.numpy())
  571. y5 = F.full((3, 2), 4, device=x.device)
  572. y6 = F.full((3, 2), 4, device="xpux")
  573. np.testing.assert_almost_equal(y5.numpy(), y6.numpy())
  574. @pytest.mark.parametrize("is_varnode", [True, False])
  575. def test_identity(is_varnode):
  576. if is_varnode:
  577. network = Network()
  578. else:
  579. network = None
  580. x = make_tensor(np.random.random((5, 10)).astype(np.float32), network)
  581. y = F.copy(x)
  582. np.testing.assert_equal(y.numpy(), x)
  583. def copy_test(dst, src, network):
  584. data = np.random.random((2, 3)).astype(np.float32)
  585. x = make_tensor(data, device=src, network=network)
  586. y = F.copy(x, dst)
  587. assert np.allclose(data, y.numpy())
  588. if network is None:
  589. z = x.to(dst)
  590. assert np.allclose(data, z.numpy())
  591. @pytest.mark.require_ngpu(1)
  592. @pytest.mark.parametrize("is_varnode", [True, False])
  593. def test_copy_h2d(is_varnode):
  594. if is_varnode:
  595. network = Network()
  596. else:
  597. network = None
  598. copy_test("cpu0", "gpu0", network=network)
  599. @pytest.mark.require_ngpu(1)
  600. @pytest.mark.parametrize("is_varnode", [True, False])
  601. def test_copy_d2h(is_varnode):
  602. if is_varnode:
  603. network = Network()
  604. else:
  605. network = None
  606. copy_test("gpu0", "cpu0", network=network)
  607. @pytest.mark.require_ngpu(2)
  608. @pytest.mark.parametrize("is_varnode", [True, False])
  609. def test_copy_d2d(is_varnode):
  610. if is_varnode:
  611. network = Network()
  612. else:
  613. network = None
  614. copy_test("gpu0", "gpu1", network=network)
  615. copy_test("gpu0:0", "gpu0:1", network=network)
  616. @pytest.mark.require_ngpu(2)
  617. @pytest.mark.parametrize(
  618. "shape, device_src, device_dst",
  619. [
  620. ((0,), "cpu0", "cpu0"),
  621. ((10, 0), "cpu0", "cpu1"),
  622. ((2, 0, 3), "cpu0", "gpu0"),
  623. ((1, 0, 1, 0), "gpu0", "cpu0"),
  624. ((2, 3, 4, 5, 0), "gpu0", "gpu1"),
  625. ],
  626. )
  627. @pytest.mark.parametrize("is_symbolic", [None, True, False])
  628. def test_copy_empty(shape, device_src, device_dst, is_symbolic):
  629. inp = tensor(np.random.randn(*shape).astype("float32"), device=device_src)
  630. def func(inp):
  631. return F.copy(inp, device_dst)
  632. if is_symbolic is not None:
  633. func = trace(symbolic=is_symbolic)(func)
  634. for _ in range(3):
  635. out = func(inp)
  636. assert out.numpy().shape == shape
  637. assert out.device == device_dst
  638. if is_symbolic is None:
  639. break
  640. @pytest.mark.parametrize(
  641. "shape, repeats, axis",
  642. [
  643. ((2,), 2, 0),
  644. ((2, 3, 4, 5), 3, 0),
  645. ((2, 3, 4, 5), 4, 3),
  646. ((2,), 2, None),
  647. ((2, 3, 4, 5), 3, None),
  648. ((), 1, None),
  649. ((), 10, None),
  650. ],
  651. )
  652. @pytest.mark.parametrize("is_varnode", [True, False])
  653. def test_repeat(shape, repeats, axis, is_varnode):
  654. if is_varnode:
  655. network = Network()
  656. else:
  657. network = None
  658. def repeat_func(inp):
  659. return F.repeat(inp=inp, repeats=repeats, axis=axis)
  660. if shape != ():
  661. cases = [
  662. {"input": np.random.randn(*shape).astype("float32")},
  663. ]
  664. else:
  665. cases = [{"input": np.array(1.23)}]
  666. opr_test(
  667. cases,
  668. repeat_func,
  669. ref_fn=lambda inp: np.repeat(inp, repeats, axis),
  670. network=network,
  671. )
  672. @pytest.mark.parametrize(
  673. "shape, reps",
  674. [
  675. ((2,), (2,)),
  676. ((2, 3, 4, 5), (1, 1, 1, 1)),
  677. ((2, 3, 4, 5), (1, 2, 3, 4)),
  678. # FIXME: tile does not support ndim 7
  679. # ((2, 3, 4, 5), (2, 2, 2, 2, 2, 2, 2)),
  680. ],
  681. )
  682. @pytest.mark.parametrize("is_varnode", [True])
  683. def test_tile(shape, reps, is_varnode):
  684. if is_varnode:
  685. network = Network()
  686. else:
  687. network = None
  688. def tile_func(inp):
  689. return F.tile(inp=inp, reps=reps)
  690. cases = [{"input": np.random.randn(*shape).astype("float32")}]
  691. opr_test(cases, tile_func, ref_fn=lambda inp: np.tile(inp, reps), network=network)
  692. @pytest.mark.parametrize(
  693. "shape, shifts, axis",
  694. [
  695. ((2, 3), 0, None),
  696. ((2, 3), 1, 0),
  697. ((2, 3), 100, 0),
  698. ((2, 3), -100, 0),
  699. ((2, 3, 4, 5), (-1, 1), (0, 1)),
  700. ((2, 3, 4, 5), (-2, 1, 2), (1, 2, 3)),
  701. ],
  702. )
  703. @pytest.mark.parametrize("is_varnode", [True, False])
  704. def test_roll(shape, shifts, axis, is_varnode):
  705. if is_varnode:
  706. network = Network()
  707. else:
  708. network = None
  709. inp = np.random.randn(*shape).astype("float32")
  710. def func(inp):
  711. return F.roll(inp, shifts, axis)
  712. cases = [
  713. {"input": inp},
  714. ]
  715. opr_test(
  716. cases, func, ref_fn=lambda inp: np.roll(inp, shifts, axis), network=network
  717. )
  718. @pytest.mark.parametrize(
  719. "shape, shifts, axis", [((10, 0), 5, 1), ((10, 0), -10, 1),],
  720. )
  721. @pytest.mark.parametrize("is_symbolic", [None, True, False])
  722. def test_roll_empty_tensor(shape, shifts, axis, is_symbolic):
  723. inp = tensor(np.random.randn(*shape).astype("float32"))
  724. def func(inp):
  725. return F.roll(inp, shifts, axis)
  726. if is_symbolic is not None:
  727. func = trace(symbolic=is_symbolic)(func)
  728. out_ref = np.roll(inp.numpy(), shifts, axis)
  729. for _ in range(3):
  730. out = F.roll(inp, shifts, axis)
  731. np.testing.assert_equal(out.numpy(), out_ref)
  732. if is_symbolic is None:
  733. break