Browse Source

feat(dnn/cuda): support float16 for index_incr_multi_axis_vec

GitOrigin-RevId: c2ae93d568
release-1.5
Megvii Engine Team 4 years ago
parent
commit
23032f50f2
3 changed files with 21 additions and 6 deletions
  1. +3
    -3
      dnn/src/cuda/indexing_multi_axis_vec/kern_apply_opr_incr.cu
  2. +0
    -3
      dnn/src/cuda/indexing_multi_axis_vec/opr_impl.cpp
  3. +18
    -0
      dnn/test/cuda/indexing_multi_axis_vec.cpp

+ 3
- 3
dnn/src/cuda/indexing_multi_axis_vec/kern_apply_opr_incr.cu View File

@@ -11,11 +11,11 @@


#include "megdnn/dtype.h"
#include "src/cuda/utils.cuh"

#if !MEGDNN_DISABLE_FLOAT16
__device__ void atomicAdd(megdnn::dt_float16 *, megdnn::dt_float16) {
__trap();
((int*)0)[0] = 1;
__device__ void atomicAdd(megdnn::dt_float16 * address, megdnn::dt_float16 val) {
::megdnn::cuda::atomic_add(address, val);
}

__device__ void atomicAdd(megdnn::dt_bfloat16 *, megdnn::dt_bfloat16) {


+ 0
- 3
dnn/src/cuda/indexing_multi_axis_vec/opr_impl.cpp View File

@@ -199,9 +199,6 @@ size_t IndexingIncrMultiAxisVecImpl::get_workspace_in_bytes(
void IndexingIncrMultiAxisVecImpl::exec(
_megdnn_tensor_inout data, _megdnn_tensor_in value,
const IndexDesc &index, _megdnn_workspace workspace) {
DNN_INC_FLOAT16(
megdnn_assert(data.layout.dtype != dtype::Float16(),
"float16 incr on cuda currently not supported"));
auto info = check_exec(data.layout, value.layout, index, workspace.size);
info.error_tracker = m_error_tracker;
info.error_info = async_error_info(handle());


+ 18
- 0
dnn/test/cuda/indexing_multi_axis_vec.cpp View File

@@ -32,6 +32,11 @@ namespace {
for (size_t i = 0, it = span.dist_elem(); i < it; ++ i) {
ptr[i] = i;
}
} else if (tensor.layout.dtype == dtype::Float16()) {
auto ptr = tensor.ptr<dt_float16>() + span.low_elem;
for (size_t i = 0, it = span.dist_elem(); i < it; ++ i) {
ptr[i] = i;
}
} else {
auto ptr = tensor.ptr<int>() + span.low_elem;
for (size_t i = 0, it = span.dist_elem(); i < it; ++ i) {
@@ -135,6 +140,19 @@ TEST_F(CUDA, INDEXING_MULTI_AXIS_VEC) {

TEST_F(CUDA, INDEXING_INCR_MULTI_AXIS_VEC) {
run_check<IndexingIncrMultiAxisVec>(handle_cuda());
Checker<IndexingIncrMultiAxisVec> checker(handle_cuda());
OrderedRNG rng;
checker.
set_dtype(0, dtype::Float16()). // data
set_dtype(1, dtype::Float16()). // value
set_dtype(2, dtype::Int32()). // idx0
set_rng(0, &rng).
set_rng(1, &rng).
set_rng(2, &rng);

checker.
set_proxy({{1}}).
execs({{5, 8, 3}, {5, 2, 3}, {2}});
}

TEST_F(CUDA, INDEXING_SET_MULTI_AXIS_VEC) {


Loading…
Cancel
Save