Browse Source

feat(dnn/cuda): add typecvt uint16

GitOrigin-RevId: d1368c414e
release-1.10
Megvii Engine Team 3 years ago
parent
commit
7dc347697a
6 changed files with 69 additions and 7 deletions
  1. +6
    -0
      dnn/src/cuda/elemwise_helper.cpp
  2. +1
    -0
      dnn/src/cuda/elemwise_helper.cuh
  3. +54
    -2
      dnn/src/cuda/type_cvt/kern.cu
  4. +5
    -4
      dnn/src/cuda/type_cvt/opr_impl.cpp
  5. +2
    -1
      dnn/test/cuda/type_cvt.cpp
  6. +1
    -0
      src/core/impl/dtype.cpp

+ 6
- 0
dnn/src/cuda/elemwise_helper.cpp View File

@@ -148,6 +148,9 @@ INST_FOR_CTYPE
#define ct dt_int16
INST_FOR_CTYPE
#undef ct
#define ct dt_uint16
INST_FOR_CTYPE
#undef ct
#define ct dt_quint8
INST_FOR_CTYPE
#undef ct
@@ -201,6 +204,9 @@ INST_FOR_CTYPE
#define ct dt_int16
INST_FOR_CTYPE
#undef ct
#define ct dt_uint16
INST_FOR_CTYPE
#undef ct
#define ct dt_quint8
INST_FOR_CTYPE
#undef ct


+ 1
- 0
dnn/src/cuda/elemwise_helper.cuh View File

@@ -92,6 +92,7 @@ INST(dt_float16, half4);
INST(dt_bfloat16, bhalf4);
INST(dt_int32, int4);
INST(dt_int16, short4);
INST(dt_uint16, ushort4);
INST(dt_bool, uchar4);
#undef as_raw
#define as_raw(x) x.as_int8()


+ 54
- 2
dnn/src/cuda/type_cvt/kern.cu View File

@@ -247,6 +247,19 @@ struct TypeCvtOpFromQuantizedToQuantized4bit<
namespace megdnn {
namespace cuda {

// currently only typecvt_kern_{n2q,n2q4} respect this. change others typecvt_kern_* if
// needed.
template <typename dtype_src, typename dtype_dest, typename sfinae = void>
struct enable_typecvt_kern {
static constexpr bool value = true;
};

#define MEGDNN_DISABLE_CUDA_TYPECVT_KERN(dtype_src, dtype_dest) \
template <> \
struct enable_typecvt_kern<dtype_src, dtype_dest, void> { \
static constexpr bool value = false; \
};

template <typename dtype_src, typename dtype_dest>
void typecvt_kern_q2q(
const TensorND& dest, const TensorND& src,
@@ -257,13 +270,29 @@ void typecvt_kern_q2q(
}

template <typename dtype_src, typename dtype_dest>
void typecvt_kern_n2q(
typename std::enable_if<enable_typecvt_kern<dtype_src, dtype_dest>::value>::type
typecvt_kern_n2q_impl(
const TensorND& dest, const TensorND& src,
const CudaDTypeParam<dtype_dest>& dst_param, cudaStream_t stream) {
main_func(TypeCvtOpToQuantized, op.param = dst_param;);
}

template <typename dtype_src, typename dtype_dest>
typename std::enable_if<!enable_typecvt_kern<dtype_src, dtype_dest>::value>::type
typecvt_kern_n2q_impl(
const TensorND& dest, const TensorND& src,
const CudaDTypeParam<dtype_dest>& dst_param, cudaStream_t stream) {
megdnn_throw("TypeCvt: CUDA kernel for this dtype pair is disabled");
}

template <typename dtype_src, typename dtype_dest>
void typecvt_kern_n2q(
const TensorND& dest, const TensorND& src,
const CudaDTypeParam<dtype_dest>& dst_param, cudaStream_t stream) {
typecvt_kern_n2q_impl<dtype_src, dtype_dest>(dest, src, dst_param, stream);
}

template <typename dtype_src, typename dtype_dest>
void typecvt_kern_q2n(
const TensorND& dest, const TensorND& src,
const CudaDTypeParam<dtype_src>& src_param, cudaStream_t stream) {
@@ -312,12 +341,15 @@ void typecvt_kern_n2n(const TensorND& dest, const TensorND& src, cudaStream_t st
cb(dtype_src, dt_qint8) \
cb(dtype_src, dt_qint1) \

MEGDNN_FOREACH_QUANTIZED_DTYPE_WITH_DTYPE_SRC(dt_uint16, MEGDNN_DISABLE_CUDA_TYPECVT_KERN)

#define INST_SRC_QUANTIZED(dtype_src) \
MEGDNN_FOREACH_COMPUTING_DTYPE_WITH_DTYPE_SRC(dtype_src, INST_Q2N) \
MEGDNN_FOREACH_QUANTIZED_DTYPE_WITH_DTYPE_SRC(dtype_src, INST_Q2Q) \

#define INST_SRC_NORMAL(dtype_src) \
MEGDNN_FOREACH_COMPUTING_DTYPE_WITH_DTYPE_SRC(dtype_src, INST_N2N) \
INST_N2N(dtype_src, dt_uint16) \
MEGDNN_FOREACH_QUANTIZED_DTYPE_WITH_DTYPE_SRC(dtype_src, INST_N2Q) \

#define MEGDNN_FOREACH_COMPUTING_CTYPE(cb) \
@@ -340,6 +372,7 @@ void typecvt_kern_n2n(const TensorND& dest, const TensorND& src, cudaStream_t st

MEGDNN_FOREACH_QUANTIZED_CTYPE(INST_SRC_QUANTIZED)
MEGDNN_FOREACH_COMPUTING_CTYPE(INST_SRC_NORMAL)
INST_SRC_NORMAL(dt_uint16)
// clang-format on

template void typecvt_kern_n2q<dtype::Int8, dtype::QuantizedS8>(
@@ -377,12 +410,28 @@ void typecvt_kern_q2q4(
}

template <typename dtype_src, typename dtype_dest>
void typecvt_kern_n2q4(
typename std::enable_if<enable_typecvt_kern<dtype_src, dtype_dest>::value>::type
typecvt_kern_n2q4_impl(
const TensorND& dest, const TensorND& src,
const CudaDTypeParam<dtype_dest>& dst_param, cudaStream_t stream) {
main_func_to_q4(TypeCvtOpFromNormalToQuantized4bit, op.dst_param = dst_param;)
}

template <typename dtype_src, typename dtype_dest>
typename std::enable_if<!enable_typecvt_kern<dtype_src, dtype_dest>::value>::type
typecvt_kern_n2q4_impl(
const TensorND& dest, const TensorND& src,
const CudaDTypeParam<dtype_dest>& dst_param, cudaStream_t stream) {
megdnn_throw("TypeCvt: CUDA kernel for this dtype pair is disabled");
}

template <typename dtype_src, typename dtype_dest>
void typecvt_kern_n2q4(
const TensorND& dest, const TensorND& src,
const CudaDTypeParam<dtype_dest>& dst_param, cudaStream_t stream) {
typecvt_kern_n2q4_impl<dtype_src, dtype_dest>(dest, src, dst_param, stream);
}

#define INST_Q2Q4(dtype_src, dtype_dest) \
template void typecvt_kern_q2q4<dtype_src, dtype_dest>( \
const TensorND& dest, const TensorND& src, \
@@ -399,6 +448,8 @@ void typecvt_kern_n2q4(
cb(dtype_src, dt_qint4) \
cb(dtype_src, dt_quint4) \

MEGDNN_FOREACH_QUANTIZED_LOWBIT_WITH_DTYPE_SRC(dt_uint16, MEGDNN_DISABLE_CUDA_TYPECVT_KERN)

#define INST_SRC_QUANTIZED_LOWBIT(dtype_src) \
MEGDNN_FOREACH_QUANTIZED_LOWBIT_WITH_DTYPE_SRC(dtype_src, INST_Q2Q4) \

@@ -407,6 +458,7 @@ void typecvt_kern_n2q4(

MEGDNN_FOREACH_QUANTIZED_CTYPE(INST_SRC_QUANTIZED_LOWBIT)
MEGDNN_FOREACH_COMPUTING_CTYPE(INST_SRC_NORMAL_LOWBIT)
INST_SRC_NORMAL_LOWBIT(dt_uint16)

} // namespace cuda
} // namespace megdnn


+ 5
- 4
dnn/src/cuda/type_cvt/opr_impl.cpp View File

@@ -12,6 +12,8 @@
#include "./opr_impl.h"
#include "./kern.cuh"

#include "megdnn/dtype.h"
#include "src/common/utils.cuh"
#include "src/cuda/utils.cuh"
#include "src/cuda/utils.h"

@@ -87,10 +89,9 @@ void exec_src_normal(const TensorND& dst, const TensorND& src, cudaStream_t stre
return; \
}
MEGDNN_FOREACH_COMPUTING_DTYPE(cb);
cb(::megdnn::dtype::Bool);
cb(::megdnn::dtype::Bool) cb(::megdnn::dtype::Uint16)
#undef cb
default:
megdnn_assert_internal(0);
default : megdnn_assert_internal(0);
}
} else if (!is_dst_lowbit) {
switch (dst.layout.dtype.enumv()) {
@@ -138,7 +139,7 @@ void TypeCvtImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_out dst) {
return; \
}
MEGDNN_FOREACH_COMPUTING_DTYPE(cb)
cb(::megdnn::dtype::Bool)
cb(::megdnn::dtype::Bool) cb(::megdnn::dtype::Uint16)
#undef cb
default : megdnn_assert_internal(0);
}


+ 2
- 1
dnn/test/cuda/type_cvt.cpp View File

@@ -19,7 +19,8 @@ using namespace test;
TEST_F(CUDA, TYPE_CVT) {
UniformFloatRNG init(0, 20);
std::vector<DType> dtypes = {dtype::Float32(), dtype::Float16(), dtype::Int32(),
dtype::Int16(), dtype::Int8(), dtype::Uint8()};
dtype::Int16(), dtype::Int8(), dtype::Uint8(),
dtype::Uint16()};
for (auto sdtype : dtypes)
for (auto ddtype : dtypes) {
TensorLayout src({10, 10}, sdtype), dst({10, 10}, ddtype);


+ 1
- 0
src/core/impl/dtype.cpp View File

@@ -210,6 +210,7 @@ typename ctype_enable_if<ctype>::type DTypeScalar::set_retain_dtype(ctype val) {
MEGDNN_FOREACH_COMPUTING_DTYPE(cb)
MEGDNN_FOREACH_QUANTIZED_DTYPE(cb)
cb(dt_bool);
cb(dt_uint16);
#undef cb
default:
mgb_throw(ConversionError, "can not assign to dtype %s", m_dtype.name());


Loading…
Cancel
Save