You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_indexing_op.py 17 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550
  1. # -*- coding: utf-8 -*-
  2. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  3. #
  4. # Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
  5. #
  6. # Unless required by applicable law or agreed to in writing,
  7. # software distributed under the License is distributed on an
  8. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  9. import collections
  10. import numpy as np
  11. import pytest
  12. import megengine.core.ops.builtin
  13. import megengine.core.tensor.raw_tensor
  14. from megengine.core._trace_option import use_tensor_shape
  15. from megengine.core.ops._internal import all_ops
  16. from megengine.core.tensor import Tensor
  17. from megengine.core.tensor.core import apply
  18. from megengine.core.tensor.raw_tensor import RawTensor, as_raw_tensor
  19. def cvt_to_shape_desc(val, inpvar, config=None):
  20. def as_tensor(val, device):
  21. assert device is not None, "can not infer device"
  22. # TODO: should copy to appropriate device
  23. val = as_raw_tensor(val, device=device)
  24. return val
  25. device = None
  26. if inpvar is not None:
  27. assert isinstance(inpvar, RawTensor)
  28. device = device or inpvar.device
  29. if config is not None:
  30. device = device or config.device
  31. if isinstance(val, RawTensor):
  32. return as_tensor(val, device)
  33. if not isinstance(val, collections.Iterable):
  34. val = [val]
  35. components = []
  36. on_host = True
  37. for i in val:
  38. if isinstance(i, RawTensor):
  39. on_host = False
  40. device = device or i.device
  41. else:
  42. assert isinstance(i, int), (
  43. "shape desc could contain either int or Tensor, got {}"
  44. " actually".format(repr(i))
  45. )
  46. components.append(i)
  47. assert components, "shape desc could not be empty"
  48. if on_host:
  49. shape = np.ascontiguousarray(components, dtype=np.int32)
  50. assert np.all(shape == components), "failed to convert to shape: {}".format(
  51. components
  52. )
  53. return as_tensor(shape, device)
  54. for idx, v in enumerate(components):
  55. if not isinstance(v, RawTensor):
  56. vi = int(v)
  57. assert vi == v, "could not convert {} to int".format(v)
  58. v = vi
  59. components[idx] = as_tensor(v, device)
  60. return invoke_op(all_oprs.Concat(axis=0), components)
  61. def canonize_reshape(inputs, *, config):
  62. src, tshape = inputs
  63. tshape = cvt_to_shape_desc(tshape, src, config)
  64. return src, tshape
  65. def canonize_inputs(inputs, *, config):
  66. """convert immediate numbers and SharedND to SymbolVar in inputs; at least
  67. one of the inputs must be SymbolVar, so comp node and comp graph can
  68. beinferred
  69. :return: list of converted vars
  70. """
  71. if (
  72. isinstance(inputs, (list, tuple))
  73. and len(inputs) == 1
  74. and isinstance(inputs[0], (list, tuple))
  75. ):
  76. # handle the case when a list is passed to a function with
  77. # variable-length argument (e.g. concat has signature concat(*inputs)
  78. # and is called with concat([a, b]))
  79. inputs = inputs[0]
  80. if isinstance(inputs, RawTensor):
  81. return [inputs]
  82. old_inputs = inputs
  83. inputs = []
  84. get_comp_node = None
  85. need_cvt = False
  86. for i in old_inputs:
  87. if isinstance(i, RawTensor):
  88. get_comp_node = lambda cn=i.device.to_c(): cn
  89. else:
  90. need_cvt = True
  91. inputs.append(i)
  92. if not need_cvt:
  93. return inputs
  94. if get_comp_node is None:
  95. def get_comp_node():
  96. return config.comp_node
  97. for idx, var in enumerate(inputs):
  98. if not isinstance(var, RawTensor):
  99. var = as_raw_tensor(var)
  100. inputs[idx] = var
  101. return inputs
  102. def invoke_op(op, inputs_, cvt_inputs=canonize_inputs):
  103. inputs = cvt_inputs(
  104. inputs_, config=megengine.core._imperative_rt.OperatorNodeConfig()
  105. )
  106. return apply(op, *inputs)
  107. def unpack_getitem(inp, tuple_val, *, allow_newaxis=True):
  108. assert isinstance(inp, RawTensor)
  109. if not isinstance(tuple_val, tuple):
  110. tuple_val = (tuple_val,)
  111. def as_tensor(v):
  112. if not isinstance(v, RawTensor):
  113. vi = np.ascontiguousarray(v, dtype=np.int32)
  114. assert np.abs(vi - v).max() == 0, "bad index: {!r}".format(v)
  115. v = as_raw_tensor(vi)
  116. return v
  117. new_axes = []
  118. tensors = []
  119. items = []
  120. cur_axis = -1
  121. for i_idx, i in enumerate(tuple_val):
  122. cur_axis += 1
  123. if i is np.newaxis:
  124. if cur_axis >= 0:
  125. new_axes.append(cur_axis)
  126. continue
  127. if i is Ellipsis:
  128. cur_axis = -1
  129. for j in tuple_val[:i_idx:-1]:
  130. if j is Ellipsis:
  131. raise IndexError("only one ellipsis is allowed")
  132. if j is np.newaxis:
  133. new_axes.append(cur_axis)
  134. cur_axis -= 1
  135. continue
  136. item = [
  137. cur_axis,
  138. ]
  139. def push(v, item, tensors):
  140. if v is None:
  141. item.append(False)
  142. else:
  143. item.append(True)
  144. tensors.append(as_tensor(v))
  145. if isinstance(i, slice):
  146. if i.start is None and i.stop is None and i.step is None:
  147. continue
  148. push(i.start, item, tensors)
  149. push(i.stop, item, tensors)
  150. push(i.step, item, tensors)
  151. item.append(False) # idx
  152. else:
  153. item += [False,] * 3 # begin, end, stop
  154. push(i, item, tensors)
  155. assert len(item) == 5
  156. items.append(item)
  157. if new_axes:
  158. raise IndexError("newaxis is not allowed here")
  159. return inp, tensors, items
  160. def dimshuffle(*args, **kwargs):
  161. op = all_ops.Dimshuffle(**kwargs).to_c()
  162. return invoke_op(op, args)
  163. def broadcast(input, tshape):
  164. op = all_ops.Broadcast().to_c()
  165. return invoke_op(op, (input, tshape), canonize_reshape)
  166. def subtensor(input, tuple_val):
  167. input, tensors, items = unpack_getitem(input, tuple_val)
  168. op = all_ops.Subtensor(items).to_c()
  169. return invoke_op(op, (input, *tensors))
  170. def set_subtensor(input, value, tuple_val):
  171. input, tensors, items = unpack_getitem(input, tuple_val)
  172. op = all_ops.SetSubtensor(items).to_c()
  173. return invoke_op(op, (input, value, *tensors))
  174. def incr_subtensor(input, value, tuple_val):
  175. input, tensors, items = unpack_getitem(input, tuple_val)
  176. op = all_ops.IncrSubtensor(items).to_c()
  177. return invoke_op(op, (input, value, *tensors))
  178. def advance_indexing(input, tuple_val):
  179. input, tensors, items = unpack_getitem(input, tuple_val)
  180. op = all_ops.IndexingMultiAxisVec(items).to_c()
  181. return invoke_op(op, (input, *tensors))
  182. def set_advance_indexing(input, value, tuple_val):
  183. input, tensors, items = unpack_getitem(input, tuple_val)
  184. op = all_ops.IndexingSetMultiAxisVec(items).to_c()
  185. return invoke_op(op, (input, value, *tensors))
  186. def incr_advance_indexing(input, value, tuple_val):
  187. input, tensors, items = unpack_getitem(input, tuple_val)
  188. op = all_ops.IndexingIncrMultiAxisVec(items).to_c()
  189. return invoke_op(op, (input, value, *tensors))
  190. def mesh_indexing(input, tuple_val):
  191. input, tensors, items = unpack_getitem(input, tuple_val)
  192. op = all_ops.MeshIndexing(items).to_c()
  193. return invoke_op(op, (input, *tensors))
  194. def set_mesh_indexing(input, value, tuple_val):
  195. input, tensors, items = unpack_getitem(input, tuple_val)
  196. op = all_ops.SetMeshIndexing(items).to_c()
  197. return invoke_op(op, (input, value, *tensors))
  198. def incr_mesh_indexing(input, value, tuple_val):
  199. input, tensors, items = unpack_getitem(input, tuple_val)
  200. op = all_ops.IncrMeshIndexing(items).to_c()
  201. return invoke_op(op, (input, value, *tensors))
  202. def batched_mesh_indexing(input, tuple_val):
  203. input, tensors, items = unpack_getitem(input, tuple_val)
  204. op = all_ops.BatchedMeshIndexing(items).to_c()
  205. return invoke_op(op, (input, *tensors))
  206. def batched_set_mesh_indexing(input, value, tuple_val):
  207. input, tensors, items = unpack_getitem(input, tuple_val)
  208. op = all_ops.BatchedSetMeshIndexing(items).to_c()
  209. return invoke_op(op, (input, value, *tensors))
  210. def batched_incr_mesh_indexing(input, value, tuple_val):
  211. input, tensors, items = unpack_getitem(input, tuple_val)
  212. op = all_ops.BatchedIncrMeshIndexing(items).to_c()
  213. return invoke_op(op, (input, value, *tensors))
  214. def test_dimshuffle():
  215. x = np.arange(10).reshape(2, 5).astype("int32")
  216. xx = as_raw_tensor(x)
  217. (yy,) = dimshuffle(xx, pattern="1x0")
  218. np.testing.assert_equal(np.expand_dims(x.transpose(), axis=1), yy.numpy())
  219. def test_broadcast():
  220. x = np.arange(10).reshape(1, 10).astype("int32")
  221. xx = as_raw_tensor(x)
  222. (yy,) = broadcast(xx, (10, 10))
  223. np.testing.assert_equal(np.repeat(x, 10, 0), yy.numpy())
  224. def test_subtensor():
  225. x = np.arange(25).reshape(5, 5).astype("int32")
  226. d = np.arange(2).astype("int32")
  227. xx = as_raw_tensor(x)
  228. (yy0,) = subtensor(xx, (slice(0, 4, 2), 3))
  229. (yy1,) = set_subtensor(xx, d, (slice(0, 4, 2), 3))
  230. (yy2,) = incr_subtensor(xx, d, (slice(0, 4, 2), 3))
  231. np.testing.assert_equal(x[0:4:2, 3], yy0.numpy())
  232. x_ = x.copy()
  233. x_[0:4:2, 3] = d
  234. np.testing.assert_equal(x_, yy1.numpy())
  235. x_ = x.copy()
  236. x_[0:4:2, 3] += d
  237. np.testing.assert_equal(x_, yy2.numpy())
  238. def test_advance_indexing():
  239. x = np.arange(25).reshape(5, 5).astype("int32")
  240. d = np.arange(15).reshape(3, 5).astype("int32")
  241. xx = as_raw_tensor(x)
  242. (yy0,) = advance_indexing(xx, ((0, 4, 2), slice(None, None, None)))
  243. (yy1,) = set_advance_indexing(xx, d, ((0, 4, 2), slice(None, None, None)))
  244. (yy2,) = incr_advance_indexing(xx, d, ((0, 4, 2), slice(None, None, None)))
  245. np.testing.assert_equal(x[(0, 4, 2), :], yy0.numpy())
  246. x_ = x.copy()
  247. x_[(0, 4, 2), :] = d
  248. np.testing.assert_equal(x_, yy1.numpy())
  249. x_ = x.copy()
  250. x_[(0, 4, 2), :] += d
  251. np.testing.assert_equal(x_, yy2.numpy())
  252. def test_mesh_indexing():
  253. x = np.arange(25).reshape(5, 5).astype("int32")
  254. d = np.arange(6).reshape(3, 2).astype("int32")
  255. xx = as_raw_tensor(x)
  256. (yy0,) = mesh_indexing(xx, (slice(0, 5, 2), (1, 3)))
  257. (yy1,) = set_mesh_indexing(xx, d, (slice(0, 5, 2), (1, 3)))
  258. (yy2,) = incr_mesh_indexing(xx, d, (slice(0, 5, 2), (1, 3)))
  259. r = np.ndarray(shape=(3, 2), dtype="int32")
  260. for i0, i1 in enumerate(range(0, 5, 2)):
  261. for j0, j1 in enumerate((1, 3)):
  262. r[i0, j0] = x[i1, j1]
  263. np.testing.assert_equal(r, yy0.numpy())
  264. r = x.copy()
  265. for i0, i1 in enumerate(range(0, 5, 2)):
  266. for j0, j1 in enumerate((1, 3)):
  267. r[i1, j1] = d[i0, j0]
  268. np.testing.assert_equal(r, yy1.numpy())
  269. r = x.copy()
  270. for i0, i1 in enumerate(range(0, 5, 2)):
  271. for j0, j1 in enumerate((1, 3)):
  272. r[i1, j1] += d[i0, j0]
  273. np.testing.assert_equal(r, yy2.numpy())
  274. def test_batched_mesh_indexing():
  275. x = np.arange(24).reshape(2, 3, 4).astype("int32")
  276. d = np.arange(12).reshape(2, 2, 3).astype("int32")
  277. xx = as_raw_tensor(x)
  278. s = [(0, 1, 2), (1, 2, 3)]
  279. (yy0,) = batched_mesh_indexing(xx, (slice(None, None, None), [(0, 2)] * 2, s))
  280. (yy1,) = batched_set_mesh_indexing(
  281. xx, d, (slice(None, None, None), [(0, 2)] * 2, s)
  282. )
  283. (yy2,) = batched_incr_mesh_indexing(
  284. xx, d, (slice(None, None, None), [(0, 2)] * 2, s)
  285. )
  286. r = np.ndarray(shape=(2, 2, 3), dtype="int32")
  287. for i in range(2):
  288. for j0, j1 in enumerate((0, 2)):
  289. for k0, k1 in enumerate(s[i]):
  290. r[i, j0, k0] = x[i, j1, k1]
  291. np.testing.assert_equal(r, yy0.numpy())
  292. r = x.copy()
  293. for i in range(2):
  294. for j0, j1 in enumerate((0, 2)):
  295. for k0, k1 in enumerate(s[i]):
  296. r[i, j1, k1] = d[i, j0, k0]
  297. np.testing.assert_equal(r, yy1.numpy())
  298. r = x.copy()
  299. for i in range(2):
  300. for j0, j1 in enumerate((0, 2)):
  301. for k0, k1 in enumerate(s[i]):
  302. r[i, j1, k1] += d[i, j0, k0]
  303. np.testing.assert_equal(r, yy2.numpy())
  304. # high level
  305. def test_advance_indexing_high_level():
  306. x = np.arange(25).reshape(5, 5).astype("int32")
  307. d = np.arange(15).reshape(3, 5).astype("int32")
  308. xx = Tensor(x)
  309. np.testing.assert_equal(x[1, :], xx[1, :].numpy())
  310. np.testing.assert_equal(x[:, 1], xx[:, 1].numpy())
  311. np.testing.assert_equal(x[1:3, :], xx[1:3, :].numpy())
  312. np.testing.assert_equal(x[:, :], xx[:, :].numpy())
  313. np.testing.assert_equal(x[1, 1], xx[1, 1].numpy())
  314. yy = xx[(0, 4, 2), :]
  315. np.testing.assert_equal(x[(0, 4, 2), :], yy.numpy())
  316. x_ = x.copy()
  317. x_[(0, 4, 2), :] = d
  318. xx_ = Tensor(xx)
  319. xx_[(0, 4, 2), :] = d
  320. np.testing.assert_equal(x_, xx_.numpy())
  321. x = np.arange(27).reshape(3, 3, 3).astype("int32")
  322. xx = Tensor(x)
  323. np.testing.assert_equal(x[1, :, :], xx[1, :, :].numpy())
  324. np.testing.assert_equal(x[1, :, 1], xx[1, :, 1].numpy())
  325. np.testing.assert_equal(x[1, 0:1, :], xx[1, 0:1, :].numpy())
  326. np.testing.assert_equal(x[0:1, 1, 1], xx[0:1, 1, 1].numpy())
  327. np.testing.assert_equal(x[:, 1, 1], xx[:, 1, 1].numpy())
  328. np.testing.assert_equal(x[:, 1], xx[:, 1].numpy())
  329. np.testing.assert_equal(x[1, 1:2], xx[1, 1:2].numpy())
  330. x_ = x.copy()
  331. x_[1, 1, 1] = -1
  332. xx[1, 1, 1] = -1
  333. np.testing.assert_equal(x_, xx.numpy())
  334. x_[:, 1, 1] = -2
  335. xx[:, 1, 1] = x_[:, 1, 1]
  336. np.testing.assert_equal(x_, xx.numpy())
  337. x_[0:1, :, 1] = -3
  338. xx[0:1, :, 1] = x_[0:1, :, 1]
  339. np.testing.assert_equal(x_, xx.numpy())
  340. x_[0:1, :, 1] = -4
  341. y = Tensor(x_)
  342. xx[0:1, :, 1] = y[0:1, :, 1]
  343. np.testing.assert_equal(y.numpy(), xx.numpy())
  344. x[:] = 1
  345. xx[:] = 1
  346. np.testing.assert_equal(x, xx.numpy())
  347. x = np.arange(9).reshape(3, 3).astype("int32")
  348. xx = Tensor(x)
  349. y = np.array([1, 2])
  350. yy = Tensor(y)
  351. np.testing.assert_equal(x[:, y[0]], xx[:, y[0]].numpy())
  352. # np.testing.assert_equal(x[:, y[0]], xx[:, yy[0]].numpy()) # FIXME
  353. np.testing.assert_equal(x[:, y], xx[:, y].numpy())
  354. np.testing.assert_equal(x[:, y], xx[:, yy].numpy())
  355. x_ = x.copy()
  356. x_[:, y[0]] = -1
  357. xx_ = Tensor(x_)
  358. xx[:, yy[0]] = xx_[:, yy[0]]
  359. np.testing.assert_equal(x_, xx.numpy())
  360. x_[:, y] = -1
  361. xx_ = Tensor(x_)
  362. xx[:, yy] = xx_[:, yy]
  363. np.testing.assert_equal(x_, xx.numpy())
  364. x = np.arange(9).reshape(3, 3).astype("int32")
  365. xx = Tensor(x)
  366. y = np.array([1])
  367. yy = Tensor(y)
  368. np.testing.assert_equal(x[:, y[0]], xx[:, y[0]].numpy())
  369. # np.testing.assert_equal(x[:, y[0]], xx[:, yy[0]].numpy()) # FIXME
  370. np.testing.assert_equal(x[:, y], xx[:, y].numpy())
  371. # XXX: no way to tell whether yy is scalar or ndim=1 array
  372. np.testing.assert_equal(x[:, y], xx[:, yy].numpy())
  373. x = np.arange(9).reshape(3, 3).astype("int32")
  374. xx = Tensor(x)
  375. np.testing.assert_equal(x[[0, 1], 0], xx[[0, 1], 0].numpy())
  376. np.testing.assert_equal(x[0:2, 0], xx[0:2, 0].numpy())
  377. def test_advance_indexing_with_bool():
  378. a = np.arange(9).reshape(3, 3).astype(np.float32)
  379. b = np.array([1, 2, 3])
  380. c = np.array([1, 2, 3])
  381. aa = Tensor(a)
  382. bb = Tensor(b)
  383. cc = Tensor(c)
  384. np.testing.assert_equal(a[b == 1, c == 2], aa[bb == 1, cc == 2].numpy())
  385. a[b == 1, c == 2] = -1.0
  386. aa[bb == 1, cc == 2] = -1.0
  387. np.testing.assert_equal(a, aa.numpy())
  388. a = np.arange(9).reshape(3, 3).astype(np.float32)
  389. b = np.array([False, True, True])
  390. c = np.array([2, 0]).astype(np.int32)
  391. aa = Tensor(a)
  392. bb = Tensor(b)
  393. cc = Tensor(c)
  394. np.testing.assert_equal(a[b, c], aa[bb, cc].numpy())
  395. a[b, c] = -1.0
  396. aa[bb, cc] = -1.0
  397. np.testing.assert_equal(a, aa.numpy())
  398. d = np.array([-1, -2], dtype=np.float32)
  399. dd = Tensor(d)
  400. a[b, c] = d
  401. aa[bb, cc] = dd
  402. np.testing.assert_equal(a, aa.numpy())
  403. a = np.ones((2, 2))
  404. b = np.array([[True, False], [False, True]])
  405. aa = Tensor(a)
  406. bb = Tensor(b)
  407. np.testing.assert_equal(a[b], aa[bb].numpy())
  408. b[:] = True
  409. bb[:] = True
  410. np.testing.assert_equal(a[b], aa[bb].numpy())
  411. np.testing.assert_equal(a[:, [True, False]], aa[:, [True, False]].numpy())
  412. # XXX: trace does not expect empty condtake tensor
  413. if not use_tensor_shape():
  414. a = np.ones((2, 2), dtype=np.int32)
  415. b = np.array([[False, False], [False, False]])
  416. aa = Tensor(a)
  417. bb = Tensor(b)
  418. np.testing.assert_equal(a[b], aa[b].numpy())
  419. np.testing.assert_equal(a[b], aa[bb].numpy())
  420. b = np.array([False, False])
  421. bb = Tensor(b)
  422. np.testing.assert_equal(a[b], aa[bb].numpy().reshape(a[b].shape)) # FIXME
  423. a = np.arange(576).reshape(2, 3, 4, 3, 4, 2).astype("int32")
  424. aa = Tensor(a)
  425. b = (np.random.sample((2, 3, 4)) > 0.5).astype("bool")
  426. bb = Tensor(b)
  427. np.testing.assert_equal(a[b, :, 0:4:2], aa[bb, :, 0:4:2].numpy())
  428. b = (np.random.sample((4, 3, 4)) > 0.5).astype("bool")
  429. bb = Tensor(b)
  430. np.testing.assert_equal(a[..., b, 0:2], aa[..., bb, 0:2].numpy())
  431. b = (np.random.sample((3, 4, 3)) > 0.5).astype("bool")
  432. bb = Tensor(b)
  433. np.testing.assert_equal(
  434. a[:, b, 0:2, [True, False]], aa[:, bb, 0:2, [True, False]].numpy()
  435. )

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台