@@ -59,6 +59,7 @@ CondTakeImpl::Output CondTakeImpl::exec( | |||||
break; \ | break; \ | ||||
} | } | ||||
MEGDNN_FOREACH_COMPUTING_DTYPE(cb) | MEGDNN_FOREACH_COMPUTING_DTYPE(cb) | ||||
cb(::megdnn::dtype::Bool) | |||||
#undef cb | #undef cb | ||||
default: | default: | ||||
megdnn_throw("bad mask dtype"); | megdnn_throw("bad mask dtype"); | ||||
@@ -111,8 +111,7 @@ struct TypeCvtOpFromQuantized< | |||||
ctype_dest, ctype_src, | ctype_dest, ctype_src, | ||||
typename std::enable_if< | typename std::enable_if< | ||||
std::is_same<ctype_src, dt_qint8>::value || | std::is_same<ctype_src, dt_qint8>::value || | ||||
std::is_same<ctype_src, dt_quint8>::value || | |||||
std::is_same<ctype_src, dt_bool>::value>::type> { | |||||
std::is_same<ctype_src, dt_quint8>::value>::type> { | |||||
ctype_dest* dest; | ctype_dest* dest; | ||||
CudaDTypeParam<ctype_src> param; | CudaDTypeParam<ctype_src> param; | ||||
using src_vect_type = typename VectTypeTrait<ctype_src>::vect_type; | using src_vect_type = typename VectTypeTrait<ctype_src>::vect_type; | ||||
@@ -140,8 +139,7 @@ struct TypeCvtOpBetweenQuantized< | |||||
ctype_dest, ctype_src, | ctype_dest, ctype_src, | ||||
typename std::enable_if< | typename std::enable_if< | ||||
std::is_same<ctype_src, dt_qint8>::value || | std::is_same<ctype_src, dt_qint8>::value || | ||||
std::is_same<ctype_src, dt_quint8>::value || | |||||
std::is_same<ctype_src, dt_bool>::value>::type> { | |||||
std::is_same<ctype_src, dt_quint8>::value>::type> { | |||||
ctype_dest* dest; | ctype_dest* dest; | ||||
CudaDTypeParam<ctype_src> src_param; | CudaDTypeParam<ctype_src> src_param; | ||||
CudaDTypeParam<ctype_dest> dst_param; | CudaDTypeParam<ctype_dest> dst_param; | ||||
@@ -109,6 +109,7 @@ void TypeCvtImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_out dst) { | |||||
return; \ | return; \ | ||||
} | } | ||||
MEGDNN_FOREACH_COMPUTING_DTYPE(cb) | MEGDNN_FOREACH_COMPUTING_DTYPE(cb) | ||||
cb(::megdnn::dtype::Bool) | |||||
#undef cb | #undef cb | ||||
default: | default: | ||||
megdnn_assert_internal(0); | megdnn_assert_internal(0); | ||||
@@ -568,6 +568,27 @@ TEST(TestOprBasicArith, TypeCvtBool) { | |||||
ASSERT_EQ(TensorShape({3}), host_y.shape()); | ASSERT_EQ(TensorShape({3}), host_y.shape()); | ||||
} | } | ||||
TEST(TestOprBasicArith, TypeCvtFromBool) { | |||||
auto graph = ComputingGraph::make(); | |||||
HostTensorGenerator<dtype::Bool> gen; | |||||
auto host_x = gen({2}); | |||||
auto px = host_x->ptr<bool>(); | |||||
px[0] = true; | |||||
px[1] = false; | |||||
auto x = opr::Host2DeviceCopy::make(*graph, host_x), | |||||
y = opr::TypeCvt::make(x, dtype::Int32{}); | |||||
HostTensorND host_y; | |||||
auto func = graph->compile({make_callback_copy(y, host_y)}); | |||||
func->execute(); | |||||
auto py = host_y.ptr<int>(); | |||||
for (size_t i = 0;i < 2;i ++) { | |||||
ASSERT_EQ(static_cast<int>(px[i]), py[i]); | |||||
} | |||||
ASSERT_EQ(TensorShape({2}), host_y.shape()); | |||||
} | |||||
TEST(TestOprBasicArith, ElemwiseMemFwd) { | TEST(TestOprBasicArith, ElemwiseMemFwd) { | ||||
auto graph = ComputingGraph::make(); | auto graph = ComputingGraph::make(); | ||||
graph->options().graph_opt_level = 0; | graph->options().graph_opt_level = 0; | ||||