Browse Source

fix(dnn): support bool for IndexingMultiAxisVec

GitOrigin-RevId: ddcfaa06b0
release-1.1
Megvii Engine Team 4 years ago
parent
commit
912d733ea9
5 changed files with 20 additions and 0 deletions
  1. +1
    -0
      dnn/src/cuda/indexing_multi_axis_vec/kern_apply_opr_impl.cuinl
  2. +5
    -0
      dnn/src/cuda/indexing_multi_axis_vec/kern_apply_opr_incr.cu
  3. +1
    -0
      dnn/src/cuda/indexing_multi_axis_vec/opr_impl.cpp
  4. +1
    -0
      dnn/src/naive/indexing_multi_axis_vec/opr_impl.cpp
  5. +12
    -0
      imperative/python/test/unit/core/test_indexing_op.py

+ 1
- 0
dnn/src/cuda/indexing_multi_axis_vec/kern_apply_opr_impl.cuinl View File

@@ -72,6 +72,7 @@ namespace indexing_multi_axis_vec {
#define cb0(_dtype) \
MEGDNN_FOREACH_TENSOR_NDIM(INST, DTypeTrait<_dtype>::ctype)
MEGDNN_FOREACH_COMPUTING_DTYPE(cb0)
cb0(::megdnn::dtype::Bool)
#undef cb0
#undef INST



+ 5
- 0
dnn/src/cuda/indexing_multi_axis_vec/kern_apply_opr_incr.cu View File

@@ -39,6 +39,11 @@ __device__ void atomicAdd(megdnn::dt_int16 *, megdnn::dt_int16) {
((int*)0)[0] = 1;
}

__device__ void atomicAdd(megdnn::dt_bool *, megdnn::dt_bool) {
__trap();
((int*)0)[0] = 1;
}

#define KERN_APPLY_OPR_OPR \
::megdnn::cuda::indexing_multi_axis_vec::OprAtomicIncr
#include "./kern_apply_opr_impl.cuinl"


+ 1
- 0
dnn/src/cuda/indexing_multi_axis_vec/opr_impl.cpp View File

@@ -120,6 +120,7 @@ void ExecImpl<Opr>::dispatch_exec() {
case DTypeTrait<_dtype>::enumv: \
return dispatch_exec_ctype<DTypeTrait<_dtype>::ctype>();
MEGDNN_FOREACH_COMPUTING_DTYPE(cb)
cb(::megdnn::dtype::Bool)
#undef cb
default:
megdnn_throw("bad dtype");


+ 1
- 0
dnn/src/naive/indexing_multi_axis_vec/opr_impl.cpp View File

@@ -88,6 +88,7 @@ void dispatch_exec(HandleImpl *handle,
}
switch (data.layout.dtype.enumv()) {
MEGDNN_FOREACH_COMPUTING_DTYPE(cb)
cb(::megdnn::dtype::Bool)
default:
megdnn_throw(megdnn_mangle("bad dtype"));
}


+ 12
- 0
imperative/python/test/unit/core/test_indexing_op.py View File

@@ -519,6 +519,18 @@ def test_advance_indexing_with_bool():
np.testing.assert_equal(a[b], aa[bb].numpy())
np.testing.assert_equal(a[:, [True, False]], aa[:, [True, False]].numpy())

a = np.array([[True, False], [False, True]])
b = np.array([1])
aa = Tensor(a)
bb = Tensor(b)
np.testing.assert_equal(a[b], aa[bb].numpy())
b = np.array([[True, True], [False, True]])
bb = Tensor(b)
np.testing.assert_equal(a[b], aa[bb].numpy())
a[b] = False
aa[bb] = False
np.testing.assert_equal(a, aa.numpy())

# XXX: trace does not expect empty condtake tensor
if not use_tensor_shape():
a = np.ones((2, 2), dtype=np.int32)


Loading…
Cancel
Save