GitOrigin-RevId: c2ae93d568
release-1.5
@@ -11,11 +11,11 @@ | |||||
#include "megdnn/dtype.h" | #include "megdnn/dtype.h" | ||||
#include "src/cuda/utils.cuh" | |||||
#if !MEGDNN_DISABLE_FLOAT16 | #if !MEGDNN_DISABLE_FLOAT16 | ||||
__device__ void atomicAdd(megdnn::dt_float16 *, megdnn::dt_float16) { | |||||
__trap(); | |||||
((int*)0)[0] = 1; | |||||
__device__ void atomicAdd(megdnn::dt_float16 * address, megdnn::dt_float16 val) { | |||||
::megdnn::cuda::atomic_add(address, val); | |||||
} | } | ||||
__device__ void atomicAdd(megdnn::dt_bfloat16 *, megdnn::dt_bfloat16) { | __device__ void atomicAdd(megdnn::dt_bfloat16 *, megdnn::dt_bfloat16) { | ||||
@@ -199,9 +199,6 @@ size_t IndexingIncrMultiAxisVecImpl::get_workspace_in_bytes( | |||||
void IndexingIncrMultiAxisVecImpl::exec( | void IndexingIncrMultiAxisVecImpl::exec( | ||||
_megdnn_tensor_inout data, _megdnn_tensor_in value, | _megdnn_tensor_inout data, _megdnn_tensor_in value, | ||||
const IndexDesc &index, _megdnn_workspace workspace) { | const IndexDesc &index, _megdnn_workspace workspace) { | ||||
DNN_INC_FLOAT16( | |||||
megdnn_assert(data.layout.dtype != dtype::Float16(), | |||||
"float16 incr on cuda currently not supported")); | |||||
auto info = check_exec(data.layout, value.layout, index, workspace.size); | auto info = check_exec(data.layout, value.layout, index, workspace.size); | ||||
info.error_tracker = m_error_tracker; | info.error_tracker = m_error_tracker; | ||||
info.error_info = async_error_info(handle()); | info.error_info = async_error_info(handle()); | ||||
@@ -32,6 +32,11 @@ namespace { | |||||
for (size_t i = 0, it = span.dist_elem(); i < it; ++ i) { | for (size_t i = 0, it = span.dist_elem(); i < it; ++ i) { | ||||
ptr[i] = i; | ptr[i] = i; | ||||
} | } | ||||
} else if (tensor.layout.dtype == dtype::Float16()) { | |||||
auto ptr = tensor.ptr<dt_float16>() + span.low_elem; | |||||
for (size_t i = 0, it = span.dist_elem(); i < it; ++ i) { | |||||
ptr[i] = i; | |||||
} | |||||
} else { | } else { | ||||
auto ptr = tensor.ptr<int>() + span.low_elem; | auto ptr = tensor.ptr<int>() + span.low_elem; | ||||
for (size_t i = 0, it = span.dist_elem(); i < it; ++ i) { | for (size_t i = 0, it = span.dist_elem(); i < it; ++ i) { | ||||
@@ -135,6 +140,19 @@ TEST_F(CUDA, INDEXING_MULTI_AXIS_VEC) { | |||||
TEST_F(CUDA, INDEXING_INCR_MULTI_AXIS_VEC) { | TEST_F(CUDA, INDEXING_INCR_MULTI_AXIS_VEC) { | ||||
run_check<IndexingIncrMultiAxisVec>(handle_cuda()); | run_check<IndexingIncrMultiAxisVec>(handle_cuda()); | ||||
Checker<IndexingIncrMultiAxisVec> checker(handle_cuda()); | |||||
OrderedRNG rng; | |||||
checker. | |||||
set_dtype(0, dtype::Float16()). // data | |||||
set_dtype(1, dtype::Float16()). // value | |||||
set_dtype(2, dtype::Int32()). // idx0 | |||||
set_rng(0, &rng). | |||||
set_rng(1, &rng). | |||||
set_rng(2, &rng); | |||||
checker. | |||||
set_proxy({{1}}). | |||||
execs({{5, 8, 3}, {5, 2, 3}, {2}}); | |||||
} | } | ||||
TEST_F(CUDA, INDEXING_SET_MULTI_AXIS_VEC) { | TEST_F(CUDA, INDEXING_SET_MULTI_AXIS_VEC) { | ||||