diff --git a/dnn/src/cuda/indexing_multi_axis_vec/kern_apply_opr_impl.cuinl b/dnn/src/cuda/indexing_multi_axis_vec/kern_apply_opr_impl.cuinl index a640d865..de8ec033 100644 --- a/dnn/src/cuda/indexing_multi_axis_vec/kern_apply_opr_impl.cuinl +++ b/dnn/src/cuda/indexing_multi_axis_vec/kern_apply_opr_impl.cuinl @@ -72,6 +72,7 @@ namespace indexing_multi_axis_vec { #define cb0(_dtype) \ MEGDNN_FOREACH_TENSOR_NDIM(INST, DTypeTrait<_dtype>::ctype) MEGDNN_FOREACH_COMPUTING_DTYPE(cb0) + cb0(::megdnn::dtype::Bool) #undef cb0 #undef INST diff --git a/dnn/src/cuda/indexing_multi_axis_vec/kern_apply_opr_incr.cu b/dnn/src/cuda/indexing_multi_axis_vec/kern_apply_opr_incr.cu index 74874f23..cfefef12 100644 --- a/dnn/src/cuda/indexing_multi_axis_vec/kern_apply_opr_incr.cu +++ b/dnn/src/cuda/indexing_multi_axis_vec/kern_apply_opr_incr.cu @@ -39,6 +39,11 @@ __device__ void atomicAdd(megdnn::dt_int16 *, megdnn::dt_int16) { ((int*)0)[0] = 1; } +__device__ void atomicAdd(megdnn::dt_bool *, megdnn::dt_bool) { + __trap(); + ((int*)0)[0] = 1; +} + #define KERN_APPLY_OPR_OPR \ ::megdnn::cuda::indexing_multi_axis_vec::OprAtomicIncr #include "./kern_apply_opr_impl.cuinl" diff --git a/dnn/src/cuda/indexing_multi_axis_vec/opr_impl.cpp b/dnn/src/cuda/indexing_multi_axis_vec/opr_impl.cpp index 4e864905..85b83468 100644 --- a/dnn/src/cuda/indexing_multi_axis_vec/opr_impl.cpp +++ b/dnn/src/cuda/indexing_multi_axis_vec/opr_impl.cpp @@ -120,6 +120,7 @@ void ExecImpl::dispatch_exec() { case DTypeTrait<_dtype>::enumv: \ return dispatch_exec_ctype::ctype>(); MEGDNN_FOREACH_COMPUTING_DTYPE(cb) + cb(::megdnn::dtype::Bool) #undef cb default: megdnn_throw("bad dtype"); diff --git a/dnn/src/naive/indexing_multi_axis_vec/opr_impl.cpp b/dnn/src/naive/indexing_multi_axis_vec/opr_impl.cpp index 52de1335..16ca74b0 100644 --- a/dnn/src/naive/indexing_multi_axis_vec/opr_impl.cpp +++ b/dnn/src/naive/indexing_multi_axis_vec/opr_impl.cpp @@ -88,6 +88,7 @@ void dispatch_exec(HandleImpl *handle, } switch (data.layout.dtype.enumv()) { MEGDNN_FOREACH_COMPUTING_DTYPE(cb) + cb(::megdnn::dtype::Bool) default: megdnn_throw(megdnn_mangle("bad dtype")); } diff --git a/imperative/python/test/unit/core/test_indexing_op.py b/imperative/python/test/unit/core/test_indexing_op.py index e369ea08..80478d63 100644 --- a/imperative/python/test/unit/core/test_indexing_op.py +++ b/imperative/python/test/unit/core/test_indexing_op.py @@ -519,6 +519,18 @@ def test_advance_indexing_with_bool(): np.testing.assert_equal(a[b], aa[bb].numpy()) np.testing.assert_equal(a[:, [True, False]], aa[:, [True, False]].numpy()) + a = np.array([[True, False], [False, True]]) + b = np.array([1]) + aa = Tensor(a) + bb = Tensor(b) + np.testing.assert_equal(a[b], aa[bb].numpy()) + b = np.array([[True, True], [False, True]]) + bb = Tensor(b) + np.testing.assert_equal(a[b], aa[bb].numpy()) + a[b] = False + aa[bb] = False + np.testing.assert_equal(a, aa.numpy()) + # XXX: trace does not expect empty condtake tensor if not use_tensor_shape(): a = np.ones((2, 2), dtype=np.int32)