|
|
@@ -247,6 +247,19 @@ struct TypeCvtOpFromQuantizedToQuantized4bit< |
|
|
|
namespace megdnn { |
|
|
|
namespace cuda { |
|
|
|
|
|
|
|
// currently only typecvt_kern_{n2q,n2q4} respect this. change others typecvt_kern_* if |
|
|
|
// needed. |
|
|
|
template <typename dtype_src, typename dtype_dest, typename sfinae = void> |
|
|
|
struct enable_typecvt_kern { |
|
|
|
static constexpr bool value = true; |
|
|
|
}; |
|
|
|
|
|
|
|
#define MEGDNN_DISABLE_CUDA_TYPECVT_KERN(dtype_src, dtype_dest) \ |
|
|
|
template <> \ |
|
|
|
struct enable_typecvt_kern<dtype_src, dtype_dest, void> { \ |
|
|
|
static constexpr bool value = false; \ |
|
|
|
}; |
|
|
|
|
|
|
|
template <typename dtype_src, typename dtype_dest> |
|
|
|
void typecvt_kern_q2q( |
|
|
|
const TensorND& dest, const TensorND& src, |
|
|
@@ -257,13 +270,29 @@ void typecvt_kern_q2q( |
|
|
|
} |
|
|
|
|
|
|
|
template <typename dtype_src, typename dtype_dest> |
|
|
|
void typecvt_kern_n2q( |
|
|
|
typename std::enable_if<enable_typecvt_kern<dtype_src, dtype_dest>::value>::type |
|
|
|
typecvt_kern_n2q_impl( |
|
|
|
const TensorND& dest, const TensorND& src, |
|
|
|
const CudaDTypeParam<dtype_dest>& dst_param, cudaStream_t stream) { |
|
|
|
main_func(TypeCvtOpToQuantized, op.param = dst_param;); |
|
|
|
} |
|
|
|
|
|
|
|
template <typename dtype_src, typename dtype_dest> |
|
|
|
typename std::enable_if<!enable_typecvt_kern<dtype_src, dtype_dest>::value>::type |
|
|
|
typecvt_kern_n2q_impl( |
|
|
|
const TensorND& dest, const TensorND& src, |
|
|
|
const CudaDTypeParam<dtype_dest>& dst_param, cudaStream_t stream) { |
|
|
|
megdnn_throw("TypeCvt: CUDA kernel for this dtype pair is disabled"); |
|
|
|
} |
|
|
|
|
|
|
|
template <typename dtype_src, typename dtype_dest> |
|
|
|
void typecvt_kern_n2q( |
|
|
|
const TensorND& dest, const TensorND& src, |
|
|
|
const CudaDTypeParam<dtype_dest>& dst_param, cudaStream_t stream) { |
|
|
|
typecvt_kern_n2q_impl<dtype_src, dtype_dest>(dest, src, dst_param, stream); |
|
|
|
} |
|
|
|
|
|
|
|
template <typename dtype_src, typename dtype_dest> |
|
|
|
void typecvt_kern_q2n( |
|
|
|
const TensorND& dest, const TensorND& src, |
|
|
|
const CudaDTypeParam<dtype_src>& src_param, cudaStream_t stream) { |
|
|
@@ -312,12 +341,15 @@ void typecvt_kern_n2n(const TensorND& dest, const TensorND& src, cudaStream_t st |
|
|
|
cb(dtype_src, dt_qint8) \ |
|
|
|
cb(dtype_src, dt_qint1) \ |
|
|
|
|
|
|
|
MEGDNN_FOREACH_QUANTIZED_DTYPE_WITH_DTYPE_SRC(dt_uint16, MEGDNN_DISABLE_CUDA_TYPECVT_KERN) |
|
|
|
|
|
|
|
#define INST_SRC_QUANTIZED(dtype_src) \ |
|
|
|
MEGDNN_FOREACH_COMPUTING_DTYPE_WITH_DTYPE_SRC(dtype_src, INST_Q2N) \ |
|
|
|
MEGDNN_FOREACH_QUANTIZED_DTYPE_WITH_DTYPE_SRC(dtype_src, INST_Q2Q) \ |
|
|
|
|
|
|
|
#define INST_SRC_NORMAL(dtype_src) \ |
|
|
|
MEGDNN_FOREACH_COMPUTING_DTYPE_WITH_DTYPE_SRC(dtype_src, INST_N2N) \ |
|
|
|
INST_N2N(dtype_src, dt_uint16) \ |
|
|
|
MEGDNN_FOREACH_QUANTIZED_DTYPE_WITH_DTYPE_SRC(dtype_src, INST_N2Q) \ |
|
|
|
|
|
|
|
#define MEGDNN_FOREACH_COMPUTING_CTYPE(cb) \ |
|
|
@@ -340,6 +372,7 @@ void typecvt_kern_n2n(const TensorND& dest, const TensorND& src, cudaStream_t st |
|
|
|
|
|
|
|
MEGDNN_FOREACH_QUANTIZED_CTYPE(INST_SRC_QUANTIZED) |
|
|
|
MEGDNN_FOREACH_COMPUTING_CTYPE(INST_SRC_NORMAL) |
|
|
|
INST_SRC_NORMAL(dt_uint16) |
|
|
|
// clang-format on |
|
|
|
|
|
|
|
template void typecvt_kern_n2q<dtype::Int8, dtype::QuantizedS8>( |
|
|
@@ -377,12 +410,28 @@ void typecvt_kern_q2q4( |
|
|
|
} |
|
|
|
|
|
|
|
template <typename dtype_src, typename dtype_dest> |
|
|
|
void typecvt_kern_n2q4( |
|
|
|
typename std::enable_if<enable_typecvt_kern<dtype_src, dtype_dest>::value>::type |
|
|
|
typecvt_kern_n2q4_impl( |
|
|
|
const TensorND& dest, const TensorND& src, |
|
|
|
const CudaDTypeParam<dtype_dest>& dst_param, cudaStream_t stream) { |
|
|
|
main_func_to_q4(TypeCvtOpFromNormalToQuantized4bit, op.dst_param = dst_param;) |
|
|
|
} |
|
|
|
|
|
|
|
template <typename dtype_src, typename dtype_dest> |
|
|
|
typename std::enable_if<!enable_typecvt_kern<dtype_src, dtype_dest>::value>::type |
|
|
|
typecvt_kern_n2q4_impl( |
|
|
|
const TensorND& dest, const TensorND& src, |
|
|
|
const CudaDTypeParam<dtype_dest>& dst_param, cudaStream_t stream) { |
|
|
|
megdnn_throw("TypeCvt: CUDA kernel for this dtype pair is disabled"); |
|
|
|
} |
|
|
|
|
|
|
|
template <typename dtype_src, typename dtype_dest> |
|
|
|
void typecvt_kern_n2q4( |
|
|
|
const TensorND& dest, const TensorND& src, |
|
|
|
const CudaDTypeParam<dtype_dest>& dst_param, cudaStream_t stream) { |
|
|
|
typecvt_kern_n2q4_impl<dtype_src, dtype_dest>(dest, src, dst_param, stream); |
|
|
|
} |
|
|
|
|
|
|
|
#define INST_Q2Q4(dtype_src, dtype_dest) \ |
|
|
|
template void typecvt_kern_q2q4<dtype_src, dtype_dest>( \ |
|
|
|
const TensorND& dest, const TensorND& src, \ |
|
|
@@ -399,6 +448,8 @@ void typecvt_kern_n2q4( |
|
|
|
cb(dtype_src, dt_qint4) \ |
|
|
|
cb(dtype_src, dt_quint4) \ |
|
|
|
|
|
|
|
MEGDNN_FOREACH_QUANTIZED_LOWBIT_WITH_DTYPE_SRC(dt_uint16, MEGDNN_DISABLE_CUDA_TYPECVT_KERN) |
|
|
|
|
|
|
|
#define INST_SRC_QUANTIZED_LOWBIT(dtype_src) \ |
|
|
|
MEGDNN_FOREACH_QUANTIZED_LOWBIT_WITH_DTYPE_SRC(dtype_src, INST_Q2Q4) \ |
|
|
|
|
|
|
@@ -407,6 +458,7 @@ void typecvt_kern_n2q4( |
|
|
|
|
|
|
|
MEGDNN_FOREACH_QUANTIZED_CTYPE(INST_SRC_QUANTIZED_LOWBIT) |
|
|
|
MEGDNN_FOREACH_COMPUTING_CTYPE(INST_SRC_NORMAL_LOWBIT) |
|
|
|
INST_SRC_NORMAL_LOWBIT(dt_uint16) |
|
|
|
|
|
|
|
} // namespace cuda |
|
|
|
} // namespace megdnn |
|
|
|