You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_tensor.py 30 kB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025
  1. # -*- coding: utf-8 -*-
  2. import os
  3. import platform
  4. import numpy as np
  5. import pytest
  6. from utils import get_var_value, make_tensor, opr_test
  7. import megengine.functional as F
  8. from megengine import Tensor
  9. from megengine.core._trace_option import use_symbolic_shape
  10. from megengine.core.tensor import megbrain_graph as G
  11. from megengine.core.tensor.utils import astensor1d
  12. from megengine.jit import trace
  13. from megengine.utils.network import Network, set_symbolic_shape
  14. from megengine.utils.network_node import VarNode
  15. def test_eye():
  16. dtypes = [np.float32, np.bool]
  17. cases = [{"input": [10, 20]}, {"input": [30]}]
  18. for dtype in dtypes:
  19. for case in cases:
  20. np.testing.assert_allclose(
  21. F.eye(case["input"], dtype=dtype).numpy(),
  22. np.eye(*case["input"]).astype(dtype),
  23. )
  24. np.testing.assert_allclose(
  25. F.eye(*case["input"], dtype=dtype).numpy(),
  26. np.eye(*case["input"]).astype(dtype),
  27. )
  28. np.testing.assert_allclose(
  29. F.eye(Tensor(case["input"]), dtype=dtype).numpy(),
  30. np.eye(*case["input"]).astype(dtype),
  31. )
  32. @pytest.mark.parametrize("is_varnode", [False, True])
  33. def test_diag(is_varnode):
  34. if is_varnode:
  35. network = Network()
  36. else:
  37. network = None
  38. shapes = [(10, 10), (6, 9), (8, 7), (8,)]
  39. cases = []
  40. for shp in shapes:
  41. cases.append({"input": [np.random.random(shp).astype("float32")]})
  42. for axis in range(-2, 3):
  43. def run(data):
  44. return F.diag(data, k=axis)
  45. opr_test(cases, run, ref_fn=lambda x: np.diag(x, axis), network=network)
  46. def test_full():
  47. shape = (2, 3)
  48. values = [True, 4, 5.0]
  49. for value in values:
  50. np.testing.assert_allclose(F.full(shape, value).numpy(), np.full(shape, value))
  51. assert F.full(shape, value).dtype == Tensor(value).dtype
  52. @pytest.mark.parametrize("is_varnode", [True, False])
  53. def test_cumsum(is_varnode):
  54. if is_varnode:
  55. network = Network()
  56. else:
  57. network = None
  58. x = Tensor([[1, 2, 3], [4, 5, 6]], np.int32)
  59. y = F.cumsum(x, -1)
  60. np.testing.assert_equal(
  61. y.numpy(), np.array([[1, 3, 6], [4, 9, 15]]).astype(np.int32)
  62. )
  63. @pytest.mark.parametrize("is_varnode", [True, False])
  64. def test_concat(is_varnode):
  65. if is_varnode:
  66. network = Network()
  67. else:
  68. network = None
  69. def get_data_shape(length: int):
  70. return (length, 2, 3)
  71. data1 = np.random.random(get_data_shape(5)).astype("float32")
  72. data2 = np.random.random(get_data_shape(6)).astype("float32")
  73. data3 = np.random.random(get_data_shape(7)).astype("float32")
  74. def run(data1, data2):
  75. return F.concat([data1, data2])
  76. cases = [{"input": [data1, data2]}, {"input": [data1, data3]}]
  77. opr_test(cases, run, ref_fn=lambda x, y: np.concatenate([x, y]), network=network)
  78. x1 = Tensor(np.arange(0, 6, dtype=np.float32).reshape((2, 3)))
  79. x2 = Tensor(np.arange(6, 12, dtype=np.float32).reshape((2, 3)))
  80. y = F.concat([x1, x2], axis=-1)
  81. np.testing.assert_equal(
  82. y.numpy(),
  83. np.array([[0, 1, 2, 6, 7, 8], [3, 4, 5, 9, 10, 11]]).astype(np.float32),
  84. )
  85. @pytest.mark.parametrize("is_varnode", [True, False])
  86. def test_condtake(is_varnode):
  87. if is_varnode:
  88. network = Network()
  89. else:
  90. network = None
  91. x = np.array([[1, 2, 3], [4, 5, 6]]).astype("float32")
  92. y = np.array([[True, False, True], [False, True, True]])
  93. xx = make_tensor(x, network)
  94. yy = make_tensor(y, network)
  95. val, idx = F.cond_take(yy, xx)
  96. if is_varnode:
  97. np.testing.assert_equal(get_var_value(val), x[y])
  98. np.testing.assert_equal(get_var_value(idx), np.where(y.reshape(-1))[0])
  99. else:
  100. np.testing.assert_equal(val.numpy(), x[y])
  101. np.testing.assert_equal(idx.numpy(), np.where(y.reshape(-1))[0])
  102. @pytest.mark.parametrize("is_varnode", [True, False])
  103. def test_concat_device(is_varnode):
  104. if is_varnode:
  105. network = Network()
  106. else:
  107. network = None
  108. data1 = make_tensor(np.random.random((3, 2, 2)).astype("float32"), network, "cpu0")
  109. data2 = make_tensor(np.random.random((2, 2, 2)).astype("float32"), network, "cpu1")
  110. out = F.concat([data1, data2], device="cpu0")
  111. assert str(out.device).split(":")[0] == "cpu0"
  112. @pytest.mark.parametrize("is_varnode", [True, False])
  113. def test_stack(is_varnode):
  114. if is_varnode:
  115. network = Network()
  116. else:
  117. network = None
  118. data1 = np.random.random((3, 2, 2)).astype("float32")
  119. data2 = np.random.random((3, 2, 2)).astype("float32")
  120. data3 = np.random.random((3, 2, 2)).astype("float32")
  121. cases = [{"input": [data1, data2]}, {"input": [data1, data3]}]
  122. for ai in range(3):
  123. def run(data1, data2):
  124. return F.stack([data1, data2], axis=ai)
  125. opr_test(
  126. cases, run, ref_fn=lambda x, y: np.stack([x, y], axis=ai), network=network
  127. )
  128. x1 = Tensor(np.arange(0, 3, dtype=np.float32).reshape((3)))
  129. x2 = Tensor(np.arange(6, 9, dtype=np.float32).reshape((3)))
  130. y = F.stack([x1, x2], axis=-1)
  131. np.testing.assert_equal(
  132. y.numpy(), np.array([[0, 6], [1, 7], [2, 8]]).astype(np.float32)
  133. )
  134. x1 = Tensor(np.arange(0, 3, dtype=np.float32).reshape((3)))
  135. x2 = Tensor(np.arange(6, 9, dtype=np.float32).reshape((3)))
  136. y = F.stack([x1, x2], axis=-1)
  137. np.testing.assert_equal(
  138. y.numpy(), np.array([[0, 6], [1, 7], [2, 8]]).astype(np.float32)
  139. )
  140. @pytest.mark.parametrize("is_varnode", [True, False])
  141. def test_split_basic(is_varnode):
  142. if is_varnode:
  143. network = Network()
  144. saved_symbolic_shape = set_symbolic_shape(False)
  145. else:
  146. network = None
  147. data = np.random.random((2, 3, 4, 5)).astype(np.float32)
  148. inp = make_tensor(data, network)
  149. mge_out0 = F.split(inp, 2, axis=3)
  150. mge_out1 = F.split(inp, [3], axis=3)
  151. np_out = np.split(data, [3, 5], axis=3)
  152. assert len(mge_out0) == 2
  153. assert len(mge_out1) == 2
  154. np.testing.assert_equal(mge_out0[0].numpy(), np_out[0])
  155. np.testing.assert_equal(mge_out1[0].numpy(), np_out[0])
  156. np.testing.assert_equal(mge_out0[1].numpy(), np_out[1])
  157. np.testing.assert_equal(mge_out1[1].numpy(), np_out[1])
  158. try:
  159. F.split(inp, 4)
  160. assert False
  161. except ValueError as e:
  162. pass
  163. try:
  164. F.split(inp, [3, 2, 5], axis=3)
  165. assert False
  166. except ValueError as e:
  167. assert str(e) == "Invalid nsplits_or_secions: [3, 2, 5]"
  168. if is_varnode:
  169. set_symbolic_shape(saved_symbolic_shape)
  170. @pytest.mark.parametrize("symbolic", [None, False, True])
  171. def test_split(symbolic):
  172. x = Tensor(np.random.random((10, 20)), dtype=np.float32)
  173. y = F.split(x, 3, axis=-1)
  174. z = F.split(x, [6, 17], axis=-1)
  175. assert str([i.numpy().shape for i in y]) == "[(10, 7), (10, 7), (10, 6)]"
  176. assert str([i.numpy().shape for i in z]) == "[(10, 6), (10, 11), (10, 3)]"
  177. inp1 = np.random.random((3, 4, 5, 6)).astype(np.float32)
  178. inp2 = np.random.random((0, 4, 5, 6)).astype(np.float32)
  179. def ref(inp, nsplits_or_sections, axis):
  180. return np.split(inp, nsplits_or_sections, axis)
  181. def func(inp, nsplits_or_sections, axis):
  182. return F.split(inp, nsplits_or_sections, axis)
  183. cases = [
  184. (inp1, 2, 3),
  185. (inp1, [3], 3),
  186. (inp1, [3, 3, 5], 3),
  187. (inp2, 2, 3),
  188. (inp2, [3], 3),
  189. (inp2, [3, 3, 5], 3),
  190. ]
  191. for case in cases:
  192. if symbolic is None:
  193. fn = func
  194. else:
  195. fn = trace(symbolic=symbolic)(func)
  196. for i in range(3 if symbolic is not None else 1):
  197. ref_out = ref(*case)
  198. out = fn(Tensor(case[0]), case[1], case[2])
  199. assert len(ref_out) == len(out)
  200. for idx in range(len(ref_out)):
  201. np.testing.assert_equal(ref_out[idx], out[idx].numpy())
  202. def test_gather():
  203. x = Tensor([[1, 2], [3, 4], [5, 6],])
  204. index = Tensor([[0, 1], [1, 0], [1, 1]])
  205. y = F.gather(x, 1, index)
  206. np.testing.assert_equal(
  207. y.numpy(), np.array([[1, 2], [4, 3], [6, 6]]).astype(np.int32)
  208. )
  209. def test_scatter():
  210. x = Tensor(np.zeros(shape=(3, 5), dtype=np.float32))
  211. source = Tensor(
  212. [
  213. [0.9935, 0.9465, 0.2256, 0.8926, 0.4396],
  214. [0.7723, 0.0718, 0.5939, 0.357, 0.4576],
  215. ]
  216. )
  217. index = Tensor([[0, 2, 0, 2, 1], [2, 0, 1, 1, 2]])
  218. y = F.scatter(x, -2, index, source)
  219. np.testing.assert_equal(
  220. y.numpy().round(decimals=4),
  221. np.array(
  222. [
  223. [0.9935, 0.0718, 0.2256, 0.0, 0.0],
  224. [0.0, 0.0, 0.5939, 0.357, 0.4396],
  225. [0.7723, 0.9465, 0.0, 0.8926, 0.4576],
  226. ]
  227. ).astype(np.float32),
  228. )
  229. @pytest.mark.parametrize("is_varnode", [True, False])
  230. def test_swapaxes(is_varnode):
  231. if is_varnode:
  232. network = Network()
  233. else:
  234. network = None
  235. x = Tensor(np.array([[1, 2, 3]], dtype=np.int32))
  236. y = F.swapaxes(x, 0, 1)
  237. np.testing.assert_equal(y.numpy(), np.array([[1], [2], [3]]).astype(np.int32))
  238. @pytest.mark.parametrize("is_varnode", [True, False])
  239. def test_reshape(is_varnode):
  240. if is_varnode:
  241. network = Network()
  242. else:
  243. network = None
  244. x = np.arange(6, dtype="float32")
  245. xx = make_tensor(x, network)
  246. y = x.reshape(1, 2, 3)
  247. for shape in [
  248. (1, 2, 3),
  249. (1, -1, 3),
  250. (1, make_tensor(-1, network), 3),
  251. np.array([1, -1, 3], dtype="int32"),
  252. make_tensor([1, -1, 3], network),
  253. ]:
  254. yy = F.reshape(xx, shape)
  255. np.testing.assert_equal(yy.numpy(), y)
  256. @pytest.mark.parametrize("is_varnode", [True, False])
  257. def test_broadcast_auto_infer(is_varnode):
  258. if is_varnode:
  259. network = Network()
  260. else:
  261. network = None
  262. x = np.random.random((1, 2, 3)).astype(np.float32)
  263. xx = make_tensor(x, network)
  264. for shape in [
  265. (1, 2, 3),
  266. (1, None, 3),
  267. ]:
  268. yy = F.broadcast_to(xx, shape)
  269. np.testing.assert_equal(yy.numpy(), x)
  270. with pytest.raises(ValueError):
  271. F.broadcast_to(xx, (1, -1, 3))
  272. with pytest.raises(ValueError):
  273. F.broadcast_to(xx, (None, 1, 2, 3))
  274. F.broadcast_to(xx, (1, None, 2, 3))
  275. t = make_tensor(2, network)
  276. F.broadcast_to(xx, (t, None, 2, 3))
  277. @pytest.mark.parametrize("is_trace", [True, False])
  278. def test_reshape_on_empty_tensor(is_trace):
  279. input1_shape = (100, 0, 1)
  280. output1_shape = (100, 0, 10)
  281. data1 = Tensor(np.random.random(input1_shape).astype(np.float32))
  282. input2_shape = (10, 0)
  283. output2_shape = (0,)
  284. data2 = Tensor(np.random.random(input2_shape).astype(np.float32))
  285. input3_shape = (10, 0, 10)
  286. output3_shape = (0, 1, 2, 3)
  287. data3 = Tensor(np.random.random(input3_shape).astype(np.float32))
  288. def comp(out, target_shp):
  289. assert out._tuple_shape == target_shp
  290. def func(x, shp):
  291. return F.reshape(x, shp)
  292. cases = [
  293. [data1, output1_shape],
  294. [data2, output2_shape],
  295. [data3, output3_shape],
  296. ]
  297. def test(func, inp, comp, target_shp):
  298. out = func(inp, target_shp)
  299. comp(out, target_shp)
  300. if is_trace:
  301. for symbolic in [False, True]:
  302. for inp, target_shp in cases:
  303. func_traced = trace(symbolic=symbolic)(func)
  304. test(func_traced, inp, comp, target_shp)
  305. test(func_traced, inp, comp, target_shp)
  306. test(func_traced, inp, comp, target_shp)
  307. else:
  308. for inp, target_shp in cases:
  309. test(func, inp, comp, target_shp)
  310. @pytest.mark.parametrize("is_varnode", [True, False])
  311. def test_reshape_shape_inference(is_varnode):
  312. if is_varnode:
  313. network = Network()
  314. saved_symbolic_shape = set_symbolic_shape(False)
  315. else:
  316. network = None
  317. x_shape_known = make_tensor([1, 2, 3, 4], network)
  318. x_shape_unknown = F.broadcast_to(
  319. make_tensor([1.0], network), shape=make_tensor([1, 1, 1, 1], network).sum()
  320. )
  321. tshp_unknown = astensor1d(
  322. (make_tensor([2], network), make_tensor([2], network)), x_shape_known
  323. )
  324. tshp_known = astensor1d((2, 2), x_shape_known)
  325. tshp_known_unspec = astensor1d((2, -1), x_shape_known)
  326. def check_shape(output, target):
  327. source = output.shape
  328. if isinstance(source, Tensor):
  329. source = source.numpy()
  330. np.testing.assert_equal(source, target.shape)
  331. def func(x, target_shape):
  332. return x.reshape(target_shape)
  333. cases = [
  334. {"input": [x_shape_known, tshp_unknown], "output": [np.zeros((2, 2)),]},
  335. {"input": [x_shape_unknown, tshp_unknown], "output": [np.zeros((2, 2)),]},
  336. {"input": [x_shape_known, tshp_known], "output": [np.zeros((2, 2)),]},
  337. {"input": [x_shape_known, tshp_known_unspec], "output": [np.zeros((2, 2)),]},
  338. {"input": [x_shape_unknown, tshp_known], "output": [np.zeros((2, 2)),]},
  339. {"input": [x_shape_unknown, tshp_known_unspec], "output": [np.zeros((2, 2)),]},
  340. ]
  341. opr_test(cases, func, compare_fn=check_shape, test_trace=True, network=network)
  342. if is_varnode:
  343. set_symbolic_shape(saved_symbolic_shape)
  344. @pytest.mark.parametrize("is_varnode", [True, False])
  345. def test_squeeze(is_varnode):
  346. if is_varnode:
  347. network = Network()
  348. saved_symbolic_shape = set_symbolic_shape(False)
  349. else:
  350. network = None
  351. x = Tensor(np.array([1, 2], dtype=np.int32).reshape(1, 1, 2, 1))
  352. y = F.squeeze(x, -1)
  353. np.testing.assert_equal(y.numpy(), np.array([[[1, 2]]]).astype(np.int32))
  354. x = np.arange(6, dtype="float32").reshape(1, 2, 3, 1)
  355. xx = make_tensor(x, network)
  356. for axis in [None, 3, -4, (3, -4)]:
  357. y = np.squeeze(x, axis)
  358. yy = F.squeeze(xx, axis)
  359. np.testing.assert_equal(y, yy.numpy())
  360. if is_varnode:
  361. set_symbolic_shape(saved_symbolic_shape)
  362. @pytest.mark.parametrize("is_varnode", [True, False])
  363. def test_expand_dims(is_varnode):
  364. if is_varnode:
  365. network = Network()
  366. else:
  367. network = None
  368. x = Tensor(np.arange(1, 7, dtype=np.int32).reshape(2, 3))
  369. y = F.expand_dims(x, -1)
  370. np.testing.assert_equal(
  371. y.numpy(), np.array([[[1], [2], [3]], [[4], [5], [6]]]).astype(np.int32)
  372. )
  373. x = np.arange(6, dtype="float32").reshape(2, 3)
  374. xx = make_tensor(x, network)
  375. for axis in [2, -3, (3, -4), (1, -4)]:
  376. y = np.expand_dims(x, axis)
  377. yy = F.expand_dims(xx, axis)
  378. np.testing.assert_equal(y, yy.numpy())
  379. def test_expand_dims_for_scalar():
  380. x = np.array(1, dtype="float32")
  381. xx = make_tensor(x, None)
  382. for axis in [0, -1, (0, 1), (-1, -2), (0, -1)]:
  383. y = np.expand_dims(x, axis)
  384. yy = F.expand_dims(xx, axis)
  385. np.testing.assert_equal(y, yy.numpy())
  386. for axis in [1, -2, (1, 2), (-2, -3)]:
  387. np.testing.assert_raises(np.AxisError, np.expand_dims, x, axis)
  388. np.testing.assert_raises(RuntimeError, F.expand_dims, xx, axis)
  389. @pytest.mark.parametrize("is_varnode", [True, False])
  390. def test_elemwise_dtype_promotion(is_varnode):
  391. if is_varnode:
  392. network = Network()
  393. else:
  394. network = None
  395. x = np.random.rand(2, 3).astype("float32")
  396. y = np.random.rand(1, 3).astype("float16")
  397. xx = make_tensor(x, network)
  398. yy = make_tensor(y, network)
  399. z = xx * yy
  400. np.testing.assert_equal(z.numpy(), x * y)
  401. z = xx + y
  402. np.testing.assert_equal(z.numpy(), x + y)
  403. z = x - yy
  404. np.testing.assert_equal(z.numpy(), x - y)
  405. @pytest.mark.parametrize("is_varnode", [True, False])
  406. def test_linspace(is_varnode):
  407. if is_varnode:
  408. network = Network()
  409. else:
  410. network = None
  411. cases = [
  412. {"input": [1, 9, 9]},
  413. {"input": [3, 10, 8]},
  414. ]
  415. opr_test(
  416. cases,
  417. F.linspace,
  418. ref_fn=lambda start, end, step: np.linspace(start, end, step, dtype=np.float32),
  419. network=network,
  420. )
  421. cases = [
  422. {"input": [9, 1, 9]},
  423. {"input": [10, 3, 8]},
  424. ]
  425. opr_test(
  426. cases,
  427. F.linspace,
  428. ref_fn=lambda start, end, step: np.linspace(start, end, step, dtype=np.float32),
  429. network=network,
  430. )
  431. cases = [
  432. {"input": [1, make_tensor(9, network), 9]},
  433. {"input": [make_tensor(1, network), 9, make_tensor(9, network)]},
  434. ]
  435. opr_test(
  436. cases,
  437. F.linspace,
  438. ref_fn=lambda start, end, step: np.linspace(1, 9, 9, dtype=np.float32),
  439. network=network,
  440. )
  441. @pytest.mark.parametrize("is_varnode", [True, False])
  442. def test_arange(is_varnode):
  443. if is_varnode:
  444. network = Network()
  445. else:
  446. network = None
  447. cases = [
  448. {"input": [1, 9, 1]},
  449. {"input": [2, 10, 2]},
  450. ]
  451. opr_test(
  452. cases,
  453. F.arange,
  454. ref_fn=lambda start, end, step: np.arange(start, end, step, dtype=np.float32),
  455. network=network,
  456. )
  457. cases = [
  458. {"input": [9, 1, -1]},
  459. {"input": [10, 2, -2]},
  460. ]
  461. opr_test(
  462. cases,
  463. F.arange,
  464. ref_fn=lambda start, end, step: np.arange(start, end, step, dtype=np.float32),
  465. network=network,
  466. )
  467. cases = [
  468. {"input": [9.3, 1.2, -0.5]},
  469. {"input": [10.3, 2.1, -1.7]},
  470. ]
  471. opr_test(
  472. cases,
  473. F.arange,
  474. ref_fn=lambda start, end, step: np.arange(start, end, step, dtype=np.float32),
  475. network=network,
  476. )
  477. @pytest.mark.parametrize("is_varnode", [True, False])
  478. def test_round(is_varnode):
  479. if is_varnode:
  480. network = Network()
  481. else:
  482. network = None
  483. data1_shape = (15,)
  484. data2_shape = (25,)
  485. data1 = np.random.random(data1_shape).astype(np.float32)
  486. data2 = np.random.random(data2_shape).astype(np.float32)
  487. cases = [{"input": data1}, {"input": data2}]
  488. opr_test(cases, F.round, ref_fn=np.round, network=network)
  489. @pytest.mark.parametrize("is_varnode", [True, False])
  490. def test_flatten(is_varnode):
  491. if is_varnode:
  492. network = Network()
  493. else:
  494. network = None
  495. inp_shape = (2, 2, 3, 3)
  496. x = Tensor(np.arange(36, dtype=np.int32).reshape(inp_shape),)
  497. y = F.flatten(x, -2, -1)
  498. np.testing.assert_equal(
  499. y.numpy(),
  500. np.array(
  501. [
  502. [[0, 1, 2, 3, 4, 5, 6, 7, 8], [9, 10, 11, 12, 13, 14, 15, 16, 17]],
  503. [
  504. [18, 19, 20, 21, 22, 23, 24, 25, 26],
  505. [27, 28, 29, 30, 31, 32, 33, 34, 35],
  506. ],
  507. ]
  508. ).astype(np.int32),
  509. )
  510. data0_shape = (2, 3, 4, 5)
  511. data1_shape = (4, 5, 6, 7)
  512. data0 = np.random.random(data0_shape).astype(np.float32)
  513. data1 = np.random.random(data1_shape).astype(np.float32)
  514. cases = [
  515. {"input": data0, "output": data0.flatten()},
  516. {"input": data1, "output": data1.flatten()},
  517. ]
  518. opr_test(cases, F.flatten, network=network)
  519. cases = [
  520. {"input": data0, "output": data0.reshape(2, -1)},
  521. {"input": data1, "output": data1.reshape(4, -1)},
  522. ]
  523. opr_test(cases, F.flatten, start_axis=1, network=network)
  524. cases = [
  525. {"input": data0, "output": data0.reshape(2, 3, -1)},
  526. {"input": data1, "output": data1.reshape(4, 5, -1)},
  527. ]
  528. opr_test(cases, F.flatten, start_axis=2, network=network)
  529. cases = [
  530. {"input": data0, "output": data0.reshape(2, -1, 5)},
  531. {"input": data1, "output": data1.reshape(4, -1, 7)},
  532. ]
  533. opr_test(
  534. cases, F.flatten, start_axis=1, end_axis=2, network=network,
  535. )
  536. @pytest.mark.parametrize("is_varnode", [True, False])
  537. def test_broadcast(is_varnode):
  538. if is_varnode:
  539. network = Network()
  540. else:
  541. network = None
  542. input1_shape = (20, 30)
  543. output1_shape = (30, 20, 30)
  544. data1 = np.random.random(input1_shape).astype(np.float32)
  545. input2_shape = (10, 1)
  546. output2_shape = (20, 10, 20)
  547. data2 = np.random.random(input2_shape).astype(np.float32)
  548. input3_shape = (10, 10)
  549. output3_shape = (10, 10)
  550. data3 = np.random.random(input3_shape).astype(np.float32)
  551. cases = [
  552. {
  553. "input": [data1, output1_shape],
  554. "output": np.broadcast_to(data1, output1_shape),
  555. },
  556. {
  557. "input": [data2, output2_shape],
  558. "output": np.broadcast_to(data2, output2_shape),
  559. },
  560. {
  561. "input": [data3, output3_shape],
  562. "output": np.broadcast_to(data3, output3_shape),
  563. },
  564. ]
  565. opr_test(cases, F.broadcast_to, network=network)
  566. x = F.ones((2, 1, 3))
  567. with pytest.raises(RuntimeError):
  568. F.broadcast_to(x, (2, 3, 4))
  569. with pytest.raises(RuntimeError):
  570. F.broadcast_to(x, (4, 1, 3))
  571. with pytest.raises(RuntimeError):
  572. F.broadcast_to(x, (1, 3))
  573. @pytest.mark.parametrize("is_trace", [True, False])
  574. def test_broadcast_on_empty_tensor(is_trace):
  575. input1_shape = (100, 0, 1)
  576. output1_shape = (100, 0, 10)
  577. data1 = Tensor(np.random.random(input1_shape).astype(np.float32))
  578. input2_shape = (10, 0)
  579. output2_shape = (10, 10, 0)
  580. data2 = Tensor(np.random.random(input2_shape).astype(np.float32))
  581. input3_shape = (0, 0, 1, 10)
  582. output3_shape = (10, 0, 0, 10, 10)
  583. data3 = Tensor(np.random.random(input3_shape).astype(np.float32))
  584. def comp(out, target_shp):
  585. assert out._tuple_shape == target_shp
  586. def func(x, shp):
  587. return F.broadcast_to(x, shp)
  588. cases = [
  589. [data1, output1_shape],
  590. [data2, output2_shape],
  591. [data3, output3_shape],
  592. ]
  593. def test(func, inp, comp, target_shp):
  594. out = func(inp, target_shp)
  595. comp(out, target_shp)
  596. if is_trace:
  597. for symbolic in [False, True]:
  598. for inp, target_shp in cases:
  599. func_traced = trace(symbolic=symbolic)(func)
  600. test(func_traced, inp, comp, target_shp)
  601. test(func_traced, inp, comp, target_shp)
  602. test(func_traced, inp, comp, target_shp)
  603. else:
  604. for inp, target_shp in cases:
  605. test(func, inp, comp, target_shp)
  606. @pytest.mark.parametrize("is_varnode", [True, False])
  607. def test_utils_astensor1d(is_varnode):
  608. if is_varnode:
  609. network = Network()
  610. else:
  611. network = None
  612. reference = make_tensor(0, network)
  613. # literal
  614. x = [1, 2, 3]
  615. for dtype in [None, "float32"]:
  616. xx = astensor1d(x, reference, dtype=dtype)
  617. assert isinstance(xx, type(reference))
  618. np.testing.assert_equal(xx.numpy(), x)
  619. # numpy array
  620. x = np.asarray([1, 2, 3], dtype="int32")
  621. for dtype in [None, "float32"]:
  622. xx = astensor1d(x, reference, dtype=dtype)
  623. assert isinstance(xx, type(reference))
  624. np.testing.assert_equal(xx.numpy(), x.astype(dtype) if dtype else x)
  625. # tensor
  626. x = make_tensor([1, 2, 3], network)
  627. for dtype in [None, "float32"]:
  628. xx = astensor1d(x, reference, dtype=dtype)
  629. assert isinstance(xx, type(reference))
  630. np.testing.assert_equal(xx.numpy(), x.numpy())
  631. # mixed
  632. x = [1, make_tensor(2, network), 3]
  633. for dtype in [None, "float32"]:
  634. xx = astensor1d(x, reference, dtype=dtype)
  635. assert isinstance(xx, type(reference))
  636. np.testing.assert_equal(xx.numpy(), [1, 2, 3])
  637. # varnode
  638. if is_varnode:
  639. a = np.array([[1, 2, 3], [4, 5, 6]]).astype("float32")
  640. b = np.array([[True, False, True], [False, True, True]])
  641. aa = make_tensor(a, network)
  642. bb = make_tensor(b, network)
  643. x, y = F.cond_take(bb, aa)
  644. for dtype in [None, "float32"]:
  645. xx = astensor1d(x, reference, dtype=dtype)
  646. assert isinstance(xx, type(reference))
  647. np.testing.assert_equal(get_var_value(xx), get_var_value(x))
  648. def test_device():
  649. x = Tensor([1, 2, 3], dtype="float32")
  650. y1 = F.eye(x.shape, dtype="float32")
  651. y2 = F.eye(x.shape, dtype="float32", device=None)
  652. np.testing.assert_almost_equal(y1.numpy(), y2.numpy())
  653. y3 = F.eye(x.shape, dtype="float32", device="xpux")
  654. y4 = F.eye(x.shape, dtype="float32", device=x.device)
  655. np.testing.assert_almost_equal(y3.numpy(), y4.numpy())
  656. y5 = F.full((3, 2), 4, device=x.device)
  657. y6 = F.full((3, 2), 4, device="xpux")
  658. np.testing.assert_almost_equal(y5.numpy(), y6.numpy())
  659. @pytest.mark.parametrize("is_varnode", [True, False])
  660. def test_identity(is_varnode):
  661. if is_varnode:
  662. network = Network()
  663. else:
  664. network = None
  665. x = make_tensor(np.random.random((5, 10)).astype(np.float32), network)
  666. y = F.copy(x)
  667. np.testing.assert_equal(y.numpy(), x)
  668. def copy_test(dst, src, network):
  669. data = np.random.random((2, 3)).astype(np.float32)
  670. x = make_tensor(data, device=src, network=network)
  671. y = F.copy(x, dst)
  672. assert np.allclose(data, y.numpy())
  673. if network is None:
  674. z = x.to(dst)
  675. assert np.allclose(data, z.numpy())
  676. @pytest.mark.require_ngpu(1)
  677. @pytest.mark.parametrize("is_varnode", [True, False])
  678. def test_copy_h2d(is_varnode):
  679. if is_varnode:
  680. network = Network()
  681. else:
  682. network = None
  683. copy_test("cpu0", "gpu0", network=network)
  684. @pytest.mark.require_ngpu(1)
  685. @pytest.mark.parametrize("is_varnode", [True, False])
  686. def test_copy_d2h(is_varnode):
  687. if is_varnode:
  688. network = Network()
  689. else:
  690. network = None
  691. copy_test("gpu0", "cpu0", network=network)
  692. @pytest.mark.require_ngpu(2)
  693. @pytest.mark.parametrize("is_varnode", [True, False])
  694. def test_copy_d2d(is_varnode):
  695. if is_varnode:
  696. network = Network()
  697. else:
  698. network = None
  699. copy_test("gpu0", "gpu1", network=network)
  700. copy_test("gpu0:0", "gpu0:1", network=network)
  701. @pytest.mark.require_ngpu(2)
  702. @pytest.mark.parametrize(
  703. "shape, device_src, device_dst",
  704. [
  705. ((0,), "cpu0", "cpu0"),
  706. ((10, 0), "cpu0", "cpu1"),
  707. ((2, 0, 3), "cpu0", "gpu0"),
  708. ((1, 0, 1, 0), "gpu0", "cpu0"),
  709. ((2, 3, 4, 5, 0), "gpu0", "gpu1"),
  710. ],
  711. )
  712. @pytest.mark.parametrize("is_symbolic", [None, True, False])
  713. def test_copy_empty(shape, device_src, device_dst, is_symbolic):
  714. inp = Tensor(np.random.randn(*shape).astype("float32"), device=device_src)
  715. def func(inp):
  716. return F.copy(inp, device_dst)
  717. if is_symbolic is not None:
  718. func = trace(symbolic=is_symbolic)(func)
  719. for _ in range(3):
  720. out = func(inp)
  721. assert out.numpy().shape == shape
  722. assert out.device == device_dst
  723. if is_symbolic is None:
  724. break
  725. @pytest.mark.parametrize(
  726. "shape, repeats, axis",
  727. [
  728. ((2,), 2, 0),
  729. ((2, 3, 4, 5), 3, 0),
  730. ((2, 3, 4, 5), 4, 3),
  731. ((2,), 2, None),
  732. ((2, 3, 4, 5), 3, None),
  733. ((), 1, None),
  734. ((), 10, None),
  735. ],
  736. )
  737. @pytest.mark.parametrize("is_varnode", [True, False])
  738. def test_repeat(shape, repeats, axis, is_varnode):
  739. if is_varnode:
  740. network = Network()
  741. else:
  742. network = None
  743. def repeat_func(inp):
  744. return F.repeat(inp=inp, repeats=repeats, axis=axis)
  745. if shape != ():
  746. cases = [
  747. {"input": np.random.randn(*shape).astype("float32")},
  748. ]
  749. else:
  750. cases = [{"input": np.array(1.23)}]
  751. opr_test(
  752. cases,
  753. repeat_func,
  754. ref_fn=lambda inp: np.repeat(inp, repeats, axis),
  755. network=network,
  756. )
  757. @pytest.mark.parametrize(
  758. "shape, reps",
  759. [
  760. ((2,), (2,)),
  761. ((2, 3, 4, 5), (1, 1, 1, 1)),
  762. ((2, 3, 4, 5), (1, 2, 3, 4)),
  763. # FIXME: tile does not support ndim 7
  764. # ((2, 3, 4, 5), (2, 2, 2, 2, 2, 2, 2)),
  765. ],
  766. )
  767. @pytest.mark.parametrize("is_varnode", [True])
  768. def test_tile(shape, reps, is_varnode):
  769. if is_varnode:
  770. network = Network()
  771. else:
  772. network = None
  773. def tile_func(inp):
  774. return F.tile(inp=inp, reps=reps)
  775. cases = [{"input": np.random.randn(*shape).astype("float32")}]
  776. opr_test(cases, tile_func, ref_fn=lambda inp: np.tile(inp, reps), network=network)
  777. @pytest.mark.parametrize(
  778. "shape, shifts, axis",
  779. [
  780. ((2, 3), 0, None),
  781. ((2, 3), 1, 0),
  782. ((2, 3), 100, 0),
  783. ((2, 3), -100, 0),
  784. ((2, 3, 4, 5), (-1, 1), (0, 1)),
  785. ((2, 3, 4, 5), (-2, 1, 2), (1, 2, 3)),
  786. ],
  787. )
  788. @pytest.mark.parametrize("is_varnode", [True, False])
  789. def test_roll(shape, shifts, axis, is_varnode):
  790. if is_varnode:
  791. network = Network()
  792. else:
  793. network = None
  794. x = Tensor([[1, 2], [3, 4], [5, 6]], np.int32)
  795. y = F.roll(x, 1, -1)
  796. np.testing.assert_equal(
  797. y.numpy(), np.array([[2, 1], [4, 3], [6, 5]]).astype(np.int32)
  798. )
  799. inp = np.random.randn(*shape).astype("float32")
  800. def func(inp):
  801. return F.roll(inp, shifts, axis)
  802. cases = [
  803. {"input": inp},
  804. ]
  805. opr_test(
  806. cases, func, ref_fn=lambda inp: np.roll(inp, shifts, axis), network=network
  807. )
  808. @pytest.mark.parametrize(
  809. "shape, shifts, axis", [((10, 0), 5, 1), ((10, 0), -10, 1),],
  810. )
  811. @pytest.mark.parametrize("is_symbolic", [None, True, False])
  812. def test_roll_empty_tensor(shape, shifts, axis, is_symbolic):
  813. inp = Tensor(np.random.randn(*shape).astype("float32"))
  814. def func(inp):
  815. return F.roll(inp, shifts, axis)
  816. if is_symbolic is not None:
  817. func = trace(symbolic=is_symbolic)(func)
  818. out_ref = np.roll(inp.numpy(), shifts, axis)
  819. for _ in range(3):
  820. out = F.roll(inp, shifts, axis)
  821. np.testing.assert_equal(out.numpy(), out_ref)
  822. if is_symbolic is None:
  823. break