|
|
@@ -22,13 +22,14 @@ namespace cuda { |
|
|
|
|
|
|
|
using device_reduce::CheckNonFiniteOp; |
|
|
|
#define total_nr_elems_max 2048 |
|
|
|
template <typename T> |
|
|
|
size_t CheckNonFiniteImpl::_get_workspace_in_bytes() { |
|
|
|
// Call the _get_workspace_in_bytes to reduce the loop fetch workspace bytes |
|
|
|
typedef CheckNonFiniteOp<dt_float32, size_t, dt_int32, dt_int32> Op; |
|
|
|
typedef CheckNonFiniteOp<T, size_t, dt_int32, dt_int32> Op; |
|
|
|
megdnn_assert(m_size > 0); |
|
|
|
WorkspaceBundle bundle( |
|
|
|
nullptr, { |
|
|
|
sizeof(dt_float32*) * m_size, |
|
|
|
sizeof(T*) * m_size, |
|
|
|
sizeof(size_t) * m_size, |
|
|
|
}); |
|
|
|
return get_reduce_workspace_in_bytes<Op>(1, m_size * total_nr_elems_max, 1) + |
|
|
@@ -41,17 +42,38 @@ size_t CheckNonFiniteImpl::get_workspace_in_bytes( |
|
|
|
for (const auto& src : srcs) { |
|
|
|
m_size += DIVUP(src.layout.total_nr_elems(), total_nr_elems_max); |
|
|
|
} |
|
|
|
return _get_workspace_in_bytes(); |
|
|
|
if (srcs.begin()->layout.dtype == dtype::Float32()) { |
|
|
|
return _get_workspace_in_bytes<dt_float32>(); |
|
|
|
} else if (srcs.begin()->layout.dtype == dtype::Float16()) { |
|
|
|
return _get_workspace_in_bytes<dt_float16>(); |
|
|
|
} else { |
|
|
|
megdnn_log_warn("only support fp16 and fp32, fallback to fp32"); |
|
|
|
return _get_workspace_in_bytes<dt_float32>(); |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
void CheckNonFiniteImpl::exec( |
|
|
|
_megdnn_in const TensorNDArray& srcs, _megdnn_tensor_out dst, |
|
|
|
_megdnn_workspace workspace) { |
|
|
|
if (srcs.begin()->layout.dtype == dtype::Float32()) { |
|
|
|
_exec<dt_float32>(srcs, dst, workspace); |
|
|
|
} |
|
|
|
#ifdef DNN_INC_FLOAT16 |
|
|
|
else if (srcs.begin()->layout.dtype == dtype::Float16()) { |
|
|
|
_exec<dt_float16>(srcs, dst, workspace); |
|
|
|
} |
|
|
|
#endif |
|
|
|
} |
|
|
|
|
|
|
|
template <typename T> |
|
|
|
void CheckNonFiniteImpl::_exec( |
|
|
|
_megdnn_in const TensorNDArray& srcs, _megdnn_tensor_out dst, |
|
|
|
_megdnn_workspace workspace) { |
|
|
|
check_exec(srcs, dst, workspace.size); |
|
|
|
typedef CheckNonFiniteOp<dt_float32, size_t, dt_int32, dt_int32> Op; |
|
|
|
typedef CheckNonFiniteOp<T, size_t, dt_int32, dt_int32> Op; |
|
|
|
auto stream = cuda_stream(this->handle()); |
|
|
|
SmallVector<size_t> workspace_sizes{ |
|
|
|
sizeof(dt_float32*) * m_size, |
|
|
|
sizeof(T*) * m_size, |
|
|
|
sizeof(size_t) * m_size, |
|
|
|
}; |
|
|
|
WorkspaceBundle workspace_cpu(nullptr, workspace_sizes), |
|
|
@@ -63,8 +85,8 @@ void CheckNonFiniteImpl::exec( |
|
|
|
workspace_cpu = WorkspaceBundle(workspace_cpu_raw, workspace_sizes); |
|
|
|
workspace_gpu = WorkspaceBundle(workspace_gpu_raw, workspace_sizes); |
|
|
|
|
|
|
|
auto srcs_cpu = static_cast<dt_float32**>(workspace_cpu.get(0)); |
|
|
|
auto srcs_gpu = static_cast<dt_float32**>(workspace_gpu.get(0)); |
|
|
|
auto srcs_cpu = static_cast<T**>(workspace_cpu.get(0)); |
|
|
|
auto srcs_gpu = static_cast<T**>(workspace_gpu.get(0)); |
|
|
|
auto srcs_total_nr_elems_cpu = static_cast<size_t*>(workspace_cpu.get(1)); |
|
|
|
auto srcs_total_nr_elems_gpu = static_cast<size_t*>(workspace_gpu.get(1)); |
|
|
|
|
|
|
@@ -75,7 +97,7 @@ void CheckNonFiniteImpl::exec( |
|
|
|
size_t src_nr_elems = src.layout.total_nr_elems(); |
|
|
|
size_t nr_elems = DIVUP(src_nr_elems, total_nr_elems_max); |
|
|
|
for (size_t j = 0; j < nr_elems; ++j, ++i) { |
|
|
|
srcs_cpu[i] = src.ptr<dt_float32>() + j * total_nr_elems_max; |
|
|
|
srcs_cpu[i] = src.ptr<T>() + j * total_nr_elems_max; |
|
|
|
if (j + 1 == nr_elems && src_nr_elems % total_nr_elems_max) { |
|
|
|
srcs_total_nr_elems_cpu[i] = src_nr_elems % total_nr_elems_max; |
|
|
|
} else { |
|
|
@@ -97,7 +119,7 @@ void CheckNonFiniteImpl::exec( |
|
|
|
workspace_gpu.total_size_in_bytes())), |
|
|
|
1, m_size * total_nr_elems_max, 1, stream, |
|
|
|
Op(srcs_gpu, srcs_total_nr_elems_gpu, dst.ptr<dt_int32>(), |
|
|
|
total_nr_elems_max, param().scale)); |
|
|
|
total_nr_elems_max, static_cast<T>(param().scale))); |
|
|
|
} |
|
|
|
|
|
|
|
} // namespace cuda |
|
|
|