From 92f7cceb11c7d696d2e2cd80dcacf0178ff12cd0 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Fri, 28 Aug 2020 10:53:39 +0800 Subject: [PATCH] fix(mge/functional): fix cond_take when index is numpy array GitOrigin-RevId: 5fb93740f49ff1b6283ca8d0e30e5d417d66717d --- imperative/python/megengine/core/tensor/indexing.py | 2 +- imperative/python/test/unit/test_indexing_op.py | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/imperative/python/megengine/core/tensor/indexing.py b/imperative/python/megengine/core/tensor/indexing.py index 2c612c9b..8da5a66d 100644 --- a/imperative/python/megengine/core/tensor/indexing.py +++ b/imperative/python/megengine/core/tensor/indexing.py @@ -180,7 +180,7 @@ def try_condtake(tensor, index): if index.dtype != np.bool_ or index.shape != tensor.shape: return [] if isinstance(index, np.ndarray): - (i,) = Const(i, dtype=np.bool_, device=inp.device)(inp) + (index,) = Const(index, dtype=np.bool_, device=tensor.device)(tensor) assert isinstance(index, (TensorBase, TensorWrapperBase)) if not isinstance(tensor, (TensorWrapperBase, TensorBase)): raise TypeError("input must be a tensor") diff --git a/imperative/python/test/unit/test_indexing_op.py b/imperative/python/test/unit/test_indexing_op.py index 70b2911f..213819da 100644 --- a/imperative/python/test/unit/test_indexing_op.py +++ b/imperative/python/test/unit/test_indexing_op.py @@ -522,6 +522,7 @@ def test_advance_indexing_with_bool(): b = np.array([[False, False], [False, False]]) aa = Tensor(a) bb = Tensor(b) + np.testing.assert_equal(a[b], aa[b].numpy()) np.testing.assert_equal(a[b], aa[bb].numpy()) b = np.array([False, False])