Browse Source

fix(mge/functional): fix cond_take when index is numpy array

GitOrigin-RevId: 5fb93740f4
tags/v1.0.0-rc1
Megvii Engine Team 4 years ago
parent
commit
92f7cceb11
2 changed files with 2 additions and 1 deletions
  1. +1
    -1
      imperative/python/megengine/core/tensor/indexing.py
  2. +1
    -0
      imperative/python/test/unit/test_indexing_op.py

+ 1
- 1
imperative/python/megengine/core/tensor/indexing.py View File

@@ -180,7 +180,7 @@ def try_condtake(tensor, index):
if index.dtype != np.bool_ or index.shape != tensor.shape:
return []
if isinstance(index, np.ndarray):
(i,) = Const(i, dtype=np.bool_, device=inp.device)(inp)
(index,) = Const(index, dtype=np.bool_, device=tensor.device)(tensor)
assert isinstance(index, (TensorBase, TensorWrapperBase))
if not isinstance(tensor, (TensorWrapperBase, TensorBase)):
raise TypeError("input must be a tensor")


+ 1
- 0
imperative/python/test/unit/test_indexing_op.py View File

@@ -522,6 +522,7 @@ def test_advance_indexing_with_bool():
b = np.array([[False, False], [False, False]])
aa = Tensor(a)
bb = Tensor(b)
np.testing.assert_equal(a[b], aa[b].numpy())
np.testing.assert_equal(a[b], aa[bb].numpy())

b = np.array([False, False])


Loading…
Cancel
Save