From 23032f50f28407e4298f26b0b0c6a8e3d1f58062 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Tue, 25 May 2021 16:54:32 +0800 Subject: [PATCH] feat(dnn/cuda): support float16 for index_incr_multi_axis_vec GitOrigin-RevId: c2ae93d568892d1af6a602aed3ed7c60f9dba1bd --- .../indexing_multi_axis_vec/kern_apply_opr_incr.cu | 6 +++--- dnn/src/cuda/indexing_multi_axis_vec/opr_impl.cpp | 3 --- dnn/test/cuda/indexing_multi_axis_vec.cpp | 18 ++++++++++++++++++ 3 files changed, 21 insertions(+), 6 deletions(-) diff --git a/dnn/src/cuda/indexing_multi_axis_vec/kern_apply_opr_incr.cu b/dnn/src/cuda/indexing_multi_axis_vec/kern_apply_opr_incr.cu index a1601745..9879eb02 100644 --- a/dnn/src/cuda/indexing_multi_axis_vec/kern_apply_opr_incr.cu +++ b/dnn/src/cuda/indexing_multi_axis_vec/kern_apply_opr_incr.cu @@ -11,11 +11,11 @@ #include "megdnn/dtype.h" +#include "src/cuda/utils.cuh" #if !MEGDNN_DISABLE_FLOAT16 -__device__ void atomicAdd(megdnn::dt_float16 *, megdnn::dt_float16) { - __trap(); - ((int*)0)[0] = 1; +__device__ void atomicAdd(megdnn::dt_float16 * address, megdnn::dt_float16 val) { + ::megdnn::cuda::atomic_add(address, val); } __device__ void atomicAdd(megdnn::dt_bfloat16 *, megdnn::dt_bfloat16) { diff --git a/dnn/src/cuda/indexing_multi_axis_vec/opr_impl.cpp b/dnn/src/cuda/indexing_multi_axis_vec/opr_impl.cpp index 3dd62d2c..86c28fd4 100644 --- a/dnn/src/cuda/indexing_multi_axis_vec/opr_impl.cpp +++ b/dnn/src/cuda/indexing_multi_axis_vec/opr_impl.cpp @@ -199,9 +199,6 @@ size_t IndexingIncrMultiAxisVecImpl::get_workspace_in_bytes( void IndexingIncrMultiAxisVecImpl::exec( _megdnn_tensor_inout data, _megdnn_tensor_in value, const IndexDesc &index, _megdnn_workspace workspace) { - DNN_INC_FLOAT16( - megdnn_assert(data.layout.dtype != dtype::Float16(), - "float16 incr on cuda currently not supported")); auto info = check_exec(data.layout, value.layout, index, workspace.size); info.error_tracker = m_error_tracker; info.error_info = async_error_info(handle()); diff --git a/dnn/test/cuda/indexing_multi_axis_vec.cpp b/dnn/test/cuda/indexing_multi_axis_vec.cpp index 2d565eb0..553294f1 100644 --- a/dnn/test/cuda/indexing_multi_axis_vec.cpp +++ b/dnn/test/cuda/indexing_multi_axis_vec.cpp @@ -32,6 +32,11 @@ namespace { for (size_t i = 0, it = span.dist_elem(); i < it; ++ i) { ptr[i] = i; } + } else if (tensor.layout.dtype == dtype::Float16()) { + auto ptr = tensor.ptr() + span.low_elem; + for (size_t i = 0, it = span.dist_elem(); i < it; ++ i) { + ptr[i] = i; + } } else { auto ptr = tensor.ptr() + span.low_elem; for (size_t i = 0, it = span.dist_elem(); i < it; ++ i) { @@ -135,6 +140,19 @@ TEST_F(CUDA, INDEXING_MULTI_AXIS_VEC) { TEST_F(CUDA, INDEXING_INCR_MULTI_AXIS_VEC) { run_check(handle_cuda()); + Checker checker(handle_cuda()); + OrderedRNG rng; + checker. + set_dtype(0, dtype::Float16()). // data + set_dtype(1, dtype::Float16()). // value + set_dtype(2, dtype::Int32()). // idx0 + set_rng(0, &rng). + set_rng(1, &rng). + set_rng(2, &rng); + + checker. + set_proxy({{1}}). + execs({{5, 8, 3}, {5, 2, 3}, {2}}); } TEST_F(CUDA, INDEXING_SET_MULTI_AXIS_VEC) {