Browse Source

feat(mge/dnn): support checknonfinite for fp16

GitOrigin-RevId: 83fa139ac0
release-1.10
Megvii Engine Team 3 years ago
parent
commit
73112558d0
7 changed files with 49 additions and 19 deletions
  1. +0
    -1
      dnn/include/megdnn/oprs/general.h
  2. +1
    -2
      dnn/src/common/check_non_finite.cpp
  3. +1
    -1
      dnn/src/common/reduce_helper_device.h
  4. +8
    -4
      dnn/src/cuda/check_non_finite/kern.cu
  5. +31
    -9
      dnn/src/cuda/check_non_finite/opr_impl.cpp
  6. +7
    -1
      dnn/src/cuda/check_non_finite/opr_impl.h
  7. +1
    -1
      dnn/src/naive/check_non_finite/opr_impl.h

+ 0
- 1
dnn/include/megdnn/oprs/general.h View File

@@ -1383,7 +1383,6 @@ public:
protected: protected:
void check_exec( void check_exec(
const TensorNDArray& srcs, const TensorND& dst, size_t workspace_in_bytes); const TensorNDArray& srcs, const TensorND& dst, size_t workspace_in_bytes);
virtual size_t _get_workspace_in_bytes() = 0;
}; };


/*! /*!


+ 1
- 2
dnn/src/common/check_non_finite.cpp View File

@@ -18,8 +18,7 @@ void CheckNonFinite::check_exec(
const TensorNDArray& srcs, const TensorND& dst, size_t workspace_in_bytes) { const TensorNDArray& srcs, const TensorND& dst, size_t workspace_in_bytes) {
megdnn_assert_contiguous(dst.layout); megdnn_assert_contiguous(dst.layout);
megdnn_assert(srcs.size() > 0); megdnn_assert(srcs.size() > 0);
megdnn_assert(srcs.begin()->layout.dtype == dtype::Float32());
auto required_workspace_in_bytes = _get_workspace_in_bytes();
auto required_workspace_in_bytes = get_workspace_in_bytes(srcs, dst.layout);
megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes); megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes);
} }




+ 1
- 1
dnn/src/common/reduce_helper_device.h View File

@@ -236,4 +236,4 @@ void get_ABC(const TensorShape& shape, size_t& A, size_t& B, size_t& C, size_t a


} // namespace megdnn } // namespace megdnn


// vim: syntax=cpp.doxygen
// vim: syntax=cpp.doxygen

+ 8
- 4
dnn/src/cuda/check_non_finite/kern.cu View File

@@ -18,11 +18,15 @@ namespace cuda {


#define COMMA , #define COMMA ,


INST_REDUCE(
device_reduce::CheckNonFiniteOp<
dt_float32 COMMA size_t COMMA dt_int32 COMMA dt_int32>,
false);
#define cb(_dtype) \
INST_REDUCE( \
device_reduce::CheckNonFiniteOp< \
_dtype COMMA size_t COMMA dt_int32 COMMA dt_int32>, \
false);


cb(dt_float32);
cb(dt_float16);
#undef cb
#undef COMMA #undef COMMA
} // namespace cuda } // namespace cuda
} // namespace megdnn } // namespace megdnn


+ 31
- 9
dnn/src/cuda/check_non_finite/opr_impl.cpp View File

@@ -22,13 +22,14 @@ namespace cuda {


using device_reduce::CheckNonFiniteOp; using device_reduce::CheckNonFiniteOp;
#define total_nr_elems_max 2048 #define total_nr_elems_max 2048
template <typename T>
size_t CheckNonFiniteImpl::_get_workspace_in_bytes() { size_t CheckNonFiniteImpl::_get_workspace_in_bytes() {
// Call the _get_workspace_in_bytes to reduce the loop fetch workspace bytes // Call the _get_workspace_in_bytes to reduce the loop fetch workspace bytes
typedef CheckNonFiniteOp<dt_float32, size_t, dt_int32, dt_int32> Op;
typedef CheckNonFiniteOp<T, size_t, dt_int32, dt_int32> Op;
megdnn_assert(m_size > 0); megdnn_assert(m_size > 0);
WorkspaceBundle bundle( WorkspaceBundle bundle(
nullptr, { nullptr, {
sizeof(dt_float32*) * m_size,
sizeof(T*) * m_size,
sizeof(size_t) * m_size, sizeof(size_t) * m_size,
}); });
return get_reduce_workspace_in_bytes<Op>(1, m_size * total_nr_elems_max, 1) + return get_reduce_workspace_in_bytes<Op>(1, m_size * total_nr_elems_max, 1) +
@@ -41,17 +42,38 @@ size_t CheckNonFiniteImpl::get_workspace_in_bytes(
for (const auto& src : srcs) { for (const auto& src : srcs) {
m_size += DIVUP(src.layout.total_nr_elems(), total_nr_elems_max); m_size += DIVUP(src.layout.total_nr_elems(), total_nr_elems_max);
} }
return _get_workspace_in_bytes();
if (srcs.begin()->layout.dtype == dtype::Float32()) {
return _get_workspace_in_bytes<dt_float32>();
} else if (srcs.begin()->layout.dtype == dtype::Float16()) {
return _get_workspace_in_bytes<dt_float16>();
} else {
megdnn_log_warn("only support fp16 and fp32, fallback to fp32");
return _get_workspace_in_bytes<dt_float32>();
}
} }


void CheckNonFiniteImpl::exec( void CheckNonFiniteImpl::exec(
_megdnn_in const TensorNDArray& srcs, _megdnn_tensor_out dst, _megdnn_in const TensorNDArray& srcs, _megdnn_tensor_out dst,
_megdnn_workspace workspace) { _megdnn_workspace workspace) {
if (srcs.begin()->layout.dtype == dtype::Float32()) {
_exec<dt_float32>(srcs, dst, workspace);
}
#ifdef DNN_INC_FLOAT16
else if (srcs.begin()->layout.dtype == dtype::Float16()) {
_exec<dt_float16>(srcs, dst, workspace);
}
#endif
}

template <typename T>
void CheckNonFiniteImpl::_exec(
_megdnn_in const TensorNDArray& srcs, _megdnn_tensor_out dst,
_megdnn_workspace workspace) {
check_exec(srcs, dst, workspace.size); check_exec(srcs, dst, workspace.size);
typedef CheckNonFiniteOp<dt_float32, size_t, dt_int32, dt_int32> Op;
typedef CheckNonFiniteOp<T, size_t, dt_int32, dt_int32> Op;
auto stream = cuda_stream(this->handle()); auto stream = cuda_stream(this->handle());
SmallVector<size_t> workspace_sizes{ SmallVector<size_t> workspace_sizes{
sizeof(dt_float32*) * m_size,
sizeof(T*) * m_size,
sizeof(size_t) * m_size, sizeof(size_t) * m_size,
}; };
WorkspaceBundle workspace_cpu(nullptr, workspace_sizes), WorkspaceBundle workspace_cpu(nullptr, workspace_sizes),
@@ -63,8 +85,8 @@ void CheckNonFiniteImpl::exec(
workspace_cpu = WorkspaceBundle(workspace_cpu_raw, workspace_sizes); workspace_cpu = WorkspaceBundle(workspace_cpu_raw, workspace_sizes);
workspace_gpu = WorkspaceBundle(workspace_gpu_raw, workspace_sizes); workspace_gpu = WorkspaceBundle(workspace_gpu_raw, workspace_sizes);


auto srcs_cpu = static_cast<dt_float32**>(workspace_cpu.get(0));
auto srcs_gpu = static_cast<dt_float32**>(workspace_gpu.get(0));
auto srcs_cpu = static_cast<T**>(workspace_cpu.get(0));
auto srcs_gpu = static_cast<T**>(workspace_gpu.get(0));
auto srcs_total_nr_elems_cpu = static_cast<size_t*>(workspace_cpu.get(1)); auto srcs_total_nr_elems_cpu = static_cast<size_t*>(workspace_cpu.get(1));
auto srcs_total_nr_elems_gpu = static_cast<size_t*>(workspace_gpu.get(1)); auto srcs_total_nr_elems_gpu = static_cast<size_t*>(workspace_gpu.get(1));


@@ -75,7 +97,7 @@ void CheckNonFiniteImpl::exec(
size_t src_nr_elems = src.layout.total_nr_elems(); size_t src_nr_elems = src.layout.total_nr_elems();
size_t nr_elems = DIVUP(src_nr_elems, total_nr_elems_max); size_t nr_elems = DIVUP(src_nr_elems, total_nr_elems_max);
for (size_t j = 0; j < nr_elems; ++j, ++i) { for (size_t j = 0; j < nr_elems; ++j, ++i) {
srcs_cpu[i] = src.ptr<dt_float32>() + j * total_nr_elems_max;
srcs_cpu[i] = src.ptr<T>() + j * total_nr_elems_max;
if (j + 1 == nr_elems && src_nr_elems % total_nr_elems_max) { if (j + 1 == nr_elems && src_nr_elems % total_nr_elems_max) {
srcs_total_nr_elems_cpu[i] = src_nr_elems % total_nr_elems_max; srcs_total_nr_elems_cpu[i] = src_nr_elems % total_nr_elems_max;
} else { } else {
@@ -97,7 +119,7 @@ void CheckNonFiniteImpl::exec(
workspace_gpu.total_size_in_bytes())), workspace_gpu.total_size_in_bytes())),
1, m_size * total_nr_elems_max, 1, stream, 1, m_size * total_nr_elems_max, 1, stream,
Op(srcs_gpu, srcs_total_nr_elems_gpu, dst.ptr<dt_int32>(), Op(srcs_gpu, srcs_total_nr_elems_gpu, dst.ptr<dt_int32>(),
total_nr_elems_max, param().scale));
total_nr_elems_max, static_cast<T>(param().scale)));
} }


} // namespace cuda } // namespace cuda


+ 7
- 1
dnn/src/cuda/check_non_finite/opr_impl.h View File

@@ -18,7 +18,13 @@ namespace megdnn {
namespace cuda { namespace cuda {


class CheckNonFiniteImpl final : public CheckNonFinite { class CheckNonFiniteImpl final : public CheckNonFinite {
size_t _get_workspace_in_bytes() override;
template <typename T>
size_t _get_workspace_in_bytes();

template <typename T>
void _exec(
_megdnn_in const TensorNDArray& srcs, _megdnn_tensor_out dst,
_megdnn_workspace workspace);


public: public:
using CheckNonFinite::CheckNonFinite; using CheckNonFinite::CheckNonFinite;


+ 1
- 1
dnn/src/naive/check_non_finite/opr_impl.h View File

@@ -17,7 +17,7 @@ namespace megdnn {
namespace naive { namespace naive {


class CheckNonFiniteImpl final : public CheckNonFinite { class CheckNonFiniteImpl final : public CheckNonFinite {
size_t _get_workspace_in_bytes() override { return 0; }
size_t _get_workspace_in_bytes() { return 0; }


public: public:
using CheckNonFinite::CheckNonFinite; using CheckNonFinite::CheckNonFinite;


Loading…
Cancel
Save