From 894a2407c2a90b0314a845bda8545f18d26e7e3d Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Fri, 4 Jun 2021 14:43:59 +0800 Subject: [PATCH] feat(dnn/cuda): add relayout format kernel for nchw <-> nhwc GitOrigin-RevId: e11f3e54085929ab9919fe2070fcc9e633755b69 --- dnn/scripts/opr_param_defs.py | 4 +- dnn/src/common/relayout_format.cpp | 28 +++ dnn/src/cuda/pooling/pooling2d_qint.cu | 4 +- dnn/src/cuda/relayout_format/opr_impl.cpp | 8 +- dnn/src/cuda/relayout_format/relayout_format.cpp | 16 ++ .../cuda/relayout_format/relayout_format_kern.cuh | 236 ++++++++++++++++++++- .../relayout_format/relayout_format_nchw_nhwc.cu | 211 ++++++++++++++++++ dnn/src/cuda/relayout_format/translayout.cuh | 75 ++++--- dnn/src/cuda/warp_perspective/forward.cu | 30 +-- dnn/test/cuda/relayout_format.cpp | 111 ++++++++++ 10 files changed, 660 insertions(+), 63 deletions(-) create mode 100644 dnn/src/cuda/relayout_format/relayout_format_nchw_nhwc.cu diff --git a/dnn/scripts/opr_param_defs.py b/dnn/scripts/opr_param_defs.py index 3a3fa37b..26738684 100755 --- a/dnn/scripts/opr_param_defs.py +++ b/dnn/scripts/opr_param_defs.py @@ -1001,7 +1001,9 @@ Note: NCHW_NCHW4_WEIGHT will auto pad oc and ic, you should remove oc in later o 'NCHW_NCHW4_WEIGHT', 'NCHW_NCHW64', 'NCHW64_NCHW', - ) + 'NCHW_NHWC', + 'NHWC_NCHW', + ) ) (pdef('RelayoutFormat', 'Change the tensor layout format', version=1). diff --git a/dnn/src/common/relayout_format.cpp b/dnn/src/common/relayout_format.cpp index 01eb4595..d3bf7115 100644 --- a/dnn/src/common/relayout_format.cpp +++ b/dnn/src/common/relayout_format.cpp @@ -268,6 +268,22 @@ void RelayoutFormat::deduce_layout_fwd(const TensorLayout& src, dst[2] = src[2]; dst[3] = src[3]; break; + case Param::Mode::NCHW_NHWC: + megdnn_assert(src.ndim == 4); + dst.ndim = 4; + dst[0] = src[0]; + dst[1] = src[2]; + dst[2] = src[3]; + dst[3] = src[1]; + break; + case Param::Mode::NHWC_NCHW: + megdnn_assert(src.ndim == 4); + dst.ndim = 4; + dst[0] = src[0]; + dst[1] = src[3]; + dst[2] = src[1]; + dst[3] = src[2]; + break; default: megdnn_assert(0, "Invalid RelayoutFormat Mode"); break; @@ -375,6 +391,10 @@ void RelayoutFormat::deduce_format(TensorFormat src, TensorFormat& dst) { case Param::Mode::NCHW64_NCHW: dst = src; break; + case Param::Mode::NCHW_NHWC: + case Param::Mode::NHWC_NCHW: + dst = src; + break; default: megdnn_throw("Invalid relayout format mode"); break; @@ -666,6 +686,14 @@ void RelayoutFormat::deduce_exec_layout(const TensorLayout& src, exec_src = src.dimshuffle({0, 1, 4, 2, 3}); exec_dst = dst; break; + case Param::Mode::NCHW_NHWC: + exec_src = src.dimshuffle({0, 2, 3, 1}); + exec_dst = dst; + break; + case Param::Mode::NHWC_NCHW: + exec_src = src.dimshuffle({0, 3, 1, 2}); + exec_dst = dst; + break; default: megdnn_assert(0, "Invalid RelayoutFormat Mode"); } diff --git a/dnn/src/cuda/pooling/pooling2d_qint.cu b/dnn/src/cuda/pooling/pooling2d_qint.cu index ddd970ea..e0a7e6ca 100644 --- a/dnn/src/cuda/pooling/pooling2d_qint.cu +++ b/dnn/src/cuda/pooling/pooling2d_qint.cu @@ -505,7 +505,7 @@ void megdnn::cuda::pooling2d::do_pooling2d_int8_cdiv4hwn4(const int8_t* d_src, void megdnn::cuda::pooling2d::do_pooling2d_int8_ncdiv4hw4( const int8_t* d_src, int8_t* d_dst, const Param& param, - cudaStream_t stream, uint32_t mode, bool uint_case, int zero_point) { + cudaStream_t stream, uint32_t mode, bool /* uint_case */, int zero_point) { using Mode = megdnn::param_enumv::Pooling::Mode; void (*kern)(const int8_t* __restrict__, int8_t* __restrict__, Param param, int zero_point); @@ -545,7 +545,7 @@ void megdnn::cuda::pooling2d::do_pooling2d_int8_ncdiv4hw4( void megdnn::cuda::pooling2d::do_pooling2d_int8_ncdiv32hw32( const int8_t* d_src, int8_t* d_dst, const Param& param, - cudaStream_t stream, uint32_t mode, bool uint_case, int zero_point) { + cudaStream_t stream, uint32_t mode, bool /* uint_case */, int zero_point) { using Mode = megdnn::param_enumv::Pooling::Mode; void (*kern)(const int8_t* __restrict__, int8_t* __restrict__, Param param, int zero_point); diff --git a/dnn/src/cuda/relayout_format/opr_impl.cpp b/dnn/src/cuda/relayout_format/opr_impl.cpp index 5860ec43..70d8c58f 100644 --- a/dnn/src/cuda/relayout_format/opr_impl.cpp +++ b/dnn/src/cuda/relayout_format/opr_impl.cpp @@ -33,7 +33,9 @@ void RelayoutFormatImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_out dst, Param::Mode:: NCHW_NCHW4_IC_SMALL_CONV_DENSE_WEIGHT || param().mode == Param::Mode::NCHW_NCHW64 || - param().mode == Param::Mode::NCHW64_NCHW, + param().mode == Param::Mode::NCHW64_NCHW || + param().mode == Param::Mode::NCHW_NHWC || + param().mode == Param::Mode::NHWC_NCHW, "relayout format of cuda only support NCHW4->CHWN4 or " "CHWN4->NCHW4 or NCHW->NCHW4"); if ((param().mode == param::RelayoutFormat::Mode::NCHW4_CHWN4 || @@ -82,7 +84,9 @@ void RelayoutFormatImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_out dst, {src.raw_ptr, exec_src_layout}, {dst.raw_ptr, exec_dst_layout}); } bool is_trans_4bits = (param().mode == Param::Mode::NCHW_NCHW64 || - param().mode == Param::Mode::NCHW64_NCHW) && + param().mode == Param::Mode::NCHW64_NCHW || + param().mode == Param::Mode::NCHW_NHWC || + param().mode == Param::Mode::NHWC_NCHW) && (src_dtype.enumv() == DTypeEnum::QuantizedS4 || src_dtype.enumv() == DTypeEnum::Quantized4Asymm); bool is_nchw_nchw4 = param().mode == Param::Mode::NCHW_NCHW4 || diff --git a/dnn/src/cuda/relayout_format/relayout_format.cpp b/dnn/src/cuda/relayout_format/relayout_format.cpp index 8ac6656a..539901a9 100644 --- a/dnn/src/cuda/relayout_format/relayout_format.cpp +++ b/dnn/src/cuda/relayout_format/relayout_format.cpp @@ -66,6 +66,22 @@ void relayout_format::RelayoutFormatFast::exec(const TensorND& src, return relayout_format_cuda_nchwx_nchw(src, dst, stream, src_scale, dst_scale, src_zero_point, dst_zero_point); + } else if (mode == RelayoutFormat::Param::Mode::NCHW_NHWC) { +#define CHECK(dt) \ + megdnn_assert(dt.enumv() == DTypeEnum::Quantized4Asymm || \ + dt.enumv() == DTypeEnum::QuantizedS4) + CHECK(src.layout.dtype); + CHECK(dst.layout.dtype); + return relayout_format_cuda_nchw_nhwc(src, dst, stream, src_scale, + dst_scale, src_zero_point, + dst_zero_point); + } else if (mode == RelayoutFormat::Param::Mode::NHWC_NCHW) { + CHECK(src.layout.dtype); + CHECK(dst.layout.dtype); + return relayout_format_cuda_nhwc_nchw(src, dst, stream, src_scale, + dst_scale, src_zero_point, + dst_zero_point); +#undef CHECK } else if (mode == RelayoutFormat::Param::Mode::NCHW_NCHW4_WEIGHT) { return relayout_format_cuda_nchw_nchw4_weight(src, dst, stream); } else if (mode == RelayoutFormat::Param::Mode::NCHW4_NCHW) { diff --git a/dnn/src/cuda/relayout_format/relayout_format_kern.cuh b/dnn/src/cuda/relayout_format/relayout_format_kern.cuh index b2ce2eb6..6fc4b1f2 100644 --- a/dnn/src/cuda/relayout_format/relayout_format_kern.cuh +++ b/dnn/src/cuda/relayout_format/relayout_format_kern.cuh @@ -20,8 +20,17 @@ namespace relayout_format { namespace internal { using namespace memory; +struct LayoutType { + static constexpr uint32_t NCHWx = 0; + static constexpr uint32_t NHWC = 1; +}; + template + int size_nbits_, uint32_t layout_type_ = LayoutType::NCHWx> +class TensorIteratorOverChannel; + +template class TensorIteratorOverChannel { public: using Type = Type_; @@ -116,6 +125,98 @@ private: template +class TensorIteratorOverChannel { +public: + using Type = Type_; + static constexpr int pack_size = pack_size_; + static constexpr int chan_blk = chan_blk_; + static constexpr int width = width_; + static constexpr int size_nbits = size_nbits_; + static constexpr int elements_in_type = + chan_blk * width * size_nbits / (8 * sizeof(Type)); + static constexpr int pack_size_in_type = + pack_size * size_nbits / (8 * sizeof(Type)); + static constexpr int pack_size_in_byte = pack_size_in_type * sizeof(Type); + using AccessType = array_wrapper; + using Fragment = array_wrapper; + + MEGDNN_HOST TensorIteratorOverChannel() + : pointer{nullptr}, hw_stride_in_elements{0}, channel{0} {} + MEGDNN_HOST TensorIteratorOverChannel(Type* pointer_, + int hw_stride_in_elements_, + int channel_, int, int) + : pointer{pointer_}, + hw_stride_in_elements{hw_stride_in_elements_}, + channel{channel_} {} + + MEGDNN_DEVICE __forceinline__ void initialize(int c_idx, int hw_idx) { + pointer += c_idx * size_nbits / (8 * sizeof(Type)) + + hw_idx * hw_stride_in_elements; + channel -= c_idx; + } + + MEGDNN_DEVICE __forceinline__ void add_pointer_offset( + size_t offset_in_type) { + pointer += offset_in_type; + } + + MEGDNN_DEVICE __forceinline__ void load(Fragment& frag, int zero_point) { + AccessType* frag_ptr = reinterpret_cast(&frag); + Type* pointer_ = pointer; +#pragma unroll + for (int i = 0; i < width; ++i) { +#pragma unroll + for (int j = 0; j < chan_blk; j += pack_size) { + int frag_idx = i * (chan_blk / pack_size) + (j / pack_size); + bool guard = j < channel; + global_load( + frag_ptr[frag_idx], + reinterpret_cast( + pointer_ + j * size_nbits / (8 * sizeof(Type))), + guard, zero_point); + } + pointer_ += hw_stride_in_elements; + } + } + + MEGDNN_DEVICE __forceinline__ void store(const Fragment& frag) { + const AccessType* frag_ptr = reinterpret_cast(&frag); + Type* pointer_ = pointer; +#pragma unroll + for (int i = 0; i < width; ++i) { +#pragma unroll + for (int j = 0; j < chan_blk; j += pack_size) { + int frag_idx = i * (chan_blk / pack_size) + (j / pack_size); + bool guard = j < channel; + global_store( + frag_ptr[frag_idx], + reinterpret_cast( + pointer_ + j * size_nbits / (8 * sizeof(Type))), + guard); + } + pointer_ += hw_stride_in_elements; + } + } + + MEGDNN_DEVICE __forceinline__ void advance() { + pointer += chan_blk * size_nbits / (8 * sizeof(Type)); + channel -= chan_blk; + } + +private: + Type* pointer; + int hw_stride_in_elements; + int channel; +}; + + +template +class MaskedTensorIteratorOverChannel; + +template class MaskedTensorIteratorOverChannel { public: using Type = Type_; @@ -243,24 +344,143 @@ private: size_t stride[lane_size_in_type / pack_size_in_type]; }; +template +class MaskedTensorIteratorOverChannel { +public: + using Type = Type_; + static constexpr int pack_size = pack_size_; + static constexpr int chan_blk = chan_blk_; + static constexpr int width = width_; + static constexpr int size_nbits = size_nbits_; + static constexpr int elements_in_type = + chan_blk * width * size_nbits / (8 * sizeof(Type)); + static constexpr int lane_size_in_type = + (width * pack_size * size_nbits) / (8 * sizeof(Type)); + static constexpr int pack_size_in_type = + pack_size * size_nbits / (8 * sizeof(Type)); + static constexpr int pack_size_in_byte = pack_size_in_type * sizeof(Type); + static constexpr int accesses = elements_in_type / pack_size_in_type; + static constexpr int mask_size = (accesses + 32 - 1) / 32; + using AccessType = array_wrapper; + using Fragment = array_wrapper; + + MEGDNN_HOST MaskedTensorIteratorOverChannel() + : pointer{nullptr}, hw_stride_in_elements{0}, channel{0} {} + MEGDNN_HOST MaskedTensorIteratorOverChannel(Type* pointer_, + int hw_stride_in_elements_, + int channel_, int bound_, + int div_) + : pointer{pointer_}, + hw_stride_in_elements{hw_stride_in_elements_}, + channel{channel_}, + bound{bound_}, + div{uint32_t(div_)} {} + + MEGDNN_DEVICE __forceinline__ void initialize(int c_idx, int hw_idx) { + pointer += c_idx * size_nbits / (8 * sizeof(Type)); + channel -= c_idx; +#pragma unroll + for (int i = 0; i < mask_size; ++i) { + mask[i] = 0; + } +#pragma unroll + for (int i = 0; i < width; ++i) { + int offset = hw_idx + i; + int h = (int)((uint32_t)(offset) / div); + int w = (int)((uint32_t)(offset) % div); + stride[i] = (h * bound + w) * hw_stride_in_elements; +#pragma unroll + for (int j = 0; j < chan_blk; j += pack_size) { + bool guard = (j < channel) && (w < bound); + int index = i * (chan_blk / pack_size) + (j / pack_size); + int mask_index = (index >> 5); + int mask_shift = (index & 0x1f); + mask[mask_index] |= (guard << mask_shift); + } + } + } + + MEGDNN_DEVICE __forceinline__ void add_pointer_offset( + size_t offset_in_type) { + pointer += offset_in_type; + } + + MEGDNN_DEVICE __forceinline__ void load(Fragment& frag, int zero_point) { + AccessType* frag_ptr = reinterpret_cast(&frag); +#pragma unroll + for (int i = 0; i < width; ++i) { + Type* pointer_ = pointer + stride[i]; +#pragma unroll + for (int j = 0; j < chan_blk; j+= pack_size) { + int frag_idx = i * (chan_blk / pack_size) + (j / pack_size); + int mask_index = (frag_idx >> 5); + int mask_shift = (frag_idx & 0x1f); + bool guard = (mask[mask_index] & (1 << mask_shift)); + global_load( + frag_ptr[frag_idx], + reinterpret_cast( + pointer_ + j * size_nbits / (8 * sizeof(Type))), + guard, zero_point); + } + } + } + + MEGDNN_DEVICE __forceinline__ void store(const Fragment& frag) { + const AccessType* frag_ptr = reinterpret_cast(&frag); +#pragma unroll + for (int i = 0; i < width; ++i) { + Type* pointer_ = pointer + stride[i]; +#pragma unroll + for (int j = 0; j < chan_blk; j+= pack_size) { + int frag_idx = i * (chan_blk / pack_size) + (j / pack_size); + int mask_index = (frag_idx >> 5); + int mask_shift = (frag_idx & 0x1f); + bool guard = (mask[mask_index] & (1 << mask_shift)); + global_store( + frag_ptr[frag_idx], + reinterpret_cast( + pointer_ + j * size_nbits / (8 * sizeof(Type))), + guard); + } + } + } + + MEGDNN_DEVICE __forceinline__ void advance() { + pointer += chan_blk * size_nbits / (8 * sizeof(Type)); + channel -= chan_blk; + } + +private: + Type* pointer; + int hw_stride_in_elements; + int channel; + int bound; + Uint32Fastdiv div; + uint32_t mask[mask_size]; + size_t stride[width]; +}; + template + int width_, int size_nbits_, + uint32_t layout_type_ = LayoutType::NCHWx> struct TensorIteratorPolicy; template + int size_nbits_, uint32_t layout_type_> struct TensorIteratorPolicy { + size_nbits_, layout_type_> { using TensorIterator = MaskedTensorIteratorOverChannel; + width_, size_nbits_, layout_type_>; }; template + int size_nbits_, uint32_t layout_type_> struct TensorIteratorPolicy { + size_nbits_, layout_type_> { using TensorIterator = TensorIteratorOverChannel; + size_nbits_, layout_type_>; }; template +struct rwtype_helper; + +template <> +struct rwtype_helper<2> { + using InnerDtype = char; +}; + +template <> +struct rwtype_helper<8> { + using InnerDtype = unsigned; +}; +} // namespace + +void relayout_format::relayout_format_cuda_nchw_nhwc( + const TensorND& src, const TensorND& dst, const cudaStream_t& stream, + const float src_scale, const float dst_scale, + const uint8_t src_zero_point, const uint8_t dst_zero_point) { + auto&& stype = src.layout.dtype; + auto&& dtype = dst.layout.dtype; + auto& src_layout = src.layout; + auto& dst_layout = dst.layout; + int n = src.layout[0]; + int ic = src.layout[1]; + int h = src.layout[2]; + int w = src.layout[3]; + int w_pad = DIVUP(w, 2) * 2; + int hw = h * w_pad; + int n_stride_src = src_layout.stride[0]; + int ic_stride = src_layout.stride[1]; + int n_stride_dst = dst_layout.stride[0]; + int hw_stride = dst_layout.stride[2]; + static constexpr int chan_blk = 8; + static constexpr int pack_oc = 8; + int problem_size = n * DIVUP(ic, chan_blk) * hw; + int oc = dst.layout[3]; + + bool same_scale = src_scale == dst_scale; + bool padding = w % 2 != 0; +#define DISPATCH_RAW(_padding, _same_scale, _pack_w, _src_type, _dst_type, \ + _src_c_type, _dst_c_type, _size_nbits) \ + if (padding == _padding && same_scale == _same_scale && \ + hw % _pack_w == 0 && stype.enumv().ev == DTypeEnum::Ev::_src_type && \ + dtype.enumv().ev == DTypeEnum::Ev::_dst_type) { \ + using InnerDtype_ = typename rwtype_helper<_pack_w>::InnerDtype; \ + using SrcIterator_ = \ + TensorIteratorOverChannel; \ + using DstIterator_ = typename TensorIteratorPolicy< \ + _padding, _dst_c_type, pack_oc, chan_blk, _pack_w, \ + _size_nbits, LayoutType::NHWC>::TensorIterator; \ + using CudaPostProcess_ = \ + CudaPostProcess; \ + using Transpose_ = \ + Translayout<_pack_w, chan_blk, InnerDtype_, dtype::_src_type, \ + dtype::_dst_type, _same_scale>; \ + using RelayoutProblem_ = \ + RelayoutProblem; \ + n_stride_src = n_stride_src * _size_nbits / (8 * sizeof(InnerDtype_)); \ + ic_stride = ic_stride * _size_nbits / (8 * sizeof(InnerDtype_)); \ + n_stride_dst = n_stride_dst * _size_nbits / (8 * sizeof(_dst_c_type)); \ + hw_stride = hw_stride * _size_nbits / (8 * sizeof(_dst_c_type)); \ + typename RelayoutProblem_::Param param{ \ + SrcIterator_{(InnerDtype_*)src.raw_ptr, ic_stride, ic, w, \ + w_pad}, \ + DstIterator_{(_dst_c_type*)dst.raw_ptr, hw_stride, oc, w, \ + w_pad}, \ + CudaPostProcess_{src_scale, src_zero_point, dst_scale, \ + dst_zero_point}, \ + n_stride_src, \ + n_stride_dst, \ + n, \ + ic, \ + hw, \ + src_zero_point}; \ + auto kernel = relayout_kern; \ + int nr_threads = query_blocksize_for_kernel(kernel); \ + nr_threads = std::min(nr_threads, DIVUP(problem_size, _pack_w)); \ + const dim3 block_dim(DIVUP(problem_size, nr_threads* _pack_w)); \ + const dim3 thread_dim(nr_threads); \ + return kernel<<>>(param); \ + } +#define DISPATCH_4BITS(_src_type, _dst_type) \ + DISPATCH_RAW(true, true, 8, _src_type, _dst_type, char, char, 4); \ + DISPATCH_RAW(true, false, 8, _src_type, _dst_type, char, char, 4); \ + DISPATCH_RAW(true, true, 2, _src_type, _dst_type, char, char, 4); \ + DISPATCH_RAW(true, false, 2, _src_type, _dst_type, char, char, 4); \ + DISPATCH_RAW(false, true, 8, _src_type, _dst_type, char, char, 4); \ + DISPATCH_RAW(false, false, 8, _src_type, _dst_type, char, char, 4); \ + DISPATCH_RAW(false, true, 2, _src_type, _dst_type, char, char, 4); \ + DISPATCH_RAW(false, false, 2, _src_type, _dst_type, char, char, 4); + DISPATCH_4BITS(QuantizedS4, QuantizedS4); + DISPATCH_4BITS(Quantized4Asymm, Quantized4Asymm); +#undef DISPATCH_4BITS +#undef DISPATCH_RAW + megdnn_assert(false, + "Unsupported data type(src:%s, dst:%s) or image size(%dx%d).", + stype.name(), dtype.name(), h, w); +} + +void relayout_format::relayout_format_cuda_nhwc_nchw( + const TensorND& src, const TensorND& dst, const cudaStream_t& stream, + const float src_scale, const float dst_scale, + const uint8_t src_zero_point, const uint8_t dst_zero_point) { + auto&& stype = src.layout.dtype; + auto&& dtype = dst.layout.dtype; + auto& src_layout = src.layout; + auto& dst_layout = dst.layout; + + int n = src.layout[0]; + int h = src.layout[1]; + int w = src.layout[2]; + int ic = src.layout[3]; + int w_pad = DIVUP(w, 2) * 2; + int hw = h * w_pad; + int n_stride_src = src_layout.stride[0]; + int hw_stride = src_layout.stride[2]; + int n_stride_dst = dst_layout.stride[0]; + int oc_stride = dst_layout.stride[1]; + static constexpr int chan_blk = 8; + static constexpr int pack_oc = 8; + int problem_size = n * DIVUP(ic, chan_blk) * hw; + int oc = dst.layout[1]; + + bool same_scale = src_scale == dst_scale; + bool padding = w % 2 != 0; +#define DISPATCH_RAW(_padding, _same_scale, _pack_w, _src_type, _dst_type, \ + _src_c_type, _dst_c_type, _size_nbits) \ + if (padding == _padding && same_scale == _same_scale && \ + hw % _pack_w == 0 && stype.enumv().ev == DTypeEnum::Ev::_src_type && \ + dtype.enumv().ev == DTypeEnum::Ev::_dst_type) { \ + using SrcIterator_ = typename TensorIteratorPolicy< \ + _padding, _src_c_type, pack_oc, chan_blk, _pack_w, \ + _size_nbits, LayoutType::NHWC>::TensorIterator; \ + using InnerDtype_ = typename rwtype_helper<_pack_w>::InnerDtype; \ + using DstIterator_ = \ + TensorIteratorOverChannel; \ + using CudaPostProcess_ = \ + CudaPostProcess; \ + using Transpose_ = \ + Translayout; \ + using RelayoutProblem_ = \ + RelayoutProblem; \ + n_stride_src = n_stride_src * _size_nbits / (8 * sizeof(_src_c_type)); \ + hw_stride = hw_stride * _size_nbits / (8 * sizeof(_src_c_type)); \ + n_stride_dst = n_stride_dst * _size_nbits / (8 * sizeof(InnerDtype_)); \ + oc_stride = oc_stride * _size_nbits / (8 * sizeof(InnerDtype_)); \ + typename RelayoutProblem_::Param param{ \ + SrcIterator_{(_src_c_type*)src.raw_ptr, hw_stride, ic, w, \ + w_pad}, \ + DstIterator_{(InnerDtype_*)dst.raw_ptr, oc_stride, oc, w, \ + w_pad}, \ + CudaPostProcess_{src_scale, src_zero_point, dst_scale, \ + dst_zero_point}, \ + n_stride_src, \ + n_stride_dst, \ + n, \ + ic, \ + hw, \ + src_zero_point}; \ + auto kernel = relayout_kern; \ + int nr_threads = query_blocksize_for_kernel(kernel); \ + nr_threads = std::min(nr_threads, DIVUP(problem_size, _pack_w)); \ + const dim3 block_dim(DIVUP(problem_size, nr_threads* _pack_w)); \ + const dim3 thread_dim(nr_threads); \ + return kernel<<>>(param); \ + } +#define DISPATCH_4BITS(_src_type, _dst_type) \ + DISPATCH_RAW(true, true, 8, _src_type, _dst_type, char, char, 4); \ + DISPATCH_RAW(true, false, 8, _src_type, _dst_type, char, char, 4); \ + DISPATCH_RAW(true, true, 2, _src_type, _dst_type, char, char, 4); \ + DISPATCH_RAW(true, false, 2, _src_type, _dst_type, char, char, 4); \ + DISPATCH_RAW(false, true, 8, _src_type, _dst_type, char, char, 4); \ + DISPATCH_RAW(false, false, 8, _src_type, _dst_type, char, char, 4); \ + DISPATCH_RAW(false, true, 2, _src_type, _dst_type, char, char, 4); \ + DISPATCH_RAW(false, false, 2, _src_type, _dst_type, char, char, 4); + DISPATCH_4BITS(QuantizedS4, QuantizedS4); + DISPATCH_4BITS(Quantized4Asymm, Quantized4Asymm); +#undef DISPATCH_4BITS +#undef DISPATCH_RAW + megdnn_assert(false, + "Unsupported data type(src:%s, dst:%s) or image size(%dx%d).", + stype.name(), dtype.name(), h, w); +} diff --git a/dnn/src/cuda/relayout_format/translayout.cuh b/dnn/src/cuda/relayout_format/translayout.cuh index d1cc8238..bea8b224 100644 --- a/dnn/src/cuda/relayout_format/translayout.cuh +++ b/dnn/src/cuda/relayout_format/translayout.cuh @@ -42,8 +42,9 @@ struct enable_qtype_b4 { static constexpr bool val_dst = std::is_same::value || std::is_same::value; - using type = typename std::enable_if::value && - val_src && val_dst>::type; + static constexpr bool value = + std::is_same::value && val_src && val_dst; + using type = typename std::enable_if::type; }; // The input fragment is stored in RowMajor order. The translayout operator @@ -393,26 +394,32 @@ struct Translayout<2, 8, SrcType, DnnSrcType_, DnnDstType_, same_scale, using Fragment = array_wrapper; static inline __device__ void trans( Fragment& dst, const Fragment& src, - CudaPostProcess& post_process, - const char zero_point) { + CudaPostProcess& post_process) { int intermediate[8][2]; - transform_b4x2_to_int8(intermediate[0], - reinterpret_cast(src[0])); - transform_b4x2_to_int8(intermediate[1], - reinterpret_cast(src[1])); - transform_b4x2_to_int8(intermediate[2], - reinterpret_cast(src[2])); - transform_b4x2_to_int8(intermediate[3], - reinterpret_cast(src[3])); - transform_b4x2_to_int8(intermediate[4], - reinterpret_cast(src[4])); - transform_b4x2_to_int8(intermediate[5], - reinterpret_cast(src[5])); - transform_b4x2_to_int8(intermediate[6], - reinterpret_cast(src[6])); - transform_b4x2_to_int8(intermediate[7], - reinterpret_cast(src[7])); - + transform_b4x2_to_int8( + intermediate[0], + reinterpret_cast(src[0 * col_in_type])); + transform_b4x2_to_int8( + intermediate[1], + reinterpret_cast(src[1 * col_in_type])); + transform_b4x2_to_int8( + intermediate[2], + reinterpret_cast(src[2 * col_in_type])); + transform_b4x2_to_int8( + intermediate[3], + reinterpret_cast(src[3 * col_in_type])); + transform_b4x2_to_int8( + intermediate[4], + reinterpret_cast(src[4 * col_in_type])); + transform_b4x2_to_int8( + intermediate[5], + reinterpret_cast(src[5 * col_in_type])); + transform_b4x2_to_int8( + intermediate[6], + reinterpret_cast(src[6 * col_in_type])); + transform_b4x2_to_int8( + intermediate[7], + reinterpret_cast(src[7 * col_in_type])); int* dst_frag = reinterpret_cast(&dst); auto pack = [&](int idx) -> int { return transform_int8_to_b4x8( @@ -445,25 +452,24 @@ struct Translayout<8, 8, SrcType, DnnSrcType_, DnnDstType_, same_scale, using Fragment = array_wrapper; static inline __device__ void trans( Fragment& dst, const Fragment& src, - CudaPostProcess& post_process, - const char zero_point) { + CudaPostProcess& post_process) { int intermediate[8][8]; transform_b4x8_to_int8( - intermediate[0], reinterpret_cast(src[0])); + intermediate[0], reinterpret_cast(src[0 * col_in_type])); transform_b4x8_to_int8( - intermediate[1], reinterpret_cast(src[1])); + intermediate[1], reinterpret_cast(src[1 * col_in_type])); transform_b4x8_to_int8( - intermediate[2], reinterpret_cast(src[2])); + intermediate[2], reinterpret_cast(src[2 * col_in_type])); transform_b4x8_to_int8( - intermediate[3], reinterpret_cast(src[3])); + intermediate[3], reinterpret_cast(src[3 * col_in_type])); transform_b4x8_to_int8( - intermediate[4], reinterpret_cast(src[4])); + intermediate[4], reinterpret_cast(src[4 * col_in_type])); transform_b4x8_to_int8( - intermediate[5], reinterpret_cast(src[5])); + intermediate[5], reinterpret_cast(src[5 * col_in_type])); transform_b4x8_to_int8( - intermediate[6], reinterpret_cast(src[6])); + intermediate[6], reinterpret_cast(src[6 * col_in_type])); transform_b4x8_to_int8( - intermediate[7], reinterpret_cast(src[7])); + intermediate[7], reinterpret_cast(src[7 * col_in_type])); int* dst_frag = reinterpret_cast(&dst); auto pack = [&](int idx) { return transform_int8_to_b4x8( @@ -502,13 +508,12 @@ struct Translayout<8, 2, SrcType, DnnSrcType_, DnnDstType_, same_scale, using Fragment = array_wrapper; static inline __device__ void trans( Fragment& dst, const Fragment& src, - CudaPostProcess& post_process, - const char zero_point) { + CudaPostProcess& post_process) { int intermediate[2][8]; transform_b4x8_to_int8( - intermediate[0], reinterpret_cast(src[0])); + intermediate[0], reinterpret_cast(src[0 * col_in_type])); transform_b4x8_to_int8( - intermediate[1], reinterpret_cast(src[1])); + intermediate[1], reinterpret_cast(src[1 * col_in_type])); int* dst_frag = reinterpret_cast(&dst); dst_frag[0] = transform_int8_to_b4x8( post_process(intermediate[0][0]), diff --git a/dnn/src/cuda/warp_perspective/forward.cu b/dnn/src/cuda/warp_perspective/forward.cu index 88b28c13..4e66dadb 100644 --- a/dnn/src/cuda/warp_perspective/forward.cu +++ b/dnn/src/cuda/warp_perspective/forward.cu @@ -508,7 +508,7 @@ struct KernCoreNHWC { "assert qu4 or q4"); constexpr bool signedness = std::is_same::value; int8_t bval_4 = bval.as_storage() & 0xF; - const int bval_int = transform_int8_to_bit4x8( + const int bval_int = transform_int8_to_b4x8( bval_4, bval_4, bval_4, bval_4, bval_4, bval_4, bval_4, bval_4); int src_ori[4]; src_ori[0] = src0_ok ? *(int*)(src_ptr0 + offset) : bval_int; @@ -516,10 +516,10 @@ struct KernCoreNHWC { src_ori[2] = src2_ok ? *(int*)(src_ptr2 + offset) : bval_int; src_ori[3] = src3_ok ? *(int*)(src_ptr3 + offset) : bval_int; int src[4][8]; - transform_bit4x8_to_int8(src[0], src_ori[0]); - transform_bit4x8_to_int8(src[1], src_ori[1]); - transform_bit4x8_to_int8(src[2], src_ori[2]); - transform_bit4x8_to_int8(src[3], src_ori[3]); + transform_b4x8_to_int8(src[0], src_ori[0]); + transform_b4x8_to_int8(src[1], src_ori[1]); + transform_b4x8_to_int8(src[2], src_ori[2]); + transform_b4x8_to_int8(src[3], src_ori[3]); int res = pack_output_func(output_converter, src[0], src[1], src[2], src[3], w00, w01, w10, w11); @@ -542,7 +542,7 @@ struct KernCoreNHWC { "assert qu4 or q4"); constexpr bool signedness = std::is_same::value; int8_t bval_4 = bval.as_storage() & 0xF; - const int bval_int_temp = transform_int8_to_bit4x8( + const int bval_int_temp = transform_int8_to_b4x8( bval_4, bval_4, bval_4, bval_4, bval_4, bval_4, bval_4, bval_4); const int2 bval_int{bval_int_temp, bval_int_temp}; @@ -552,15 +552,15 @@ struct KernCoreNHWC { src_ori[2] = src2_ok ? *(int2*)(src_ptr2 + offset) : bval_int; src_ori[3] = src3_ok ? *(int2*)(src_ptr3 + offset) : bval_int; int src[8][8]; - transform_bit4x8_to_int8(src[0], src_ori[0].x); - transform_bit4x8_to_int8(src[1], src_ori[1].x); - transform_bit4x8_to_int8(src[2], src_ori[2].x); - transform_bit4x8_to_int8(src[3], src_ori[3].x); - - transform_bit4x8_to_int8(src[4], src_ori[0].y); - transform_bit4x8_to_int8(src[5], src_ori[1].y); - transform_bit4x8_to_int8(src[6], src_ori[2].y); - transform_bit4x8_to_int8(src[7], src_ori[3].y); + transform_b4x8_to_int8(src[0], src_ori[0].x); + transform_b4x8_to_int8(src[1], src_ori[1].x); + transform_b4x8_to_int8(src[2], src_ori[2].x); + transform_b4x8_to_int8(src[3], src_ori[3].x); + + transform_b4x8_to_int8(src[4], src_ori[0].y); + transform_b4x8_to_int8(src[5], src_ori[1].y); + transform_b4x8_to_int8(src[6], src_ori[2].y); + transform_b4x8_to_int8(src[7], src_ori[3].y); int2 res; res.x = pack_output_func(output_converter, src[0], src[1], diff --git a/dnn/test/cuda/relayout_format.cpp b/dnn/test/cuda/relayout_format.cpp index 0da9533b..a4abe498 100644 --- a/dnn/test/cuda/relayout_format.cpp +++ b/dnn/test/cuda/relayout_format.cpp @@ -325,6 +325,91 @@ TEST_F(CUDA, RELAYOUT_FORMAT_NCHW64_NCHW) { } } +TEST_F(CUDA, RELAYOUT_FORMAT_NCHW_NHWC) { + Checker checker(handle_cuda()); + UniformIntRNG s4{-8, 7}; + UniformIntRNG u4{0, 15}; + param::RelayoutFormat param; + param.mode = param::RelayoutFormat::Mode::NCHW_NHWC; + for (size_t n : {1, 3}) { + for (size_t c : {8, 16}) { + for (size_t h : {7, 14, 16, 28}) { + for (size_t w : {2, 3, 7, 8, 16, 31}) { + checker.set_dtype(0, dtype::QuantizedS4{2.f}) + .set_dtype(1, dtype::QuantizedS4{2.f}) + .set_rng(0, &s4) + .set_param(param) + .execs({{n, c, h, w}, {}}); + + checker.set_dtype(0, dtype::Quantized4Asymm{1.2f, 8}) + .set_dtype(1, dtype::Quantized4Asymm{1.2f, 4}) + .set_rng(0, &u4) + .set_param(param) + .execs({{n, c, h, w}, {}}); + + checker.set_dtype(0, dtype::QuantizedS4{1.19990307f}) + .set_dtype(1, dtype::QuantizedS4{1.f}) + .set_rng(0, &s4) + .set_param(param) + .execs({{n, c, h, w}, {}}); + + checker.set_dtype(0, dtype::Quantized4Asymm{1.19990307f, 8}) + .set_dtype(1, dtype::Quantized4Asymm{1.f, 4}) + .set_rng(0, &u4) + .set_param(param) + .set_epsilon(1e-3) + .execs({{n, c, h, w}, {}}); + } + } + } + } + checker.execs({{1, 256, 384, 640}, {}}); +} + +TEST_F(CUDA, RELAYOUT_FORMAT_NHWC_NCHW) { + Checker checker(handle_cuda()); + UniformIntRNG s4{-8, 7}; + UniformIntRNG u4{0, 15}; + param::RelayoutFormat param; + param.mode = param::RelayoutFormat::Mode::NHWC_NCHW; + for (size_t n : {1, 3}) { + for (size_t c : {8, 16}) { + for (size_t h : {7, 14, 16, 28}) { + for (size_t w : {2, 3, 4, 7, 14, 16, 17}) { + checker.set_dtype(0, dtype::QuantizedS4{2.f}) + .set_dtype(1, dtype::QuantizedS4{2.f}) + .set_rng(0, &s4) + .set_param(param) + .set_epsilon(1e-3) + .execs({{n, h, w, c}, {}}); + + checker.set_dtype(0, dtype::Quantized4Asymm{1.2f, 4}) + .set_dtype(1, dtype::Quantized4Asymm{1.2f, 8}) + .set_rng(0, &u4) + .set_param(param) + .set_epsilon(1e-3) + .execs({{n, h, w, c}, {}}); + + checker.set_dtype(0, dtype::QuantizedS4{1.19990307f}) + .set_dtype(1, dtype::QuantizedS4{1.f}) + .set_rng(0, &s4) + .set_param(param) + .set_epsilon(1e-3) + .execs({{n, h, w, c}, {}}); + + checker.set_dtype(0, dtype::Quantized4Asymm{1.20211209f, 8}) + .set_dtype(1, dtype::Quantized4Asymm{1.f, 4}) + .set_rng(0, &u4) + .set_param(param) + .set_epsilon(1e-3) + .execs({{n, h, w, c}, {}}); + } + } + } + } + checker.execs({{1, 384, 640, 256}, {}}); +} + #if MEGDNN_WITH_BENCHMARK TEST_F(CUDA, BENCHMARK_RELAYOUT_FORMAT) { using Param = RelayoutFormat::Param; @@ -393,6 +478,7 @@ TEST_F(CUDA, BENCHMARK_RELAYOUT_FORMAT_QS4) { } }; + printf("nchw -> nchw64\n"); { TensorShapeArray shapes = { {1, 64, 56, 56}, {16, 64, 56, 56}, {64, 64, 56, 56}, @@ -403,6 +489,18 @@ TEST_F(CUDA, BENCHMARK_RELAYOUT_FORMAT_QS4) { param.mode = param::RelayoutFormat::Mode::NCHW_NCHW64; run(shapes, param); } + printf("nchw -> nhwc\n"); + { + TensorShapeArray shapes = { + {1, 64, 56, 56}, {16, 64, 56, 56}, {64, 64, 56, 56}, + {1, 64, 56, 55}, {16, 64, 56, 55}, {64, 64, 56, 55}, + {1, 256, 384, 640}, {16, 16, 384, 640}, + }; + Param param; + param.mode = param::RelayoutFormat::Mode::NCHW_NHWC; + run(shapes, param); + } + printf("nchw64 -> nchw\n"); { TensorShapeArray shapes = { {64, 1, 56, 56, 64}, @@ -415,6 +513,19 @@ TEST_F(CUDA, BENCHMARK_RELAYOUT_FORMAT_QS4) { param.mode = param::RelayoutFormat::Mode::NCHW64_NCHW; run(shapes, param); } + printf("nhwc -> nchw\n"); + { + TensorShapeArray shapes = { + {64, 56, 56, 64}, + {1, 7, 7, 64*32}, + {16, 7, 7, 64*32}, + {64, 7, 7, 64*32}, + {1, 384, 640, 64*4}, + }; + Param param; + param.mode = param::RelayoutFormat::Mode::NHWC_NCHW; + run(shapes, param); + } } #endif