You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_tensor.py 25 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863
  1. # -*- coding: utf-8 -*-
  2. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  3. #
  4. # Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  5. #
  6. # Unless required by applicable law or agreed to in writing,
  7. # software distributed under the License is distributed on an
  8. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  9. import os
  10. import platform
  11. import numpy as np
  12. import pytest
  13. from utils import get_var_value, make_tensor, opr_test
  14. import megengine.functional as F
  15. from megengine import tensor
  16. from megengine.core._trace_option import use_symbolic_shape
  17. from megengine.core.tensor import megbrain_graph as G
  18. from megengine.core.tensor.utils import astensor1d
  19. from megengine.jit import trace
  20. from megengine.utils.network import Network, set_symbolic_shape
  21. from megengine.utils.network_node import VarNode
  22. def test_eye():
  23. dtypes = [np.float32, np.bool]
  24. cases = [{"input": [10, 20]}, {"input": [30]}]
  25. for dtype in dtypes:
  26. for case in cases:
  27. np.testing.assert_allclose(
  28. F.eye(case["input"], dtype=dtype).numpy(),
  29. np.eye(*case["input"]).astype(dtype),
  30. )
  31. np.testing.assert_allclose(
  32. F.eye(*case["input"], dtype=dtype).numpy(),
  33. np.eye(*case["input"]).astype(dtype),
  34. )
  35. np.testing.assert_allclose(
  36. F.eye(tensor(case["input"]), dtype=dtype).numpy(),
  37. np.eye(*case["input"]).astype(dtype),
  38. )
  39. def test_full():
  40. shape = (2, 3)
  41. values = [True, 4, 5.0]
  42. for value in values:
  43. np.testing.assert_allclose(F.full(shape, value).numpy(), np.full(shape, value))
  44. assert F.full(shape, value).dtype == tensor(value).dtype
  45. @pytest.mark.parametrize("is_varnode", [True, False])
  46. def test_concat(is_varnode):
  47. if is_varnode:
  48. network = Network()
  49. else:
  50. network = None
  51. def get_data_shape(length: int):
  52. return (length, 2, 3)
  53. data1 = np.random.random(get_data_shape(5)).astype("float32")
  54. data2 = np.random.random(get_data_shape(6)).astype("float32")
  55. data3 = np.random.random(get_data_shape(7)).astype("float32")
  56. def run(data1, data2):
  57. return F.concat([data1, data2])
  58. cases = [{"input": [data1, data2]}, {"input": [data1, data3]}]
  59. opr_test(cases, run, ref_fn=lambda x, y: np.concatenate([x, y]), network=network)
  60. @pytest.mark.parametrize("is_varnode", [True, False])
  61. def test_condtake(is_varnode):
  62. if is_varnode:
  63. network = Network()
  64. else:
  65. network = None
  66. x = np.array([[1, 2, 3], [4, 5, 6]]).astype("float32")
  67. y = np.array([[True, False, True], [False, True, True]])
  68. xx = make_tensor(x, network)
  69. yy = make_tensor(y, network)
  70. val, idx = F.cond_take(yy, xx)
  71. if is_varnode:
  72. np.testing.assert_equal(get_var_value(val), x[y])
  73. np.testing.assert_equal(get_var_value(idx), np.where(y.reshape(-1))[0])
  74. else:
  75. np.testing.assert_equal(val.numpy(), x[y])
  76. np.testing.assert_equal(idx.numpy(), np.where(y.reshape(-1))[0])
  77. @pytest.mark.parametrize("is_varnode", [True, False])
  78. def test_concat_device(is_varnode):
  79. if is_varnode:
  80. network = Network()
  81. else:
  82. network = None
  83. data1 = make_tensor(np.random.random((3, 2, 2)).astype("float32"), network, "cpu0")
  84. data2 = make_tensor(np.random.random((2, 2, 2)).astype("float32"), network, "cpu1")
  85. out = F.concat([data1, data2], device="cpu0")
  86. assert str(out.device).split(":")[0] == "cpu0"
  87. @pytest.mark.parametrize("is_varnode", [True, False])
  88. def test_stack(is_varnode):
  89. if is_varnode:
  90. network = Network()
  91. else:
  92. network = None
  93. data1 = np.random.random((3, 2, 2)).astype("float32")
  94. data2 = np.random.random((3, 2, 2)).astype("float32")
  95. data3 = np.random.random((3, 2, 2)).astype("float32")
  96. cases = [{"input": [data1, data2]}, {"input": [data1, data3]}]
  97. for ai in range(3):
  98. def run(data1, data2):
  99. return F.stack([data1, data2], axis=ai)
  100. opr_test(
  101. cases, run, ref_fn=lambda x, y: np.stack([x, y], axis=ai), network=network
  102. )
  103. @pytest.mark.parametrize("is_varnode", [True, False])
  104. def test_split_basic(is_varnode):
  105. if is_varnode:
  106. network = Network()
  107. saved_symbolic_shape = set_symbolic_shape(False)
  108. else:
  109. network = None
  110. data = np.random.random((2, 3, 4, 5)).astype(np.float32)
  111. inp = make_tensor(data, network)
  112. mge_out0 = F.split(inp, 2, axis=3)
  113. mge_out1 = F.split(inp, [3], axis=3)
  114. np_out = np.split(data, [3, 5], axis=3)
  115. assert len(mge_out0) == 2
  116. assert len(mge_out1) == 2
  117. np.testing.assert_equal(mge_out0[0].numpy(), np_out[0])
  118. np.testing.assert_equal(mge_out1[0].numpy(), np_out[0])
  119. np.testing.assert_equal(mge_out0[1].numpy(), np_out[1])
  120. np.testing.assert_equal(mge_out1[1].numpy(), np_out[1])
  121. try:
  122. F.split(inp, 4)
  123. assert False
  124. except ValueError as e:
  125. pass
  126. try:
  127. F.split(inp, [3, 2, 5], axis=3)
  128. assert False
  129. except ValueError as e:
  130. assert str(e) == "Invalid nsplits_or_secions: [3, 2, 5]"
  131. if is_varnode:
  132. set_symbolic_shape(saved_symbolic_shape)
  133. @pytest.mark.parametrize("symbolic", [None, False, True])
  134. def test_split(symbolic):
  135. inp1 = np.random.random((3, 4, 5, 6)).astype(np.float32)
  136. inp2 = np.random.random((0, 4, 5, 6)).astype(np.float32)
  137. def ref(inp, nsplits_or_sections, axis):
  138. return np.split(inp, nsplits_or_sections, axis)
  139. def func(inp, nsplits_or_sections, axis):
  140. return F.split(inp, nsplits_or_sections, axis)
  141. cases = [
  142. (inp1, 2, 3),
  143. (inp1, [3], 3),
  144. (inp1, [3, 3, 5], 3),
  145. (inp2, 2, 3),
  146. (inp2, [3], 3),
  147. (inp2, [3, 3, 5], 3),
  148. ]
  149. for case in cases:
  150. if symbolic is None:
  151. fn = func
  152. else:
  153. fn = trace(symbolic=symbolic)(func)
  154. for i in range(3 if symbolic is not None else 1):
  155. ref_out = ref(*case)
  156. out = fn(tensor(case[0]), case[1], case[2])
  157. assert len(ref_out) == len(out)
  158. for idx in range(len(ref_out)):
  159. np.testing.assert_equal(ref_out[idx], out[idx].numpy())
  160. @pytest.mark.parametrize("is_varnode", [True, False])
  161. def test_reshape(is_varnode):
  162. if is_varnode:
  163. network = Network()
  164. else:
  165. network = None
  166. x = np.arange(6, dtype="float32")
  167. xx = make_tensor(x, network)
  168. y = x.reshape(1, 2, 3)
  169. for shape in [
  170. (1, 2, 3),
  171. (1, -1, 3),
  172. (1, make_tensor(-1, network), 3),
  173. np.array([1, -1, 3], dtype="int32"),
  174. make_tensor([1, -1, 3], network),
  175. ]:
  176. yy = F.reshape(xx, shape)
  177. np.testing.assert_equal(yy.numpy(), y)
  178. @pytest.mark.parametrize("is_trace", [True, False])
  179. def test_reshape_on_empty_tensor(is_trace):
  180. input1_shape = (100, 0, 1)
  181. output1_shape = (100, 0, 10)
  182. data1 = tensor(np.random.random(input1_shape).astype(np.float32))
  183. input2_shape = (10, 0)
  184. output2_shape = (0,)
  185. data2 = tensor(np.random.random(input2_shape).astype(np.float32))
  186. input3_shape = (10, 0, 10)
  187. output3_shape = (0, 1, 2, 3)
  188. data3 = tensor(np.random.random(input3_shape).astype(np.float32))
  189. def comp(out, target_shp):
  190. assert out._tuple_shape == target_shp
  191. def func(x, shp):
  192. return F.reshape(x, shp)
  193. cases = [
  194. [data1, output1_shape],
  195. [data2, output2_shape],
  196. [data3, output3_shape],
  197. ]
  198. def test(func, inp, comp, target_shp):
  199. out = func(inp, target_shp)
  200. comp(out, target_shp)
  201. if is_trace:
  202. for symbolic in [False, True]:
  203. for inp, target_shp in cases:
  204. func_traced = trace(symbolic=symbolic)(func)
  205. test(func_traced, inp, comp, target_shp)
  206. test(func_traced, inp, comp, target_shp)
  207. test(func_traced, inp, comp, target_shp)
  208. else:
  209. for inp, target_shp in cases:
  210. test(func, inp, comp, target_shp)
  211. @pytest.mark.parametrize("is_varnode", [True, False])
  212. def test_reshape_shape_inference(is_varnode):
  213. if is_varnode:
  214. network = Network()
  215. saved_symbolic_shape = set_symbolic_shape(False)
  216. else:
  217. network = None
  218. x_shape_known = make_tensor([1, 2, 3, 4], network)
  219. x_shape_unknown = F.broadcast_to(
  220. make_tensor([1.0], network), shape=make_tensor([1, 1, 1, 1], network).sum()
  221. )
  222. tshp_unknown = astensor1d(
  223. (make_tensor([2], network), make_tensor([2], network)), x_shape_known
  224. )
  225. tshp_known = astensor1d((2, 2), x_shape_known)
  226. tshp_known_unspec = astensor1d((2, -1), x_shape_known)
  227. def check_shape(output, target):
  228. source = output.shape
  229. if isinstance(source, tensor):
  230. source = source.numpy()
  231. np.testing.assert_equal(source, target)
  232. def func(x, target_shape):
  233. return x.reshape(target_shape)
  234. cases = [
  235. {"input": [x_shape_known, tshp_unknown], "output": [(2, 2),]},
  236. {"input": [x_shape_unknown, tshp_unknown], "output": [(2, 2),]},
  237. {"input": [x_shape_known, tshp_known], "output": [(2, 2),]},
  238. {"input": [x_shape_known, tshp_known_unspec], "output": [(2, 2),]},
  239. {"input": [x_shape_unknown, tshp_known], "output": [(2, 2),]},
  240. {"input": [x_shape_unknown, tshp_known_unspec], "output": [(2, 2),]},
  241. ]
  242. opr_test(cases, func, compare_fn=check_shape, test_trace=True, network=network)
  243. if is_varnode:
  244. set_symbolic_shape(saved_symbolic_shape)
  245. @pytest.mark.parametrize("is_varnode", [True, False])
  246. def test_squeeze(is_varnode):
  247. if is_varnode:
  248. network = Network()
  249. saved_symbolic_shape = set_symbolic_shape(False)
  250. else:
  251. network = None
  252. x = np.arange(6, dtype="float32").reshape(1, 2, 3, 1)
  253. xx = make_tensor(x, network)
  254. for axis in [None, 3, -4, (3, -4)]:
  255. y = np.squeeze(x, axis)
  256. yy = F.squeeze(xx, axis)
  257. np.testing.assert_equal(y, yy.numpy())
  258. if is_varnode:
  259. set_symbolic_shape(saved_symbolic_shape)
  260. @pytest.mark.parametrize("is_varnode", [True, False])
  261. def test_expand_dims(is_varnode):
  262. if is_varnode:
  263. network = Network()
  264. else:
  265. network = None
  266. x = np.arange(6, dtype="float32").reshape(2, 3)
  267. xx = make_tensor(x, network)
  268. for axis in [2, -3, (3, -4), (1, -4)]:
  269. y = np.expand_dims(x, axis)
  270. yy = F.expand_dims(xx, axis)
  271. np.testing.assert_equal(y, yy.numpy())
  272. def test_expand_dims_for_scalar():
  273. x = np.array(1, dtype="float32")
  274. xx = make_tensor(x, None)
  275. for axis in [0, -1, (0, 1), (-1, -2), (0, -1)]:
  276. y = np.expand_dims(x, axis)
  277. yy = F.expand_dims(xx, axis)
  278. np.testing.assert_equal(y, yy.numpy())
  279. for axis in [1, -2, (1, 2), (-2, -3)]:
  280. np.testing.assert_raises(np.AxisError, np.expand_dims, x, axis)
  281. np.testing.assert_raises(RuntimeError, F.expand_dims, xx, axis)
  282. @pytest.mark.parametrize("is_varnode", [True, False])
  283. def test_elemwise_dtype_promotion(is_varnode):
  284. if is_varnode:
  285. network = Network()
  286. else:
  287. network = None
  288. x = np.random.rand(2, 3).astype("float32")
  289. y = np.random.rand(1, 3).astype("float16")
  290. xx = make_tensor(x, network)
  291. yy = make_tensor(y, network)
  292. z = xx * yy
  293. np.testing.assert_equal(z.numpy(), x * y)
  294. z = xx + y
  295. np.testing.assert_equal(z.numpy(), x + y)
  296. z = x - yy
  297. np.testing.assert_equal(z.numpy(), x - y)
  298. @pytest.mark.parametrize("is_varnode", [True, False])
  299. def test_linspace(is_varnode):
  300. if is_varnode:
  301. network = Network()
  302. else:
  303. network = None
  304. cases = [
  305. {"input": [1, 9, 9]},
  306. {"input": [3, 10, 8]},
  307. ]
  308. opr_test(
  309. cases,
  310. F.linspace,
  311. ref_fn=lambda start, end, step: np.linspace(start, end, step, dtype=np.float32),
  312. network=network,
  313. )
  314. cases = [
  315. {"input": [9, 1, 9]},
  316. {"input": [10, 3, 8]},
  317. ]
  318. opr_test(
  319. cases,
  320. F.linspace,
  321. ref_fn=lambda start, end, step: np.linspace(start, end, step, dtype=np.float32),
  322. network=network,
  323. )
  324. cases = [
  325. {"input": [1, make_tensor(9, network), 9]},
  326. {"input": [make_tensor(1, network), 9, make_tensor(9, network)]},
  327. ]
  328. opr_test(
  329. cases,
  330. F.linspace,
  331. ref_fn=lambda start, end, step: np.linspace(1, 9, 9, dtype=np.float32),
  332. network=network,
  333. )
  334. @pytest.mark.parametrize("is_varnode", [True, False])
  335. def test_arange(is_varnode):
  336. if is_varnode:
  337. network = Network()
  338. else:
  339. network = None
  340. cases = [
  341. {"input": [1, 9, 1]},
  342. {"input": [2, 10, 2]},
  343. ]
  344. opr_test(
  345. cases,
  346. F.arange,
  347. ref_fn=lambda start, end, step: np.arange(start, end, step, dtype=np.float32),
  348. network=network,
  349. )
  350. cases = [
  351. {"input": [9, 1, -1]},
  352. {"input": [10, 2, -2]},
  353. ]
  354. opr_test(
  355. cases,
  356. F.arange,
  357. ref_fn=lambda start, end, step: np.arange(start, end, step, dtype=np.float32),
  358. network=network,
  359. )
  360. cases = [
  361. {"input": [9.3, 1.2, -0.5]},
  362. {"input": [10.3, 2.1, -1.7]},
  363. ]
  364. opr_test(
  365. cases,
  366. F.arange,
  367. ref_fn=lambda start, end, step: np.arange(start, end, step, dtype=np.float32),
  368. network=network,
  369. )
  370. @pytest.mark.parametrize("is_varnode", [True, False])
  371. def test_round(is_varnode):
  372. if is_varnode:
  373. network = Network()
  374. else:
  375. network = None
  376. data1_shape = (15,)
  377. data2_shape = (25,)
  378. data1 = np.random.random(data1_shape).astype(np.float32)
  379. data2 = np.random.random(data2_shape).astype(np.float32)
  380. cases = [{"input": data1}, {"input": data2}]
  381. opr_test(cases, F.round, ref_fn=np.round, network=network)
  382. @pytest.mark.parametrize("is_varnode", [True, False])
  383. def test_flatten(is_varnode):
  384. if is_varnode:
  385. network = Network()
  386. else:
  387. network = None
  388. data0_shape = (2, 3, 4, 5)
  389. data1_shape = (4, 5, 6, 7)
  390. data0 = np.random.random(data0_shape).astype(np.float32)
  391. data1 = np.random.random(data1_shape).astype(np.float32)
  392. def compare_fn(x, y):
  393. assert x._tuple_shape[0] == y
  394. output0 = (2 * 3 * 4 * 5,)
  395. output1 = (4 * 5 * 6 * 7,)
  396. cases = [
  397. {"input": data0, "output": output0},
  398. {"input": data1, "output": output1},
  399. ]
  400. opr_test(cases, F.flatten, compare_fn=compare_fn, network=network)
  401. output0 = (2, 3 * 4 * 5)
  402. output1 = (4, 5 * 6 * 7)
  403. cases = [
  404. {"input": data0, "output": output0},
  405. {"input": data1, "output": output1},
  406. ]
  407. opr_test(cases, F.flatten, compare_fn=compare_fn, start_axis=1, network=network)
  408. output0 = (2, 3, 4 * 5)
  409. output1 = (4, 5, 6 * 7)
  410. cases = [
  411. {"input": data0, "output": output0},
  412. {"input": data1, "output": output1},
  413. ]
  414. opr_test(cases, F.flatten, compare_fn=compare_fn, start_axis=2, network=network)
  415. output0 = (2, 3 * 4, 5)
  416. output1 = (4, 5 * 6, 7)
  417. cases = [
  418. {"input": data0, "output": output0},
  419. {"input": data1, "output": output1},
  420. ]
  421. opr_test(
  422. cases,
  423. F.flatten,
  424. compare_fn=compare_fn,
  425. start_axis=1,
  426. end_axis=2,
  427. network=network,
  428. )
  429. @pytest.mark.parametrize("is_varnode", [True, False])
  430. def test_broadcast(is_varnode):
  431. if is_varnode:
  432. network = Network()
  433. else:
  434. network = None
  435. input1_shape = (20, 30)
  436. output1_shape = (30, 20, 30)
  437. data1 = np.random.random(input1_shape).astype(np.float32)
  438. input2_shape = (10, 1)
  439. output2_shape = (20, 10, 20)
  440. data2 = np.random.random(input2_shape).astype(np.float32)
  441. input3_shape = (10, 10)
  442. output3_shape = (10, 10)
  443. data3 = np.random.random(input3_shape).astype(np.float32)
  444. def compare_fn(x, y):
  445. assert x._tuple_shape[0] == y
  446. cases = [
  447. {"input": [data1, output1_shape], "output": output1_shape},
  448. {"input": [data2, output2_shape], "output": output2_shape},
  449. {"input": [data3, output3_shape], "output": output3_shape},
  450. ]
  451. opr_test(cases, F.broadcast_to, compare_fn=compare_fn, network=network)
  452. x = F.ones((2, 1, 3))
  453. with pytest.raises(RuntimeError):
  454. F.broadcast_to(x, (2, 3, 4))
  455. with pytest.raises(RuntimeError):
  456. F.broadcast_to(x, (4, 1, 3))
  457. with pytest.raises(RuntimeError):
  458. F.broadcast_to(x, (1, 3))
  459. @pytest.mark.parametrize("is_trace", [True, False])
  460. def test_broadcast_on_empty_tensor(is_trace):
  461. input1_shape = (100, 0, 1)
  462. output1_shape = (100, 0, 10)
  463. data1 = tensor(np.random.random(input1_shape).astype(np.float32))
  464. input2_shape = (10, 0)
  465. output2_shape = (10, 10, 0)
  466. data2 = tensor(np.random.random(input2_shape).astype(np.float32))
  467. input3_shape = (0, 0, 1, 10)
  468. output3_shape = (10, 0, 0, 10, 10)
  469. data3 = tensor(np.random.random(input3_shape).astype(np.float32))
  470. def comp(out, target_shp):
  471. assert out._tuple_shape == target_shp
  472. def func(x, shp):
  473. return F.broadcast_to(x, shp)
  474. cases = [
  475. [data1, output1_shape],
  476. [data2, output2_shape],
  477. [data3, output3_shape],
  478. ]
  479. def test(func, inp, comp, target_shp):
  480. out = func(inp, target_shp)
  481. comp(out, target_shp)
  482. if is_trace:
  483. for symbolic in [False, True]:
  484. for inp, target_shp in cases:
  485. func_traced = trace(symbolic=symbolic)(func)
  486. test(func_traced, inp, comp, target_shp)
  487. test(func_traced, inp, comp, target_shp)
  488. test(func_traced, inp, comp, target_shp)
  489. else:
  490. for inp, target_shp in cases:
  491. test(func, inp, comp, target_shp)
  492. @pytest.mark.parametrize("is_varnode", [True, False])
  493. def test_utils_astensor1d(is_varnode):
  494. if is_varnode:
  495. network = Network()
  496. else:
  497. network = None
  498. reference = make_tensor(0, network)
  499. # literal
  500. x = [1, 2, 3]
  501. for dtype in [None, "float32"]:
  502. xx = astensor1d(x, reference, dtype=dtype)
  503. assert isinstance(xx, type(reference))
  504. np.testing.assert_equal(xx.numpy(), x)
  505. # numpy array
  506. x = np.asarray([1, 2, 3], dtype="int32")
  507. for dtype in [None, "float32"]:
  508. xx = astensor1d(x, reference, dtype=dtype)
  509. assert isinstance(xx, type(reference))
  510. np.testing.assert_equal(xx.numpy(), x.astype(dtype) if dtype else x)
  511. # tensor
  512. x = make_tensor([1, 2, 3], network)
  513. for dtype in [None, "float32"]:
  514. xx = astensor1d(x, reference, dtype=dtype)
  515. assert isinstance(xx, type(reference))
  516. np.testing.assert_equal(xx.numpy(), x.numpy())
  517. # mixed
  518. x = [1, make_tensor(2, network), 3]
  519. for dtype in [None, "float32"]:
  520. xx = astensor1d(x, reference, dtype=dtype)
  521. assert isinstance(xx, type(reference))
  522. np.testing.assert_equal(xx.numpy(), [1, 2, 3])
  523. def test_device():
  524. x = tensor([1, 2, 3], dtype="float32")
  525. y1 = F.eye(x.shape, dtype="float32")
  526. y2 = F.eye(x.shape, dtype="float32", device=None)
  527. np.testing.assert_almost_equal(y1.numpy(), y2.numpy())
  528. y3 = F.eye(x.shape, dtype="float32", device="xpux")
  529. y4 = F.eye(x.shape, dtype="float32", device=x.device)
  530. np.testing.assert_almost_equal(y3.numpy(), y4.numpy())
  531. y5 = F.full((3, 2), 4, device=x.device)
  532. y6 = F.full((3, 2), 4, device="xpux")
  533. np.testing.assert_almost_equal(y5.numpy(), y6.numpy())
  534. @pytest.mark.parametrize("is_varnode", [True, False])
  535. def test_identity(is_varnode):
  536. if is_varnode:
  537. network = Network()
  538. else:
  539. network = None
  540. x = make_tensor(np.random.random((5, 10)).astype(np.float32), network)
  541. y = F.copy(x)
  542. np.testing.assert_equal(y.numpy(), x)
  543. def copy_test(dst, src, network):
  544. data = np.random.random((2, 3)).astype(np.float32)
  545. x = make_tensor(data, device=src, network=network)
  546. y = F.copy(x, dst)
  547. assert np.allclose(data, y.numpy())
  548. if network is None:
  549. z = x.to(dst)
  550. assert np.allclose(data, z.numpy())
  551. @pytest.mark.require_ngpu(1)
  552. @pytest.mark.parametrize("is_varnode", [True, False])
  553. def test_copy_h2d(is_varnode):
  554. if is_varnode:
  555. network = Network()
  556. else:
  557. network = None
  558. copy_test("cpu0", "gpu0", network=network)
  559. @pytest.mark.require_ngpu(1)
  560. @pytest.mark.parametrize("is_varnode", [True, False])
  561. def test_copy_d2h(is_varnode):
  562. if is_varnode:
  563. network = Network()
  564. else:
  565. network = None
  566. copy_test("gpu0", "cpu0", network=network)
  567. @pytest.mark.require_ngpu(2)
  568. @pytest.mark.parametrize("is_varnode", [True, False])
  569. def test_copy_d2d(is_varnode):
  570. if is_varnode:
  571. network = Network()
  572. else:
  573. network = None
  574. copy_test("gpu0", "gpu1", network=network)
  575. copy_test("gpu0:0", "gpu0:1", network=network)
  576. @pytest.mark.require_ngpu(2)
  577. @pytest.mark.parametrize(
  578. "shape, device_src, device_dst",
  579. [
  580. ((0,), "cpu0", "cpu0"),
  581. ((10, 0), "cpu0", "cpu1"),
  582. ((2, 0, 3), "cpu0", "gpu0"),
  583. ((1, 0, 1, 0), "gpu0", "cpu0"),
  584. ((2, 3, 4, 5, 0), "gpu0", "gpu1"),
  585. ],
  586. )
  587. @pytest.mark.parametrize("is_symbolic", [None, True, False])
  588. def test_copy_empty(shape, device_src, device_dst, is_symbolic):
  589. inp = tensor(np.random.randn(*shape).astype("float32"), device=device_src)
  590. def func(inp):
  591. return F.copy(inp, device_dst)
  592. if is_symbolic is not None:
  593. func = trace(symbolic=is_symbolic)(func)
  594. for _ in range(3):
  595. out = func(inp)
  596. assert out.numpy().shape == shape
  597. assert out.device == device_dst
  598. if is_symbolic is None:
  599. break
  600. @pytest.mark.parametrize(
  601. "shape, repeats, axis",
  602. [
  603. ((2,), 2, 0),
  604. ((2, 3, 4, 5), 3, 0),
  605. ((2, 3, 4, 5), 4, 3),
  606. ((2,), 2, None),
  607. ((2, 3, 4, 5), 3, None),
  608. ((), 1, None),
  609. ((), 10, None),
  610. ],
  611. )
  612. @pytest.mark.parametrize("is_varnode", [True, False])
  613. def test_repeat(shape, repeats, axis, is_varnode):
  614. if is_varnode:
  615. network = Network()
  616. else:
  617. network = None
  618. def repeat_func(inp):
  619. return F.repeat(inp=inp, repeats=repeats, axis=axis)
  620. if shape != ():
  621. cases = [
  622. {"input": np.random.randn(*shape).astype("float32")},
  623. ]
  624. else:
  625. cases = [{"input": np.array(1.23)}]
  626. opr_test(
  627. cases,
  628. repeat_func,
  629. ref_fn=lambda inp: np.repeat(inp, repeats, axis),
  630. network=network,
  631. )
  632. @pytest.mark.parametrize(
  633. "shape, reps",
  634. [
  635. ((2,), (2,)),
  636. ((2, 3, 4, 5), (1, 1, 1, 1)),
  637. ((2, 3, 4, 5), (1, 2, 3, 4)),
  638. ((2, 3, 4, 5), (2, 2, 2, 2, 2, 2, 2)),
  639. ],
  640. )
  641. @pytest.mark.parametrize("is_varnode", [True])
  642. def test_tile(shape, reps, is_varnode):
  643. if is_varnode:
  644. network = Network()
  645. else:
  646. network = None
  647. def tile_func(inp):
  648. return F.tile(inp=inp, reps=reps)
  649. cases = [{"input": np.random.randn(*shape).astype("float32")}]
  650. opr_test(cases, tile_func, ref_fn=lambda inp: np.tile(inp, reps), network=network)
  651. @pytest.mark.parametrize(
  652. "shape, shifts, axis",
  653. [
  654. ((2, 3), 0, None),
  655. ((2, 3), 1, 0),
  656. ((2, 3), 100, 0),
  657. ((2, 3), -100, 0),
  658. ((2, 3, 4, 5), (-1, 1), (0, 1)),
  659. ((2, 3, 4, 5), (-2, 1, 2), (1, 2, 3)),
  660. ],
  661. )
  662. @pytest.mark.parametrize("is_varnode", [True, False])
  663. def test_roll(shape, shifts, axis, is_varnode):
  664. if is_varnode:
  665. network = Network()
  666. else:
  667. network = None
  668. inp = np.random.randn(*shape).astype("float32")
  669. def func(inp):
  670. return F.roll(inp, shifts, axis)
  671. cases = [
  672. {"input": inp},
  673. ]
  674. opr_test(
  675. cases, func, ref_fn=lambda inp: np.roll(inp, shifts, axis), network=network
  676. )
  677. @pytest.mark.parametrize(
  678. "shape, shifts, axis", [((10, 0), 5, 1), ((10, 0), -10, 1),],
  679. )
  680. @pytest.mark.parametrize("is_symbolic", [None, True, False])
  681. def test_roll_empty_tensor(shape, shifts, axis, is_symbolic):
  682. inp = tensor(np.random.randn(*shape).astype("float32"))
  683. def func(inp):
  684. return F.roll(inp, shifts, axis)
  685. if is_symbolic is not None:
  686. func = trace(symbolic=is_symbolic)(func)
  687. out_ref = np.roll(inp.numpy(), shifts, axis)
  688. for _ in range(3):
  689. out = F.roll(inp, shifts, axis)
  690. np.testing.assert_equal(out.numpy(), out_ref)
  691. if is_symbolic is None:
  692. break