GitOrigin-RevId: ad5ec7bc1c
release-1.6
@@ -72,6 +72,7 @@ namespace indexing_multi_axis_vec { | |||||
#define cb(_dtype) \ | #define cb(_dtype) \ | ||||
MEGDNN_FOREACH_TENSOR_NDIM(INST, DTypeTrait<_dtype>::ctype) | MEGDNN_FOREACH_TENSOR_NDIM(INST, DTypeTrait<_dtype>::ctype) | ||||
MEGDNN_FOREACH_COMPUTING_DTYPE(cb) | MEGDNN_FOREACH_COMPUTING_DTYPE(cb) | ||||
cb(::megdnn::dtype::Bool) | |||||
#undef cb | #undef cb | ||||
#undef INST | #undef INST | ||||
@@ -38,6 +38,11 @@ __device__ void atomicAdd(megdnn::dt_int16 *, megdnn::dt_int16) { | |||||
((int*)0)[0] = 1; | ((int*)0)[0] = 1; | ||||
} | } | ||||
__device__ void atomicAdd(megdnn::dt_bool *, megdnn::dt_bool) { | |||||
asm("s_trap 2;"); | |||||
((int*)0)[0] = 1; | |||||
} | |||||
#define KERN_APPLY_OPR_OPR \ | #define KERN_APPLY_OPR_OPR \ | ||||
::megdnn::rocm::indexing_multi_axis_vec::OprAtomicIncr | ::megdnn::rocm::indexing_multi_axis_vec::OprAtomicIncr | ||||
#include "./kern_apply_opr_impl.hipinl" | #include "./kern_apply_opr_impl.hipinl" | ||||
@@ -71,6 +71,7 @@ void exec_src_normal(const TensorND& dst, const TensorND& src, | |||||
return; \ | return; \ | ||||
} | } | ||||
MEGDNN_FOREACH_COMPUTING_DTYPE(cb); | MEGDNN_FOREACH_COMPUTING_DTYPE(cb); | ||||
cb(::megdnn::dtype::Bool); | |||||
#undef cb | #undef cb | ||||
default: | default: | ||||
megdnn_assert_internal(0); | megdnn_assert_internal(0); | ||||
@@ -106,6 +107,7 @@ void TypeCvtImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_out dst) { | |||||
return; \ | return; \ | ||||
} | } | ||||
MEGDNN_FOREACH_COMPUTING_DTYPE(cb) | MEGDNN_FOREACH_COMPUTING_DTYPE(cb) | ||||
cb(::megdnn::dtype::Bool); | |||||
#undef cb | #undef cb | ||||
default: | default: | ||||
megdnn_assert_internal(0); | megdnn_assert_internal(0); | ||||
@@ -136,6 +136,7 @@ void typecvt_kern_n2n(const TensorND& dest, const TensorND& src, | |||||
#if !MEGDNN_DISABLE_FLOAT16 | #if !MEGDNN_DISABLE_FLOAT16 | ||||
#define MEGDNN_FOREACH_COMPUTING_DTYPE_WITH_DTYPE_SRC(dtype_src, cb) \ | #define MEGDNN_FOREACH_COMPUTING_DTYPE_WITH_DTYPE_SRC(dtype_src, cb) \ | ||||
cb(dtype_src, dt_bool) \ | |||||
cb(dtype_src, dt_int8) \ | cb(dtype_src, dt_int8) \ | ||||
cb(dtype_src, dt_int32) \ | cb(dtype_src, dt_int32) \ | ||||
cb(dtype_src, dt_int16) \ | cb(dtype_src, dt_int16) \ | ||||
@@ -147,6 +148,7 @@ void typecvt_kern_n2n(const TensorND& dest, const TensorND& src, | |||||
#else | #else | ||||
#define MEGDNN_FOREACH_COMPUTING_DTYPE_WITH_DTYPE_SRC(dtype_src, cb) \ | #define MEGDNN_FOREACH_COMPUTING_DTYPE_WITH_DTYPE_SRC(dtype_src, cb) \ | ||||
cb(dtype_src, dt_bool) \ | |||||
cb(dtype_src, dt_int8) \ | cb(dtype_src, dt_int8) \ | ||||
cb(dtype_src, dt_int32) \ | cb(dtype_src, dt_int32) \ | ||||
cb(dtype_src, dt_int16) \ | cb(dtype_src, dt_int16) \ | ||||
@@ -171,6 +173,7 @@ void typecvt_kern_n2n(const TensorND& dest, const TensorND& src, | |||||
#if !MEGDNN_DISABLE_FLOAT16 | #if !MEGDNN_DISABLE_FLOAT16 | ||||
#define MEGDNN_FOREACH_COMPUTING_CTYPE(cb) \ | #define MEGDNN_FOREACH_COMPUTING_CTYPE(cb) \ | ||||
cb(dt_bool) \ | |||||
cb(dt_int8) \ | cb(dt_int8) \ | ||||
cb(dt_int32) \ | cb(dt_int32) \ | ||||
cb(dt_int16) \ | cb(dt_int16) \ | ||||
@@ -181,6 +184,7 @@ void typecvt_kern_n2n(const TensorND& dest, const TensorND& src, | |||||
#else | #else | ||||
#define MEGDNN_FOREACH_COMPUTING_CTYPE(cb) \ | #define MEGDNN_FOREACH_COMPUTING_CTYPE(cb) \ | ||||
cb(dt_bool) \ | |||||
cb(dt_int8) \ | cb(dt_int8) \ | ||||
cb(dt_int32) \ | cb(dt_int32) \ | ||||
cb(dt_int16) \ | cb(dt_int16) \ | ||||
@@ -176,6 +176,7 @@ void init_common(py::module m) { | |||||
py::enum_<CompNode::DeviceType>(m, "DeviceType") | py::enum_<CompNode::DeviceType>(m, "DeviceType") | ||||
.value("UNSPEC", CompNode::DeviceType::UNSPEC) | .value("UNSPEC", CompNode::DeviceType::UNSPEC) | ||||
.value("CUDA", CompNode::DeviceType::CUDA) | .value("CUDA", CompNode::DeviceType::CUDA) | ||||
.value("ROCM", CompNode::DeviceType::ROCM) | |||||
.value("CPU", CompNode::DeviceType::CPU) | .value("CPU", CompNode::DeviceType::CPU) | ||||
.value("CAMBRICON", CompNode::DeviceType::CAMBRICON) | .value("CAMBRICON", CompNode::DeviceType::CAMBRICON) | ||||
.value("ATLAS", CompNode::DeviceType::ATLAS) | .value("ATLAS", CompNode::DeviceType::ATLAS) | ||||
@@ -378,6 +378,7 @@ public: | |||||
if (is_finalized()) return; | if (is_finalized()) return; | ||||
for (auto&& i : m_used_comp_node) { | for (auto&& i : m_used_comp_node) { | ||||
if (i.device_type() == CompNode::DeviceType::CUDA) continue; | if (i.device_type() == CompNode::DeviceType::CUDA) continue; | ||||
if (i.device_type() == CompNode::DeviceType::ROCM) continue; | |||||
i.sync(); | i.sync(); | ||||
} | } | ||||
} | } | ||||