From 606540bef477cada613ec680137c799fb7a6648b Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Mon, 24 May 2021 20:09:36 +0800 Subject: [PATCH] feat(dnn/cuda): add nhwc 4bit warp perspective GitOrigin-RevId: fbec4a4a1f7a5dd184b4e76d60b26a6cac8a24ef --- dnn/src/common/warp_perspective.cpp | 2 +- dnn/src/cuda/warp_perspective/common.cuh | 8 + dnn/src/cuda/warp_perspective/common.h | 8 + dnn/src/cuda/warp_perspective/forward.cpp | 80 +++++- dnn/src/cuda/warp_perspective/forward.cu | 380 +++++++++++++++++++++++----- dnn/src/naive/warp_perspective/opr_impl.cpp | 32 ++- dnn/test/cuda/warp_perspective.cpp | 92 ++++++- dnn/test/naive/warp_perspective.cpp | 106 +++++++- 8 files changed, 611 insertions(+), 97 deletions(-) diff --git a/dnn/src/common/warp_perspective.cpp b/dnn/src/common/warp_perspective.cpp index 68bcb087..b247c943 100644 --- a/dnn/src/common/warp_perspective.cpp +++ b/dnn/src/common/warp_perspective.cpp @@ -226,7 +226,7 @@ std::string WarpPerspectiveBase::param_msg() const { res.append("LANCZOS4"); break; } - res.append("bmode="); + res.append(", bmode="); switch (param().bmode) { case BorderMode::WRAP: res.append("WRAP"); diff --git a/dnn/src/cuda/warp_perspective/common.cuh b/dnn/src/cuda/warp_perspective/common.cuh index 2ded1af7..2ab6899d 100644 --- a/dnn/src/cuda/warp_perspective/common.cuh +++ b/dnn/src/cuda/warp_perspective/common.cuh @@ -63,6 +63,14 @@ class WrapGetter { } }; +class ConstGetter { + public: + __device__ int operator()(int i, int n) + { + return i; + } +}; + } // namespace warp_perspective } // namespace cuda } // namespace megdnn diff --git a/dnn/src/cuda/warp_perspective/common.h b/dnn/src/cuda/warp_perspective/common.h index c22b3782..8b13e7ad 100644 --- a/dnn/src/cuda/warp_perspective/common.h +++ b/dnn/src/cuda/warp_perspective/common.h @@ -28,6 +28,14 @@ void forward_proxy(bool is_nhwc, const ctype* src, const float* mat, megcore::AsyncErrorInfo* error_info, void* error_tracker, cudaStream_t stream); +template +void forward_proxy_nhwc_bit4(const ctype* src, const float* mat, + const int* mat_idx, ctype* dst, int N_SRC, + int N_MAT, int C, int IH, int IW, int OH, int OW, + ctype bval, BorderMode bmode, + megcore::AsyncErrorInfo* error_info, + void* error_tracker, cudaStream_t stream); + template void forward_proxy_nchw4(const ctype* src, const float* mat, const int* mat_idx, ctype* dst, int N_SRC, int N_MAT, int C, int IH, diff --git a/dnn/src/cuda/warp_perspective/forward.cpp b/dnn/src/cuda/warp_perspective/forward.cpp index 32491f9d..60d2f3a1 100644 --- a/dnn/src/cuda/warp_perspective/forward.cpp +++ b/dnn/src/cuda/warp_perspective/forward.cpp @@ -328,12 +328,10 @@ void WarpPerspectiveForwardImpl::exec(_megdnn_tensor_in ssrc, mat.layout[0], C, IH, IW, OH, OW, bval, bmode, async_error_info(handle()), m_error_tracker, stream); - } else if (src.layout.dtype.enumv() == DTypeEnum::QuantizedS4) { - megdnn_assert( - param().format == Param::Format::NCHW64 || - param().format == Param::Format::NCHW, - "WarpPerspective on CUDA supports NCHW64 or NCHW+ " - "QuantizedS4"); + } else if ((src.layout.dtype.enumv() == + DTypeEnum::QuantizedS4) && + (param().format == Param::Format::NCHW64 || + param().format == Param::Format::NCHW)) { bval = roundf(bval); bval = fmin(fmax(-8.f, bval), 7.f); warp_perspective::forward_proxy_nchw64( @@ -355,13 +353,10 @@ void WarpPerspectiveForwardImpl::exec(_megdnn_tensor_in ssrc, relayout_opr->param() = trans_param; relayout_opr->exec(dst, sdst, {}); } - } else if (src.layout.dtype.enumv() == - DTypeEnum::Quantized4Asymm) { - megdnn_assert( - param().format == Param::Format::NCHW64 || - param().format == Param::Format::NCHW, - "WarpPerspective on CUDA supports NCHW64 or NCHW+ " - "Quantized4Asymm"); + } else if ((src.layout.dtype.enumv() == + DTypeEnum::Quantized4Asymm) && + (param().format == Param::Format::NCHW64 || + param().format == Param::Format::NCHW)) { bval = roundf(bval); bval = fmin(fmax(0, bval), 15); warp_perspective::forward_proxy_nchw64( @@ -383,6 +378,65 @@ void WarpPerspectiveForwardImpl::exec(_megdnn_tensor_in ssrc, relayout_opr->param() = trans_param; relayout_opr->exec(dst, sdst, {}); } + } else if ((src.layout.dtype.enumv() == + DTypeEnum::QuantizedS4 || + src.layout.dtype.enumv() == + DTypeEnum::Quantized4Asymm) && + (param().format == Param::Format::NHWC)) { + constexpr int pack_c = 8; + megdnn_assert(C % pack_c == 0); + bval = roundf(bval); + if (src.layout.dtype.enumv() == DTypeEnum::QuantizedS4) { + bval = fmin(fmax(-8.f, bval), 7.f); + if (C % 16 == 0) { + warp_perspective::forward_proxy_nhwc_bit4( + src.ptr(), mat.ptr(), + mat_idx.raw_ptr ? mat_idx.ptr() + : nullptr, + dst.ptr(), src.layout[0], + mat.layout[0], C, IH, IW, OH, OW, + static_cast(bval), bmode, + async_error_info(handle()), m_error_tracker, + stream); + } else { + warp_perspective::forward_proxy_nhwc_bit4( + src.ptr(), mat.ptr(), + mat_idx.raw_ptr ? mat_idx.ptr() + : nullptr, + dst.ptr(), src.layout[0], + mat.layout[0], C, IH, IW, OH, OW, + static_cast(bval), bmode, + async_error_info(handle()), m_error_tracker, + stream); + } + } else { + bval = fmin(fmax(0.f, bval), 15.f); + if (C % 16 == 0) { + warp_perspective::forward_proxy_nhwc_bit4( + src.ptr(), mat.ptr(), + mat_idx.raw_ptr ? mat_idx.ptr() + : nullptr, + dst.ptr(), src.layout[0], + mat.layout[0], C, IH, IW, OH, OW, + static_cast(bval), bmode, + async_error_info(handle()), m_error_tracker, + stream); + } else { + warp_perspective::forward_proxy_nhwc_bit4( + src.ptr(), mat.ptr(), + mat_idx.raw_ptr ? mat_idx.ptr() + : nullptr, + dst.ptr(), src.layout[0], + mat.layout[0], C, IH, IW, OH, OW, + static_cast(bval), bmode, + async_error_info(handle()), m_error_tracker, + stream); + } + } } } else if ((src.layout.dtype.enumv() == DTypeEnum::Quantized8Asymm || diff --git a/dnn/src/cuda/warp_perspective/forward.cu b/dnn/src/cuda/warp_perspective/forward.cu index 9d994384..aa43591c 100644 --- a/dnn/src/cuda/warp_perspective/forward.cu +++ b/dnn/src/cuda/warp_perspective/forward.cu @@ -27,20 +27,51 @@ using namespace integer_subbyte; namespace { template +struct CtypeHelper; + +template <> +struct CtypeHelper { + static constexpr int bit_width = 32; +}; +template <> +struct CtypeHelper { + static constexpr int bit_width = 16; +}; +template <> +struct CtypeHelper { + static constexpr int bit_width = 8; +}; +template <> +struct CtypeHelper { + static constexpr int bit_width = 8; +}; +template <> +struct CtypeHelper { + static constexpr int bit_width = 4; +}; +template <> +struct CtypeHelper { + static constexpr int bit_width = 4; +}; + +template struct DirectSrcVisitor { - const ctype* ptr; + const void* ptr; __device__ __forceinline__ const ctype* get(int batch, int im_size) { - return ptr + - static_cast(batch) * static_cast(im_size); + return (ctype*)((char*)ptr + static_cast(batch) * + static_cast(im_size) * + CtypeHelper::bit_width / 8); } - void move_batch(size_t batch, size_t im_size) { ptr += batch * im_size; } + void move_batch(size_t batch, size_t im_size) { + ptr = (char*)ptr + batch * im_size * CtypeHelper::bit_width / 8; + } }; template struct IndexedSrcVisitor { - const ctype* ptr; + const void* ptr; const int* idx; int N_SRC; @@ -57,8 +88,9 @@ struct IndexedSrcVisitor { orig_batch, batch, N_SRC); batch = 0; } - return ptr + - static_cast(batch) * static_cast(im_size); + return (ctype*)((char*)ptr + static_cast(batch) * + static_cast(im_size) * + CtypeHelper::bit_width / 8); } void move_batch(size_t batch, size_t) { idx += batch; } @@ -183,13 +215,13 @@ transform_bit4x8_to_int8(int (&result)[8], const int& source){ template MEGDNN_DEVICE __forceinline__ int pack_output_func( OutputConverter& output_converter, int (&s00)[8], int (&s01)[8], - int (&s10)[8], int (&s11)[8], float palpha, float pbeta, float nalpha, - float nbeta) { + int (&s10)[8], int (&s11)[8], float w00, float w01, float w10, + float w11) { #define warp_perspective_transform(idx) \ - static_cast(output_converter(s00[idx] * nalpha * nbeta + \ - s01[idx] * nalpha * pbeta + \ - s10[idx] * palpha * nbeta + \ - s11[idx] * palpha * pbeta) \ + static_cast(output_converter(s00[idx] * w00 + \ + s01[idx] * w01 + \ + s10[idx] * w10 + \ + s11[idx] * w11) \ .as_storage()) return transform_int8_to_bit4x8( @@ -212,7 +244,7 @@ __global__ void kern_general_nchw64(SrcVisitor src, const float* __restrict mat, int c1 = ow % 2; ow = ow / 2; int oh = blockIdx.y * blockDim.y + threadIdx.y; - const ctype* __restrict sptr = src.get(blockIdx.z, C * IH * IW / 2); + const ctype* __restrict sptr = src.get(blockIdx.z, C * IH * IW); dst += blockIdx.z * C * OH * OW / 2; mat += blockIdx.z * 3 * 3; const int4* sptr_int4 = reinterpret_cast(sptr); @@ -229,6 +261,10 @@ __global__ void kern_general_nchw64(SrcVisitor src, const float* __restrict mat, float pbeta = iw - floor(iw); float nalpha = 1.0f - palpha; float nbeta = 1.0f - pbeta; + float w00 = nalpha * nbeta; + float w01 = nalpha * pbeta; + float w10 = palpha * nbeta; + float w11 = palpha * pbeta; int o_coor = (oh * OW + ow) << 1; int i_coor_00 = (ih0 * IW + iw0) << 1; int i_coor_01 = (ih0 * IW + iw1) << 1; @@ -247,32 +283,28 @@ __global__ void kern_general_nchw64(SrcVisitor src, const float* __restrict mat, transform_bit4x8_to_int8(s10, s[2].x); transform_bit4x8_to_int8(s11, s[3].x); d.x = pack_output_func(output_converter, s00, s01, s10, - s11, palpha, pbeta, nalpha, - nbeta); + s11, w00, w01, w10, w11); transform_bit4x8_to_int8(s00, s[0].y); transform_bit4x8_to_int8(s01, s[1].y); transform_bit4x8_to_int8(s10, s[2].y); transform_bit4x8_to_int8(s11, s[3].y); d.y = pack_output_func(output_converter, s00, s01, s10, - s11, palpha, pbeta, nalpha, - nbeta); + s11, w00, w01, w10, w11); transform_bit4x8_to_int8(s00, s[0].z); transform_bit4x8_to_int8(s01, s[1].z); transform_bit4x8_to_int8(s10, s[2].z); transform_bit4x8_to_int8(s11, s[3].z); d.z = pack_output_func(output_converter, s00, s01, s10, - s11, palpha, pbeta, nalpha, - nbeta); + s11, w00, w01, w10, w11); transform_bit4x8_to_int8(s00, s[0].w); transform_bit4x8_to_int8(s01, s[1].w); transform_bit4x8_to_int8(s10, s[2].w); transform_bit4x8_to_int8(s11, s[3].w); d.w = pack_output_func(output_converter, s00, s01, s10, - s11, palpha, pbeta, nalpha, - nbeta); + s11, w00, w01, w10, w11); dst_int4[o_coor + c1] = d; sptr_int4 += IH * IW * 2; @@ -392,7 +424,7 @@ __global__ void kern_const_border_nchw64(SrcVisitor src, int c1 = ow % 2; ow = ow / 2; int oh = blockIdx.y * blockDim.y + threadIdx.y; - const ctype* __restrict sptr = src.get(blockIdx.z, C * IH * IW / 2); + const ctype* __restrict sptr = src.get(blockIdx.z, C * IH * IW); dst += blockIdx.z * C * OH * OW / 2; mat += blockIdx.z * 3 * 3; const int4* sptr_int4 = reinterpret_cast(sptr); @@ -413,6 +445,10 @@ __global__ void kern_const_border_nchw64(SrcVisitor src, float pbeta = iw - floor(iw); float nalpha = 1.0f - palpha; float nbeta = 1.0f - pbeta; + float w00 = nalpha * nbeta; + float w01 = nalpha * pbeta; + float w10 = palpha * nbeta; + float w11 = palpha * pbeta; int o_coor = (oh * OW + ow) << 1; int i_coor_00 = (ih0 * IW + iw0) << 1; int i_coor_01 = (ih0 * IW + iw1) << 1; @@ -457,32 +493,28 @@ __global__ void kern_const_border_nchw64(SrcVisitor src, transform_bit4x8_to_int8(s10, s[2].x); transform_bit4x8_to_int8(s11, s[3].x); d.x = pack_output_func(output_converter, s00, s01, s10, - s11, palpha, pbeta, nalpha, - nbeta); + s11, w00, w01, w10, w11); transform_bit4x8_to_int8(s00, s[0].y); transform_bit4x8_to_int8(s01, s[1].y); transform_bit4x8_to_int8(s10, s[2].y); transform_bit4x8_to_int8(s11, s[3].y); d.y = pack_output_func(output_converter, s00, s01, s10, - s11, palpha, pbeta, nalpha, - nbeta); + s11, w00, w01, w10, w11); transform_bit4x8_to_int8(s00, s[0].z); transform_bit4x8_to_int8(s01, s[1].z); transform_bit4x8_to_int8(s10, s[2].z); transform_bit4x8_to_int8(s11, s[3].z); d.z = pack_output_func(output_converter, s00, s01, s10, - s11, palpha, pbeta, nalpha, - nbeta); + s11, w00, w01, w10, w11); transform_bit4x8_to_int8(s00, s[0].w); transform_bit4x8_to_int8(s01, s[1].w); transform_bit4x8_to_int8(s10, s[2].w); transform_bit4x8_to_int8(s11, s[3].w); d.w = pack_output_func(output_converter, s00, s01, s10, - s11, palpha, pbeta, nalpha, - nbeta); + s11, w00, w01, w10, w11); dst_int4[o_coor + c1] = d; sptr_int4 += IH * IW * 2; @@ -491,17 +523,114 @@ __global__ void kern_const_border_nchw64(SrcVisitor src, } } +template +struct KernCoreNHWC { + MEGDNN_DEVICE __forceinline__ static void func( + char* dst_ptr, const char* src_ptr0, const char* src_ptr1,const char* src_ptr2, const char* src_ptr3, const int offset, + float w00, float w01, float w10, float w11, + OutputConverter& output_converter, const bool src0_ok, const bool src1_ok, + const bool src2_ok, const bool src3_ok, const ctype bval) { + static_assert(pack_c == 1, "static_assert pack_c == 1"); + ctype v00 = src0_ok ? *(ctype*)(src_ptr0 + offset): bval; + ctype v01 = src1_ok ? *(ctype*)(src_ptr1 + offset): bval; + ctype v10 = src2_ok ? *(ctype*)(src_ptr2 + offset): bval; + ctype v11 = src3_ok ? *(ctype*)(src_ptr3 + offset): bval; + ctype res = + output_converter(v00 * w00+ v01 * w01 + + v10 * w10 + v11 * w11); + *(ctype*)(dst_ptr + offset) = res; + } +}; + +template +struct KernCoreNHWC { + MEGDNN_DEVICE __forceinline__ static void func( + char* dst_ptr, const char* src_ptr0, const char* src_ptr1,const char* src_ptr2, const char* src_ptr3, const int offset, + float w00, float w01, float w10, float w11, + OutputConverter& output_converter, const bool src0_ok, const bool src1_ok, + const bool src2_ok, const bool src3_ok, const ctype bval){ + static_assert(std::is_same::value || + std::is_same::value, + "assert qu4 or q4"); + constexpr bool signedness = std::is_same::value; + int8_t bval_4 = bval.as_storage() & 0xF; + const int bval_int = transform_int8_to_bit4x8( + bval_4, bval_4, bval_4, bval_4, bval_4, bval_4, bval_4, bval_4); + int src_ori[4]; + src_ori[0] = src0_ok ? *(int*)(src_ptr0 + offset) : bval_int; + src_ori[1] = src1_ok ? *(int*)(src_ptr1 + offset) : bval_int; + src_ori[2] = src2_ok ? *(int*)(src_ptr2 + offset) : bval_int; + src_ori[3] = src3_ok ? *(int*)(src_ptr3 + offset) : bval_int; + int src[4][8]; + transform_bit4x8_to_int8(src[0], src_ori[0]); + transform_bit4x8_to_int8(src[1], src_ori[1]); + transform_bit4x8_to_int8(src[2], src_ori[2]); + transform_bit4x8_to_int8(src[3], src_ori[3]); + int res = pack_output_func(output_converter, src[0], src[1], + src[2], src[3], w00, w01, w10, + w11); + *(int*)(dst_ptr + offset) = res; + } +}; + + +template +struct KernCoreNHWC { + MEGDNN_DEVICE __forceinline__ static void func( + char* dst_ptr, const char* src_ptr0, const char* src_ptr1, + const char* src_ptr2, const char* src_ptr3, const int offset, + float w00, float w01, float w10, float w11, + OutputConverter& output_converter, const bool src0_ok, + const bool src1_ok, const bool src2_ok, const bool src3_ok, + const ctype bval) { + static_assert(std::is_same::value || + std::is_same::value, + "assert qu4 or q4"); + constexpr bool signedness = std::is_same::value; + int8_t bval_4 = bval.as_storage() & 0xF; + const int bval_int_temp = transform_int8_to_bit4x8( + bval_4, bval_4, bval_4, bval_4, bval_4, bval_4, bval_4, bval_4); + const int2 bval_int{bval_int_temp, bval_int_temp}; + + int2 src_ori[4]; + src_ori[0] = src0_ok ? *(int2*)(src_ptr0 + offset) : bval_int; + src_ori[1] = src1_ok ? *(int2*)(src_ptr1 + offset) : bval_int; + src_ori[2] = src2_ok ? *(int2*)(src_ptr2 + offset) : bval_int; + src_ori[3] = src3_ok ? *(int2*)(src_ptr3 + offset) : bval_int; + int src[8][8]; + transform_bit4x8_to_int8(src[0], src_ori[0].x); + transform_bit4x8_to_int8(src[1], src_ori[1].x); + transform_bit4x8_to_int8(src[2], src_ori[2].x); + transform_bit4x8_to_int8(src[3], src_ori[3].x); + + transform_bit4x8_to_int8(src[4], src_ori[0].y); + transform_bit4x8_to_int8(src[5], src_ori[1].y); + transform_bit4x8_to_int8(src[6], src_ori[2].y); + transform_bit4x8_to_int8(src[7], src_ori[3].y); + + int2 res; + res.x = pack_output_func(output_converter, src[0], src[1], + src[2], src[3], w00, w01, w10, + w11); + res.y = pack_output_func(output_converter, src[4], src[5], + src[6], src[7], w00, w01, w10, + w11); + *(int2*)(dst_ptr + offset) = res; + } +}; + template + typename OutputConverter, int pack_c> __global__ void kern_general_nhwc(SrcVisitor src, const float* __restrict mat, ctype* __restrict dst, int C, int IH, int IW, int OH, int OW) { Getter getter; OutputConverter output_converter; + constexpr int bit_width = CtypeHelper::bit_width; int ow = blockIdx.x * blockDim.x + threadIdx.x; int oh = blockIdx.y * blockDim.y + threadIdx.y; const ctype* __restrict sptr = src.get(blockIdx.z, C * IH * IW); - dst += blockIdx.z * C * OH * OW; + dst = (ctype*)((char*)dst + blockIdx.z * C * OH * OW * bit_width / 8); mat += blockIdx.z * 3 * 3; if (ow < OW && oh < OH) { float denominator = mat[6] * ow + mat[7] * oh + mat[8]; @@ -515,52 +644,75 @@ __global__ void kern_general_nhwc(SrcVisitor src, const float* __restrict mat, float pbeta = iw - floor(iw); float nalpha = 1.0f - palpha; float nbeta = 1.0f - pbeta; - for (int c = 0; c < C; ++c) { - dst[(oh * OW + ow) * C + c] = output_converter( - sptr[(ih0 * IW + iw0) * C + c] * nalpha * nbeta + - sptr[(ih0 * IW + iw1) * C + c] * nalpha * pbeta + - sptr[(ih1 * IW + iw0) * C + c] * palpha * nbeta + - sptr[(ih1 * IW + iw1) * C + c] * palpha * pbeta); + float w00 = nalpha * nbeta; + float w01 = nalpha * pbeta; + float w10 = palpha * nbeta; + float w11 = palpha * pbeta; + const char* src_ptr0 = + (char*)sptr + (ih0 * IW + iw0) * C * bit_width / 8; + const char* src_ptr1 = + (char*)sptr + (ih0 * IW + iw1) * C * bit_width / 8; + const char* src_ptr2 = + (char*)sptr + (ih1 * IW + iw0) * C * bit_width / 8; + const char* src_ptr3 = + (char*)sptr + (ih1 * IW + iw1) * C * bit_width / 8; + char* dst_ptr = (char*)dst + (oh * OW + ow) * C * bit_width / 8; + + for (int c = 0; c < C; c += pack_c) { + KernCoreNHWC::func( + dst_ptr, src_ptr0, src_ptr1, src_ptr2, src_ptr3, + c * bit_width / 8, w00, w01, w10, w11, output_converter, + true, true, true, true, (ctype)0); } } } -template -__global__ void kern_const_border_nhwc(SrcVisitor src, - const float* __restrict mat, - ctype* __restrict dst, int C, int IH, - int IW, int OH, int OW, ctype bval) { +template +__global__ void kern_general_nhwc_const(SrcVisitor src, const float* __restrict mat, + ctype* __restrict dst, int C, int IH, int IW, + int OH, int OW, ctype bval) { + Getter getter; OutputConverter output_converter; + constexpr int bit_width = CtypeHelper::bit_width; int ow = blockIdx.x * blockDim.x + threadIdx.x; int oh = blockIdx.y * blockDim.y + threadIdx.y; const ctype* __restrict sptr = src.get(blockIdx.z, C * IH * IW); - dst += blockIdx.z * C * OH * OW; + dst = (ctype*)((char*)dst + blockIdx.z * C * OH * OW * bit_width / 8); mat += blockIdx.z * 3 * 3; if (ow < OW && oh < OH) { float denominator = mat[6] * ow + mat[7] * oh + mat[8]; float iw = (mat[0] * ow + mat[1] * oh + mat[2]) / denominator; float ih = (mat[3] * ow + mat[4] * oh + mat[5]) / denominator; - int iw0 = floor(iw) + 0; - int iw1 = floor(iw) + 1; - int ih0 = floor(ih) + 0; - int ih1 = floor(ih) + 1; - bool okw0 = (iw0 >= 0 && iw0 < IW); - bool okw1 = (iw1 >= 0 && iw1 < IW); - bool okh0 = (ih0 >= 0 && ih0 < IH); - bool okh1 = (ih1 >= 0 && ih1 < IH); + int iw0 = getter(floor(iw) + 0, IW); + int iw1 = getter(floor(iw) + 1, IW); + int ih0 = getter(floor(ih) + 0, IH); + int ih1 = getter(floor(ih) + 1, IH); float palpha = ih - floor(ih); float pbeta = iw - floor(iw); float nalpha = 1.0f - palpha; float nbeta = 1.0f - pbeta; - for (int c = 0; c < C; ++c) { - ctype v00 = (okh0 && okw0 ? sptr[(ih0 * IW + iw0) * C + c] : bval); - ctype v01 = (okh0 && okw1 ? sptr[(ih0 * IW + iw1) * C + c] : bval); - ctype v10 = (okh1 && okw0 ? sptr[(ih1 * IW + iw0) * C + c] : bval); - ctype v11 = (okh1 && okw1 ? sptr[(ih1 * IW + iw1) * C + c] : bval); - ctype val = output_converter( - v00 * nalpha * nbeta + v01 * nalpha * pbeta + - v10 * palpha * nbeta + v11 * palpha * pbeta); - dst[(oh * OW + ow) * C + c] = val; + float w00 = nalpha * nbeta; + float w01 = nalpha * pbeta; + float w10 = palpha * nbeta; + float w11 = palpha * pbeta; + const char* src_ptr0 = (char*)sptr + (ih0 * IW + iw0) * C * bit_width / 8; + const char* src_ptr1 = (char*)sptr + (ih0 * IW + iw1) * C * bit_width / 8; + const char* src_ptr2 = (char*)sptr + (ih1 * IW + iw0) * C * bit_width / 8; + const char* src_ptr3 = (char*)sptr + (ih1 * IW + iw1) * C * bit_width / 8; + char* dst_ptr = (char*)dst + (oh * OW + ow) * C * bit_width / 8; + bool okw0 = (iw0 >= 0 && iw0 < IW); + bool okw1 = (iw1 >= 0 && iw1 < IW); + bool okh0 = (ih0 >= 0 && ih0 < IH); + bool okh1 = (ih1 >= 0 && ih1 < IH); + bool src0_ok = okh0 && okw0; + bool src1_ok = okh0 && okw1; + bool src2_ok = okh1 && okw0; + bool src3_ok = okh1 && okw1; + for (int c = 0; c < C; c += pack_c) { + KernCoreNHWC::func( + dst_ptr, src_ptr0, src_ptr1, src_ptr2, src_ptr3, c * bit_width / 8, w00, w01, w10, w11, output_converter, src0_ok, src1_ok, + src2_ok, src3_ok, bval); } } } @@ -570,12 +722,13 @@ void dispatch_with_visitor(bool is_nhwc, SrcVisitor src, const float* mat, ctype* dst, int N, int C, int IH, int IW, int OH, int OW, ctype bval, BorderMode bmode, cudaStream_t stream) { + constexpr int pack_c = 1; const int BY = 16, BX = 32; #define DISPATCH(Getter) \ do { \ if (is_nhwc) { \ kern_general_nhwc> \ + rounding::RoundingConverter, pack_c> \ <<>>(src, mat, dst, C, IH, IW, \ OH, OW); \ } else { \ @@ -608,10 +761,10 @@ void dispatch_with_visitor(bool is_nhwc, SrcVisitor src, const float* mat, #undef DISPATCH case BORDER_CONSTANT: if (is_nhwc) { - kern_const_border_nhwc> - <<>>( - src, mat, dst, C, IH, IW, OH, OW, bval); + kern_general_nhwc_const, + pack_c><<>>( + src, mat, dst, C, IH, IW, OH, OW, bval); } else { kern_const_border> @@ -630,6 +783,59 @@ void dispatch_with_visitor(bool is_nhwc, SrcVisitor src, const float* mat, } } +template +void dispatch_with_visitor_nhwc_bit4(SrcVisitor src, const float* mat, + ctype* dst, int N, int C, int IH, int IW, + int OH, int OW, ctype bval, + BorderMode bmode, cudaStream_t stream) { + const int BY = 16, BX = 32; +#define DISPATCH(Getter) \ + do { \ + kern_general_nhwc, pack_c> \ + <<>>(src, mat, dst, C, IH, IW, OH, \ + OW); \ + } while (0) + + const int max_batch_size = 65535; + while (N) { + size_t curr_batch_size = N < max_batch_size ? N : max_batch_size; + dim3 threads(BX, BY); + dim3 blocks((OW + BX - 1) / BX, (OH + BY - 1) / BY, curr_batch_size); + + switch (bmode) { + case BORDER_REPLICATE: + DISPATCH(ReplicateGetter); + break; + case BORDER_REFLECT: + DISPATCH(ReflectGetter); + break; + case BORDER_REFLECT_101: + DISPATCH(Reflect101Getter); + break; + case BORDER_WRAP: + DISPATCH(WrapGetter); + break; + case BORDER_CONSTANT: + { + kern_general_nhwc_const, pack_c> + <<>>(src, mat, dst, C, IH, IW, OH, + OW, bval); + } + break; + default: + break; + } +#undef DISPATCH + + N -= curr_batch_size; + src.move_batch(curr_batch_size, C * IH * IW / 2); + mat += curr_batch_size * 3 * 3; + dst += curr_batch_size * C * OH * OW / 2; + } +} + template void dispatch_with_visitor_nchw4(SrcVisitor src, const float* mat, ctype* dst, int N, int C, int IH, int IW, int OH, int OW, @@ -1440,6 +1646,34 @@ void forward_proxy(bool is_nhwc, const ctype* src, const float* mat, after_kernel_launch(); } +template +void forward_proxy_nhwc_bit4(const ctype* src, const float* mat, + const int* mat_idx, ctype* dst, int N_SRC, + int N_MAT, int C, int IH, int IW, int OH, int OW, + ctype bval, BorderMode bmode, + megcore::AsyncErrorInfo* error_info, + void* error_tracker, cudaStream_t stream) { + if (mat_idx) { + IndexedSrcVisitor visitor; + visitor.ptr = src; + visitor.idx = mat_idx; + visitor.N_SRC = N_SRC; + visitor.error_info = error_info; + visitor.error_tracker = error_tracker; + dispatch_with_visitor_nhwc_bit4, + pack_c>(visitor, mat, dst, N_MAT, C, IH, + IW, OH, OW, bval, bmode, + stream); + } else { + DirectSrcVisitor visitor; + visitor.ptr = src; + dispatch_with_visitor_nhwc_bit4, pack_c>( + visitor, mat, dst, N_MAT, C, IH, IW, OH, OW, bval, bmode, + stream); + } + after_kernel_launch(); +} + template void forward_proxy_nchw4(const ctype* src, const float* mat, const int* mat_idx, ctype* dst, int N_SRC, int N_MAT, int C, int IH, @@ -1520,6 +1754,18 @@ INST(dt_qint4) INST(dt_quint4) #undef INST +#define INST(ctype, pack_c) \ + template void forward_proxy_nhwc_bit4( \ + const ctype*, const float*, const int*, ctype*, int, int, int, \ + int, int, int, int, ctype, BorderMode, megcore::AsyncErrorInfo*, \ + void*, cudaStream_t); + +INST(dt_qint4, 8) +INST(dt_quint4, 8) +INST(dt_qint4, 16) +INST(dt_quint4, 16) +#undef INST + template void forward_proxy_quint8_dimshuffle_typecvt_nchw4( bool is_nhwc, const src_ctype* src, const float* mat, diff --git a/dnn/src/naive/warp_perspective/opr_impl.cpp b/dnn/src/naive/warp_perspective/opr_impl.cpp index b3647e16..44ba8e55 100644 --- a/dnn/src/naive/warp_perspective/opr_impl.cpp +++ b/dnn/src/naive/warp_perspective/opr_impl.cpp @@ -257,7 +257,7 @@ void WarpPerspectiveForwardImpl::kern_naive_int4( MIDOUT_BEGIN(megdnn_naive_warpperspective, ctype, mtype, midout_iv(0)) { UNPACK_WARP_PERSPECTIVE_FWD_KERN_PARAM(kern_param); MEGDNN_MARK_USED_VAR(N_MAT); - uint8_t c_shift, c_mask, iw_shift = 0, ow_shift = 0; + uint32_t c_shift, c_mask, iw_shift = 0, ow_shift = 0; constexpr bool signedness = std::is_same::value; switch (param().format) { case Format::NCHW: @@ -270,19 +270,29 @@ void WarpPerspectiveForwardImpl::kern_naive_int4( c_shift = 6; c_mask = 0x3F; break; + case Format::NHWC: + megdnn_assert(C % 2 == 0); + c_shift = 0; + c_mask = 0; + break; default: megdnn_throw("bad format"); break; } //! strides of C, H, W on src and dst - size_t sstrd[2] = {IH * (IW + iw_shift), IW + iw_shift}, - dstrd[2] = {OH * (OW + ow_shift), OW + ow_shift}; + std::vector sstrd = {IH * ((IW + iw_shift) << c_shift), + (IW + iw_shift) << c_shift, 1}; + std::vector dstrd = {OH * ((OW + ow_shift) << c_shift), + (OW + ow_shift) << c_shift, 1}; + if (param().format == Format::NHWC) { + sstrd = {1, IW * C, C}; + dstrd = {1, OW * C, C}; + } static constexpr uint8_t mask = (uint8_t)((1 << 4) - 1); auto visit_src = [&sptr, sstrd, c_shift, c_mask](size_t c, int h, int w) -> float { - size_t index = ((sstrd[0] * (c >> c_shift) + sstrd[1] * h + w) - << c_shift) + - (c & c_mask); + size_t index = (c >> c_shift) * sstrd[0] + h * sstrd[1] + + (w << c_shift) * sstrd[2] + (c & c_mask); uint8_t result = (sptr[index / 2].as_storage() >> (4 * (index % 2))) & 0xF; if (signedness) { @@ -295,9 +305,8 @@ void WarpPerspectiveForwardImpl::kern_naive_int4( auto visit_src_bd = [&sptr, sstrd, border_val, c_shift, c_mask]( size_t c, int h, int w) -> float { if (h != -1 && w != -1) { - size_t index = ((sstrd[0] * (c >> c_shift) + sstrd[1] * h + w) - << c_shift) + - (c & c_mask); + size_t index = (c >> c_shift) * sstrd[0] + h * sstrd[1] + + (w << c_shift) * sstrd[2] + (c & c_mask); uint8_t result = (sptr[index / 2].as_storage() >> (4 * (index % 2))) & 0xF; @@ -312,9 +321,8 @@ void WarpPerspectiveForwardImpl::kern_naive_int4( }; auto set_visit_dst = [&dptr, dstrd, c_shift, c_mask](size_t c, int h, int w, ctype v) { - size_t index = ((dstrd[0] * (c >> c_shift) + dstrd[1] * h + w) - << c_shift) + - (c & c_mask); + size_t index = (c >> c_shift) * dstrd[0] + h * dstrd[1] + + (w << c_shift) * dstrd[2] + (c & c_mask); dptr[index / 2] = (dptr[index / 2].as_storage() & (0xF0 >> (4 * (index % 2)))) | (v.as_storage() << (4 * (index % 2))); diff --git a/dnn/test/cuda/warp_perspective.cpp b/dnn/test/cuda/warp_perspective.cpp index 3f5bed4f..f484f0f9 100644 --- a/dnn/test/cuda/warp_perspective.cpp +++ b/dnn/test/cuda/warp_perspective.cpp @@ -176,10 +176,12 @@ TEST_F(CUDA, WARP_PERSPECTIVE_FORWARD) { Checker checker(handle_cuda()); WarpPerspectiveMatRNG rng; checker.set_rng(1, &rng); - for (auto bmode : {WarpPerspective::BorderMode::WRAP, + for (auto bmode : { + WarpPerspective::BorderMode::WRAP, WarpPerspective::BorderMode::REFLECT, WarpPerspective::BorderMode::REPLICATE, - WarpPerspective::BorderMode::CONSTANT}) { + WarpPerspective::BorderMode::CONSTANT + }) { WarpPerspective::Param param; param.border_val = 0.3f; param.bmode = bmode; @@ -215,6 +217,84 @@ TEST_F(CUDA, WARP_PERSPECTIVE_FORWARD) { } } +TEST_F(CUDA, WARP_PERSPECTIVE_FORWARD_NHWC) { + using Param = WarpPerspective::Param; + Checker checker(handle_cuda()); + WarpPerspectiveMatRNG_V2 rng; + checker.set_dtype(0, dtype::QuantizedS4(0.1f)); + checker.set_dtype(2, dtype::QuantizedS4(0.1f)); + checker.set_rng(1, &rng); + for (auto bmode : {WarpPerspective::BorderMode::WRAP, + WarpPerspective::BorderMode::REFLECT, + WarpPerspective::BorderMode::REPLICATE, + WarpPerspective::BorderMode::CONSTANT}) { + WarpPerspective::Param param; + param.border_val = 1.2f; + param.bmode = bmode; + param.imode = Param::InterpolationMode::LINEAR; + + param.format = Param::Format::NHWC; + checker.set_param(param); + checker.set_epsilon(1 + 1e-3); + rng.set_hw(10, 11); + checker.execs({{23, 10, 11, 16}, {23, 3, 3}, {23, 11, 12, 16}}); + checker.execs({{20, 10, 11, 32}, {20, 3, 3}, {20, 11, 12, 32}}); + checker.execs({{20, 10, 11, 32}, {20, 3, 3}, {20, 11, 12, 32}}); + rng.set_hw(55, 66); + checker.execs({{20, 55, 66, 32}, {20, 3, 3}, {20, 44, 34, 32}}); + } + { + checker.set_dtype(0, dtype::Quantized4Asymm(0.1f, 3)); + checker.set_dtype(2, dtype::Quantized4Asymm(0.1f, 3)); + checker.set_rng(1, &rng); + for (auto bmode : {WarpPerspective::BorderMode::WRAP, + WarpPerspective::BorderMode::REFLECT, + WarpPerspective::BorderMode::REPLICATE, + WarpPerspective::BorderMode::CONSTANT}) { + WarpPerspective::Param param; + param.border_val = 0.3f; + param.bmode = bmode; + param.imode = Param::InterpolationMode::LINEAR; + + param.format = Param::Format::NHWC; + checker.set_param(param); + checker.set_epsilon(1 + 1e-3); + rng.set_hw(10, 11); + checker.execs({{23, 10, 11, 16}, {23, 3, 3}, {23, 11, 12, 16}}); + checker.execs({{20, 10, 11, 32}, {20, 3, 3}, {20, 11, 12, 32}}); + checker.execs({{20, 10, 11, 32}, {20, 3, 3}, {20, 11, 12, 32}}); + rng.set_hw(55, 66); + checker.execs({{20, 55, 66, 32}, {20, 3, 3}, {20, 44, 34, 32}}); + } + } + { + Checker checker( + handle_cuda()); + constexpr int N_SRC = 5; + UniformIntRNG mat_idx_rng{0, N_SRC - 1}; + checker.set_dtype(0, dtype::QuantizedS4(0.1f)); + checker.set_rng(1, &rng); + checker.set_dtype(2, dtype::Int32()); + checker.set_rng(2, &mat_idx_rng); + checker.set_dtype(3, dtype::QuantizedS4(0.1f)); + WarpPerspective::Param param; + param.border_val = 0.3f; + param.format = Param::Format::NHWC; + param.bmode = WarpPerspective::Param::BorderMode::REFLECT; + param.imode = param::WarpPerspective::InterpolationMode::LINEAR; + checker.set_param(param); + checker.set_epsilon(1 + 1e-3); + rng.set_hw(10, 11); + checker.set_rng(1, &rng); + checker.execs({{N_SRC, 10, 11, 48}, {2, 3, 3}, {2}, {2, 11, 12, 48}}); + rng.set_hw(17, 13); + checker.set_rng(1, &rng); + checker.execs( + {{N_SRC, 17, 13, 64}, {123, 3, 3}, {123}, {123, 16, 15, 64}}); + } +} + + TEST_F(CUDA, WARP_PERSPECTIVE_FORWARD_INTMAX) { require_compute_capability(6, 0); using Param = WarpPerspective::Param; @@ -895,6 +975,14 @@ TEST_F(CUDA, BENCHMARK_WARP_PERSPECTIVE_NCHW4) { run({TensorShape{1, 25, 256, 5120, 4}, {1, 3, 3}, {1, 25, 256, 256, 4}}); run({TensorShape{1, 25, 256, 256, 4}, {1, 3, 3}, {1, 25, 512, 512, 4}}); run({TensorShape{1, 25, 512, 512, 4}, {1, 3, 3}, {1, 25, 256, 256, 4}}); + + param.format = Param::Format::NHWC; + benchmarker.set_dtype(0, dtype::QuantizedS4(1.f)); + benchmarker.set_dtype(2, dtype::QuantizedS4(1.f)); + run({TensorShape{1, 256, 256, 4 * 24}, {1, 3, 3}, {1, 256, 5120, 4 * 24}}); + run({TensorShape{1, 256, 5120, 4 * 24}, {1, 3, 3}, {1, 256, 256, 4 * 24}}); + run({TensorShape{1, 256, 256, 4 * 24}, {1, 3, 3}, {1, 512, 512, 4 * 24}}); + run({TensorShape{1, 512, 512, 4 * 24}, {1, 3, 3}, {1, 256, 256, 4 * 24}}); } #endif diff --git a/dnn/test/naive/warp_perspective.cpp b/dnn/test/naive/warp_perspective.cpp index 102b1e1b..8f13ca81 100644 --- a/dnn/test/naive/warp_perspective.cpp +++ b/dnn/test/naive/warp_perspective.cpp @@ -642,12 +642,114 @@ TEST_F(NAIVE, WARP_PERSPECTIVE_NCHW64) { param.format = Param::Format::NCHW64; checker.set_param(param); checker.execs({{2, 1, 10, 10, 64}, {2, 3, 3}, {2, 1, 10, 12, 64}}); - checker.execs( - {{20, 3, 10, 12, 64}, {20, 3, 3}, {20, 3, 11, 12, 64}}); + checker.execs({{20, 3, 10, 12, 64}, {20, 3, 3}, {20, 3, 11, 12, 64}}); checker.execs({{1, 3, 25, 24, 64}, {1, 3, 3}, {1, 3, 25, 51, 64}}); checker.execs({{1, 3, 25, 51, 64}, {1, 3, 3}, {1, 3, 25, 24, 64}}); checker.execs({{1, 3, 25, 24, 64}, {1, 3, 3}, {1, 3, 51, 50, 64}}); checker.execs({{1, 3, 51, 50, 64}, {1, 3, 3}, {1, 3, 25, 24, 64}}); } } + +TEST_F(NAIVE, WARP_PERSPECTIVE_NHWC) { + using Param = WarpPerspective::Param; + + auto convert_true_format = [](const TensorLayout& layout) { + if (layout.ndim == 4) { + TensorLayout ret{{layout[0], layout[2], layout[3], layout[1]}, + layout.dtype}; + return ret.dimshuffle({0, 3, 1, 2}); + } else + return layout; + }; + + WarpPerspective::Param param; + auto extra_impl = [¶m, this, + convert_true_format](const TensorNDArray& tensors) { + auto warp_perspective = handle()->create_operator(); + warp_perspective->param() = param; + warp_perspective->param().format = Param::Format::NCHW; + + TensorNDArray nchw_tensors; + for (size_t i = 0; i < tensors.size(); ++i) { + TensorLayout ly; + auto layout = tensors[i].layout; + if (layout.ndim == 4) { + ly = TensorLayout{{layout[0], layout[3], layout[1], layout[2]}, + layout.dtype}; + } else { + ly = layout; + } + nchw_tensors.emplace_back(malloc(ly.span().dist_byte()), ly); + } + TensorNDArray nhwc_tensors; + for (size_t i = 0; i < tensors.size(); ++i) { + auto layout = convert_true_format(nchw_tensors[i].layout); + nhwc_tensors.emplace_back(tensors[i].raw_ptr, std::move(layout)); + } + + auto workspace_size = warp_perspective->get_workspace_in_bytes( + tensors[0].layout, tensors[1].layout, tensors[2].layout); + dt_byte* workspace_ptr = static_cast(malloc(workspace_size)); + Workspace workspace{workspace_ptr, workspace_size}; + + auto relayout = handle()->create_operator(); + relayout->exec(nhwc_tensors[0], nchw_tensors[0]); + relayout->exec(nhwc_tensors[1], nchw_tensors[1]); + + warp_perspective->exec(nchw_tensors[0], nchw_tensors[1], + nchw_tensors[2], workspace); + + relayout->exec(nchw_tensors[2], nhwc_tensors[2]); + free(workspace_ptr); + for (auto&& tensor : nchw_tensors) { + free(tensor.raw_ptr); + } + }; + + { + Checker checker(handle()); + WarpPerspectiveMatRNG rng; + checker.set_rng(1, &rng); + checker.set_dtype(0, dtype::QuantizedS4(0.1f)); + checker.set_dtype(2, dtype::QuantizedS4(0.1f)); + checker.set_extra_opr_impl(extra_impl); + for (auto bmode : {WarpPerspective::BorderMode::WRAP, + WarpPerspective::BorderMode::REFLECT, + WarpPerspective::BorderMode::REPLICATE, + WarpPerspective::BorderMode::CONSTANT}) { + param.border_val = 0.3f; + param.bmode = bmode; + param.imode = Param::InterpolationMode::LINEAR; + + param.format = Param::Format::NHWC; + checker.set_param(param); + checker.execs({{1, 2, 2, 4}, {1, 3, 3}, {1, 2, 2, 4}}); + checker.execs({{2, 10, 10, 4}, {2, 3, 3}, {2, 10, 12, 4}}); + checker.execs({{3, 25, 24, 8}, {3, 3, 3}, {3, 12, 10, 8}}); + checker.execs({{4, 33, 22, 16}, {4, 3, 3}, {4, 9, 12, 16}}); + } + } + { + Checker checker(handle()); + WarpPerspectiveMatRNG rng; + checker.set_rng(1, &rng); + checker.set_dtype(0, dtype::Quantized4Asymm(0.1f, 3)); + checker.set_dtype(2, dtype::Quantized4Asymm(0.1f, 3)); + checker.set_extra_opr_impl(extra_impl); + for (auto bmode : {WarpPerspective::BorderMode::WRAP, + WarpPerspective::BorderMode::REFLECT, + WarpPerspective::BorderMode::REPLICATE, + WarpPerspective::BorderMode::CONSTANT}) { + param.border_val = 0.3f; + param.bmode = bmode; + param.imode = Param::InterpolationMode::LINEAR; + param.format = Param::Format::NHWC; + checker.set_param(param); + checker.execs({{1, 2, 2, 4}, {1, 3, 3}, {1, 2, 2, 4}}); + checker.execs({{2, 10, 10, 4}, {2, 3, 3}, {2, 10, 12, 4}}); + checker.execs({{3, 25, 24, 8}, {3, 3, 3}, {3, 12, 10, 8}}); + checker.execs({{4, 33, 22, 16}, {4, 3, 3}, {4, 9, 12, 16}}); + } + } +} // vim: syntax=cpp.doxygen