You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_tensor.py 26 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908
  1. # -*- coding: utf-8 -*-
  2. import os
  3. import platform
  4. import numpy as np
  5. import pytest
  6. from utils import get_var_value, make_tensor, opr_test
  7. import megengine.functional as F
  8. from megengine import tensor
  9. from megengine.core._trace_option import use_symbolic_shape
  10. from megengine.core.tensor import megbrain_graph as G
  11. from megengine.core.tensor.utils import astensor1d
  12. from megengine.jit import trace
  13. from megengine.utils.network import Network, set_symbolic_shape
  14. from megengine.utils.network_node import VarNode
  15. def test_eye():
  16. dtypes = [np.float32, np.bool]
  17. cases = [{"input": [10, 20]}, {"input": [30]}]
  18. for dtype in dtypes:
  19. for case in cases:
  20. np.testing.assert_allclose(
  21. F.eye(case["input"], dtype=dtype).numpy(),
  22. np.eye(*case["input"]).astype(dtype),
  23. )
  24. np.testing.assert_allclose(
  25. F.eye(*case["input"], dtype=dtype).numpy(),
  26. np.eye(*case["input"]).astype(dtype),
  27. )
  28. np.testing.assert_allclose(
  29. F.eye(tensor(case["input"]), dtype=dtype).numpy(),
  30. np.eye(*case["input"]).astype(dtype),
  31. )
  32. @pytest.mark.parametrize("is_varnode", [False, True])
  33. def test_diag(is_varnode):
  34. if is_varnode:
  35. network = Network()
  36. else:
  37. network = None
  38. shapes = [(10, 10), (6, 9), (8, 7), (8,)]
  39. cases = []
  40. for shp in shapes:
  41. cases.append({"input": [np.random.random(shp).astype("float32")]})
  42. for axis in range(-2, 3):
  43. def run(data):
  44. return F.diag(data, k=axis)
  45. opr_test(cases, run, ref_fn=lambda x: np.diag(x, axis), network=network)
  46. def test_full():
  47. shape = (2, 3)
  48. values = [True, 4, 5.0]
  49. for value in values:
  50. np.testing.assert_allclose(F.full(shape, value).numpy(), np.full(shape, value))
  51. assert F.full(shape, value).dtype == tensor(value).dtype
  52. @pytest.mark.parametrize("is_varnode", [True, False])
  53. def test_concat(is_varnode):
  54. if is_varnode:
  55. network = Network()
  56. else:
  57. network = None
  58. def get_data_shape(length: int):
  59. return (length, 2, 3)
  60. data1 = np.random.random(get_data_shape(5)).astype("float32")
  61. data2 = np.random.random(get_data_shape(6)).astype("float32")
  62. data3 = np.random.random(get_data_shape(7)).astype("float32")
  63. def run(data1, data2):
  64. return F.concat([data1, data2])
  65. cases = [{"input": [data1, data2]}, {"input": [data1, data3]}]
  66. opr_test(cases, run, ref_fn=lambda x, y: np.concatenate([x, y]), network=network)
  67. @pytest.mark.parametrize("is_varnode", [True, False])
  68. def test_condtake(is_varnode):
  69. if is_varnode:
  70. network = Network()
  71. else:
  72. network = None
  73. x = np.array([[1, 2, 3], [4, 5, 6]]).astype("float32")
  74. y = np.array([[True, False, True], [False, True, True]])
  75. xx = make_tensor(x, network)
  76. yy = make_tensor(y, network)
  77. val, idx = F.cond_take(yy, xx)
  78. if is_varnode:
  79. np.testing.assert_equal(get_var_value(val), x[y])
  80. np.testing.assert_equal(get_var_value(idx), np.where(y.reshape(-1))[0])
  81. else:
  82. np.testing.assert_equal(val.numpy(), x[y])
  83. np.testing.assert_equal(idx.numpy(), np.where(y.reshape(-1))[0])
  84. @pytest.mark.parametrize("is_varnode", [True, False])
  85. def test_concat_device(is_varnode):
  86. if is_varnode:
  87. network = Network()
  88. else:
  89. network = None
  90. data1 = make_tensor(np.random.random((3, 2, 2)).astype("float32"), network, "cpu0")
  91. data2 = make_tensor(np.random.random((2, 2, 2)).astype("float32"), network, "cpu1")
  92. out = F.concat([data1, data2], device="cpu0")
  93. assert str(out.device).split(":")[0] == "cpu0"
  94. @pytest.mark.parametrize("is_varnode", [True, False])
  95. def test_stack(is_varnode):
  96. if is_varnode:
  97. network = Network()
  98. else:
  99. network = None
  100. data1 = np.random.random((3, 2, 2)).astype("float32")
  101. data2 = np.random.random((3, 2, 2)).astype("float32")
  102. data3 = np.random.random((3, 2, 2)).astype("float32")
  103. cases = [{"input": [data1, data2]}, {"input": [data1, data3]}]
  104. for ai in range(3):
  105. def run(data1, data2):
  106. return F.stack([data1, data2], axis=ai)
  107. opr_test(
  108. cases, run, ref_fn=lambda x, y: np.stack([x, y], axis=ai), network=network
  109. )
  110. @pytest.mark.parametrize("is_varnode", [True, False])
  111. def test_split_basic(is_varnode):
  112. if is_varnode:
  113. network = Network()
  114. saved_symbolic_shape = set_symbolic_shape(False)
  115. else:
  116. network = None
  117. data = np.random.random((2, 3, 4, 5)).astype(np.float32)
  118. inp = make_tensor(data, network)
  119. mge_out0 = F.split(inp, 2, axis=3)
  120. mge_out1 = F.split(inp, [3], axis=3)
  121. np_out = np.split(data, [3, 5], axis=3)
  122. assert len(mge_out0) == 2
  123. assert len(mge_out1) == 2
  124. np.testing.assert_equal(mge_out0[0].numpy(), np_out[0])
  125. np.testing.assert_equal(mge_out1[0].numpy(), np_out[0])
  126. np.testing.assert_equal(mge_out0[1].numpy(), np_out[1])
  127. np.testing.assert_equal(mge_out1[1].numpy(), np_out[1])
  128. try:
  129. F.split(inp, 4)
  130. assert False
  131. except ValueError as e:
  132. pass
  133. try:
  134. F.split(inp, [3, 2, 5], axis=3)
  135. assert False
  136. except ValueError as e:
  137. assert str(e) == "Invalid nsplits_or_secions: [3, 2, 5]"
  138. if is_varnode:
  139. set_symbolic_shape(saved_symbolic_shape)
  140. @pytest.mark.parametrize("symbolic", [None, False, True])
  141. def test_split(symbolic):
  142. inp1 = np.random.random((3, 4, 5, 6)).astype(np.float32)
  143. inp2 = np.random.random((0, 4, 5, 6)).astype(np.float32)
  144. def ref(inp, nsplits_or_sections, axis):
  145. return np.split(inp, nsplits_or_sections, axis)
  146. def func(inp, nsplits_or_sections, axis):
  147. return F.split(inp, nsplits_or_sections, axis)
  148. cases = [
  149. (inp1, 2, 3),
  150. (inp1, [3], 3),
  151. (inp1, [3, 3, 5], 3),
  152. (inp2, 2, 3),
  153. (inp2, [3], 3),
  154. (inp2, [3, 3, 5], 3),
  155. ]
  156. for case in cases:
  157. if symbolic is None:
  158. fn = func
  159. else:
  160. fn = trace(symbolic=symbolic)(func)
  161. for i in range(3 if symbolic is not None else 1):
  162. ref_out = ref(*case)
  163. out = fn(tensor(case[0]), case[1], case[2])
  164. assert len(ref_out) == len(out)
  165. for idx in range(len(ref_out)):
  166. np.testing.assert_equal(ref_out[idx], out[idx].numpy())
  167. @pytest.mark.parametrize("is_varnode", [True, False])
  168. def test_reshape(is_varnode):
  169. if is_varnode:
  170. network = Network()
  171. else:
  172. network = None
  173. x = np.arange(6, dtype="float32")
  174. xx = make_tensor(x, network)
  175. y = x.reshape(1, 2, 3)
  176. for shape in [
  177. (1, 2, 3),
  178. (1, -1, 3),
  179. (1, make_tensor(-1, network), 3),
  180. np.array([1, -1, 3], dtype="int32"),
  181. make_tensor([1, -1, 3], network),
  182. ]:
  183. yy = F.reshape(xx, shape)
  184. np.testing.assert_equal(yy.numpy(), y)
  185. @pytest.mark.parametrize("is_varnode", [True, False])
  186. def test_broadcast_auto_infer(is_varnode):
  187. if is_varnode:
  188. network = Network()
  189. else:
  190. network = None
  191. x = np.random.random((1, 2, 3)).astype(np.float32)
  192. xx = make_tensor(x, network)
  193. for shape in [
  194. (1, 2, 3),
  195. (1, None, 3),
  196. ]:
  197. yy = F.broadcast_to(xx, shape)
  198. np.testing.assert_equal(yy.numpy(), x)
  199. with pytest.raises(ValueError):
  200. F.broadcast_to(xx, (1, -1, 3))
  201. with pytest.raises(ValueError):
  202. F.broadcast_to(xx, (None, 1, 2, 3))
  203. F.broadcast_to(xx, (1, None, 2, 3))
  204. t = make_tensor(2, network)
  205. F.broadcast_to(xx, (t, None, 2, 3))
  206. @pytest.mark.parametrize("is_trace", [True, False])
  207. def test_reshape_on_empty_tensor(is_trace):
  208. input1_shape = (100, 0, 1)
  209. output1_shape = (100, 0, 10)
  210. data1 = tensor(np.random.random(input1_shape).astype(np.float32))
  211. input2_shape = (10, 0)
  212. output2_shape = (0,)
  213. data2 = tensor(np.random.random(input2_shape).astype(np.float32))
  214. input3_shape = (10, 0, 10)
  215. output3_shape = (0, 1, 2, 3)
  216. data3 = tensor(np.random.random(input3_shape).astype(np.float32))
  217. def comp(out, target_shp):
  218. assert out._tuple_shape == target_shp
  219. def func(x, shp):
  220. return F.reshape(x, shp)
  221. cases = [
  222. [data1, output1_shape],
  223. [data2, output2_shape],
  224. [data3, output3_shape],
  225. ]
  226. def test(func, inp, comp, target_shp):
  227. out = func(inp, target_shp)
  228. comp(out, target_shp)
  229. if is_trace:
  230. for symbolic in [False, True]:
  231. for inp, target_shp in cases:
  232. func_traced = trace(symbolic=symbolic)(func)
  233. test(func_traced, inp, comp, target_shp)
  234. test(func_traced, inp, comp, target_shp)
  235. test(func_traced, inp, comp, target_shp)
  236. else:
  237. for inp, target_shp in cases:
  238. test(func, inp, comp, target_shp)
  239. @pytest.mark.parametrize("is_varnode", [True, False])
  240. def test_reshape_shape_inference(is_varnode):
  241. if is_varnode:
  242. network = Network()
  243. saved_symbolic_shape = set_symbolic_shape(False)
  244. else:
  245. network = None
  246. x_shape_known = make_tensor([1, 2, 3, 4], network)
  247. x_shape_unknown = F.broadcast_to(
  248. make_tensor([1.0], network), shape=make_tensor([1, 1, 1, 1], network).sum()
  249. )
  250. tshp_unknown = astensor1d(
  251. (make_tensor([2], network), make_tensor([2], network)), x_shape_known
  252. )
  253. tshp_known = astensor1d((2, 2), x_shape_known)
  254. tshp_known_unspec = astensor1d((2, -1), x_shape_known)
  255. def check_shape(output, target):
  256. source = output.shape
  257. if isinstance(source, tensor):
  258. source = source.numpy()
  259. np.testing.assert_equal(source, target.shape)
  260. def func(x, target_shape):
  261. return x.reshape(target_shape)
  262. cases = [
  263. {"input": [x_shape_known, tshp_unknown], "output": [np.zeros((2, 2)),]},
  264. {"input": [x_shape_unknown, tshp_unknown], "output": [np.zeros((2, 2)),]},
  265. {"input": [x_shape_known, tshp_known], "output": [np.zeros((2, 2)),]},
  266. {"input": [x_shape_known, tshp_known_unspec], "output": [np.zeros((2, 2)),]},
  267. {"input": [x_shape_unknown, tshp_known], "output": [np.zeros((2, 2)),]},
  268. {"input": [x_shape_unknown, tshp_known_unspec], "output": [np.zeros((2, 2)),]},
  269. ]
  270. opr_test(cases, func, compare_fn=check_shape, test_trace=True, network=network)
  271. if is_varnode:
  272. set_symbolic_shape(saved_symbolic_shape)
  273. @pytest.mark.parametrize("is_varnode", [True, False])
  274. def test_squeeze(is_varnode):
  275. if is_varnode:
  276. network = Network()
  277. saved_symbolic_shape = set_symbolic_shape(False)
  278. else:
  279. network = None
  280. x = np.arange(6, dtype="float32").reshape(1, 2, 3, 1)
  281. xx = make_tensor(x, network)
  282. for axis in [None, 3, -4, (3, -4)]:
  283. y = np.squeeze(x, axis)
  284. yy = F.squeeze(xx, axis)
  285. np.testing.assert_equal(y, yy.numpy())
  286. if is_varnode:
  287. set_symbolic_shape(saved_symbolic_shape)
  288. @pytest.mark.parametrize("is_varnode", [True, False])
  289. def test_expand_dims(is_varnode):
  290. if is_varnode:
  291. network = Network()
  292. else:
  293. network = None
  294. x = np.arange(6, dtype="float32").reshape(2, 3)
  295. xx = make_tensor(x, network)
  296. for axis in [2, -3, (3, -4), (1, -4)]:
  297. y = np.expand_dims(x, axis)
  298. yy = F.expand_dims(xx, axis)
  299. np.testing.assert_equal(y, yy.numpy())
  300. def test_expand_dims_for_scalar():
  301. x = np.array(1, dtype="float32")
  302. xx = make_tensor(x, None)
  303. for axis in [0, -1, (0, 1), (-1, -2), (0, -1)]:
  304. y = np.expand_dims(x, axis)
  305. yy = F.expand_dims(xx, axis)
  306. np.testing.assert_equal(y, yy.numpy())
  307. for axis in [1, -2, (1, 2), (-2, -3)]:
  308. np.testing.assert_raises(np.AxisError, np.expand_dims, x, axis)
  309. np.testing.assert_raises(RuntimeError, F.expand_dims, xx, axis)
  310. @pytest.mark.parametrize("is_varnode", [True, False])
  311. def test_elemwise_dtype_promotion(is_varnode):
  312. if is_varnode:
  313. network = Network()
  314. else:
  315. network = None
  316. x = np.random.rand(2, 3).astype("float32")
  317. y = np.random.rand(1, 3).astype("float16")
  318. xx = make_tensor(x, network)
  319. yy = make_tensor(y, network)
  320. z = xx * yy
  321. np.testing.assert_equal(z.numpy(), x * y)
  322. z = xx + y
  323. np.testing.assert_equal(z.numpy(), x + y)
  324. z = x - yy
  325. np.testing.assert_equal(z.numpy(), x - y)
  326. @pytest.mark.parametrize("is_varnode", [True, False])
  327. def test_linspace(is_varnode):
  328. if is_varnode:
  329. network = Network()
  330. else:
  331. network = None
  332. cases = [
  333. {"input": [1, 9, 9]},
  334. {"input": [3, 10, 8]},
  335. ]
  336. opr_test(
  337. cases,
  338. F.linspace,
  339. ref_fn=lambda start, end, step: np.linspace(start, end, step, dtype=np.float32),
  340. network=network,
  341. )
  342. cases = [
  343. {"input": [9, 1, 9]},
  344. {"input": [10, 3, 8]},
  345. ]
  346. opr_test(
  347. cases,
  348. F.linspace,
  349. ref_fn=lambda start, end, step: np.linspace(start, end, step, dtype=np.float32),
  350. network=network,
  351. )
  352. cases = [
  353. {"input": [1, make_tensor(9, network), 9]},
  354. {"input": [make_tensor(1, network), 9, make_tensor(9, network)]},
  355. ]
  356. opr_test(
  357. cases,
  358. F.linspace,
  359. ref_fn=lambda start, end, step: np.linspace(1, 9, 9, dtype=np.float32),
  360. network=network,
  361. )
  362. @pytest.mark.parametrize("is_varnode", [True, False])
  363. def test_arange(is_varnode):
  364. if is_varnode:
  365. network = Network()
  366. else:
  367. network = None
  368. cases = [
  369. {"input": [1, 9, 1]},
  370. {"input": [2, 10, 2]},
  371. ]
  372. opr_test(
  373. cases,
  374. F.arange,
  375. ref_fn=lambda start, end, step: np.arange(start, end, step, dtype=np.float32),
  376. network=network,
  377. )
  378. cases = [
  379. {"input": [9, 1, -1]},
  380. {"input": [10, 2, -2]},
  381. ]
  382. opr_test(
  383. cases,
  384. F.arange,
  385. ref_fn=lambda start, end, step: np.arange(start, end, step, dtype=np.float32),
  386. network=network,
  387. )
  388. cases = [
  389. {"input": [9.3, 1.2, -0.5]},
  390. {"input": [10.3, 2.1, -1.7]},
  391. ]
  392. opr_test(
  393. cases,
  394. F.arange,
  395. ref_fn=lambda start, end, step: np.arange(start, end, step, dtype=np.float32),
  396. network=network,
  397. )
  398. @pytest.mark.parametrize("is_varnode", [True, False])
  399. def test_round(is_varnode):
  400. if is_varnode:
  401. network = Network()
  402. else:
  403. network = None
  404. data1_shape = (15,)
  405. data2_shape = (25,)
  406. data1 = np.random.random(data1_shape).astype(np.float32)
  407. data2 = np.random.random(data2_shape).astype(np.float32)
  408. cases = [{"input": data1}, {"input": data2}]
  409. opr_test(cases, F.round, ref_fn=np.round, network=network)
  410. @pytest.mark.parametrize("is_varnode", [True, False])
  411. def test_flatten(is_varnode):
  412. if is_varnode:
  413. network = Network()
  414. else:
  415. network = None
  416. data0_shape = (2, 3, 4, 5)
  417. data1_shape = (4, 5, 6, 7)
  418. data0 = np.random.random(data0_shape).astype(np.float32)
  419. data1 = np.random.random(data1_shape).astype(np.float32)
  420. cases = [
  421. {"input": data0, "output": data0.flatten()},
  422. {"input": data1, "output": data1.flatten()},
  423. ]
  424. opr_test(cases, F.flatten, network=network)
  425. cases = [
  426. {"input": data0, "output": data0.reshape(2, -1)},
  427. {"input": data1, "output": data1.reshape(4, -1)},
  428. ]
  429. opr_test(cases, F.flatten, start_axis=1, network=network)
  430. cases = [
  431. {"input": data0, "output": data0.reshape(2, 3, -1)},
  432. {"input": data1, "output": data1.reshape(4, 5, -1)},
  433. ]
  434. opr_test(cases, F.flatten, start_axis=2, network=network)
  435. cases = [
  436. {"input": data0, "output": data0.reshape(2, -1, 5)},
  437. {"input": data1, "output": data1.reshape(4, -1, 7)},
  438. ]
  439. opr_test(
  440. cases, F.flatten, start_axis=1, end_axis=2, network=network,
  441. )
  442. @pytest.mark.parametrize("is_varnode", [True, False])
  443. def test_broadcast(is_varnode):
  444. if is_varnode:
  445. network = Network()
  446. else:
  447. network = None
  448. input1_shape = (20, 30)
  449. output1_shape = (30, 20, 30)
  450. data1 = np.random.random(input1_shape).astype(np.float32)
  451. input2_shape = (10, 1)
  452. output2_shape = (20, 10, 20)
  453. data2 = np.random.random(input2_shape).astype(np.float32)
  454. input3_shape = (10, 10)
  455. output3_shape = (10, 10)
  456. data3 = np.random.random(input3_shape).astype(np.float32)
  457. cases = [
  458. {
  459. "input": [data1, output1_shape],
  460. "output": np.broadcast_to(data1, output1_shape),
  461. },
  462. {
  463. "input": [data2, output2_shape],
  464. "output": np.broadcast_to(data2, output2_shape),
  465. },
  466. {
  467. "input": [data3, output3_shape],
  468. "output": np.broadcast_to(data3, output3_shape),
  469. },
  470. ]
  471. opr_test(cases, F.broadcast_to, network=network)
  472. x = F.ones((2, 1, 3))
  473. with pytest.raises(RuntimeError):
  474. F.broadcast_to(x, (2, 3, 4))
  475. with pytest.raises(RuntimeError):
  476. F.broadcast_to(x, (4, 1, 3))
  477. with pytest.raises(RuntimeError):
  478. F.broadcast_to(x, (1, 3))
  479. @pytest.mark.parametrize("is_trace", [True, False])
  480. def test_broadcast_on_empty_tensor(is_trace):
  481. input1_shape = (100, 0, 1)
  482. output1_shape = (100, 0, 10)
  483. data1 = tensor(np.random.random(input1_shape).astype(np.float32))
  484. input2_shape = (10, 0)
  485. output2_shape = (10, 10, 0)
  486. data2 = tensor(np.random.random(input2_shape).astype(np.float32))
  487. input3_shape = (0, 0, 1, 10)
  488. output3_shape = (10, 0, 0, 10, 10)
  489. data3 = tensor(np.random.random(input3_shape).astype(np.float32))
  490. def comp(out, target_shp):
  491. assert out._tuple_shape == target_shp
  492. def func(x, shp):
  493. return F.broadcast_to(x, shp)
  494. cases = [
  495. [data1, output1_shape],
  496. [data2, output2_shape],
  497. [data3, output3_shape],
  498. ]
  499. def test(func, inp, comp, target_shp):
  500. out = func(inp, target_shp)
  501. comp(out, target_shp)
  502. if is_trace:
  503. for symbolic in [False, True]:
  504. for inp, target_shp in cases:
  505. func_traced = trace(symbolic=symbolic)(func)
  506. test(func_traced, inp, comp, target_shp)
  507. test(func_traced, inp, comp, target_shp)
  508. test(func_traced, inp, comp, target_shp)
  509. else:
  510. for inp, target_shp in cases:
  511. test(func, inp, comp, target_shp)
  512. @pytest.mark.parametrize("is_varnode", [True, False])
  513. def test_utils_astensor1d(is_varnode):
  514. if is_varnode:
  515. network = Network()
  516. else:
  517. network = None
  518. reference = make_tensor(0, network)
  519. # literal
  520. x = [1, 2, 3]
  521. for dtype in [None, "float32"]:
  522. xx = astensor1d(x, reference, dtype=dtype)
  523. assert isinstance(xx, type(reference))
  524. np.testing.assert_equal(xx.numpy(), x)
  525. # numpy array
  526. x = np.asarray([1, 2, 3], dtype="int32")
  527. for dtype in [None, "float32"]:
  528. xx = astensor1d(x, reference, dtype=dtype)
  529. assert isinstance(xx, type(reference))
  530. np.testing.assert_equal(xx.numpy(), x.astype(dtype) if dtype else x)
  531. # tensor
  532. x = make_tensor([1, 2, 3], network)
  533. for dtype in [None, "float32"]:
  534. xx = astensor1d(x, reference, dtype=dtype)
  535. assert isinstance(xx, type(reference))
  536. np.testing.assert_equal(xx.numpy(), x.numpy())
  537. # mixed
  538. x = [1, make_tensor(2, network), 3]
  539. for dtype in [None, "float32"]:
  540. xx = astensor1d(x, reference, dtype=dtype)
  541. assert isinstance(xx, type(reference))
  542. np.testing.assert_equal(xx.numpy(), [1, 2, 3])
  543. # varnode
  544. if is_varnode:
  545. a = np.array([[1, 2, 3], [4, 5, 6]]).astype("float32")
  546. b = np.array([[True, False, True], [False, True, True]])
  547. aa = make_tensor(a, network)
  548. bb = make_tensor(b, network)
  549. x, y = F.cond_take(bb, aa)
  550. for dtype in [None, "float32"]:
  551. xx = astensor1d(x, reference, dtype=dtype)
  552. assert isinstance(xx, type(reference))
  553. np.testing.assert_equal(get_var_value(xx), get_var_value(x))
  554. def test_device():
  555. x = tensor([1, 2, 3], dtype="float32")
  556. y1 = F.eye(x.shape, dtype="float32")
  557. y2 = F.eye(x.shape, dtype="float32", device=None)
  558. np.testing.assert_almost_equal(y1.numpy(), y2.numpy())
  559. y3 = F.eye(x.shape, dtype="float32", device="xpux")
  560. y4 = F.eye(x.shape, dtype="float32", device=x.device)
  561. np.testing.assert_almost_equal(y3.numpy(), y4.numpy())
  562. y5 = F.full((3, 2), 4, device=x.device)
  563. y6 = F.full((3, 2), 4, device="xpux")
  564. np.testing.assert_almost_equal(y5.numpy(), y6.numpy())
  565. @pytest.mark.parametrize("is_varnode", [True, False])
  566. def test_identity(is_varnode):
  567. if is_varnode:
  568. network = Network()
  569. else:
  570. network = None
  571. x = make_tensor(np.random.random((5, 10)).astype(np.float32), network)
  572. y = F.copy(x)
  573. np.testing.assert_equal(y.numpy(), x)
  574. def copy_test(dst, src, network):
  575. data = np.random.random((2, 3)).astype(np.float32)
  576. x = make_tensor(data, device=src, network=network)
  577. y = F.copy(x, dst)
  578. assert np.allclose(data, y.numpy())
  579. if network is None:
  580. z = x.to(dst)
  581. assert np.allclose(data, z.numpy())
  582. @pytest.mark.require_ngpu(1)
  583. @pytest.mark.parametrize("is_varnode", [True, False])
  584. def test_copy_h2d(is_varnode):
  585. if is_varnode:
  586. network = Network()
  587. else:
  588. network = None
  589. copy_test("cpu0", "gpu0", network=network)
  590. @pytest.mark.require_ngpu(1)
  591. @pytest.mark.parametrize("is_varnode", [True, False])
  592. def test_copy_d2h(is_varnode):
  593. if is_varnode:
  594. network = Network()
  595. else:
  596. network = None
  597. copy_test("gpu0", "cpu0", network=network)
  598. @pytest.mark.require_ngpu(2)
  599. @pytest.mark.parametrize("is_varnode", [True, False])
  600. def test_copy_d2d(is_varnode):
  601. if is_varnode:
  602. network = Network()
  603. else:
  604. network = None
  605. copy_test("gpu0", "gpu1", network=network)
  606. copy_test("gpu0:0", "gpu0:1", network=network)
  607. @pytest.mark.require_ngpu(2)
  608. @pytest.mark.parametrize(
  609. "shape, device_src, device_dst",
  610. [
  611. ((0,), "cpu0", "cpu0"),
  612. ((10, 0), "cpu0", "cpu1"),
  613. ((2, 0, 3), "cpu0", "gpu0"),
  614. ((1, 0, 1, 0), "gpu0", "cpu0"),
  615. ((2, 3, 4, 5, 0), "gpu0", "gpu1"),
  616. ],
  617. )
  618. @pytest.mark.parametrize("is_symbolic", [None, True, False])
  619. def test_copy_empty(shape, device_src, device_dst, is_symbolic):
  620. inp = tensor(np.random.randn(*shape).astype("float32"), device=device_src)
  621. def func(inp):
  622. return F.copy(inp, device_dst)
  623. if is_symbolic is not None:
  624. func = trace(symbolic=is_symbolic)(func)
  625. for _ in range(3):
  626. out = func(inp)
  627. assert out.numpy().shape == shape
  628. assert out.device == device_dst
  629. if is_symbolic is None:
  630. break
  631. @pytest.mark.parametrize(
  632. "shape, repeats, axis",
  633. [
  634. ((2,), 2, 0),
  635. ((2, 3, 4, 5), 3, 0),
  636. ((2, 3, 4, 5), 4, 3),
  637. ((2,), 2, None),
  638. ((2, 3, 4, 5), 3, None),
  639. ((), 1, None),
  640. ((), 10, None),
  641. ],
  642. )
  643. @pytest.mark.parametrize("is_varnode", [True, False])
  644. def test_repeat(shape, repeats, axis, is_varnode):
  645. if is_varnode:
  646. network = Network()
  647. else:
  648. network = None
  649. def repeat_func(inp):
  650. return F.repeat(inp=inp, repeats=repeats, axis=axis)
  651. if shape != ():
  652. cases = [
  653. {"input": np.random.randn(*shape).astype("float32")},
  654. ]
  655. else:
  656. cases = [{"input": np.array(1.23)}]
  657. opr_test(
  658. cases,
  659. repeat_func,
  660. ref_fn=lambda inp: np.repeat(inp, repeats, axis),
  661. network=network,
  662. )
  663. @pytest.mark.parametrize(
  664. "shape, reps",
  665. [
  666. ((2,), (2,)),
  667. ((2, 3, 4, 5), (1, 1, 1, 1)),
  668. ((2, 3, 4, 5), (1, 2, 3, 4)),
  669. # FIXME: tile does not support ndim 7
  670. # ((2, 3, 4, 5), (2, 2, 2, 2, 2, 2, 2)),
  671. ],
  672. )
  673. @pytest.mark.parametrize("is_varnode", [True])
  674. def test_tile(shape, reps, is_varnode):
  675. if is_varnode:
  676. network = Network()
  677. else:
  678. network = None
  679. def tile_func(inp):
  680. return F.tile(inp=inp, reps=reps)
  681. cases = [{"input": np.random.randn(*shape).astype("float32")}]
  682. opr_test(cases, tile_func, ref_fn=lambda inp: np.tile(inp, reps), network=network)
  683. @pytest.mark.parametrize(
  684. "shape, shifts, axis",
  685. [
  686. ((2, 3), 0, None),
  687. ((2, 3), 1, 0),
  688. ((2, 3), 100, 0),
  689. ((2, 3), -100, 0),
  690. ((2, 3, 4, 5), (-1, 1), (0, 1)),
  691. ((2, 3, 4, 5), (-2, 1, 2), (1, 2, 3)),
  692. ],
  693. )
  694. @pytest.mark.parametrize("is_varnode", [True, False])
  695. def test_roll(shape, shifts, axis, is_varnode):
  696. if is_varnode:
  697. network = Network()
  698. else:
  699. network = None
  700. inp = np.random.randn(*shape).astype("float32")
  701. def func(inp):
  702. return F.roll(inp, shifts, axis)
  703. cases = [
  704. {"input": inp},
  705. ]
  706. opr_test(
  707. cases, func, ref_fn=lambda inp: np.roll(inp, shifts, axis), network=network
  708. )
  709. @pytest.mark.parametrize(
  710. "shape, shifts, axis", [((10, 0), 5, 1), ((10, 0), -10, 1),],
  711. )
  712. @pytest.mark.parametrize("is_symbolic", [None, True, False])
  713. def test_roll_empty_tensor(shape, shifts, axis, is_symbolic):
  714. inp = tensor(np.random.randn(*shape).astype("float32"))
  715. def func(inp):
  716. return F.roll(inp, shifts, axis)
  717. if is_symbolic is not None:
  718. func = trace(symbolic=is_symbolic)(func)
  719. out_ref = np.roll(inp.numpy(), shifts, axis)
  720. for _ in range(3):
  721. out = F.roll(inp, shifts, axis)
  722. np.testing.assert_equal(out.numpy(), out_ref)
  723. if is_symbolic is None:
  724. break