You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_tensor.py 31 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059
  1. # -*- coding: utf-8 -*-
  2. import os
  3. import platform
  4. import numpy as np
  5. import pytest
  6. from utils import get_var_value, make_tensor, opr_test
  7. import megengine.functional as F
  8. from megengine import Tensor
  9. from megengine.core._trace_option import use_symbolic_shape
  10. from megengine.core.tensor import megbrain_graph as G
  11. from megengine.core.tensor.utils import astensor1d
  12. from megengine.jit import trace
  13. from megengine.utils.network import Network, set_symbolic_shape
  14. from megengine.utils.network_node import VarNode
  15. def test_eye():
  16. dtypes = [np.float32, np.bool]
  17. cases = [{"input": [10, 20]}, {"input": [30]}]
  18. for dtype in dtypes:
  19. for case in cases:
  20. np.testing.assert_allclose(
  21. F.eye(case["input"], dtype=dtype).numpy(),
  22. np.eye(*case["input"]).astype(dtype),
  23. )
  24. np.testing.assert_allclose(
  25. F.eye(*case["input"], dtype=dtype).numpy(),
  26. np.eye(*case["input"]).astype(dtype),
  27. )
  28. np.testing.assert_allclose(
  29. F.eye(Tensor(case["input"]), dtype=dtype).numpy(),
  30. np.eye(*case["input"]).astype(dtype),
  31. )
  32. @pytest.mark.parametrize("is_varnode", [False, True])
  33. def test_diag(is_varnode):
  34. if is_varnode:
  35. network = Network()
  36. else:
  37. network = None
  38. shapes = [(10, 10), (6, 9), (8, 7), (8,)]
  39. cases = []
  40. for shp in shapes:
  41. cases.append({"input": [np.random.random(shp).astype("float32")]})
  42. for axis in range(-2, 3):
  43. def run(data):
  44. return F.diag(data, k=axis)
  45. opr_test(cases, run, ref_fn=lambda x: np.diag(x, axis), network=network)
  46. def test_full():
  47. shape = (2, 3)
  48. values = [True, 4, 5.0]
  49. for value in values:
  50. np.testing.assert_allclose(F.full(shape, value).numpy(), np.full(shape, value))
  51. assert F.full(shape, value).dtype == Tensor(value).dtype
  52. @pytest.mark.parametrize("is_varnode", [True, False])
  53. def test_cumsum(is_varnode):
  54. if is_varnode:
  55. network = Network()
  56. else:
  57. network = None
  58. x = Tensor([[1, 2, 3], [4, 5, 6]], np.int32)
  59. y = F.cumsum(x, -1)
  60. np.testing.assert_equal(
  61. y.numpy(), np.array([[1, 3, 6], [4, 9, 15]]).astype(np.int32)
  62. )
  63. @pytest.mark.parametrize("is_varnode", [True, False])
  64. def test_concat(is_varnode):
  65. if is_varnode:
  66. network = Network()
  67. else:
  68. network = None
  69. def get_data_shape(length: int):
  70. return (length, 2, 3)
  71. data1 = np.random.random(get_data_shape(5)).astype("float32")
  72. data2 = np.random.random(get_data_shape(6)).astype("float32")
  73. data3 = np.random.random(get_data_shape(7)).astype("float32")
  74. def run(data1, data2):
  75. return F.concat([data1, data2])
  76. cases = [{"input": [data1, data2]}, {"input": [data1, data3]}]
  77. opr_test(cases, run, ref_fn=lambda x, y: np.concatenate([x, y]), network=network)
  78. x1 = Tensor(np.arange(0, 6, dtype=np.float32).reshape((2, 3)))
  79. x2 = Tensor(np.arange(6, 12, dtype=np.float32).reshape((2, 3)))
  80. y = F.concat([x1, x2], axis=-1)
  81. np.testing.assert_equal(
  82. y.numpy(),
  83. np.array([[0, 1, 2, 6, 7, 8], [3, 4, 5, 9, 10, 11]]).astype(np.float32),
  84. )
  85. @pytest.mark.parametrize("is_varnode", [True, False])
  86. def test_condtake(is_varnode):
  87. if is_varnode:
  88. network = Network()
  89. else:
  90. network = None
  91. x = np.array([[1, 2, 3], [4, 5, 6]]).astype("float32")
  92. y = np.array([[True, False, True], [False, True, True]])
  93. xx = make_tensor(x, network)
  94. yy = make_tensor(y, network)
  95. val, idx = F.cond_take(yy, xx)
  96. if is_varnode:
  97. np.testing.assert_equal(get_var_value(val), x[y])
  98. np.testing.assert_equal(get_var_value(idx), np.where(y.reshape(-1))[0])
  99. else:
  100. np.testing.assert_equal(val.numpy(), x[y])
  101. np.testing.assert_equal(idx.numpy(), np.where(y.reshape(-1))[0])
  102. @pytest.mark.parametrize("is_varnode", [True, False])
  103. def test_concat_device(is_varnode):
  104. if is_varnode:
  105. network = Network()
  106. else:
  107. network = None
  108. data1 = make_tensor(np.random.random((3, 2, 2)).astype("float32"), network, "cpu0")
  109. data2 = make_tensor(np.random.random((2, 2, 2)).astype("float32"), network, "cpu1")
  110. out = F.concat([data1, data2], device="cpu0")
  111. assert str(out.device).split(":")[0] == "cpu0"
  112. @pytest.mark.parametrize("is_varnode", [True, False])
  113. def test_stack(is_varnode):
  114. if is_varnode:
  115. network = Network()
  116. else:
  117. network = None
  118. data1 = np.random.random((3, 2, 2)).astype("float32")
  119. data2 = np.random.random((3, 2, 2)).astype("float32")
  120. data3 = np.random.random((3, 2, 2)).astype("float32")
  121. cases = [{"input": [data1, data2]}, {"input": [data1, data3]}]
  122. for ai in range(3):
  123. def run(data1, data2):
  124. return F.stack([data1, data2], axis=ai)
  125. opr_test(
  126. cases, run, ref_fn=lambda x, y: np.stack([x, y], axis=ai), network=network
  127. )
  128. x1 = Tensor(np.arange(0, 3, dtype=np.float32).reshape((3)))
  129. x2 = Tensor(np.arange(6, 9, dtype=np.float32).reshape((3)))
  130. y = F.stack([x1, x2], axis=-1)
  131. np.testing.assert_equal(
  132. y.numpy(), np.array([[0, 6], [1, 7], [2, 8]]).astype(np.float32)
  133. )
  134. x1 = Tensor(np.arange(0, 3, dtype=np.float32).reshape((3)))
  135. x2 = Tensor(np.arange(6, 9, dtype=np.float32).reshape((3)))
  136. y = F.stack([x1, x2], axis=-1)
  137. np.testing.assert_equal(
  138. y.numpy(), np.array([[0, 6], [1, 7], [2, 8]]).astype(np.float32)
  139. )
  140. @pytest.mark.parametrize("is_varnode", [True, False])
  141. def test_split_basic(is_varnode):
  142. if is_varnode:
  143. network = Network()
  144. saved_symbolic_shape = set_symbolic_shape(False)
  145. else:
  146. network = None
  147. data = np.random.random((2, 3, 4, 5)).astype(np.float32)
  148. inp = make_tensor(data, network)
  149. mge_out0 = F.split(inp, 2, axis=3)
  150. mge_out1 = F.split(inp, [3], axis=3)
  151. np_out = np.split(data, [3, 5], axis=3)
  152. assert len(mge_out0) == 2
  153. assert len(mge_out1) == 2
  154. np.testing.assert_equal(mge_out0[0].numpy(), np_out[0])
  155. np.testing.assert_equal(mge_out1[0].numpy(), np_out[0])
  156. np.testing.assert_equal(mge_out0[1].numpy(), np_out[1])
  157. np.testing.assert_equal(mge_out1[1].numpy(), np_out[1])
  158. try:
  159. F.split(inp, 4)
  160. assert False
  161. except ValueError as e:
  162. pass
  163. try:
  164. F.split(inp, [3, 2, 5], axis=3)
  165. assert False
  166. except ValueError as e:
  167. assert str(e) == "Invalid nsplits_or_secions: [3, 2, 5]"
  168. if is_varnode:
  169. set_symbolic_shape(saved_symbolic_shape)
  170. @pytest.mark.parametrize("symbolic", [None, False, True])
  171. def test_split(symbolic):
  172. x = Tensor(np.random.random((10, 20)), dtype=np.float32)
  173. y = F.split(x, 3, axis=-1)
  174. z = F.split(x, [6, 17], axis=-1)
  175. assert str([i.numpy().shape for i in y]) == "[(10, 7), (10, 7), (10, 6)]"
  176. assert str([i.numpy().shape for i in z]) == "[(10, 6), (10, 11), (10, 3)]"
  177. inp1 = np.random.random((3, 4, 5, 6)).astype(np.float32)
  178. inp2 = np.random.random((0, 4, 5, 6)).astype(np.float32)
  179. def ref(inp, nsplits_or_sections, axis):
  180. return np.split(inp, nsplits_or_sections, axis)
  181. def func(inp, nsplits_or_sections, axis):
  182. return F.split(inp, nsplits_or_sections, axis)
  183. cases = [
  184. (inp1, 2, 3),
  185. (inp1, [3], 3),
  186. (inp1, [3, 3, 5], 3),
  187. (inp2, 2, 3),
  188. (inp2, [3], 3),
  189. (inp2, [3, 3, 5], 3),
  190. ]
  191. for case in cases:
  192. if symbolic is None:
  193. fn = func
  194. else:
  195. fn = trace(symbolic=symbolic)(func)
  196. for i in range(3 if symbolic is not None else 1):
  197. ref_out = ref(*case)
  198. out = fn(Tensor(case[0]), case[1], case[2])
  199. assert len(ref_out) == len(out)
  200. for idx in range(len(ref_out)):
  201. np.testing.assert_equal(ref_out[idx], out[idx].numpy())
  202. def test_gather():
  203. x = Tensor([[1, 2], [3, 4], [5, 6],])
  204. index = Tensor([[0, 1], [1, 0], [1, 1]])
  205. y = F.gather(x, 1, index)
  206. np.testing.assert_equal(
  207. y.numpy(), np.array([[1, 2], [4, 3], [6, 6]]).astype(np.int32)
  208. )
  209. def test_scatter():
  210. x = Tensor(np.zeros(shape=(3, 5), dtype=np.float32))
  211. source = Tensor(
  212. [
  213. [0.9935, 0.9465, 0.2256, 0.8926, 0.4396],
  214. [0.7723, 0.0718, 0.5939, 0.357, 0.4576],
  215. ]
  216. )
  217. index = Tensor([[0, 2, 0, 2, 1], [2, 0, 1, 1, 2]])
  218. y = F.scatter(x, -2, index, source)
  219. np.testing.assert_equal(
  220. y.numpy().round(decimals=4),
  221. np.array(
  222. [
  223. [0.9935, 0.0718, 0.2256, 0.0, 0.0],
  224. [0.0, 0.0, 0.5939, 0.357, 0.4396],
  225. [0.7723, 0.9465, 0.0, 0.8926, 0.4576],
  226. ]
  227. ).astype(np.float32),
  228. )
  229. @pytest.mark.parametrize("is_varnode", [True, False])
  230. def test_swapaxes(is_varnode):
  231. if is_varnode:
  232. network = Network()
  233. else:
  234. network = None
  235. x = Tensor(np.array([[1, 2, 3]], dtype=np.int32))
  236. y = F.swapaxes(x, 0, 1)
  237. np.testing.assert_equal(y.numpy(), np.array([[1], [2], [3]]).astype(np.int32))
  238. @pytest.mark.parametrize("is_varnode", [True, False])
  239. def test_reshape(is_varnode):
  240. if is_varnode:
  241. network = Network()
  242. else:
  243. network = None
  244. x = np.arange(6, dtype="float32")
  245. xx = make_tensor(x, network)
  246. y = x.reshape(1, 2, 3)
  247. for shape in [
  248. (1, 2, 3),
  249. (1, -1, 3),
  250. (1, make_tensor(-1, network), 3),
  251. np.array([1, -1, 3], dtype="int32"),
  252. make_tensor([1, -1, 3], network),
  253. ]:
  254. yy = F.reshape(xx, shape)
  255. np.testing.assert_equal(yy.numpy(), y)
  256. @pytest.mark.parametrize("is_varnode", [True, False])
  257. def test_broadcast_auto_infer(is_varnode):
  258. if is_varnode:
  259. network = Network()
  260. else:
  261. network = None
  262. x = np.random.random((1, 2, 3)).astype(np.float32)
  263. xx = make_tensor(x, network)
  264. for shape in [
  265. (1, 2, 3),
  266. (1, None, 3),
  267. ]:
  268. yy = F.broadcast_to(xx, shape)
  269. np.testing.assert_equal(yy.numpy(), x)
  270. with pytest.raises(ValueError):
  271. F.broadcast_to(xx, (1, -1, 3))
  272. with pytest.raises(ValueError):
  273. F.broadcast_to(xx, (None, 1, 2, 3))
  274. F.broadcast_to(xx, (1, None, 2, 3))
  275. t = make_tensor(2, network)
  276. F.broadcast_to(xx, (t, None, 2, 3))
  277. @pytest.mark.parametrize("is_trace", [True, False])
  278. def test_reshape_on_empty_tensor(is_trace):
  279. input1_shape = (100, 0, 1)
  280. output1_shape = (100, 0, 10)
  281. data1 = Tensor(np.random.random(input1_shape).astype(np.float32))
  282. input2_shape = (10, 0)
  283. output2_shape = (0,)
  284. data2 = Tensor(np.random.random(input2_shape).astype(np.float32))
  285. input3_shape = (10, 0, 10)
  286. output3_shape = (0, 1, 2, 3)
  287. data3 = Tensor(np.random.random(input3_shape).astype(np.float32))
  288. def comp(out, target_shp):
  289. assert out._tuple_shape == target_shp
  290. def func(x, shp):
  291. return F.reshape(x, shp)
  292. cases = [
  293. [data1, output1_shape],
  294. [data2, output2_shape],
  295. [data3, output3_shape],
  296. ]
  297. def test(func, inp, comp, target_shp):
  298. out = func(inp, target_shp)
  299. comp(out, target_shp)
  300. if is_trace:
  301. for symbolic in [False, True]:
  302. for inp, target_shp in cases:
  303. func_traced = trace(symbolic=symbolic)(func)
  304. test(func_traced, inp, comp, target_shp)
  305. test(func_traced, inp, comp, target_shp)
  306. test(func_traced, inp, comp, target_shp)
  307. else:
  308. for inp, target_shp in cases:
  309. test(func, inp, comp, target_shp)
  310. @pytest.mark.parametrize("is_varnode", [True, False])
  311. def test_reshape_shape_inference(is_varnode):
  312. if is_varnode:
  313. network = Network()
  314. saved_symbolic_shape = set_symbolic_shape(False)
  315. else:
  316. network = None
  317. x_shape_known = make_tensor([1, 2, 3, 4], network)
  318. x_shape_unknown = F.broadcast_to(
  319. make_tensor([1.0], network), shape=make_tensor([1, 1, 1, 1], network).sum()
  320. )
  321. tshp_unknown = astensor1d(
  322. (make_tensor([2], network), make_tensor([2], network)), x_shape_known
  323. )
  324. tshp_known = astensor1d((2, 2), x_shape_known)
  325. tshp_known_unspec = astensor1d((2, -1), x_shape_known)
  326. def check_shape(output, target):
  327. source = output.shape
  328. if isinstance(source, Tensor):
  329. source = source.numpy()
  330. np.testing.assert_equal(source, target.shape)
  331. def func(x, target_shape):
  332. return x.reshape(target_shape)
  333. cases = [
  334. {"input": [x_shape_known, tshp_unknown], "output": [np.zeros((2, 2)),]},
  335. {"input": [x_shape_unknown, tshp_unknown], "output": [np.zeros((2, 2)),]},
  336. {"input": [x_shape_known, tshp_known], "output": [np.zeros((2, 2)),]},
  337. {"input": [x_shape_known, tshp_known_unspec], "output": [np.zeros((2, 2)),]},
  338. {"input": [x_shape_unknown, tshp_known], "output": [np.zeros((2, 2)),]},
  339. {"input": [x_shape_unknown, tshp_known_unspec], "output": [np.zeros((2, 2)),]},
  340. ]
  341. opr_test(cases, func, compare_fn=check_shape, test_trace=True, network=network)
  342. if is_varnode:
  343. set_symbolic_shape(saved_symbolic_shape)
  344. @pytest.mark.parametrize("is_varnode", [True, False])
  345. def test_squeeze(is_varnode):
  346. if is_varnode:
  347. network = Network()
  348. saved_symbolic_shape = set_symbolic_shape(False)
  349. else:
  350. network = None
  351. x = Tensor(np.array([1, 2], dtype=np.int32).reshape(1, 1, 2, 1))
  352. y = F.squeeze(x, -1)
  353. np.testing.assert_equal(y.numpy(), np.array([[[1, 2]]]).astype(np.int32))
  354. x = np.arange(6, dtype="float32").reshape(1, 2, 3, 1)
  355. xx = make_tensor(x, network)
  356. for axis in [None, 3, -4, (3, -4)]:
  357. y = np.squeeze(x, axis)
  358. yy = F.squeeze(xx, axis)
  359. np.testing.assert_equal(y, yy.numpy())
  360. if is_varnode:
  361. set_symbolic_shape(saved_symbolic_shape)
  362. @pytest.mark.parametrize("is_varnode", [True, False])
  363. def test_expand_dims(is_varnode):
  364. if is_varnode:
  365. network = Network()
  366. else:
  367. network = None
  368. x = Tensor(np.arange(1, 7, dtype=np.int32).reshape(2, 3))
  369. y = F.expand_dims(x, -1)
  370. np.testing.assert_equal(
  371. y.numpy(), np.array([[[1], [2], [3]], [[4], [5], [6]]]).astype(np.int32)
  372. )
  373. x = np.arange(6, dtype="float32").reshape(2, 3)
  374. xx = make_tensor(x, network)
  375. for axis in [2, -3, (3, -4), (1, -4)]:
  376. y = np.expand_dims(x, axis)
  377. yy = F.expand_dims(xx, axis)
  378. np.testing.assert_equal(y, yy.numpy())
  379. def test_expand_dims_for_scalar():
  380. x = np.array(1, dtype="float32")
  381. xx = make_tensor(x, None)
  382. for axis in [0, -1, (0, 1), (-1, -2), (0, -1)]:
  383. y = np.expand_dims(x, axis)
  384. yy = F.expand_dims(xx, axis)
  385. np.testing.assert_equal(y, yy.numpy())
  386. for axis in [1, -2, (1, 2), (-2, -3)]:
  387. np.testing.assert_raises(np.AxisError, np.expand_dims, x, axis)
  388. np.testing.assert_raises(RuntimeError, F.expand_dims, xx, axis)
  389. @pytest.mark.parametrize("is_varnode", [True, False])
  390. def test_elemwise_dtype_promotion(is_varnode):
  391. if is_varnode:
  392. network = Network()
  393. else:
  394. network = None
  395. x = np.random.rand(2, 3).astype("float32")
  396. y = np.random.rand(1, 3).astype("float16")
  397. xx = make_tensor(x, network)
  398. yy = make_tensor(y, network)
  399. z = xx * yy
  400. np.testing.assert_equal(z.numpy(), x * y)
  401. z = xx + y
  402. np.testing.assert_equal(z.numpy(), x + y)
  403. z = x - yy
  404. np.testing.assert_equal(z.numpy(), x - y)
  405. @pytest.mark.parametrize("is_varnode", [True, False])
  406. def test_linspace(is_varnode):
  407. if is_varnode:
  408. network = Network()
  409. else:
  410. network = None
  411. cases = [
  412. {"input": [1, 9, 9]},
  413. {"input": [3, 10, 8]},
  414. ]
  415. opr_test(
  416. cases,
  417. F.linspace,
  418. ref_fn=lambda start, end, step: np.linspace(start, end, step, dtype=np.float32),
  419. network=network,
  420. )
  421. cases = [
  422. {"input": [9, 1, 9]},
  423. {"input": [10, 3, 8]},
  424. ]
  425. opr_test(
  426. cases,
  427. F.linspace,
  428. ref_fn=lambda start, end, step: np.linspace(start, end, step, dtype=np.float32),
  429. network=network,
  430. )
  431. cases = [
  432. {"input": [1, make_tensor(9, network), 9]},
  433. {"input": [make_tensor(1, network), 9, make_tensor(9, network)]},
  434. ]
  435. opr_test(
  436. cases,
  437. F.linspace,
  438. ref_fn=lambda start, end, step: np.linspace(1, 9, 9, dtype=np.float32),
  439. network=network,
  440. )
  441. @pytest.mark.parametrize("is_varnode", [True, False])
  442. def test_arange(is_varnode):
  443. if is_varnode:
  444. network = Network()
  445. else:
  446. network = None
  447. cases = [
  448. {"input": [1, 9, 1]},
  449. {"input": [2, 10, 2]},
  450. ]
  451. opr_test(
  452. cases,
  453. F.arange,
  454. ref_fn=lambda start, end, step: np.arange(start, end, step, dtype=np.float32),
  455. network=network,
  456. )
  457. cases = [
  458. {"input": [9, 1, -1]},
  459. {"input": [10, 2, -2]},
  460. ]
  461. opr_test(
  462. cases,
  463. F.arange,
  464. ref_fn=lambda start, end, step: np.arange(start, end, step, dtype=np.float32),
  465. network=network,
  466. )
  467. cases = [
  468. {"input": [9.3, 1.2, -0.5]},
  469. {"input": [10.3, 2.1, -1.7]},
  470. ]
  471. opr_test(
  472. cases,
  473. F.arange,
  474. ref_fn=lambda start, end, step: np.arange(start, end, step, dtype=np.float32),
  475. network=network,
  476. )
  477. @pytest.mark.parametrize("is_varnode", [True, False])
  478. def test_round(is_varnode):
  479. if is_varnode:
  480. network = Network()
  481. else:
  482. network = None
  483. data1_shape = (15,)
  484. data2_shape = (25,)
  485. data1 = np.random.random(data1_shape).astype(np.float32)
  486. data2 = np.random.random(data2_shape).astype(np.float32)
  487. cases = [{"input": data1}, {"input": data2}]
  488. opr_test(cases, F.round, ref_fn=np.round, network=network)
  489. @pytest.mark.parametrize("is_varnode", [True, False])
  490. def test_flatten(is_varnode):
  491. if is_varnode:
  492. network = Network()
  493. else:
  494. network = None
  495. inp_shape = (2, 2, 3, 3)
  496. x = Tensor(np.arange(36, dtype=np.int32).reshape(inp_shape),)
  497. y = F.flatten(x, -2, -1)
  498. np.testing.assert_equal(
  499. y.numpy(),
  500. np.array(
  501. [
  502. [[0, 1, 2, 3, 4, 5, 6, 7, 8], [9, 10, 11, 12, 13, 14, 15, 16, 17]],
  503. [
  504. [18, 19, 20, 21, 22, 23, 24, 25, 26],
  505. [27, 28, 29, 30, 31, 32, 33, 34, 35],
  506. ],
  507. ]
  508. ).astype(np.int32),
  509. )
  510. data0_shape = (2, 3, 4, 5)
  511. data1_shape = (4, 5, 6, 7)
  512. data0 = np.random.random(data0_shape).astype(np.float32)
  513. data1 = np.random.random(data1_shape).astype(np.float32)
  514. cases = [
  515. {"input": data0, "output": data0.flatten()},
  516. {"input": data1, "output": data1.flatten()},
  517. ]
  518. opr_test(cases, F.flatten, network=network)
  519. cases = [
  520. {"input": data0, "output": data0.reshape(2, -1)},
  521. {"input": data1, "output": data1.reshape(4, -1)},
  522. ]
  523. opr_test(cases, F.flatten, start_axis=1, network=network)
  524. cases = [
  525. {"input": data0, "output": data0.reshape(2, 3, -1)},
  526. {"input": data1, "output": data1.reshape(4, 5, -1)},
  527. ]
  528. opr_test(cases, F.flatten, start_axis=2, network=network)
  529. cases = [
  530. {"input": data0, "output": data0.reshape(2, -1, 5)},
  531. {"input": data1, "output": data1.reshape(4, -1, 7)},
  532. ]
  533. opr_test(
  534. cases, F.flatten, start_axis=1, end_axis=2, network=network,
  535. )
  536. @pytest.mark.parametrize("is_varnode", [True, False])
  537. def test_broadcast(is_varnode):
  538. if is_varnode:
  539. network = Network()
  540. else:
  541. network = None
  542. input1_shape = (20, 30)
  543. output1_shape = (30, 20, 30)
  544. data1 = np.random.random(input1_shape).astype(np.float32)
  545. input2_shape = (10, 1)
  546. output2_shape = (20, 10, 20)
  547. data2 = np.random.random(input2_shape).astype(np.float32)
  548. input3_shape = (10, 10)
  549. output3_shape = (10, 10)
  550. data3 = np.random.random(input3_shape).astype(np.float32)
  551. cases = [
  552. {
  553. "input": [data1, output1_shape],
  554. "output": np.broadcast_to(data1, output1_shape),
  555. },
  556. {
  557. "input": [data2, output2_shape],
  558. "output": np.broadcast_to(data2, output2_shape),
  559. },
  560. {
  561. "input": [data3, output3_shape],
  562. "output": np.broadcast_to(data3, output3_shape),
  563. },
  564. ]
  565. opr_test(cases, F.broadcast_to, network=network)
  566. x = F.ones((2, 1, 3))
  567. with pytest.raises(RuntimeError):
  568. F.broadcast_to(x, (2, 3, 4))
  569. with pytest.raises(RuntimeError):
  570. F.broadcast_to(x, (4, 1, 3))
  571. with pytest.raises(RuntimeError):
  572. F.broadcast_to(x, (1, 3))
  573. @pytest.mark.parametrize("is_trace", [True, False])
  574. def test_broadcast_on_empty_tensor(is_trace):
  575. input1_shape = (100, 0, 1)
  576. output1_shape = (100, 0, 10)
  577. data1 = Tensor(np.random.random(input1_shape).astype(np.float32))
  578. input2_shape = (10, 0)
  579. output2_shape = (10, 10, 0)
  580. data2 = Tensor(np.random.random(input2_shape).astype(np.float32))
  581. input3_shape = (0, 0, 1, 10)
  582. output3_shape = (10, 0, 0, 10, 10)
  583. data3 = Tensor(np.random.random(input3_shape).astype(np.float32))
  584. def comp(out, target_shp):
  585. assert out._tuple_shape == target_shp
  586. def func(x, shp):
  587. return F.broadcast_to(x, shp)
  588. cases = [
  589. [data1, output1_shape],
  590. [data2, output2_shape],
  591. [data3, output3_shape],
  592. ]
  593. def test(func, inp, comp, target_shp):
  594. out = func(inp, target_shp)
  595. comp(out, target_shp)
  596. if is_trace:
  597. for symbolic in [False, True]:
  598. for inp, target_shp in cases:
  599. func_traced = trace(symbolic=symbolic)(func)
  600. test(func_traced, inp, comp, target_shp)
  601. test(func_traced, inp, comp, target_shp)
  602. test(func_traced, inp, comp, target_shp)
  603. else:
  604. for inp, target_shp in cases:
  605. test(func, inp, comp, target_shp)
  606. @pytest.mark.parametrize(
  607. "input_shape, target_shapes",
  608. [
  609. ((3,), [(2, 1, 3), (1, 2, 3), (2, 2, 3)]),
  610. ((1, 3, 1), [(2, None, 3), (3, None, 3), (1, None, 1)]),
  611. ],
  612. )
  613. @pytest.mark.parametrize("is_symbolic", [True, False])
  614. def test_broadcast_on_trace(is_symbolic, input_shape, target_shapes):
  615. x = F.ones(input_shape)
  616. @trace(symbolic=is_symbolic)
  617. def broadcast(inp, shape):
  618. return F.broadcast_to(inp, shape)
  619. for target_shape in target_shapes:
  620. if None in target_shape:
  621. symbolic_target_shape = tuple(
  622. map(lambda x: None if x is None else Tensor(x), target_shape)
  623. )
  624. output = broadcast(x, symbolic_target_shape)
  625. for i in range(len(target_shape)):
  626. if target_shape[i] is not None:
  627. assert output._tuple_shape[i] == target_shape[i]
  628. else:
  629. assert (
  630. output._tuple_shape[i] == x._tuple_shape[i - len(target_shape)]
  631. )
  632. else:
  633. symbolic_target_shape = Tensor(target_shape)
  634. output = broadcast(x, symbolic_target_shape)
  635. assert output._tuple_shape == target_shape
  636. @pytest.mark.parametrize("is_varnode", [True, False])
  637. def test_utils_astensor1d(is_varnode):
  638. if is_varnode:
  639. network = Network()
  640. else:
  641. network = None
  642. reference = make_tensor(0, network)
  643. # literal
  644. x = [1, 2, 3]
  645. for dtype in [None, "float32"]:
  646. xx = astensor1d(x, reference, dtype=dtype)
  647. assert isinstance(xx, type(reference))
  648. np.testing.assert_equal(xx.numpy(), x)
  649. # numpy array
  650. x = np.asarray([1, 2, 3], dtype="int32")
  651. for dtype in [None, "float32"]:
  652. xx = astensor1d(x, reference, dtype=dtype)
  653. assert isinstance(xx, type(reference))
  654. np.testing.assert_equal(xx.numpy(), x.astype(dtype) if dtype else x)
  655. # tensor
  656. x = make_tensor([1, 2, 3], network)
  657. for dtype in [None, "float32"]:
  658. xx = astensor1d(x, reference, dtype=dtype)
  659. assert isinstance(xx, type(reference))
  660. np.testing.assert_equal(xx.numpy(), x.numpy())
  661. # mixed
  662. x = [1, make_tensor(2, network), 3]
  663. for dtype in [None, "float32"]:
  664. xx = astensor1d(x, reference, dtype=dtype)
  665. assert isinstance(xx, type(reference))
  666. np.testing.assert_equal(xx.numpy(), [1, 2, 3])
  667. # varnode
  668. if is_varnode:
  669. a = np.array([[1, 2, 3], [4, 5, 6]]).astype("float32")
  670. b = np.array([[True, False, True], [False, True, True]])
  671. aa = make_tensor(a, network)
  672. bb = make_tensor(b, network)
  673. x, y = F.cond_take(bb, aa)
  674. for dtype in [None, "float32"]:
  675. xx = astensor1d(x, reference, dtype=dtype)
  676. assert isinstance(xx, type(reference))
  677. np.testing.assert_equal(get_var_value(xx), get_var_value(x))
  678. def test_device():
  679. x = Tensor([1, 2, 3], dtype="float32")
  680. y1 = F.eye(x.shape, dtype="float32")
  681. y2 = F.eye(x.shape, dtype="float32", device=None)
  682. np.testing.assert_almost_equal(y1.numpy(), y2.numpy())
  683. y3 = F.eye(x.shape, dtype="float32", device="xpux")
  684. y4 = F.eye(x.shape, dtype="float32", device=x.device)
  685. np.testing.assert_almost_equal(y3.numpy(), y4.numpy())
  686. y5 = F.full((3, 2), 4, device=x.device)
  687. y6 = F.full((3, 2), 4, device="xpux")
  688. np.testing.assert_almost_equal(y5.numpy(), y6.numpy())
  689. @pytest.mark.parametrize("is_varnode", [True, False])
  690. def test_identity(is_varnode):
  691. if is_varnode:
  692. network = Network()
  693. else:
  694. network = None
  695. x = make_tensor(np.random.random((5, 10)).astype(np.float32), network)
  696. y = F.copy(x)
  697. np.testing.assert_equal(y.numpy(), x)
  698. def copy_test(dst, src, network):
  699. data = np.random.random((2, 3)).astype(np.float32)
  700. x = make_tensor(data, device=src, network=network)
  701. y = F.copy(x, dst)
  702. assert np.allclose(data, y.numpy())
  703. if network is None:
  704. z = x.to(dst)
  705. assert np.allclose(data, z.numpy())
  706. @pytest.mark.require_ngpu(1)
  707. @pytest.mark.parametrize("is_varnode", [True, False])
  708. def test_copy_h2d(is_varnode):
  709. if is_varnode:
  710. network = Network()
  711. else:
  712. network = None
  713. copy_test("cpu0", "gpu0", network=network)
  714. @pytest.mark.require_ngpu(1)
  715. @pytest.mark.parametrize("is_varnode", [True, False])
  716. def test_copy_d2h(is_varnode):
  717. if is_varnode:
  718. network = Network()
  719. else:
  720. network = None
  721. copy_test("gpu0", "cpu0", network=network)
  722. @pytest.mark.require_ngpu(2)
  723. @pytest.mark.parametrize("is_varnode", [True, False])
  724. def test_copy_d2d(is_varnode):
  725. if is_varnode:
  726. network = Network()
  727. else:
  728. network = None
  729. copy_test("gpu0", "gpu1", network=network)
  730. copy_test("gpu0:0", "gpu0:1", network=network)
  731. @pytest.mark.require_ngpu(2)
  732. @pytest.mark.parametrize(
  733. "shape, device_src, device_dst",
  734. [
  735. ((0,), "cpu0", "cpu0"),
  736. ((10, 0), "cpu0", "cpu1"),
  737. ((2, 0, 3), "cpu0", "gpu0"),
  738. ((1, 0, 1, 0), "gpu0", "cpu0"),
  739. ((2, 3, 4, 5, 0), "gpu0", "gpu1"),
  740. ],
  741. )
  742. @pytest.mark.parametrize("is_symbolic", [None, True, False])
  743. def test_copy_empty(shape, device_src, device_dst, is_symbolic):
  744. inp = Tensor(np.random.randn(*shape).astype("float32"), device=device_src)
  745. def func(inp):
  746. return F.copy(inp, device_dst)
  747. if is_symbolic is not None:
  748. func = trace(symbolic=is_symbolic)(func)
  749. for _ in range(3):
  750. out = func(inp)
  751. assert out.numpy().shape == shape
  752. assert out.device == device_dst
  753. if is_symbolic is None:
  754. break
  755. @pytest.mark.parametrize(
  756. "shape, repeats, axis",
  757. [
  758. ((2,), 2, 0),
  759. ((2, 3, 4, 5), 3, 0),
  760. ((2, 3, 4, 5), 4, 3),
  761. ((2,), 2, None),
  762. ((2, 3, 4, 5), 3, None),
  763. ((), 1, None),
  764. ((), 10, None),
  765. ],
  766. )
  767. @pytest.mark.parametrize("is_varnode", [True, False])
  768. def test_repeat(shape, repeats, axis, is_varnode):
  769. if is_varnode:
  770. network = Network()
  771. else:
  772. network = None
  773. def repeat_func(inp):
  774. return F.repeat(inp=inp, repeats=repeats, axis=axis)
  775. if shape != ():
  776. cases = [
  777. {"input": np.random.randn(*shape).astype("float32")},
  778. ]
  779. else:
  780. cases = [{"input": np.array(1.23)}]
  781. opr_test(
  782. cases,
  783. repeat_func,
  784. ref_fn=lambda inp: np.repeat(inp, repeats, axis),
  785. network=network,
  786. )
  787. @pytest.mark.parametrize(
  788. "shape, reps",
  789. [
  790. ((2,), (2,)),
  791. ((2, 3, 4, 5), (1, 1, 1, 1)),
  792. ((2, 3, 4, 5), (1, 2, 3, 4)),
  793. # FIXME: tile does not support ndim 7
  794. # ((2, 3, 4, 5), (2, 2, 2, 2, 2, 2, 2)),
  795. ],
  796. )
  797. @pytest.mark.parametrize("is_varnode", [True])
  798. def test_tile(shape, reps, is_varnode):
  799. if is_varnode:
  800. network = Network()
  801. else:
  802. network = None
  803. def tile_func(inp):
  804. return F.tile(inp=inp, reps=reps)
  805. cases = [{"input": np.random.randn(*shape).astype("float32")}]
  806. opr_test(cases, tile_func, ref_fn=lambda inp: np.tile(inp, reps), network=network)
  807. @pytest.mark.parametrize(
  808. "shape, shifts, axis",
  809. [
  810. ((2, 3), 0, None),
  811. ((2, 3), 1, 0),
  812. ((2, 3), 100, 0),
  813. ((2, 3), -100, 0),
  814. ((2, 3, 4, 5), (-1, 1), (0, 1)),
  815. ((2, 3, 4, 5), (-2, 1, 2), (1, 2, 3)),
  816. ],
  817. )
  818. @pytest.mark.parametrize("is_varnode", [True, False])
  819. def test_roll(shape, shifts, axis, is_varnode):
  820. if is_varnode:
  821. network = Network()
  822. else:
  823. network = None
  824. x = Tensor([[1, 2], [3, 4], [5, 6]], np.int32)
  825. y = F.roll(x, 1, -1)
  826. np.testing.assert_equal(
  827. y.numpy(), np.array([[2, 1], [4, 3], [6, 5]]).astype(np.int32)
  828. )
  829. inp = np.random.randn(*shape).astype("float32")
  830. def func(inp):
  831. return F.roll(inp, shifts, axis)
  832. cases = [
  833. {"input": inp},
  834. ]
  835. opr_test(
  836. cases, func, ref_fn=lambda inp: np.roll(inp, shifts, axis), network=network
  837. )
  838. @pytest.mark.parametrize(
  839. "shape, shifts, axis", [((10, 0), 5, 1), ((10, 0), -10, 1),],
  840. )
  841. @pytest.mark.parametrize("is_symbolic", [None, True, False])
  842. def test_roll_empty_tensor(shape, shifts, axis, is_symbolic):
  843. inp = Tensor(np.random.randn(*shape).astype("float32"))
  844. def func(inp):
  845. return F.roll(inp, shifts, axis)
  846. if is_symbolic is not None:
  847. func = trace(symbolic=is_symbolic)(func)
  848. out_ref = np.roll(inp.numpy(), shifts, axis)
  849. for _ in range(3):
  850. out = F.roll(inp, shifts, axis)
  851. np.testing.assert_equal(out.numpy(), out_ref)
  852. if is_symbolic is None:
  853. break