diff --git a/dnn/include/megdnn/tensor_format.h b/dnn/include/megdnn/tensor_format.h index c7ae4b7e..87f7065b 100644 --- a/dnn/include/megdnn/tensor_format.h +++ b/dnn/include/megdnn/tensor_format.h @@ -196,6 +196,32 @@ public: const TensorLayout& layout) const override; }; using Image2DPack4TensorFormatBase = Image2DPackedTensorFormatBase<4>; + +///*! +// * \brief used for tensors with lowbit data type +// * +// * \p SIZE_NBITS is the size in bits of element of the tensor. +// * +// */ +//template +//class LowbitTensorFormat : public TensorFormat::ImplBase { +// static constexpr size_t SIZE_NBITS = SIZE_NBITS_; +// size_t m_align_size_in_bits; +// +//protected: //? +// LowbitTensorFormat(Type type, size_t m_align_size_in_bits); +// +//public: +// size_t align_size_in_bits() const { +// return m_align_size_in_bits; +// } +// +// std::string to_string() const override; +// +// void serialize_append( +// +// +//}; } // namespace detail /*! diff --git a/dnn/scripts/opr_param_defs.py b/dnn/scripts/opr_param_defs.py index f0c5068a..dde58163 100755 --- a/dnn/scripts/opr_param_defs.py +++ b/dnn/scripts/opr_param_defs.py @@ -895,6 +895,7 @@ Relayout mode. * ``NCHW4`` layout: ``{N, C/4, H, W, 4}`` * ``NCHW88`` layout: ``{N, C/8, H, W, 8}`` * ``CHWN4`` layout: ``{C/4, H, W, N, 4}`` +* ``NCHW64`` layout: ``{N, C/64, H, W, 64}`` **Float weight transformation definitions** @@ -969,6 +970,8 @@ Note: NCHW_NCHW4_WEIGHT will auto pad oc and ic, you should remove oc in later o 'NCHW_NCHW4', 'NCHW4_NCHW', 'NCHW_NCHW4_WEIGHT', + 'NCHW_NCHW64', + 'NCHW64_NCHW', ) ) diff --git a/dnn/src/common/relayout_format.cpp b/dnn/src/common/relayout_format.cpp index 0a63fe0e..d14f17ea 100644 --- a/dnn/src/common/relayout_format.cpp +++ b/dnn/src/common/relayout_format.cpp @@ -251,6 +251,23 @@ void RelayoutFormat::deduce_layout_fwd(const TensorLayout& src, dst[3] = src[3]; megdnn_assert(dst[1] % param().group == 0); break; + case Param::Mode::NCHW_NCHW64: + megdnn_assert(src.ndim == 4 && (src[1] % 64) == 0); + dst.ndim = 5; + dst[0] = src[0]; + dst[1] = src[1] / 64; + dst[2] = src[2]; + dst[3] = src[3]; + dst[4] = 64; + break; + case Param::Mode::NCHW64_NCHW: + megdnn_assert(src.ndim == 5); + dst.ndim = 4; + dst[0] = src[0]; + dst[1] = src[1] * 64; + dst[2] = src[2]; + dst[3] = src[3]; + break; default: megdnn_assert(0, "Invalid RelayoutFormat Mode"); break; @@ -352,7 +369,12 @@ void RelayoutFormat::deduce_format(TensorFormat src, TensorFormat& dst) { CHECK_SRC(DefaultTensorFormat::make()); dst = src; break; - + case Param::Mode::NCHW_NCHW64: + dst = src; + break; + case Param::Mode::NCHW64_NCHW: + dst = src; + break; default: megdnn_throw("Invalid relayout format mode"); break; @@ -633,6 +655,19 @@ void RelayoutFormat::deduce_exec_layout(const TensorLayout& src, exec_src = src.dimshuffle({3, 0, 1, 2, 4}); exec_dst = dst; break; + case Param::Mode::NCHW_NCHW64: + // src is {N, C, H, W} + // dst is {N, C/64, H, W, 64} + exec_src = src.reshape({src[0], src[1] / 64, 64, src[2], src[3]}) + .dimshuffle({0, 1, 3, 4, 2}); + exec_dst = dst; + break; + case Param::Mode::NCHW64_NCHW: + // src is {N, C/64, H, W, 64} + // dst is {N, C, H, W} + exec_src = src.dimshuffle({0, 1, 4, 2, 3}); + exec_dst = dst; + break; default: megdnn_assert(0, "Invalid RelayoutFormat Mode"); } diff --git a/dnn/src/cuda/conv_bias/fallback_nchw_qs4.cpp b/dnn/src/cuda/conv_bias/fallback_nchw_qs4.cpp index 07e08f5f..bbe47238 100644 --- a/dnn/src/cuda/conv_bias/fallback_nchw_qs4.cpp +++ b/dnn/src/cuda/conv_bias/fallback_nchw_qs4.cpp @@ -69,12 +69,9 @@ size_t ConvBiasForwardImpl::AlgoFallbackNCHWQS4::get_workspace_in_bytes( void ConvBiasForwardImpl::AlgoFallbackNCHWQS4::exec( const ExecArgs& args) const { - using Format = Param::Format; - auto&& param = args.opr->param(); - auto&& fm = args.filter_meta; auto layouts = make_underlying_tensor_layout( - *(args.src_layout), fm, *(args.bias_layout), *(args.z_layout), - *(args.dst_layout)); + *(args.src_layout), *(args.filter_layout), *(args.bias_layout), + *(args.z_layout), *(args.dst_layout)); auto ws = get_workspace_bundle(args.workspace.raw_ptr, args); auto ws_src = ws.get(0); auto ws_filter = ws.get(1); @@ -82,20 +79,27 @@ void ConvBiasForwardImpl::AlgoFallbackNCHWQS4::exec( void* ws_z = nullptr; if (args.z_layout->ndim > 0) ws_z = ws.get(4); - auto&& stream = cuda_stream(args.opr->handle()); - auto nchw2nchw64 = [](const TensorND& src, void* raw_dptr) { - if (raw_dptr == nullptr) + // auto&& stream = cuda_stream(args.opr->handle()); + auto nchw2nchw64 = [&args](const TensorND& src, TensorND&& dst) { + if (dst.raw_ptr == nullptr) return; + auto relayout = args.handle->create_operator(); + relayout->param() = RelayoutFormat::Param::Mode::NCHW_NCHW64; + Workspace dummy; + relayout->exec(src, dst, dummy); }; - auto nchw642nchw = [](const TensorND& src, void* raw_dptr) { - + auto nchw642nchw = [&args](const TensorND& src, TensorND&& dst) { + auto relayout = args.handle->create_operator(); + relayout->param() = RelayoutFormat::Param::Mode::NCHW64_NCHW; + Workspace dummy; + relayout->exec(src, dst, dummy); }; // reformat src - nchw2nchw64(*(args.src_tensor), ws_src); + nchw2nchw64(*(args.src_tensor), {ws_src, layouts[0]}); // reformat filter - nchw2nchw64(*(args.filter_tensor), ws_filter); + nchw2nchw64(*(args.filter_tensor), {ws_filter, layouts[1]}); // reformat z - nchw2nchw64(*(args.z_tensor), ws_z); + nchw2nchw64(*(args.z_tensor), {ws_z, layouts[3]}); TensorND src_{ws_src, layouts[0]}, filter_{ws_filter, layouts[1]}, bias_{args.bias_tensor->raw_ptr, layouts[2]}, z_{ws_z, layouts[3]}, dst_{ws_dst, layouts[4]}; @@ -109,22 +113,22 @@ void ConvBiasForwardImpl::AlgoFallbackNCHWQS4::exec( args.preprocessed_filter}; m_underlying_algo.exec(args); // reformat dst - nchw642nchw(dst_, args.dst_tensor->raw_ptr); + nchw642nchw(dst_, {args.dst_tensor->raw_ptr, args.dst_tensor->layout}); } SmallVector ConvBiasForwardImpl::AlgoFallbackNCHWQS4::make_underlying_tensor_layout( - const TensorLayout& src, const CanonizedFilterMeta& filter_meta, + const TensorLayout& src, const TensorLayout& filter, const TensorLayout& bias, const TensorLayout& z, const TensorLayout& dst) const { size_t n = src[0], ci = src[1], hi = src[2], wi = src[3]; size_t co = dst[1], ho = dst[2], wo = dst[3]; - size_t fh = filter_meta.spatial[0], fw = filter_meta.spatial[1]; + size_t fh = filter[2], fw = filter[3]; SmallVector rst; rst.emplace_back(TensorLayout{{n, ci / 64, hi, wi, 64}, src.dtype}); rst.emplace_back(TensorLayout{{co, ci / 64, fh, fw, 64}, filter.dtype}); rst.emplace_back(TensorLayout{{1, co / 64, 1, 1, 64}, bias.dtype}); - if (z.layout.ndim > 0) { + if (z.ndim > 0) { rst.emplace_back(TensorLayout{{n, co / 64, ho, wo, 64}, z.dtype}); } else { rst.emplace_back(TensorLayout{}); @@ -134,15 +138,13 @@ ConvBiasForwardImpl::AlgoFallbackNCHWQS4::make_underlying_tensor_layout( } WorkspaceBundle ConvBiasForwardImpl::AlgoFallbackNCHWQS4::get_workspace_bundle( - void* ptr, const SizeArgs& args) const { + void* raw_ptr, const SizeArgs& args) const { size_t ws_size_src = args.src_layout->span().dist_byte(); size_t ws_size_filter = args.filter_layout->span().dist_byte(); size_t ws_size_dst = args.dst_layout->span().dist_byte(); - auto&& param = args.opr->param(); - auto&& fm = args.filter_meta; auto layouts = make_underlying_tensor_layout( - *(args.src_layout), fm, *(args.bias_layout), *(args.z_layout), - *(args.dst_layout)); + *(args.src_layout), *(args.filter_layout), *(args.bias_layout), + *(args.z_layout), *(args.dst_layout)); SizeArgs args_{args.opr, layouts[0], layouts[1], diff --git a/dnn/src/cuda/relayout_format/opr_impl.cpp b/dnn/src/cuda/relayout_format/opr_impl.cpp index 61cd5896..f7d1d9a2 100644 --- a/dnn/src/cuda/relayout_format/opr_impl.cpp +++ b/dnn/src/cuda/relayout_format/opr_impl.cpp @@ -78,29 +78,33 @@ void RelayoutFormatImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_out dst, return handle()->create_operator()->exec( {src.raw_ptr, exec_src_layout}, {dst.raw_ptr, exec_dst_layout}); } - - if (param().mode == Param::Mode::NCHW_NCHW4 || - param().mode == Param::Mode::NCHW4_NCHW || - param().mode == Param::Mode::NCHW_NCHW4_WEIGHT) { + bool is_trans_4bits = (param().mode == Param::Mode::NCHW_NCHW64 || + param().mode == Param::Mode::NCHW64_NCHW) && + (src_dtype.enumv() == DTypeEnum::QuantizedS4 || + src_dtype.enumv() == DTypeEnum::Quantized4Asymm); + bool is_nchw_nchw4 = param().mode == Param::Mode::NCHW_NCHW4 || + param().mode == Param::Mode::NCHW4_NCHW || + param().mode == Param::Mode::NCHW_NCHW4_WEIGHT; + if (is_trans_4bits || is_nchw_nchw4) { bool is_usable = relayout_format::RelayoutFormatFast::usable( src.layout, dst.layout); megdnn_assert(is_usable, - "RelayoutFormatNCHW_NCHW4 kernel not usable for %s(%s) " - "to %s(%s)", + "RelayoutFormatFast kernel is not usable for " + "transforming %s(%s) to %s(%s).", src.layout.to_string().c_str(), src.layout.dtype.name(), dst.layout.to_string().c_str(), dst.layout.dtype.name()); - relayout_format::RelayoutFormatFast::exec(src, dst, - cuda_stream(this->handle()), - param().mode, param().group); - } else { - TensorLayout exec_src, exec_dst, exec_workspace; - deduce_exec_layout(src.layout, dst.layout, exec_workspace, exec_src, - exec_dst); - TensorND exec_src_nd{src.raw_ptr, exec_src}; - TensorND exec_dst_nd{dst.raw_ptr, exec_dst}; - handle()->create_operator()->exec(exec_src_nd, - exec_dst_nd); + return relayout_format::RelayoutFormatFast::exec( + src, dst, cuda_stream(this->handle()), param().mode, + param().group); } + // fallback impls + TensorLayout exec_src, exec_dst, exec_workspace; + deduce_exec_layout(src.layout, dst.layout, exec_workspace, exec_src, + exec_dst); + TensorND exec_src_nd{src.raw_ptr, exec_src}; + TensorND exec_dst_nd{dst.raw_ptr, exec_dst}; + handle()->create_operator()->exec(exec_src_nd, + exec_dst_nd); } size_t RelayoutFormatImpl::get_workspace_in_bytes( diff --git a/dnn/src/cuda/relayout_format/relayout_format.cpp b/dnn/src/cuda/relayout_format/relayout_format.cpp index be8612d9..456df8b1 100644 --- a/dnn/src/cuda/relayout_format/relayout_format.cpp +++ b/dnn/src/cuda/relayout_format/relayout_format.cpp @@ -24,6 +24,8 @@ inline void get_scale_zeropoint(const DType& tensor_dtype, float& scale, scale = tensor_dtype.param().scale; } else if (tensor_dtype.enumv() == DTypeEnum::QuantizedS8) { scale = tensor_dtype.param().scale; + } else if (tensor_dtype.enumv() == DTypeEnum::QuantizedS4) { + scale = tensor_dtype.param().scale; } } @@ -39,9 +41,8 @@ void relayout_format::RelayoutFormatFast::exec(const TensorND& src, cudaStream_t stream, RelayoutFormat::Param::Mode mode, int group) { - size_t ih = src.layout[2]; - size_t iw = src.layout[3]; - size_t hw = ih * iw; + auto&& stype = src.layout.dtype; + auto&& dtype = dst.layout.dtype; float src_scale = 1.f; float dst_scale = 1.f; uint8_t src_zero_point = 0; @@ -51,22 +52,28 @@ void relayout_format::RelayoutFormatFast::exec(const TensorND& src, if (src.layout.dtype.enumv() == DTypeEnum::Uint8) { src_zero_point = 128; } - if (mode == RelayoutFormat::Param::Mode::NCHW_NCHW4) { - if (hw % 4 == 0) { - relayout_format_cuda_nchw_nchw4<4>(src, dst, stream, src_scale, + if (mode == RelayoutFormat::Param::Mode::NCHW_NCHW4 || + mode == RelayoutFormat::Param::Mode::NCHW_NCHW64) { + return relayout_format_cuda_nchw_nchwx(src, dst, stream, src_scale, dst_scale, src_zero_point, dst_zero_point, group); - } else { - relayout_format_cuda_nchw_nchw4<1>(src, dst, stream, src_scale, + } else if (mode == RelayoutFormat::Param::Mode::NCHW64_NCHW) { + megdnn_assert(group == 1, + "RelayoutFormat kernel only support transforming NCHW64 " + "to NCHW with group = 1(group:%d)", + group); + return relayout_format_cuda_nchwx_nchw(src, dst, stream, src_scale, dst_scale, src_zero_point, - dst_zero_point, group); - } - + dst_zero_point); } else if (mode == RelayoutFormat::Param::Mode::NCHW_NCHW4_WEIGHT) { - relayout_format_cuda_nchw_nchw4_weight(src, dst, stream); + return relayout_format_cuda_nchw_nchw4_weight(src, dst, stream); } else if (mode == RelayoutFormat::Param::Mode::NCHW4_NCHW) { - relayout_format_cuda_nchw4_nchw(src, dst, stream, group); + return relayout_format_cuda_nchw4_nchw(src, dst, stream, group); } else { - megdnn_throw("only support nchw_nchw4 nchw4_nchw layout_format"); + megdnn_throw( + "only support nchw_nchw64/nchw64_nchw/nchw_nchw4/nchw4_nchw " + "layout_format"); } } + +// vim: ft=cpp syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} diff --git a/dnn/src/cuda/relayout_format/relayout_format.cu b/dnn/src/cuda/relayout_format/relayout_format.cu index c2637c60..722aa672 100644 --- a/dnn/src/cuda/relayout_format/relayout_format.cu +++ b/dnn/src/cuda/relayout_format/relayout_format.cu @@ -10,6 +10,12 @@ * implied. */ +#include +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wunused-parameter" +#pragma GCC diagnostic ignored "-Wstrict-aliasing" +#include "cutlass/arch/memory.h" +#pragma GCC diagnostic pop #include "src/cuda/query_blocksize.cuh" #include "src/cuda/relayout_format/relayout_format.cuh" using namespace megdnn; @@ -104,37 +110,121 @@ struct CudaPostProcess { inline __device__ int operator()(int val) { return val; } }; -template -struct DTypeRWHelper; template <> -struct DTypeRWHelper { +struct CudaPostProcess { + CudaDTypeParamImpl m_dst_type_cvt; + CudaDTypeParamImpl m_src_type_cvt; + CudaPostProcess(float src_scale, uint8_t, float dst_scale, uint8_t) { + m_dst_type_cvt = CudaDTypeParamImpl(dst_scale); + m_src_type_cvt = CudaDTypeParamImpl(src_scale); + } + inline __device__ int8_t operator()(int8_t val) { + float intermediate = m_src_type_cvt.dequantize(dt_qint4(val)); + return m_dst_type_cvt.quantize(intermediate).as_int8(); + } +}; + +template <> +struct CudaPostProcess { + CudaPostProcess(float, uint8_t, float, uint8_t){}; + inline __device__ int8_t operator()(int8_t val) { return val; } +}; + +template <> +struct CudaPostProcess { + CudaDTypeParamImpl m_dst_type_cvt; + CudaDTypeParamImpl m_src_type_cvt; + CudaPostProcess(float src_scale, uint8_t src_zero_point, float dst_scale, + uint8_t dst_zero_point) { + m_dst_type_cvt = + CudaDTypeParamImpl(dst_scale, dst_zero_point); + m_src_type_cvt = + CudaDTypeParamImpl(src_scale, src_zero_point); + }; + inline __device__ uint8_t operator()(uint8_t val) { + float intermediate = m_src_type_cvt.dequantize(dt_quint4(val)); + return m_dst_type_cvt.quantize(intermediate).as_uint8(); + } +}; + +template <> +struct CudaPostProcess { + uint8_t m_src_zero_point = 0; + uint8_t m_dst_zero_point = 0; + CudaPostProcess(float, uint8_t src_zero_point, float, + uint8_t dst_zero_point) { + m_src_zero_point = src_zero_point; + m_dst_zero_point = dst_zero_point; + }; + inline __device__ uint8_t operator()(uint8_t val) { + int result = val - m_src_zero_point + m_dst_zero_point; + result = result >= 0 ? result : 0; + result = result < 16 ? result : 15; + return static_cast(result); + } +}; + +template +struct DTypeRWHelper; +template +struct DTypeRWHelper< + ctype, 1, + typename std::enable_if::value || + std::is_same::value || + std::is_same::value>::type> { using InnerDtype = char; using DstDtype = char4; }; -template <> -struct DTypeRWHelper { +template +struct DTypeRWHelper< + ctype, 4, + typename std::enable_if::value || + std::is_same::value || + std::is_same::value>::type> { using InnerDtype = char4; using DstDtype = char4; }; template <> -struct DTypeRWHelper { +struct DTypeRWHelper { using InnerDtype = int; using DstDtype = int4; }; template <> -struct DTypeRWHelper { +struct DTypeRWHelper { using InnerDtype = int4; using DstDtype = int4; }; + +template +struct DTypeRWHelper< + ctype, 2, + typename std::enable_if::value || + std::is_same::value>::type> { + using InnerDtype = char; + using DstDtype = array_wrapper; +}; + +template +struct DTypeRWHelper< + ctype, 8, + typename std::enable_if::value || + std::is_same::value>::type> { + using InnerDtype = unsigned; + using DstDtype = array_wrapper; +}; template struct Translayout { - using InnerDtype = typename DTypeRWHelper::InnerDtype; - using DstDtype = typename DTypeRWHelper::DstDtype; + using InnerDtype = + typename DTypeRWHelper::ctype, + pack_w>::InnerDtype; + using DstDtype = + typename DTypeRWHelper::ctype, + pack_w>::DstDtype; static inline __device__ void trans(DstDtype (&dst_width)[pack_w], InnerDtype (&read_channel)[pack_c], const char zero_point); @@ -143,8 +233,12 @@ struct Translayout { template struct Translayout<1, 4, SrcType, DnnSrcType, DnnDstType, same_scale> { - using InnerDtype = typename DTypeRWHelper::InnerDtype; - using DstDtype = typename DTypeRWHelper::DstDtype; + using InnerDtype = + typename DTypeRWHelper::ctype, + 1>::InnerDtype; + using DstDtype = + typename DTypeRWHelper::ctype, + 1>::DstDtype; static inline __device__ void trans( DstDtype (&dst_width)[1], InnerDtype (&read_channel)[4], CudaPostProcess& post_process, @@ -159,8 +253,12 @@ struct Translayout<1, 4, SrcType, DnnSrcType, DnnDstType, same_scale> { template struct Translayout<4, 4, SrcType, DnnSrcType, DnnDstType, same_scale> { - using InnerDtype = typename DTypeRWHelper::InnerDtype; - using DstDtype = typename DTypeRWHelper::DstDtype; + using InnerDtype = + typename DTypeRWHelper::ctype, + 4>::InnerDtype; + using DstDtype = + typename DTypeRWHelper::ctype, + 4>::DstDtype; static inline __device__ void trans( DstDtype (&dst_width)[4], InnerDtype (&read_channel)[4], CudaPostProcess& post_process, @@ -187,6 +285,412 @@ struct Translayout<4, 4, SrcType, DnnSrcType, DnnDstType, same_scale> { } }; +#define pack_channel(_idx) \ + transform_int8_to_int4x8(post_process(intermediate[0][_idx]), \ + post_process(intermediate[1][_idx]), \ + post_process(intermediate[2][_idx]), \ + post_process(intermediate[3][_idx]), \ + post_process(intermediate[4][_idx]), \ + post_process(intermediate[5][_idx]), \ + post_process(intermediate[6][_idx]), \ + post_process(intermediate[7][_idx])); +template +struct Translayout<2, 64, SrcType, dtype::QuantizedS4, dtype::QuantizedS4, + same_scale> { + using DnnSrcType = dtype::QuantizedS4; + using DnnDstType = dtype::QuantizedS4; + using InnerDtype = + typename DTypeRWHelper::ctype, + 2>::InnerDtype; + using DstDtype = + typename DTypeRWHelper::ctype, + 2>::DstDtype; + static inline __device__ void trans( + DstDtype (&dst_width)[2], InnerDtype (&read_channel)[64], + CudaPostProcess& post_process, + const char zero_point) { + int intermediate[8][2]; + int* dst_frag = reinterpret_cast(dst_width); +#pragma unroll + for (int i = 0; i < 64; i += 8) { +#define unpack_int4x2(_idx) \ + intermediate[_idx][0] = unpack_integer_4bits( \ + reinterpret_cast(read_channel[i + _idx]), 0); \ + intermediate[_idx][1] = unpack_integer_4bits( \ + reinterpret_cast(read_channel[i + _idx]), 4); + // clang-format off + unpack_int4x2(0) + unpack_int4x2(1) + unpack_int4x2(2) + unpack_int4x2(3) + unpack_int4x2(4) + unpack_int4x2(5) + unpack_int4x2(6) + unpack_int4x2(7) + // clang-format on + + int frag_idx = i / 8; + dst_frag[0 * 8 + frag_idx] = pack_channel(0); + dst_frag[1 * 8 + frag_idx] = pack_channel(1); +#undef unpack_int4x2 + } + } +}; + +template +struct Translayout<8, 64, SrcType, dtype::QuantizedS4, dtype::QuantizedS4, + same_scale> { + using DnnSrcType = dtype::QuantizedS4; + using DnnDstType = dtype::QuantizedS4; + using InnerDtype = + typename DTypeRWHelper::ctype, + 8>::InnerDtype; + using DstDtype = + typename DTypeRWHelper::ctype, + 8>::DstDtype; + static inline __device__ void trans( + DstDtype (&dst_width)[8], InnerDtype (&read_channel)[64], + CudaPostProcess& post_process, + const char zero_point) { + int intermediate[8][8]; + int* dst_frag = reinterpret_cast(dst_width); +#pragma unroll + for (int i = 0; i < 64; i += 8) { + transform_int4x8_to_int8(intermediate[0], read_channel[i + 0]); + transform_int4x8_to_int8(intermediate[1], read_channel[i + 1]); + transform_int4x8_to_int8(intermediate[2], read_channel[i + 2]); + transform_int4x8_to_int8(intermediate[3], read_channel[i + 3]); + transform_int4x8_to_int8(intermediate[4], read_channel[i + 4]); + transform_int4x8_to_int8(intermediate[5], read_channel[i + 5]); + transform_int4x8_to_int8(intermediate[6], read_channel[i + 6]); + transform_int4x8_to_int8(intermediate[7], read_channel[i + 7]); + int frag_idx = i / 8; + dst_frag[0 * 8 + frag_idx] = pack_channel(0); + dst_frag[1 * 8 + frag_idx] = pack_channel(1); + dst_frag[2 * 8 + frag_idx] = pack_channel(2); + dst_frag[3 * 8 + frag_idx] = pack_channel(3); + dst_frag[4 * 8 + frag_idx] = pack_channel(4); + dst_frag[5 * 8 + frag_idx] = pack_channel(5); + dst_frag[6 * 8 + frag_idx] = pack_channel(6); + dst_frag[7 * 8 + frag_idx] = pack_channel(7); + } + } +}; +#undef pack_channel + +#define pack_channel(_idx) \ + transform_int8_to_uint4x8(post_process(intermediate[0][_idx]), \ + post_process(intermediate[1][_idx]), \ + post_process(intermediate[2][_idx]), \ + post_process(intermediate[3][_idx]), \ + post_process(intermediate[4][_idx]), \ + post_process(intermediate[5][_idx]), \ + post_process(intermediate[6][_idx]), \ + post_process(intermediate[7][_idx])); +template +struct Translayout<2, 64, SrcType, dtype::Quantized4Asymm, + dtype::Quantized4Asymm, same_scale> { + using DnnSrcType = dtype::Quantized4Asymm; + using DnnDstType = dtype::Quantized4Asymm; + using InnerDtype = + typename DTypeRWHelper::ctype, + 2>::InnerDtype; + using DstDtype = + typename DTypeRWHelper::ctype, + 2>::DstDtype; + static inline __device__ void trans( + DstDtype (&dst_width)[2], InnerDtype (&read_channel)[64], + CudaPostProcess& post_process, + const char zero_point) { + int intermediate[8][2]; + int* dst_frag = reinterpret_cast(dst_width); +#pragma unroll + for (int i = 0; i < 64; i += 8) { +#define unpack_int4x2(_idx) \ + intermediate[_idx][0] = unpack_integer_4bits( \ + reinterpret_cast(read_channel[i + _idx]), 0); \ + intermediate[_idx][1] = unpack_integer_4bits( \ + reinterpret_cast(read_channel[i + _idx]), 4); + // clang-format off + unpack_int4x2(0) + unpack_int4x2(1) + unpack_int4x2(2) + unpack_int4x2(3) + unpack_int4x2(4) + unpack_int4x2(5) + unpack_int4x2(6) + unpack_int4x2(7) + // clang-format on + + int frag_idx = i / 8; + dst_frag[0 * 8 + frag_idx] = pack_channel(0); + dst_frag[1 * 8 + frag_idx] = pack_channel(1); +#undef unpack_int4x2 + } + } +}; + +template +struct Translayout<8, 64, SrcType, dtype::Quantized4Asymm, + dtype::Quantized4Asymm, same_scale> { + using DnnSrcType = dtype::Quantized4Asymm; + using DnnDstType = dtype::Quantized4Asymm; + using InnerDtype = + typename DTypeRWHelper::ctype, + 8>::InnerDtype; + using DstDtype = + typename DTypeRWHelper::ctype, + 8>::DstDtype; + static inline __device__ void trans( + DstDtype (&dst_width)[8], InnerDtype (&read_channel)[64], + CudaPostProcess& post_process, + const char zero_point) { + int intermediate[8][8]; + int* dst_frag = reinterpret_cast(dst_width); +#pragma unroll + for (int i = 0; i < 64; i += 8) { + transform_uint4x8_to_int8(intermediate[0], read_channel[i + 0]); + transform_uint4x8_to_int8(intermediate[1], read_channel[i + 1]); + transform_uint4x8_to_int8(intermediate[2], read_channel[i + 2]); + transform_uint4x8_to_int8(intermediate[3], read_channel[i + 3]); + transform_uint4x8_to_int8(intermediate[4], read_channel[i + 4]); + transform_uint4x8_to_int8(intermediate[5], read_channel[i + 5]); + transform_uint4x8_to_int8(intermediate[6], read_channel[i + 6]); + transform_uint4x8_to_int8(intermediate[7], read_channel[i + 7]); + int frag_idx = i / 8; + dst_frag[0 * 8 + frag_idx] = pack_channel(0); + dst_frag[1 * 8 + frag_idx] = pack_channel(1); + dst_frag[2 * 8 + frag_idx] = pack_channel(2); + dst_frag[3 * 8 + frag_idx] = pack_channel(3); + dst_frag[4 * 8 + frag_idx] = pack_channel(4); + dst_frag[5 * 8 + frag_idx] = pack_channel(5); + dst_frag[6 * 8 + frag_idx] = pack_channel(6); + dst_frag[7 * 8 + frag_idx] = pack_channel(7); + } + } +}; +#undef pack_channel + +#define pack(_idx) \ + transform_int8_to_int4x8(post_process(intermediate[0][_idx]), \ + post_process(intermediate[1][_idx]), \ + post_process(intermediate[2][_idx]), \ + post_process(intermediate[3][_idx]), \ + post_process(intermediate[4][_idx]), \ + post_process(intermediate[5][_idx]), \ + post_process(intermediate[6][_idx]), \ + post_process(intermediate[7][_idx])); +template +struct Translayout<64, 8, SrcType, dtype::QuantizedS4, dtype::QuantizedS4, + same_scale> { + using DnnSrcType = dtype::QuantizedS4; + using DnnDstType = dtype::QuantizedS4; + static constexpr int row = 8; + static constexpr int col = 64; + static constexpr int size_nbits = 4; + static constexpr int col_in_type = col * size_nbits / (8 * sizeof(SrcType)); + static constexpr int elements_in_type = row * col_in_type; + static constexpr int inc_col = 8; + static constexpr int inc_col_in_type = + inc_col * size_nbits / (8 * sizeof(SrcType)); + using Fragment = array_wrapper; + static MEGDNN_DEVICE __forceinline__ void trans( + Fragment& dst, const Fragment& src, + CudaPostProcess& post_process) { + int intermediate[8][8]; + int* dst_frag = reinterpret_cast(&dst); +#pragma unroll + for (int j = 0; j < col_in_type; j += inc_col_in_type) { + transform_int4x8_to_int8( + intermediate[0], + reinterpret_cast(src[0 * col_in_type + j])); + transform_int4x8_to_int8( + intermediate[1], + reinterpret_cast(src[1 * col_in_type + j])); + transform_int4x8_to_int8( + intermediate[2], + reinterpret_cast(src[2 * col_in_type + j])); + transform_int4x8_to_int8( + intermediate[3], + reinterpret_cast(src[3 * col_in_type + j])); + transform_int4x8_to_int8( + intermediate[4], + reinterpret_cast(src[4 * col_in_type + j])); + transform_int4x8_to_int8( + intermediate[5], + reinterpret_cast(src[5 * col_in_type + j])); + transform_int4x8_to_int8( + intermediate[6], + reinterpret_cast(src[6 * col_in_type + j])); + transform_int4x8_to_int8( + intermediate[7], + reinterpret_cast(src[7 * col_in_type + j])); + dst_frag[(j / inc_col_in_type) * 8 + 0] = pack(0); + dst_frag[(j / inc_col_in_type) * 8 + 1] = pack(1); + dst_frag[(j / inc_col_in_type) * 8 + 2] = pack(2); + dst_frag[(j / inc_col_in_type) * 8 + 3] = pack(3); + dst_frag[(j / inc_col_in_type) * 8 + 4] = pack(4); + dst_frag[(j / inc_col_in_type) * 8 + 5] = pack(5); + dst_frag[(j / inc_col_in_type) * 8 + 6] = pack(6); + dst_frag[(j / inc_col_in_type) * 8 + 7] = pack(7); + } + } +}; +#undef pack + +#define pack(_idx) \ + ((uint8_t)(post_process(intermediate[0][_idx])) | \ + ((uint8_t)(post_process(intermediate[1][_idx])) << 4)) +template +struct Translayout<64, 2, SrcType, dtype::QuantizedS4, dtype::QuantizedS4, + same_scale> { + using DnnSrcType = dtype::QuantizedS4; + using DnnDstType = dtype::QuantizedS4; + static constexpr int row = 2; + static constexpr int col = 64; + static constexpr int size_nbits = 4; + static constexpr int col_in_type = col * size_nbits / (8 * sizeof(SrcType)); + static constexpr int elements_in_type = row * col_in_type; + static constexpr int inc_col = 8; + static constexpr int inc_col_in_type = + inc_col * size_nbits / (8 * sizeof(SrcType)); + using Fragment = array_wrapper; + static MEGDNN_DEVICE __forceinline__ void trans( + Fragment& dst, const Fragment& src, + CudaPostProcess& post_process) { + int intermediate[2][8]; + uint8_t* dst_frag = reinterpret_cast(&dst); +#pragma unroll + for (int j = 0; j < col_in_type; j += inc_col_in_type) { + transform_int4x8_to_int8( + intermediate[0], + reinterpret_cast(src[0 * col_in_type + j])); + transform_int4x8_to_int8( + intermediate[1], + reinterpret_cast(src[1 * col_in_type + j])); + dst_frag[(j / inc_col_in_type) * 8 + 0] = pack(0); + dst_frag[(j / inc_col_in_type) * 8 + 1] = pack(1); + dst_frag[(j / inc_col_in_type) * 8 + 2] = pack(2); + dst_frag[(j / inc_col_in_type) * 8 + 3] = pack(3); + dst_frag[(j / inc_col_in_type) * 8 + 4] = pack(4); + dst_frag[(j / inc_col_in_type) * 8 + 5] = pack(5); + dst_frag[(j / inc_col_in_type) * 8 + 6] = pack(6); + dst_frag[(j / inc_col_in_type) * 8 + 7] = pack(7); + } + } +}; +#undef pack + +#define pack(_idx) \ + transform_int8_to_uint4x8(post_process(intermediate[0][_idx]), \ + post_process(intermediate[1][_idx]), \ + post_process(intermediate[2][_idx]), \ + post_process(intermediate[3][_idx]), \ + post_process(intermediate[4][_idx]), \ + post_process(intermediate[5][_idx]), \ + post_process(intermediate[6][_idx]), \ + post_process(intermediate[7][_idx])); +template +struct Translayout<64, 8, SrcType, dtype::Quantized4Asymm, + dtype::Quantized4Asymm, same_scale> { + using DnnSrcType = dtype::Quantized4Asymm; + using DnnDstType = dtype::Quantized4Asymm; + static constexpr int row = 8; + static constexpr int col = 64; + static constexpr int size_nbits = 4; + static constexpr int col_in_type = col * size_nbits / (8 * sizeof(SrcType)); + static constexpr int elements_in_type = row * col_in_type; + static constexpr int inc_col = 8; + static constexpr int inc_col_in_type = + inc_col * size_nbits / (8 * sizeof(SrcType)); + using Fragment = array_wrapper; + static MEGDNN_DEVICE __forceinline__ void trans( + Fragment& dst, const Fragment& src, + CudaPostProcess& post_process) { + int intermediate[8][8]; + int* dst_frag = reinterpret_cast(&dst); +#pragma unroll + for (int j = 0; j < col_in_type; j += inc_col_in_type) { + transform_uint4x8_to_int8( + intermediate[0], + reinterpret_cast(src[0 * col_in_type + j])); + transform_uint4x8_to_int8( + intermediate[1], + reinterpret_cast(src[1 * col_in_type + j])); + transform_uint4x8_to_int8( + intermediate[2], + reinterpret_cast(src[2 * col_in_type + j])); + transform_uint4x8_to_int8( + intermediate[3], + reinterpret_cast(src[3 * col_in_type + j])); + transform_uint4x8_to_int8( + intermediate[4], + reinterpret_cast(src[4 * col_in_type + j])); + transform_uint4x8_to_int8( + intermediate[5], + reinterpret_cast(src[5 * col_in_type + j])); + transform_uint4x8_to_int8( + intermediate[6], + reinterpret_cast(src[6 * col_in_type + j])); + transform_uint4x8_to_int8( + intermediate[7], + reinterpret_cast(src[7 * col_in_type + j])); + dst_frag[(j / inc_col_in_type) * 8 + 0] = pack(0); + dst_frag[(j / inc_col_in_type) * 8 + 1] = pack(1); + dst_frag[(j / inc_col_in_type) * 8 + 2] = pack(2); + dst_frag[(j / inc_col_in_type) * 8 + 3] = pack(3); + dst_frag[(j / inc_col_in_type) * 8 + 4] = pack(4); + dst_frag[(j / inc_col_in_type) * 8 + 5] = pack(5); + dst_frag[(j / inc_col_in_type) * 8 + 6] = pack(6); + dst_frag[(j / inc_col_in_type) * 8 + 7] = pack(7); + } + } +}; +#undef pack + +#define pack(_idx) \ + ((uint8_t)(post_process(intermediate[0][_idx])) | \ + ((uint8_t)(post_process(intermediate[1][_idx])) << 4)) +template +struct Translayout<64, 2, SrcType, dtype::Quantized4Asymm, + dtype::Quantized4Asymm, same_scale> { + using DnnSrcType = dtype::Quantized4Asymm; + using DnnDstType = dtype::Quantized4Asymm; + static constexpr int row = 2; + static constexpr int col = 64; + static constexpr int size_nbits = 4; + static constexpr int col_in_type = col * size_nbits / (8 * sizeof(SrcType)); + static constexpr int elements_in_type = row * col_in_type; + static constexpr int inc_col = 8; + static constexpr int inc_col_in_type = + inc_col * size_nbits / (8 * sizeof(SrcType)); + using Fragment = array_wrapper; + static MEGDNN_DEVICE __forceinline__ void trans( + Fragment& dst, const Fragment& src, + CudaPostProcess& post_process) { + int intermediate[2][8]; + uint8_t* dst_frag = reinterpret_cast(&dst); +#pragma unroll + for (int j = 0; j < col_in_type; j += inc_col_in_type) { + transform_uint4x8_to_int8( + intermediate[0], + reinterpret_cast(src[0 * col_in_type + j])); + transform_uint4x8_to_int8( + intermediate[1], + reinterpret_cast(src[1 * col_in_type + j])); + dst_frag[(j / inc_col_in_type) * 8 + 0] = pack(0); + dst_frag[(j / inc_col_in_type) * 8 + 1] = pack(1); + dst_frag[(j / inc_col_in_type) * 8 + 2] = pack(2); + dst_frag[(j / inc_col_in_type) * 8 + 3] = pack(3); + dst_frag[(j / inc_col_in_type) * 8 + 4] = pack(4); + dst_frag[(j / inc_col_in_type) * 8 + 5] = pack(5); + dst_frag[(j / inc_col_in_type) * 8 + 6] = pack(6); + dst_frag[(j / inc_col_in_type) * 8 + 7] = pack(7); + } + } +}; +#undef pack + template inline __device__ DstType make_zero_pad(const char zero_point) { return zero_point; @@ -213,13 +717,33 @@ inline __device__ void write_helper(char4* ptr, char4 val) { *rel_ptr = *(int32_t*)(&val); } +template <> +inline __device__ void write_helper>( + array_wrapper* ptr, array_wrapper val) { + uint4 const* data = reinterpret_cast(&val); + void* ptr_ = reinterpret_cast(ptr); + asm volatile( + "{\n" + " st.global.v4.u32 [%0], {%1, %2, %3, %4};\n" + " st.global.v4.u32 [%5], {%6, %7, %8, %9};\n" + "}\n" + : + : "l"(ptr_), "r"(data[0].x), "r"(data[0].y), "r"(data[0].z), + "r"(data[0].w), "l"(((uint8_t*)ptr_) + 16), "r"(data[1].x), + "r"(data[1].y), "r"(data[1].z), "r"(data[1].w)); +}; + template struct RelayoutKern { - using InnerDtype = typename DTypeRWHelper::InnerDtype; - using DstDtype = typename DTypeRWHelper::DstDtype; - static inline __device__ void write(DstType* dst_ptr, + using InnerDtype = + typename DTypeRWHelper::ctype, + pack_w>::InnerDtype; + using DstDtype = + typename DTypeRWHelper::ctype, + pack_w>::DstDtype; + static inline __device__ void write(DstDtype* dst_ptr, DstDtype (&dst_width)[pack_w]) { DstDtype* dst_inner_ptr = (DstDtype*)dst_ptr; #pragma unroll @@ -262,56 +786,57 @@ struct RelayoutKern { } static inline __device__ void core_relayout_kern( - const SrcType* src, DstType* dst, const int src_offset_base, - const int dst_offset_base, const int ic_offset, const int ic_stride, + const SrcType* src, DstType* dst, const int ic_stride, const int remain_ic, CudaPostProcess& post_process, const char zero_point) { InnerDtype read_channel[pack_c]; if (all_pad) { const InnerDtype zero_pad = make_zero_pad(zero_point); - fake_read(src + ic_offset + src_offset_base, read_channel, - ic_stride, remain_ic, zero_pad); + fake_read(src, read_channel, ic_stride, remain_ic, zero_pad); } else { if (with_pad) { const InnerDtype zero_pad = make_zero_pad(zero_point); - read_with_pad(src + ic_offset + src_offset_base, read_channel, - ic_stride, remain_ic, zero_pad); + read_with_pad(src, read_channel, ic_stride, remain_ic, + zero_pad); } else { - read(src + ic_offset + src_offset_base, read_channel, - ic_stride); + read(src, read_channel, ic_stride); } } DstDtype dst_width[pack_w]; Translayout::trans(dst_width, read_channel, post_process, zero_point); - write(dst + ic_offset + dst_offset_base, dst_width); + write(reinterpret_cast(dst), dst_width); } }; -template -__global__ void kern_nchw_nchw4( +template +__global__ void kern_nchw_nchwx( const SrcType* src, DstType* dst, int in_n, int ic, int ihw, - int n_stride_src, int ic_stride, int n_stride_dst, + int n_stride_src, int ic_stride, int n_stride_dst, int oc_stride, CudaPostProcess post_process, const char zero_point, const int group, const int ocpg) { - constexpr int pack_c = 4; const int n_idx = blockIdx.y; const int ihw_block_idx = blockIdx.x * blockDim.x + threadIdx.x; - const int ihw_offset = ihw_block_idx * pack_w; + const int ihw_offset = + ihw_block_idx * pack_w; + const int ihw_offset_in_type = + ihw_offset * size_nbits / (8 * sizeof(SrcType)); if (ihw_offset < ihw) { - const int src_offset_base = n_idx * n_stride_src + ihw_offset; - const int dst_offset_base = n_idx * n_stride_dst + ihw_offset * pack_c; + const int src_offset_base = n_idx * n_stride_src + ihw_offset_in_type; + const int dst_offset_base = + n_idx * n_stride_dst + ihw_offset_in_type * pack_c; if (n_idx < in_n) { const int icpg = ic / group; const int ic_block = icpg / pack_c; const int remain_ic = icpg % pack_c; const int src_group_stride = icpg * ic_stride; - const int dst_group_stride = ocpg * ic_stride; + const int dst_group_stride = ocpg * oc_stride; for (int g_idx = 0; g_idx < group; ++g_idx) { const int src_offset = src_offset_base + g_idx * src_group_stride; @@ -319,30 +844,24 @@ __global__ void kern_nchw_nchw4( dst_offset_base + g_idx * dst_group_stride; for (int ic_blk_idx = 0; ic_blk_idx < ic_block; ++ic_blk_idx) { const int ic_offset = ic_blk_idx * pack_c * ic_stride; + const int oc_offset = ic_blk_idx * oc_stride; RelayoutKern::core_relayout_kern(src, dst, - src_offset, - dst_offset, - ic_offset, - ic_stride, - remain_ic, - post_process, - zero_point); + SrcType, DstType, DnnSrcType, DnnDstType>:: + core_relayout_kern(src + src_offset + ic_offset, + dst + dst_offset + oc_offset, + ic_stride, remain_ic, + post_process, zero_point); } if (remain_ic > 0) { const int ic_offset = ic_block * pack_c * ic_stride; + const int oc_offset = ic_block * oc_stride; RelayoutKern::core_relayout_kern(src, dst, - src_offset, - dst_offset, - ic_offset, - ic_stride, - remain_ic, - post_process, - zero_point); + SrcType, DstType, DnnSrcType, DnnDstType>:: + core_relayout_kern(src + src_offset + ic_offset, + dst + dst_offset + oc_offset, + ic_stride, remain_ic, + post_process, zero_point); } } } else { @@ -350,13 +869,10 @@ __global__ void kern_nchw_nchw4( const int ic_full_block = group * ocpg / pack_c; for (int ic_blk_idx = 0; ic_blk_idx < ic_full_block; ++ic_blk_idx) { RelayoutKern::core_relayout_kern(src, dst, - src_offset_base, - dst_offset_base, 0, - ic_stride, 0, - post_process, - zero_point); + DstType, DnnSrcType, DnnDstType>:: + core_relayout_kern(src + src_offset_base, + dst + dst_offset_base, ic_stride, 0, + post_process, zero_point); } } } @@ -443,29 +959,21 @@ __global__ void kern_nchw_nchw4_weight( for (int ic_blk_idx = 0; ic_blk_idx < ic_block; ++ic_blk_idx) { const int ic_offset = ic_blk_idx * pack_c * ic_stride; RelayoutKern::core_relayout_kern(src, dst, - src_offset_base, - dst_offset_base, - ic_offset, - ic_stride, - remain_ic, - post_process, - zero_point); + DstType, DnnSrcType, DnnDstType>:: + core_relayout_kern(src + src_offset_base + ic_offset, + dst + dst_offset_base + ic_offset, + ic_stride, remain_ic, post_process, + zero_point); } if (remain_ic > 0) { const int ic_offset = ic_block * pack_c * ic_stride; RelayoutKern::core_relayout_kern(src, dst, - src_offset_base, - dst_offset_base, - ic_offset, - ic_stride, - remain_ic, - post_process, - zero_point); + DstType, DnnSrcType, DnnDstType>:: + core_relayout_kern(src + src_offset_base + ic_offset, + dst + dst_offset_base + ic_offset, + ic_stride, remain_ic, post_process, + zero_point); } } else { //! pad oc per group @@ -473,29 +981,174 @@ __global__ void kern_nchw_nchw4_weight( for (int ic_blk_idx = 0; ic_blk_idx < ic_full_block; ++ic_blk_idx) { const int ic_offset = ic_blk_idx * pack_c * ic_stride; RelayoutKern::core_relayout_kern(src, dst, - src_offset_base, - dst_offset_base, - ic_offset, - ic_stride, - remain_ic, - post_process, - zero_point); + DstType, DnnSrcType, DnnDstType>:: + core_relayout_kern(src + src_offset_base + ic_offset, + dst + dst_offset_base + ic_offset, + ic_stride, remain_ic, post_process, + zero_point); } } } } +template +class TensorIteratorOverChannel { +public: + using Type = Type_; + static constexpr int pack_size = pack_size_; + static constexpr int chan_blk = chan_blk_; + static constexpr int width = width_; + static constexpr int size_nbits = size_nbits_; + static constexpr int elements_in_type = + chan_blk * width * size_nbits / (8 * sizeof(Type)); + static constexpr int lane_size_in_type = + (width * pack_size * size_nbits) / (8 * sizeof(Type)); + static constexpr int pack_size_in_type = + (pack_size * size_nbits) >= (8 * sizeof(Type)) + ? (pack_size * size_nbits / (8 * sizeof(Type))) + : (width * pack_size * size_nbits / (8 * sizeof(Type))); + static constexpr int pack_size_in_byte = pack_size_in_type * sizeof(Type); + using AccessType = array_wrapper; + using Fragment = array_wrapper; + + MEGDNN_DEVICE TensorIteratorOverChannel() + : pointer{nullptr}, chan_stride_in_elements{0}, channel{0} {} + MEGDNN_DEVICE TensorIteratorOverChannel(Type* pointer_, + int chan_stride_in_elements_, + int channel_) + : pointer{pointer_}, + chan_stride_in_elements{chan_stride_in_elements}, + channel{channel_} {} + + MEGDNN_DEVICE __forceinline__ void load(Fragment& frag) { + AccessType* frag_ptr = reinterpret_cast(&frag); + Type* pointer_ = pointer; +#pragma unroll + for (int i = 0; i < chan_blk; i += pack_size) { +#pragma unroll + for (int j = 0; j < lane_size_in_type / pack_size_in_type; j++) { + int frag_idx = i / pack_size * + (lane_size_in_type / pack_size_in_type) + + j; + bool guard = i >= channel; + cutlass::arch::global_load( + frag_ptr[frag_idx], + reinterpret_cast(pointer_ + + j * pack_size_in_type), + guard); + } + pointer_ += chan_stride_in_elements; + } + } + + MEGDNN_DEVICE __forceinline__ void store(const Fragment& frag) { + const AccessType* frag_ptr = reinterpret_cast(&frag); + Type* pointer_ = pointer; +#pragma unroll + for (int i = 0; i < chan_blk; i += pack_size) { +#pragma unroll + for (int j = 0; j < lane_size_in_type / pack_size_in_type; j++) { + int frag_idx = i / pack_size * + (lane_size_in_type / pack_size_in_type) + + j; + bool guard = i >= channel; + cutlass::arch::global_store( + frag_ptr[frag_idx], + reinterpret_cast(pointer_ + + j * pack_size_in_type), + guard); + } + pointer_ += chan_stride_in_elements; + } + } + + + MEGDNN_DEVICE __forceinline__ void advance() { + pointer += (chan_blk / pack_size) * chan_stride_in_elements; + channel -= chan_blk; + } + +private: + Type* pointer; + int chan_stride_in_elements; + int channel; +}; + +template +__global__ void kern_nchwx_nchw( + const SrcType* src, DstType* dst, int ic, int ihw, int n_stride_src, + int ic_stride, int n_stride_dst, int oc_stride, + CudaPostProcess post_process, + const char zero_point) { + using InnerDtype = + typename DTypeRWHelper::ctype, + pack_w>::InnerDtype; + using SrcIterator = TensorIteratorOverChannel; + using DstIteraotr = TensorIteratorOverChannel; + using Transpose = Translayout; + const int n_idx = blockIdx.y; + const int ihw_block_idx = blockIdx.x * blockDim.x + threadIdx.x; + const int ihw_offset = ihw_block_idx * pack_w; + const int ihw_offset_in_type = + ihw_offset * size_nbits / (8 * sizeof(SrcType)); + if (ihw_offset < ihw) { + const int ic_block = (ic + pack_c - 1) / pack_c; + const int src_offset_base = + n_idx * n_stride_src + ihw_offset_in_type * pack_c; + const int dst_offset_base = n_idx * n_stride_dst + ihw_offset_in_type; + SrcIterator src_iterator{const_cast(src + src_offset_base), + ic_stride, ic}; + DstIteraotr dst_iterator{ + reinterpret_cast(dst + dst_offset_base), oc_stride, + ic}; + + for (int ic_blk_idx = 0; ic_blk_idx < ic_block; ++ic_blk_idx) { + typename SrcIterator::Fragment src_frag; + typename DstIteraotr::Fragment dst_frag; + src_iterator.load(src_frag); + Transpose::trans( + reinterpret_cast(dst_frag), + src_frag, post_process); + dst_iterator.store(dst_frag); + src_iterator.advance(); + dst_iterator.advance(); + } + } +} } // namespace -template -void relayout_format::relayout_format_cuda_nchw_nchw4( +void relayout_format::relayout_format_cuda_nchw_nchwx( const TensorND& src, const TensorND& dst, const cudaStream_t& stream, const float src_scale, const float dst_scale, - const uint8_t src_zero_point, const uint8_t dst_zero_point, - const int group) { - constexpr int pack_oc = 4; + const uint8_t src_zero_point, const uint8_t dst_zero_point, int group) { + auto&& stype = src.layout.dtype; + auto&& dtype = dst.layout.dtype; + auto& src_layout = src.layout; + auto& dst_layout = dst.layout; + // check pack size + int pack_oc = std::numeric_limits::min(); +#define DEF(_pack_oc, _src_type, _dst_type) \ + if (stype.enumv().ev == DTypeEnum::Ev::_src_type && \ + dtype.enumv().ev == DTypeEnum::Ev::_dst_type) { \ + pack_oc = _pack_oc; \ + } + // clang-format off + DEF(64, QuantizedS4, QuantizedS4) + DEF(64, Quantized4Asymm, Quantized4Asymm) + DEF(4, QuantizedS8, QuantizedS8) + DEF(4, Uint8, QuantizedS8) + DEF(4, Quantized8Asymm, Quantized8Asymm) + DEF(4, QuantizedS32, QuantizedS32); + // clang-format on + megdnn_assert(pack_oc == 4 || pack_oc == 64, + "Unsupport pack size(pack_oc:%d)", pack_oc); +#undef DEF const int in_n = src.layout[0]; const int out_n = dst.layout[0]; const int ic = src.layout[1]; @@ -504,75 +1157,149 @@ void relayout_format::relayout_format_cuda_nchw_nchw4( const int oc = dst.layout[1] * pack_oc; const int hw = h * w; const int ocpg = oc / group; - const int n_stride_src = ic * hw; - const int ic_stride = hw; - const int n_stride_dst = oc * hw; + const int n_stride_src = src_layout.dtype.size(src_layout.stride[0]); + const int ic_stride = src_layout.dtype.size(src_layout.stride[1]); + const int n_stride_dst = dst_layout.dtype.size(dst_layout.stride[0]); + const int oc_stride = dst_layout.dtype.size(dst_layout.stride[1]); - auto& src_layout = src.layout; - auto& dst_layout = dst.layout; bool same_scale = src_scale == dst_scale; -#define RUN_KERNEL(same_scale, SRC_TYPE, DST_TYPE, SRC_C_TYPE, DST_C_TYPE) \ - if (same_scale) { \ - int nr_threads = query_blocksize_for_kernel( \ - kern_nchw_nchw4); \ - const dim3 block_dim(DIVUP(hw, nr_threads* pack_w), out_n); \ - const dim3 thread_dim(nr_threads); \ - kern_nchw_nchw4<<>>( \ - (SRC_C_TYPE*)src.raw_ptr, (DST_C_TYPE*)dst.raw_ptr, in_n, ic, \ - hw, n_stride_src, ic_stride, n_stride_dst, \ - CudaPostProcess( \ - src_scale, src_zero_point, dst_scale, dst_zero_point), \ - src_zero_point, group, ocpg); \ - } else { \ - int nr_threads = query_blocksize_for_kernel( \ - kern_nchw_nchw4); \ - const dim3 block_dim(DIVUP(hw, nr_threads* pack_w), out_n); \ - const dim3 thread_dim(nr_threads); \ - kern_nchw_nchw4<<>>( \ - (SRC_C_TYPE*)src.raw_ptr, (DST_C_TYPE*)dst.raw_ptr, in_n, ic, \ - hw, n_stride_src, ic_stride, n_stride_dst, \ - CudaPostProcess( \ - src_scale, src_zero_point, dst_scale, dst_zero_point), \ - src_zero_point, group, ocpg); \ - } - - if (src_layout.dtype.enumv().ev == DTypeEnum::Ev::Uint8 && - dst_layout.dtype.enumv().ev == DTypeEnum::Ev::QuantizedS8) { - RUN_KERNEL(same_scale, dtype::Uint8, dtype::QuantizedS8, char, char); - } else if (src_layout.dtype.enumv().ev == DTypeEnum::Ev::Quantized8Asymm && - dst_layout.dtype.enumv().ev == DTypeEnum::Ev::QuantizedS8) { - RUN_KERNEL(same_scale, dtype::Quantized8Asymm, dtype::QuantizedS8, char, - char); - } else if (src_layout.dtype.enumv().ev == DTypeEnum::Ev::QuantizedS8 && - dst_layout.dtype.enumv().ev == DTypeEnum::Ev::QuantizedS8) { - RUN_KERNEL(same_scale, dtype::QuantizedS8, dtype::QuantizedS8, char, - char); - } else if (src_layout.dtype.enumv().ev == DTypeEnum::Ev::QuantizedS32 && - dst_layout.dtype.enumv().ev == DTypeEnum::Ev::QuantizedS32) { - RUN_KERNEL(same_scale, dtype::QuantizedS32, dtype::QuantizedS32, int, - int); - } else { - megdnn_assert(0, "not support dtype %s %s", src_layout.dtype.name(), - dst_layout.dtype.name()); +#define DISPATCH_RAW(_same_scale, _pack_w, _pack_oc, _src_type, _dst_type, \ + _src_c_type, _dst_c_type, _size_nbits) \ + if (same_scale == _same_scale && hw % _pack_w == 0 && \ + stype.enumv().ev == DTypeEnum::Ev::_src_type && \ + dtype.enumv().ev == DTypeEnum::Ev::_dst_type) { \ + auto kernel = \ + kern_nchw_nchwx<_pack_w, _pack_oc, _same_scale, _src_c_type, \ + _dst_c_type, dtype::_src_type, \ + dtype::_dst_type, _size_nbits>; \ + int nr_threads = query_blocksize_for_kernel(kernel); \ + const dim3 block_dim(DIVUP(hw, nr_threads* _pack_w), out_n); \ + const dim3 thread_dim(nr_threads); \ + return kernel<<>>( \ + (_src_c_type*)src.raw_ptr, (_dst_c_type*)dst.raw_ptr, in_n, \ + ic, hw, n_stride_src, ic_stride, n_stride_dst, oc_stride, \ + CudaPostProcess(src_scale, src_zero_point, \ + dst_scale, dst_zero_point), \ + src_zero_point, group, ocpg); \ } +#define DISPATCH_INT(_src_type, _dst_type) \ + DISPATCH_RAW(true, 4, 4, _src_type, _dst_type, int, int, 32); \ + DISPATCH_RAW(false, 4, 4, _src_type, _dst_type, int, int, 32); \ + DISPATCH_RAW(true, 1, 4, _src_type, _dst_type, int, int, 32); \ + DISPATCH_RAW(false, 1, 4, _src_type, _dst_type, int, int, 32); +#define DISPATCH_BYTE(_src_type, _dst_type) \ + DISPATCH_RAW(true, 4, 4, _src_type, _dst_type, char, char, 8); \ + DISPATCH_RAW(false, 4, 4, _src_type, _dst_type, char, char, 8); \ + DISPATCH_RAW(true, 1, 4, _src_type, _dst_type, char, char, 8); \ + DISPATCH_RAW(false, 1, 4, _src_type, _dst_type, char, char, 8); +#define DISPATCH_4BITS(_src_type, _dst_type) \ + DISPATCH_RAW(true, 8, 64, _src_type, _dst_type, char, char, 4); \ + DISPATCH_RAW(false, 8, 64, _src_type, _dst_type, char, char, 4); \ + DISPATCH_RAW(true, 2, 64, _src_type, _dst_type, char, char, 4); \ + DISPATCH_RAW(false, 2, 64, _src_type, _dst_type, char, char, 4); + DISPATCH_INT(QuantizedS32, QuantizedS32); + DISPATCH_BYTE(Uint8, QuantizedS8); + DISPATCH_BYTE(Quantized8Asymm, QuantizedS8); + DISPATCH_BYTE(QuantizedS8, QuantizedS8); + DISPATCH_4BITS(QuantizedS4, QuantizedS4); + DISPATCH_4BITS(Quantized4Asymm, Quantized4Asymm); +#undef DISPATCH_4BITS +#undef DISPATCH_BYTE +#undef DISPATCH_INT +#undef DISPATCH_RAW + megdnn_assert(false, + "Unsupported data type(src:%s, dst:%s) or image size(%dx%d).", + stype.name(), dtype.name(), h, w); } bool relayout_format::relayout_format_cuda_usable( const TensorLayout& src_layout, const TensorLayout& dst_layout) { bool is_all_continue = src_layout.is_contiguous() && dst_layout.is_contiguous(); + bool is_all_int32 = + (src_layout.dtype.enumv().ev == DTypeEnum::Ev::QuantizedS32 && + dst_layout.dtype.enumv().ev == DTypeEnum::Ev::QuantizedS32); bool is_all_int8 = (src_layout.dtype.enumv().ev == DTypeEnum::Ev::Uint8 && dst_layout.dtype.enumv().ev == DTypeEnum::Ev::QuantizedS8) || (src_layout.dtype.enumv().ev == DTypeEnum::Ev::Quantized8Asymm && dst_layout.dtype.enumv().ev == DTypeEnum::Ev::QuantizedS8) || (src_layout.dtype.enumv().ev == DTypeEnum::Ev::QuantizedS8 && - dst_layout.dtype.enumv().ev == DTypeEnum::Ev::QuantizedS8) || - (src_layout.dtype.enumv().ev == DTypeEnum::Ev::QuantizedS32 && - dst_layout.dtype.enumv().ev == DTypeEnum::Ev::QuantizedS32); - return is_all_continue && is_all_int8; + dst_layout.dtype.enumv().ev == DTypeEnum::Ev::QuantizedS8); + bool is_all_int4 = + (src_layout.dtype.enumv().ev == DTypeEnum::Ev::QuantizedS4 && + dst_layout.dtype.enumv().ev == DTypeEnum::Ev::QuantizedS4) || + (src_layout.dtype.enumv().ev == DTypeEnum::Ev::Quantized4Asymm && + dst_layout.dtype.enumv().ev == DTypeEnum::Ev::Quantized4Asymm); + return is_all_continue && (is_all_int32 || is_all_int8 || is_all_int4); +} + +void relayout_format::relayout_format_cuda_nchwx_nchw( + const TensorND& src, const TensorND& dst, const cudaStream_t& stream, + const float src_scale, const float dst_scale, + const uint8_t src_zero_point, const uint8_t dst_zero_point) { + auto&& stype = src.layout.dtype; + auto&& dtype = dst.layout.dtype; + auto& src_layout = src.layout; + auto& dst_layout = dst.layout; + // check pack size + int pack_oc = std::numeric_limits::min(); +#define DEF(_pack_oc, _src_type, _dst_type) \ + if (stype.enumv().ev == DTypeEnum::Ev::_src_type && \ + dtype.enumv().ev == DTypeEnum::Ev::_dst_type) { \ + pack_oc = _pack_oc; \ + } + // clang-format off + DEF(64, QuantizedS4, QuantizedS4) + DEF(64, Quantized4Asymm, Quantized4Asymm) + // clang-format on + megdnn_assert(pack_oc == 64, "Unsupport pack size(pack_oc:%d)", pack_oc); +#undef DEF + const int n = src.layout[0]; + const int c = src.layout[1]; + const int h = src.layout[2]; + // align to byte + const int w = src.layout[3]; + const int hw = h * w; + const int n_stride_src = src_layout.dtype.size(src_layout.stride[0]); + const int ic_stride = src_layout.dtype.size(src_layout.stride[1]); + const int n_stride_dst = dst_layout.dtype.size(dst_layout.stride[0]); + const int oc_stride = dst_layout.dtype.size(dst_layout.stride[1]); + + bool same_scale = src_scale == dst_scale; +#define DISPATCH_RAW(_same_scale, _pack_w, _pack_oc, _src_type, _dst_type, \ + _src_c_type, _dst_c_type, _size_nbits) \ + if (same_scale == _same_scale && hw % _pack_w == 0 && \ + stype.enumv().ev == DTypeEnum::Ev::_src_type && \ + dtype.enumv().ev == DTypeEnum::Ev::_dst_type) { \ + auto kernel = \ + kern_nchwx_nchw<_pack_w, _pack_oc, _same_scale, _src_c_type, \ + _dst_c_type, dtype::_src_type, \ + dtype::_dst_type, _size_nbits>; \ + int nr_threads = query_blocksize_for_kernel(kernel); \ + const dim3 block_dim(DIVUP(hw, nr_threads* _pack_w), n); \ + const dim3 thread_dim(nr_threads); \ + return kernel<<>>( \ + (_src_c_type*)src.raw_ptr, (_dst_c_type*)dst.raw_ptr, c, hw, \ + n_stride_src, ic_stride, n_stride_dst, oc_stride, \ + CudaPostProcess(src_scale, src_zero_point, \ + dst_scale, dst_zero_point), \ + src_zero_point); \ + } +#define DISPATCH_4BITS(_src_type, _dst_type) \ + DISPATCH_RAW(true, 8, 64, _src_type, _dst_type, char, char, 4); \ + DISPATCH_RAW(false, 8, 64, _src_type, _dst_type, char, char, 4); \ + DISPATCH_RAW(true, 2, 64, _src_type, _dst_type, char, char, 4); \ + DISPATCH_RAW(false, 2, 64, _src_type, _dst_type, char, char, 4); + DISPATCH_4BITS(QuantizedS4, QuantizedS4); + DISPATCH_4BITS(Quantized4Asymm, Quantized4Asymm); +#undef DISPATCH_4BITS +#undef DISPATCH_RAW + megdnn_assert(false, + "Unsupported data type(src:%s, dst:%s) or image size(%dx%d).", + stype.name(), dtype.name(), h, w); } void relayout_format::relayout_format_cuda_nchw4_nchw( @@ -619,15 +1346,3 @@ void relayout_format::relayout_format_cuda_nchw_nchw4_weight( ic_stride, oc_stride_dst, group_stride_src, group_stride_dst, 0, {}); } - -template void relayout_format::relayout_format_cuda_nchw_nchw4<1>( - const TensorND& src, const TensorND& dst, const cudaStream_t& stream, - const float src_scale, const float dst_scale, - const uint8_t src_zero_point, const uint8_t dst_zero_point, - const int group); - -template void relayout_format::relayout_format_cuda_nchw_nchw4<4>( - const TensorND& src, const TensorND& dst, const cudaStream_t& stream, - const float src_scale, const float dst_scale, - const uint8_t src_zero_point, const uint8_t dst_zero_point, - const int group); diff --git a/dnn/src/cuda/relayout_format/relayout_format.cuh b/dnn/src/cuda/relayout_format/relayout_format.cuh index 3dd2058f..2d621940 100644 --- a/dnn/src/cuda/relayout_format/relayout_format.cuh +++ b/dnn/src/cuda/relayout_format/relayout_format.cuh @@ -19,14 +19,11 @@ namespace megdnn { namespace cuda { namespace relayout_format { -template -void relayout_format_cuda_nchw_nchw4(const TensorND& src, const TensorND& dst, - const cudaStream_t& stream, - const float src_scale = 1.f, - const float dst_scale = 1.f, - const uint8_t src_zero_point = 0, - const uint8_t dst_zero_point = 0, - const int group = 1); +void relayout_format_cuda_nchw_nchwx( + const TensorND& src, const TensorND& dst, const cudaStream_t& stream, + const float src_scale = 1.f, const float dst_scale = 1.f, + const uint8_t src_zero_point = 0, const uint8_t dst_zero_point = 0, + const int group = 1); bool relayout_format_cuda_usable(const TensorLayout& src_layout, const TensorLayout& dst_layout); @@ -35,6 +32,13 @@ void relayout_format_cuda_nchw4_nchw(const TensorND& src, const TensorND& dst, const cudaStream_t& stream, const int group); +void relayout_format_cuda_nchwx_nchw(const TensorND& src, const TensorND& dst, + const cudaStream_t& stream, + const float src_scale = 1.f, + const float dst_scale = 1.f, + const uint8_t src_zero_point = 0, + const uint8_t dst_zero_point = 0); + void relayout_format_cuda_nchw_nchw4_weight(const TensorND& src, const TensorND& dst, const cudaStream_t& stream); diff --git a/dnn/src/cuda/utils.cuh b/dnn/src/cuda/utils.cuh index 9cac97ec..c3feb91b 100644 --- a/dnn/src/cuda/utils.cuh +++ b/dnn/src/cuda/utils.cuh @@ -110,6 +110,12 @@ MEGDNN_NORETURN void report_error(const char* msg); template struct array_wrapper { T data[N]; + MEGDNN_DEVICE __forceinline__ T& operator[](size_t pos) { + return reinterpret_cast(data[pos]); + } + MEGDNN_DEVICE __forceinline__ T const& operator[](size_t pos) const { + return reinterpret_cast(data[pos]); + } }; /*! @@ -207,12 +213,29 @@ struct CudaDTypeParamImpl : DTypeParamImpl { CudaDTypeParamImpl(const DTypeParamImpl& param) : CudaDTypeParamImpl(param.scale, param.zero_point) {} - __device__ uint8_t quantize(float in) const { + __device__ dt_quint4 quantize(float in) const { float v = in * inv_scale; v = roundf(v); v = v + zero_point; v = fmin(fmax(0.f, v), 15.f); - return static_cast(v); + return static_cast(v); + } +}; + +template <> +struct CudaDTypeParamImpl : DTypeParamImpl { + float inv_scale; + CudaDTypeParamImpl() = default; + CudaDTypeParamImpl(float scale) + : DTypeParamImpl(scale), inv_scale(1.0f / scale) {} + CudaDTypeParamImpl(const DTypeParamImpl& param) + : CudaDTypeParamImpl(param.scale) {} + + __device__ dt_qint4 quantize(float in) const { + float v = in * inv_scale; + v = roundf(v); + v = fmin(fmax(-8.f, v), 7.f); + return static_cast(v); } }; @@ -351,6 +374,110 @@ MEGDNN_DEVICE __forceinline__ static float4 operator+(float4 lval, return make_float4(lval.x + rval.x, lval.y + rval.y, lval.z + rval.z, lval.w + rval.w); } + +MEGDNN_DEVICE __forceinline__ static int transform_int8_to_int4x8( + int s0, int s1, int s2, int s3, int s4, int s5, int s6, int s7) { + unsigned out; +#if __CUDA_ARCH__ >= 750 + asm volatile( + "{ .reg .u32 r4;" + "cvt.pack.sat.s4.s32.b32 r4, %8, %7, 0;" + "cvt.pack.sat.s4.s32.b32 r4, %6, %5, r4;" + "cvt.pack.sat.s4.s32.b32 r4, %4, %3, r4;" + "cvt.pack.sat.s4.s32.b32 %0, %2, %1, r4;" + "}" + : "=r"(out) + : "r"(s0), "r"(s1), "r"(s2), "r"(s3), "r"(s4), "r"(s5), "r"(s6), + "r"(s7)); +#else +#define CVT_SAT_S4_S32(r, bits) \ + r = r <= -8 ? -8 : r; \ + r = r > 7 ? 7 : r; \ + r = (((unsigned)r & 0xf) << bits); + CVT_SAT_S4_S32(s0, 0) + CVT_SAT_S4_S32(s1, 4) + CVT_SAT_S4_S32(s2, 8) + CVT_SAT_S4_S32(s3, 12) + CVT_SAT_S4_S32(s4, 16) + CVT_SAT_S4_S32(s5, 20) + CVT_SAT_S4_S32(s6, 24) + CVT_SAT_S4_S32(s7, 28) + out = s0 + s1 + s2 + s3 + s4 + s5 + s6 + s7; +#undef CVT_SAT_S4_S32 +#endif + return reinterpret_cast(out); +} + +MEGDNN_DEVICE __forceinline__ static int transform_int8_to_uint4x8( + int s0, int s1, int s2, int s3, int s4, int s5, int s6, int s7) { + unsigned out; +#if __CUDA_ARCH__ >= 750 + asm volatile( + "{ .reg .u32 r4;" + "cvt.pack.sat.u4.s32.b32 r4, %8, %7, 0;" + "cvt.pack.sat.u4.s32.b32 r4, %6, %5, r4;" + "cvt.pack.sat.u4.s32.b32 r4, %4, %3, r4;" + "cvt.pack.sat.u4.s32.b32 %0, %2, %1, r4;" + "}" + : "=r"(out) + : "r"(s0), "r"(s1), "r"(s2), "r"(s3), "r"(s4), "r"(s5), "r"(s6), + "r"(s7)); +#else +#define CVT_SAT_U4_S32(r, bits) \ + r = r <= 0 ? 0 : r; \ + r = r > 15 ? 15 : r; \ + r = (((unsigned)r & 0xf) << bits); + CVT_SAT_U4_S32(s0, 0) + CVT_SAT_U4_S32(s1, 4) + CVT_SAT_U4_S32(s2, 8) + CVT_SAT_U4_S32(s3, 12) + CVT_SAT_U4_S32(s4, 16) + CVT_SAT_U4_S32(s5, 20) + CVT_SAT_U4_S32(s6, 24) + CVT_SAT_U4_S32(s7, 28) + out = s0 + s1 + s2 + s3 + s4 + s5 + s6 + s7; +#undef CVT_SAT_U4_S32 +#endif + return reinterpret_cast(out); +} + +template +MEGDNN_DEVICE __forceinline__ static int unpack_integer_4bits(unsigned storage, + unsigned bits); + +template <> +MEGDNN_DEVICE __forceinline__ int unpack_integer_4bits(unsigned storage, + unsigned bits) { + uint8_t result = (uint8_t)((unsigned)(storage >> bits) & 0xf); + static constexpr uint8_t mask = (uint8_t)((1 << 4) - 1); + return (result & uint8_t(1 << 3)) ? ((int)(result) | ~(int)(mask)) + : (int)(result); +} + +template <> +MEGDNN_DEVICE __forceinline__ int unpack_integer_4bits(unsigned storage, + unsigned bits) { + uint8_t result = (uint8_t)((unsigned)(storage >> bits) & 0xf); + return (int)(result); +} + +MEGDNN_DEVICE __forceinline__ static void transform_int4x8_to_int8( + int (&result)[8], const int& source) { +#pragma unroll + for (int i = 0; i < 8; i++) { + result[i] = unpack_integer_4bits( + reinterpret_cast(source), (i << 2)); + } +} + +MEGDNN_DEVICE __forceinline__ static void transform_uint4x8_to_int8( + int (&result)[8], const int& source) { +#pragma unroll + for (int i = 0; i < 8; i++) { + result[i] = unpack_integer_4bits( + reinterpret_cast(source), (i << 2)); + } +} #endif } // namespace cuda } // namespace megdnn