Browse Source

fix(dnn): fix bool cvt

GitOrigin-RevId: 2f883dcbe0
tags/v1.0.0-rc1
Megvii Engine Team Xinran Xu 4 years ago
parent
commit
96ec586d28
4 changed files with 25 additions and 4 deletions
  1. +1
    -0
      dnn/src/cuda/cond_take/opr_impl.cpp
  2. +2
    -4
      dnn/src/cuda/type_cvt/kern.cu
  3. +1
    -0
      dnn/src/cuda/type_cvt/opr_impl.cpp
  4. +21
    -0
      src/opr/test/basic_arith/others.cpp

+ 1
- 0
dnn/src/cuda/cond_take/opr_impl.cpp View File

@@ -59,6 +59,7 @@ CondTakeImpl::Output CondTakeImpl::exec(
break; \
}
MEGDNN_FOREACH_COMPUTING_DTYPE(cb)
cb(::megdnn::dtype::Bool)
#undef cb
default:
megdnn_throw("bad mask dtype");


+ 2
- 4
dnn/src/cuda/type_cvt/kern.cu View File

@@ -111,8 +111,7 @@ struct TypeCvtOpFromQuantized<
ctype_dest, ctype_src,
typename std::enable_if<
std::is_same<ctype_src, dt_qint8>::value ||
std::is_same<ctype_src, dt_quint8>::value ||
std::is_same<ctype_src, dt_bool>::value>::type> {
std::is_same<ctype_src, dt_quint8>::value>::type> {
ctype_dest* dest;
CudaDTypeParam<ctype_src> param;
using src_vect_type = typename VectTypeTrait<ctype_src>::vect_type;
@@ -140,8 +139,7 @@ struct TypeCvtOpBetweenQuantized<
ctype_dest, ctype_src,
typename std::enable_if<
std::is_same<ctype_src, dt_qint8>::value ||
std::is_same<ctype_src, dt_quint8>::value ||
std::is_same<ctype_src, dt_bool>::value>::type> {
std::is_same<ctype_src, dt_quint8>::value>::type> {
ctype_dest* dest;
CudaDTypeParam<ctype_src> src_param;
CudaDTypeParam<ctype_dest> dst_param;


+ 1
- 0
dnn/src/cuda/type_cvt/opr_impl.cpp View File

@@ -109,6 +109,7 @@ void TypeCvtImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_out dst) {
return; \
}
MEGDNN_FOREACH_COMPUTING_DTYPE(cb)
cb(::megdnn::dtype::Bool)
#undef cb
default:
megdnn_assert_internal(0);


+ 21
- 0
src/opr/test/basic_arith/others.cpp View File

@@ -568,6 +568,27 @@ TEST(TestOprBasicArith, TypeCvtBool) {
ASSERT_EQ(TensorShape({3}), host_y.shape());
}

TEST(TestOprBasicArith, TypeCvtFromBool) {
auto graph = ComputingGraph::make();
HostTensorGenerator<dtype::Bool> gen;
auto host_x = gen({2});
auto px = host_x->ptr<bool>();
px[0] = true;
px[1] = false;

auto x = opr::Host2DeviceCopy::make(*graph, host_x),
y = opr::TypeCvt::make(x, dtype::Int32{});
HostTensorND host_y;
auto func = graph->compile({make_callback_copy(y, host_y)});
func->execute();

auto py = host_y.ptr<int>();
for (size_t i = 0;i < 2;i ++) {
ASSERT_EQ(static_cast<int>(px[i]), py[i]);
}
ASSERT_EQ(TensorShape({2}), host_y.shape());
}

TEST(TestOprBasicArith, ElemwiseMemFwd) {
auto graph = ComputingGraph::make();
graph->options().graph_opt_level = 0;


Loading…
Cancel
Save