diff --git a/dnn/src/cuda/integer_subbyte_utils.cuh b/dnn/src/cuda/integer_subbyte_utils.cuh index d9a80ac8..0371933a 100644 --- a/dnn/src/cuda/integer_subbyte_utils.cuh +++ b/dnn/src/cuda/integer_subbyte_utils.cuh @@ -110,35 +110,33 @@ MEGDNN_DEVICE __forceinline__ static int unpack_integer_4bits(T storage, return (result << (shift - bits)) >> shift; } -MEGDNN_DEVICE __forceinline__ static void transform_int4x8_to_int8( - int (&result)[8], const int& source) { -#pragma unroll - for (int i = 0; i < 8; i++) { - result[i] = unpack_integer_4bits( - reinterpret_cast(source), (i << 2)); - } -} - -MEGDNN_DEVICE __forceinline__ static void transform_uint4x8_to_int8( +template +MEGDNN_DEVICE __forceinline__ static void transform_b4x8_to_int8( int (&result)[8], const int& source) { #pragma unroll for (int i = 0; i < 8; i++) { - result[i] = unpack_integer_4bits( + result[i] = unpack_integer_4bits( reinterpret_cast(source), (i << 2)); } } -MEGDNN_DEVICE __forceinline__ static void transform_int4x2_to_int8( +template +MEGDNN_DEVICE __forceinline__ static void transform_b4x2_to_int8( int (&result)[2], const uint8_t& source) { - result[0] = unpack_integer_4bits(source, 0); - result[1] = unpack_integer_4bits(source, 4); + result[0] = unpack_integer_4bits(source, 0); + result[1] = unpack_integer_4bits(source, 4); } -MEGDNN_DEVICE __forceinline__ static void transform_uint4x2_to_int8( - int (&result)[2], const uint8_t& source) { - result[0] = unpack_integer_4bits(source, 0); - result[1] = unpack_integer_4bits(source, 4); +template +MEGDNN_DEVICE __forceinline__ static int transform_int8_to_b4x8( + int s0, int s1, int s2, int s3, int s4, int s5, int s6, int s7) { + if (signedness) { + return transform_int8_to_int4x8(s0, s1, s2, s3, s4, s5, s6, s7); + } else { + return transform_int8_to_uint4x8(s0, s1, s2, s3, s4, s5, s6, s7); + } } + } // namespace integer_subbyte } // namespace cuda } // namespace megdnn diff --git a/dnn/src/cuda/relayout_format/cuda_post_process.cuh b/dnn/src/cuda/relayout_format/cuda_post_process.cuh new file mode 100644 index 00000000..a4f8098d --- /dev/null +++ b/dnn/src/cuda/relayout_format/cuda_post_process.cuh @@ -0,0 +1,171 @@ +/** + * \file dnn/src/cuda/relayout_format/cuda_post_process.cuh + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ +#pragma once +#include "src/cuda/relayout_format/relayout_format.cuh" + +namespace megdnn { +namespace cuda { +namespace relayout_format { +namespace internal { +template +struct CudaPostProcess; + +template <> +struct CudaPostProcess { + CudaPostProcess(float, uint8_t, float, uint8_t){}; + inline __device__ int8_t operator()(uint8_t val) { return val - 128; } +}; + +template <> +struct CudaPostProcess { + CudaDTypeParamImpl m_dst_type_cvt; + CudaPostProcess(float, uint8_t, float dst_scale, uint8_t) { + m_dst_type_cvt = CudaDTypeParamImpl(dst_scale); + }; + inline __device__ int8_t operator()(uint8_t val) { + return m_dst_type_cvt.quantize((float)val - 128.f).as_int8(); + } +}; + +template <> +struct CudaPostProcess { + CudaDTypeParamImpl m_dst_type_cvt; + CudaDTypeParamImpl m_src_type_cvt; + CudaPostProcess(float src_scale, uint8_t src_zero_point, float dst_scale, + uint8_t) { + m_dst_type_cvt = CudaDTypeParamImpl(dst_scale); + m_src_type_cvt = + CudaDTypeParamImpl(src_scale, src_zero_point); + }; + inline __device__ int8_t operator()(uint8_t val) { + float med_var = m_src_type_cvt.dequantize(dt_quint8(val)); + return m_dst_type_cvt.quantize(med_var).as_int8(); + } +}; + +template <> +struct CudaPostProcess { + uint8_t m_src_zero_point = 0; + CudaPostProcess(float, uint8_t src_zero_point, float, uint8_t) { + m_src_zero_point = src_zero_point; + }; + inline __device__ int8_t operator()(uint8_t val) { + return val - m_src_zero_point; + } +}; + +template <> +struct CudaPostProcess { + CudaDTypeParamImpl m_dst_type_cvt; + CudaDTypeParamImpl m_src_type_cvt; + CudaPostProcess(float src_scale, uint8_t, float dst_scale, uint8_t) { + m_dst_type_cvt = CudaDTypeParamImpl(dst_scale); + m_src_type_cvt = CudaDTypeParamImpl(src_scale); + }; + inline __device__ int8_t operator()(int8_t val) { + float med_var = m_src_type_cvt.dequantize(dt_qint8(val)); + return m_dst_type_cvt.quantize(med_var).as_int8(); + } +}; + +template <> +struct CudaPostProcess { + CudaPostProcess(){}; + CudaPostProcess(float, uint8_t, float, uint8_t){}; + inline __device__ int8_t operator()(int8_t val) { return val; } +}; + +template <> +struct CudaPostProcess { + CudaDTypeParamImpl m_dst_type_cvt; + CudaDTypeParamImpl m_src_type_cvt; + CudaPostProcess(float src_scale, int, float dst_scale, int) { + m_dst_type_cvt = CudaDTypeParamImpl(dst_scale); + m_src_type_cvt = CudaDTypeParamImpl(src_scale); + }; + inline __device__ int operator()(int val) { + float med_var = m_src_type_cvt.dequantize(dt_qint32(val)); + return m_dst_type_cvt.quantize(med_var).as_int32(); + } +}; +template <> +struct CudaPostProcess { + CudaPostProcess(float, int, float, int){}; + inline __device__ int operator()(int val) { return val; } +}; + +template <> +struct CudaPostProcess { + using SrcType = dtype::QuantizedS4; + using DstType = dtype::QuantizedS4; + CudaDTypeParamImpl m_dst_type_cvt; + CudaDTypeParamImpl m_src_type_cvt; + CudaPostProcess(float src_scale, uint8_t, float dst_scale, uint8_t) { + m_dst_type_cvt = CudaDTypeParamImpl(dst_scale); + m_src_type_cvt = CudaDTypeParamImpl(src_scale); + } + inline __device__ int8_t operator()(int8_t val) { + float intermediate = m_src_type_cvt.dequantize(dt_qint4(val)); + return m_dst_type_cvt.quantize(intermediate).as_int8(); + } +}; + +template <> +struct CudaPostProcess { + using SrcType = dtype::QuantizedS4; + using DstType = dtype::QuantizedS4; + CudaPostProcess(float, uint8_t, float, uint8_t){}; + inline __device__ int8_t operator()(int8_t val) { return val; } +}; + +template <> +struct CudaPostProcess { + using SrcType = dtype::Quantized4Asymm; + using DstType = dtype::Quantized4Asymm; + CudaDTypeParamImpl m_dst_type_cvt; + CudaDTypeParamImpl m_src_type_cvt; + CudaPostProcess(float src_scale, uint8_t src_zero_point, float dst_scale, + uint8_t dst_zero_point) { + m_dst_type_cvt = + CudaDTypeParamImpl(dst_scale, dst_zero_point); + m_src_type_cvt = + CudaDTypeParamImpl(src_scale, src_zero_point); + }; + inline __device__ uint8_t operator()(uint8_t val) { + float intermediate = m_src_type_cvt.dequantize(dt_quint4(val)); + return m_dst_type_cvt.quantize(intermediate).as_uint8(); + } +}; + +template <> +struct CudaPostProcess { + using SrcType = dtype::Quantized4Asymm; + using DstType = dtype::Quantized4Asymm; + uint8_t m_src_zero_point = 0; + uint8_t m_dst_zero_point = 0; + CudaPostProcess(float, uint8_t src_zero_point, float, + uint8_t dst_zero_point) { + m_src_zero_point = src_zero_point; + m_dst_zero_point = dst_zero_point; + }; + inline __device__ uint8_t operator()(uint8_t val) { + int result = val - m_src_zero_point + m_dst_zero_point; + result = result >= 0 ? result : 0; + result = result < 16 ? result : 15; + return static_cast(result); + } +}; + +} // namespace internal +} // namespace relayout_format +} // namespace cuda +} // namespace megdnn diff --git a/dnn/src/cuda/relayout_format/helper.cuh b/dnn/src/cuda/relayout_format/helper.cuh deleted file mode 100644 index 69d98b70..00000000 --- a/dnn/src/cuda/relayout_format/helper.cuh +++ /dev/null @@ -1,252 +0,0 @@ -/** - * \file dnn/src/cuda/relayout_format/helper.cuh - * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") - * - * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or - * implied. - */ - -namespace megdnn { -namespace cuda { -namespace relayout_format { - -#define devfunc __forceinline__ __device__ -template -devfunc int make_zero(int zero_point); - -template <> -devfunc int make_zero<4>(int zero_point) { - return transform_int8_to_uint4x8(zero_point, zero_point, zero_point, - zero_point, zero_point, zero_point, - zero_point, zero_point); -} - -template -struct global_load_with_zero_point; - -///////////////////////////////////////////////////////////////////////////////////////////////// -// -// Specializations -// -///////////////////////////////////////////////////////////////////////////////////////////////// - -///////////////////////////////////////////////////////////////////////////////////////////////// - -// The redundant mov PTX instruction is used to enforce the compiler to -// initialize data to zero before ld.global -template -struct global_load_with_zero_point { - devfunc global_load_with_zero_point(AccessType& D, void const* ptr, - bool pred_guard, int zero_point) { - uint4* data = reinterpret_cast(&D); - - asm volatile( - "{\n" - " .reg .pred p;\n" - " setp.ne.b32 p, %9, 0;\n" - " mov.b32 %0, %10;\n" - " mov.b32 %1, %10;\n" - " mov.b32 %2, %10;\n" - " mov.b32 %3, %10;\n" - " mov.b32 %4, %10;\n" - " mov.b32 %5, %10;\n" - " mov.b32 %6, %10;\n" - " mov.b32 %7, %10;\n" - " @p ld.global.v4.u32 {%0, %1, %2, %3}, [%8];\n" - " @p ld.global.v4.u32 {%4, %5, %6, %7}, [%11];\n" - "}\n" - : "=r"(data[0].x), "=r"(data[0].y), "=r"(data[0].z), - "=r"(data[0].w), "=r"(data[1].x), "=r"(data[1].y), - "=r"(data[1].z), "=r"(data[1].w) - : "l"(ptr), "r"((int)pred_guard), - "r"(reinterpret_cast(zero_point)), - "l"(((uint8_t*)ptr) + 16)); - } -}; - -template -struct global_load_with_zero_point { - devfunc global_load_with_zero_point(AccessType& D, void const* ptr, - bool pred_guard, int zero_point) { - uint4& data = reinterpret_cast(D); - - asm volatile( - "{\n" - " .reg .pred p;\n" - " setp.ne.b32 p, %5, 0;\n" - " mov.b32 %0, %6;\n" - " mov.b32 %1, %6;\n" - " mov.b32 %2, %6;\n" - " mov.b32 %3, %6;\n" - " @p ld.global.v4.u32 {%0, %1, %2, %3}, [%4];\n" - "}\n" - : "=r"(data.x), "=r"(data.y), "=r"(data.z), "=r"(data.w) - : "l"(ptr), "r"((int)pred_guard), - "r"(reinterpret_cast(zero_point))); - } -}; - -template -struct global_load_with_zero_point { - devfunc global_load_with_zero_point(AccessType& D, void const* ptr, - bool pred_guard, int zero_point) { - uint2& data = reinterpret_cast(D); - - asm volatile( - "{\n" - " .reg .pred p;\n" - " setp.ne.b32 p, %3, 0;\n" - " mov.b32 %0, %4;\n" - " mov.b32 %1, %4;\n" - " @p ld.global.v2.u32 {%0, %1}, [%2];\n" - "}\n" - : "=r"(data.x), "=r"(data.y) - : "l"(ptr), "r"((int)pred_guard), - "r"(reinterpret_cast(zero_point))); - } -}; - -template -struct global_load_with_zero_point { - devfunc global_load_with_zero_point(AccessType& D, void const* ptr, - bool pred_guard, int zero_point) { - unsigned& data = reinterpret_cast(D); - - asm volatile( - "{\n" - " .reg .pred p;\n" - " setp.ne.b32 p, %2, 0;\n" - " mov.b32 %0, %3;\n" - " @p ld.global.u32 %0, [%1];\n" - "}\n" - : "=r"(data) - : "l"(ptr), "r"((int)pred_guard), - "r"(reinterpret_cast(zero_point))); - } -}; - -template -struct global_load_with_zero_point { - devfunc global_load_with_zero_point(AccessType& D, void const* ptr, - bool pred_guard, int zero_point) { - if (pred_guard) - D = *(reinterpret_cast(ptr)); - else { - unsigned uv = reinterpret_cast(zero_point); - uint8_t& data = reinterpret_cast(D); - data = uv & 0xff; - } - } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// -template < - /// Fragment type to store loaded data - typename AccessType, - /// The bytes of loading - int LoadBytes> -struct global_store; - -///////////////////////////////////////////////////////////////////////////////////////////////// -// -// Specializations -// -///////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct global_store { - devfunc global_store(AccessType const& D, void* ptr, bool pred_guard) { - uint4 const* data = reinterpret_cast(&D); - - asm volatile( - "{\n" - " .reg .pred p;\n" - " setp.ne.b32 p, %5, 0;\n" - " @p st.global.v4.u32 [%0], {%1, %2, %3, %4};\n" - " @p st.global.v4.u32 [%6], {%7, %8, %9, %10};\n" - "}\n" - : - : "l"(ptr), "r"(data[0].x), "r"(data[0].y), "r"(data[0].z), - "r"(data[0].w), "r"((int)pred_guard), - "l"(((uint8_t*)ptr) + 16), "r"(data[1].x), "r"(data[1].y), - "r"(data[1].z), "r"(data[1].w)); - } -}; - -template -struct global_store { - devfunc global_store(AccessType const& D, void* ptr, bool pred_guard) { - uint4 const& data = reinterpret_cast(D); - asm volatile( - "{\n" - " .reg .pred p;\n" - " setp.ne.b32 p, %5, 0;\n" - " @p st.global.v4.u32 [%0], {%1, %2, %3, %4};\n" - "}\n" - : - : "l"(ptr), "r"(data.x), "r"(data.y), "r"(data.z), "r"(data.w), - "r"((int)pred_guard)); - } -}; - -template -struct global_store { - devfunc global_store(AccessType const& D, void* ptr, bool pred_guard) { - uint2 const& data = reinterpret_cast(D); - asm volatile( - "{\n" - " .reg .pred p;\n" - " setp.ne.b32 p, %3, 0;\n" - " @p st.global.v2.u32 [%0], {%1, %2};\n" - "}\n" - : - : "l"(ptr), "r"(data.x), "r"(data.y), "r"((int)pred_guard)); - } -}; - -template -struct global_store { - devfunc global_store(AccessType const& D, void* ptr, bool pred_guard) { - uint32_t const& data = reinterpret_cast(D); - asm volatile( - "{\n" - " .reg .pred p;\n" - " setp.ne.b32 p, %2, 0;\n" - " @p st.global.u32 [%0], %1;\n" - "}\n" - : - : "l"(ptr), "r"(data), "r"((int)pred_guard)); - } -}; - -template -struct global_store { - devfunc global_store(AccessType const& D, void* ptr, bool pred_guard) { - uint16_t const& data = reinterpret_cast(D); - asm volatile( - "{\n" - " .reg .pred p;\n" - " setp.ne.b32 p, %2, 0;\n" - " @p st.global.u16 [%0], %1;\n" - "}\n" - : - : "l"(ptr), "h"(data), "r"((int)pred_guard)); - } -}; - -template -struct global_store { - devfunc global_store(AccessType const& D, void* ptr, bool pred_guard) { - if (pred_guard) - *(reinterpret_cast(ptr)) = D; - } -}; - -#undef devfunc -} // namespace relayout_format -} // namespace cuda -} // namespace megdnn diff --git a/dnn/src/cuda/relayout_format/relayout_format.cu b/dnn/src/cuda/relayout_format/relayout_format.cu index 9c7b7b9f..02f4b09e 100644 --- a/dnn/src/cuda/relayout_format/relayout_format.cu +++ b/dnn/src/cuda/relayout_format/relayout_format.cu @@ -10,790 +10,15 @@ * implied. */ -#include "src/cuda/int_fastdiv.cuh" #include "src/cuda/query_blocksize.cuh" -#include "src/cuda/relayout_format/relayout_format.cuh" -#include "src/cuda/integer_subbyte_utils.cuh" -#include "src/cuda/memory_utils.cuh" +#include "src/cuda/relayout_format/relayout_format_kern.cuh" + using namespace megdnn; using namespace cuda; -using namespace integer_subbyte; +using namespace relayout_format; +using namespace internal; namespace { - -template -struct CudaPostProcess; - -template <> -struct CudaPostProcess { - CudaPostProcess(float, uint8_t, float, uint8_t){}; - inline __device__ int8_t operator()(uint8_t val) { return val - 128; } -}; - -template <> -struct CudaPostProcess { - CudaDTypeParamImpl m_dst_type_cvt; - CudaPostProcess(float, uint8_t, float dst_scale, uint8_t) { - m_dst_type_cvt = CudaDTypeParamImpl(dst_scale); - }; - inline __device__ int8_t operator()(uint8_t val) { - return m_dst_type_cvt.quantize((float)val - 128.f).as_int8(); - } -}; - -template <> -struct CudaPostProcess { - CudaDTypeParamImpl m_dst_type_cvt; - CudaDTypeParamImpl m_src_type_cvt; - CudaPostProcess(float src_scale, uint8_t src_zero_point, float dst_scale, - uint8_t) { - m_dst_type_cvt = CudaDTypeParamImpl(dst_scale); - m_src_type_cvt = - CudaDTypeParamImpl(src_scale, src_zero_point); - }; - inline __device__ int8_t operator()(uint8_t val) { - float med_var = m_src_type_cvt.dequantize(dt_quint8(val)); - return m_dst_type_cvt.quantize(med_var).as_int8(); - } -}; - -template <> -struct CudaPostProcess { - uint8_t m_src_zero_point = 0; - CudaPostProcess(float, uint8_t src_zero_point, float, uint8_t) { - m_src_zero_point = src_zero_point; - }; - inline __device__ int8_t operator()(uint8_t val) { - return val - m_src_zero_point; - } -}; - -template <> -struct CudaPostProcess { - CudaDTypeParamImpl m_dst_type_cvt; - CudaDTypeParamImpl m_src_type_cvt; - CudaPostProcess(float src_scale, uint8_t, float dst_scale, uint8_t) { - m_dst_type_cvt = CudaDTypeParamImpl(dst_scale); - m_src_type_cvt = CudaDTypeParamImpl(src_scale); - }; - inline __device__ int8_t operator()(int8_t val) { - float med_var = m_src_type_cvt.dequantize(dt_qint8(val)); - return m_dst_type_cvt.quantize(med_var).as_int8(); - } -}; - -template <> -struct CudaPostProcess { - CudaPostProcess(){}; - CudaPostProcess(float, uint8_t, float, uint8_t){}; - inline __device__ int8_t operator()(int8_t val) { return val; } -}; - -template <> -struct CudaPostProcess { - CudaDTypeParamImpl m_dst_type_cvt; - CudaDTypeParamImpl m_src_type_cvt; - CudaPostProcess(float src_scale, int, float dst_scale, int) { - m_dst_type_cvt = CudaDTypeParamImpl(dst_scale); - m_src_type_cvt = CudaDTypeParamImpl(src_scale); - }; - inline __device__ int operator()(int val) { - float med_var = m_src_type_cvt.dequantize(dt_qint32(val)); - return m_dst_type_cvt.quantize(med_var).as_int32(); - } -}; -template <> -struct CudaPostProcess { - CudaPostProcess(float, int, float, int){}; - inline __device__ int operator()(int val) { return val; } -}; - -template <> -struct CudaPostProcess { - using SrcType = dtype::QuantizedS4; - using DstType = dtype::QuantizedS4; - CudaDTypeParamImpl m_dst_type_cvt; - CudaDTypeParamImpl m_src_type_cvt; - CudaPostProcess(float src_scale, uint8_t, float dst_scale, uint8_t) { - m_dst_type_cvt = CudaDTypeParamImpl(dst_scale); - m_src_type_cvt = CudaDTypeParamImpl(src_scale); - } - inline __device__ int8_t operator()(int8_t val) { - float intermediate = m_src_type_cvt.dequantize(dt_qint4(val)); - return m_dst_type_cvt.quantize(intermediate).as_int8(); - } -}; - -template <> -struct CudaPostProcess { - using SrcType = dtype::QuantizedS4; - using DstType = dtype::QuantizedS4; - CudaPostProcess(float, uint8_t, float, uint8_t){}; - inline __device__ int8_t operator()(int8_t val) { return val; } -}; - -template <> -struct CudaPostProcess { - using SrcType = dtype::Quantized4Asymm; - using DstType = dtype::Quantized4Asymm; - CudaDTypeParamImpl m_dst_type_cvt; - CudaDTypeParamImpl m_src_type_cvt; - CudaPostProcess(float src_scale, uint8_t src_zero_point, float dst_scale, - uint8_t dst_zero_point) { - m_dst_type_cvt = - CudaDTypeParamImpl(dst_scale, dst_zero_point); - m_src_type_cvt = - CudaDTypeParamImpl(src_scale, src_zero_point); - }; - inline __device__ uint8_t operator()(uint8_t val) { - float intermediate = m_src_type_cvt.dequantize(dt_quint4(val)); - return m_dst_type_cvt.quantize(intermediate).as_uint8(); - } -}; - -template <> -struct CudaPostProcess { - using SrcType = dtype::Quantized4Asymm; - using DstType = dtype::Quantized4Asymm; - uint8_t m_src_zero_point = 0; - uint8_t m_dst_zero_point = 0; - CudaPostProcess(float, uint8_t src_zero_point, float, - uint8_t dst_zero_point) { - m_src_zero_point = src_zero_point; - m_dst_zero_point = dst_zero_point; - }; - inline __device__ uint8_t operator()(uint8_t val) { - int result = val - m_src_zero_point + m_dst_zero_point; - result = result >= 0 ? result : 0; - result = result < 16 ? result : 15; - return static_cast(result); - } -}; - -template -struct DTypeRWHelper; -template -struct DTypeRWHelper< - ctype, 1, - typename std::enable_if::value || - std::is_same::value || - std::is_same::value>::type> { - using InnerDtype = char; - using DstDtype = char4; -}; - -template -struct DTypeRWHelper< - ctype, 4, - typename std::enable_if::value || - std::is_same::value || - std::is_same::value>::type> { - using InnerDtype = char4; - using DstDtype = char4; -}; - -template <> -struct DTypeRWHelper { - using InnerDtype = int; - using DstDtype = int4; -}; - -template <> -struct DTypeRWHelper { - using InnerDtype = int4; - using DstDtype = int4; -}; - -template -struct DTypeRWHelper< - ctype, 2, - typename std::enable_if::value || - std::is_same::value>::type> { - using InnerDtype = char; - using DstDtype = array_wrapper; -}; - -template -struct DTypeRWHelper< - ctype, 8, - typename std::enable_if::value || - std::is_same::value>::type> { - using InnerDtype = unsigned; - using DstDtype = array_wrapper; -}; - -template -struct Translayout { - using InnerDtype = - typename DTypeRWHelper::ctype, - pack_w>::InnerDtype; - using DstDtype = - typename DTypeRWHelper::ctype, - pack_w>::DstDtype; - static inline __device__ void trans(DstDtype (&dst_width)[pack_w], - InnerDtype (&read_channel)[pack_c], - const char zero_point); -}; - -template -struct Translayout<1, 4, SrcType, DnnSrcType, DnnDstType, same_scale> { - using InnerDtype = - typename DTypeRWHelper::ctype, - 1>::InnerDtype; - using DstDtype = - typename DTypeRWHelper::ctype, - 1>::DstDtype; - static inline __device__ void trans( - DstDtype (&dst_width)[1], InnerDtype (&read_channel)[4], - CudaPostProcess& post_process, - const char zero_point) { - dst_width[0].x = post_process(read_channel[0]); - dst_width[0].y = post_process(read_channel[1]); - dst_width[0].z = post_process(read_channel[2]); - dst_width[0].w = post_process(read_channel[3]); - } -}; - -template -struct Translayout<4, 4, SrcType, DnnSrcType, DnnDstType, same_scale> { - using InnerDtype = - typename DTypeRWHelper::ctype, - 4>::InnerDtype; - using DstDtype = - typename DTypeRWHelper::ctype, - 4>::DstDtype; - static inline __device__ void trans( - DstDtype (&dst_width)[4], InnerDtype (&read_channel)[4], - CudaPostProcess& post_process, - const char zero_point) { - dst_width[0].x = post_process(read_channel[0].x); - dst_width[0].y = post_process(read_channel[1].x); - dst_width[0].z = post_process(read_channel[2].x); - dst_width[0].w = post_process(read_channel[3].x); - - dst_width[1].x = post_process(read_channel[0].y); - dst_width[1].y = post_process(read_channel[1].y); - dst_width[1].z = post_process(read_channel[2].y); - dst_width[1].w = post_process(read_channel[3].y); - - dst_width[2].x = post_process(read_channel[0].z); - dst_width[2].y = post_process(read_channel[1].z); - dst_width[2].z = post_process(read_channel[2].z); - dst_width[2].w = post_process(read_channel[3].z); - - dst_width[3].x = post_process(read_channel[0].w); - dst_width[3].y = post_process(read_channel[1].w); - dst_width[3].z = post_process(read_channel[2].w); - dst_width[3].w = post_process(read_channel[3].w); - } -}; - -#define pack_channel(_idx) \ - transform_int8_to_int4x8(post_process(intermediate[0][_idx]), \ - post_process(intermediate[1][_idx]), \ - post_process(intermediate[2][_idx]), \ - post_process(intermediate[3][_idx]), \ - post_process(intermediate[4][_idx]), \ - post_process(intermediate[5][_idx]), \ - post_process(intermediate[6][_idx]), \ - post_process(intermediate[7][_idx])); -template -struct Translayout<2, 64, SrcType, dtype::QuantizedS4, dtype::QuantizedS4, - same_scale> { - using DnnSrcType = dtype::QuantizedS4; - using DnnDstType = dtype::QuantizedS4; - using InnerDtype = - typename DTypeRWHelper::ctype, - 2>::InnerDtype; - using DstDtype = - typename DTypeRWHelper::ctype, - 2>::DstDtype; - static inline __device__ void trans( - DstDtype (&dst_width)[2], InnerDtype (&read_channel)[64], - CudaPostProcess& post_process, - const char zero_point) { - int intermediate[8][2]; - int* dst_frag = reinterpret_cast(dst_width); -#pragma unroll - for (int i = 0; i < 64; i += 8) { - transform_int4x2_to_int8( - intermediate[0], - reinterpret_cast(read_channel[i + 0])); - transform_int4x2_to_int8( - intermediate[1], - reinterpret_cast(read_channel[i + 1])); - transform_int4x2_to_int8( - intermediate[2], - reinterpret_cast(read_channel[i + 2])); - transform_int4x2_to_int8( - intermediate[3], - reinterpret_cast(read_channel[i + 3])); - transform_int4x2_to_int8( - intermediate[4], - reinterpret_cast(read_channel[i + 4])); - transform_int4x2_to_int8( - intermediate[5], - reinterpret_cast(read_channel[i + 5])); - transform_int4x2_to_int8( - intermediate[6], - reinterpret_cast(read_channel[i + 6])); - transform_int4x2_to_int8( - intermediate[7], - reinterpret_cast(read_channel[i + 7])); - - int frag_idx = i / 8; - dst_frag[0 * 8 + frag_idx] = pack_channel(0); - dst_frag[1 * 8 + frag_idx] = pack_channel(1); - } - } - using Fragment = array_wrapper; - static inline __device__ void trans( - Fragment& dst, Fragment& src, - CudaPostProcess& post_process) { - trans(reinterpret_cast(dst), - reinterpret_cast(src), post_process, 0); - } -}; - -template -struct Translayout<8, 64, SrcType, dtype::QuantizedS4, dtype::QuantizedS4, - same_scale> { - using DnnSrcType = dtype::QuantizedS4; - using DnnDstType = dtype::QuantizedS4; - using InnerDtype = - typename DTypeRWHelper::ctype, - 8>::InnerDtype; - using DstDtype = - typename DTypeRWHelper::ctype, - 8>::DstDtype; - static inline __device__ void trans( - DstDtype (&dst_width)[8], InnerDtype (&read_channel)[64], - CudaPostProcess& post_process, - const char zero_point) { - int intermediate[8][8]; - int* dst_frag = reinterpret_cast(dst_width); -#pragma unroll - for (int i = 0; i < 64; i += 8) { - transform_int4x8_to_int8(intermediate[0], read_channel[i + 0]); - transform_int4x8_to_int8(intermediate[1], read_channel[i + 1]); - transform_int4x8_to_int8(intermediate[2], read_channel[i + 2]); - transform_int4x8_to_int8(intermediate[3], read_channel[i + 3]); - transform_int4x8_to_int8(intermediate[4], read_channel[i + 4]); - transform_int4x8_to_int8(intermediate[5], read_channel[i + 5]); - transform_int4x8_to_int8(intermediate[6], read_channel[i + 6]); - transform_int4x8_to_int8(intermediate[7], read_channel[i + 7]); - int frag_idx = i / 8; - dst_frag[0 * 8 + frag_idx] = pack_channel(0); - dst_frag[1 * 8 + frag_idx] = pack_channel(1); - dst_frag[2 * 8 + frag_idx] = pack_channel(2); - dst_frag[3 * 8 + frag_idx] = pack_channel(3); - dst_frag[4 * 8 + frag_idx] = pack_channel(4); - dst_frag[5 * 8 + frag_idx] = pack_channel(5); - dst_frag[6 * 8 + frag_idx] = pack_channel(6); - dst_frag[7 * 8 + frag_idx] = pack_channel(7); - } - } - using Fragment = array_wrapper; - static inline __device__ void trans( - Fragment& dst, Fragment& src, - CudaPostProcess& post_process) { - trans(reinterpret_cast(dst), - reinterpret_cast(src), post_process, 0); - } -}; -#undef pack_channel - -#define pack_channel(_idx) \ - transform_int8_to_uint4x8(post_process(intermediate[0][_idx]), \ - post_process(intermediate[1][_idx]), \ - post_process(intermediate[2][_idx]), \ - post_process(intermediate[3][_idx]), \ - post_process(intermediate[4][_idx]), \ - post_process(intermediate[5][_idx]), \ - post_process(intermediate[6][_idx]), \ - post_process(intermediate[7][_idx])); -template -struct Translayout<2, 64, SrcType, dtype::Quantized4Asymm, - dtype::Quantized4Asymm, same_scale> { - using DnnSrcType = dtype::Quantized4Asymm; - using DnnDstType = dtype::Quantized4Asymm; - using InnerDtype = - typename DTypeRWHelper::ctype, - 2>::InnerDtype; - using DstDtype = - typename DTypeRWHelper::ctype, - 2>::DstDtype; - static inline __device__ void trans( - DstDtype (&dst_width)[2], InnerDtype (&read_channel)[64], - CudaPostProcess& post_process, - const char zero_point) { - int intermediate[8][2]; - int* dst_frag = reinterpret_cast(dst_width); -#pragma unroll - for (int i = 0; i < 64; i += 8) { - transform_uint4x2_to_int8( - intermediate[0], - reinterpret_cast(read_channel[i + 0])); - transform_uint4x2_to_int8( - intermediate[1], - reinterpret_cast(read_channel[i + 1])); - transform_uint4x2_to_int8( - intermediate[2], - reinterpret_cast(read_channel[i + 2])); - transform_uint4x2_to_int8( - intermediate[3], - reinterpret_cast(read_channel[i + 3])); - transform_uint4x2_to_int8( - intermediate[4], - reinterpret_cast(read_channel[i + 4])); - transform_uint4x2_to_int8( - intermediate[5], - reinterpret_cast(read_channel[i + 5])); - transform_uint4x2_to_int8( - intermediate[6], - reinterpret_cast(read_channel[i + 6])); - transform_uint4x2_to_int8( - intermediate[7], - reinterpret_cast(read_channel[i + 7])); - - int frag_idx = i / 8; - dst_frag[0 * 8 + frag_idx] = pack_channel(0); - dst_frag[1 * 8 + frag_idx] = pack_channel(1); - } - } - using Fragment = array_wrapper; - static inline __device__ void trans( - Fragment& dst, Fragment& src, - CudaPostProcess& post_process) { - trans(reinterpret_cast(dst), - reinterpret_cast(src), post_process, 0); - } -}; - -template -struct Translayout<8, 64, SrcType, dtype::Quantized4Asymm, - dtype::Quantized4Asymm, same_scale> { - using DnnSrcType = dtype::Quantized4Asymm; - using DnnDstType = dtype::Quantized4Asymm; - using InnerDtype = - typename DTypeRWHelper::ctype, - 8>::InnerDtype; - using DstDtype = - typename DTypeRWHelper::ctype, - 8>::DstDtype; - static inline __device__ void trans( - DstDtype (&dst_width)[8], InnerDtype (&read_channel)[64], - CudaPostProcess& post_process, - const char zero_point) { - int intermediate[8][8]; - int* dst_frag = reinterpret_cast(dst_width); -#pragma unroll - for (int i = 0; i < 64; i += 8) { - transform_uint4x8_to_int8(intermediate[0], read_channel[i + 0]); - transform_uint4x8_to_int8(intermediate[1], read_channel[i + 1]); - transform_uint4x8_to_int8(intermediate[2], read_channel[i + 2]); - transform_uint4x8_to_int8(intermediate[3], read_channel[i + 3]); - transform_uint4x8_to_int8(intermediate[4], read_channel[i + 4]); - transform_uint4x8_to_int8(intermediate[5], read_channel[i + 5]); - transform_uint4x8_to_int8(intermediate[6], read_channel[i + 6]); - transform_uint4x8_to_int8(intermediate[7], read_channel[i + 7]); - int frag_idx = i / 8; - dst_frag[0 * 8 + frag_idx] = pack_channel(0); - dst_frag[1 * 8 + frag_idx] = pack_channel(1); - dst_frag[2 * 8 + frag_idx] = pack_channel(2); - dst_frag[3 * 8 + frag_idx] = pack_channel(3); - dst_frag[4 * 8 + frag_idx] = pack_channel(4); - dst_frag[5 * 8 + frag_idx] = pack_channel(5); - dst_frag[6 * 8 + frag_idx] = pack_channel(6); - dst_frag[7 * 8 + frag_idx] = pack_channel(7); - } - } - using Fragment = array_wrapper; - static inline __device__ void trans( - Fragment& dst, Fragment& src, - CudaPostProcess& post_process) { - trans(reinterpret_cast(dst), - reinterpret_cast(src), post_process, 0); - } -}; -#undef pack_channel - -#define pack(_idx) \ - transform_int8_to_int4x8(post_process(intermediate[0][_idx]), \ - post_process(intermediate[1][_idx]), \ - post_process(intermediate[2][_idx]), \ - post_process(intermediate[3][_idx]), \ - post_process(intermediate[4][_idx]), \ - post_process(intermediate[5][_idx]), \ - post_process(intermediate[6][_idx]), \ - post_process(intermediate[7][_idx])); -template -struct Translayout<64, 8, SrcType, dtype::QuantizedS4, dtype::QuantizedS4, - same_scale> { - using DnnSrcType = dtype::QuantizedS4; - using DnnDstType = dtype::QuantizedS4; - static constexpr int row = 8; - static constexpr int col = 64; - static constexpr int size_nbits = 4; - static constexpr int col_in_type = col * size_nbits / (8 * sizeof(SrcType)); - static constexpr int elements_in_type = row * col_in_type; - static constexpr int inc_col = 8; - static constexpr int inc_col_in_type = - inc_col * size_nbits / (8 * sizeof(SrcType)); - using Fragment = array_wrapper; - static MEGDNN_DEVICE __forceinline__ void trans( - Fragment& dst, const Fragment& src, - CudaPostProcess& post_process) { - int intermediate[8][8]; - int* dst_frag = reinterpret_cast(&dst); -#pragma unroll - for (int j = 0; j < col_in_type; j += inc_col_in_type) { - transform_int4x8_to_int8( - intermediate[0], - reinterpret_cast(src[0 * col_in_type + j])); - transform_int4x8_to_int8( - intermediate[1], - reinterpret_cast(src[1 * col_in_type + j])); - transform_int4x8_to_int8( - intermediate[2], - reinterpret_cast(src[2 * col_in_type + j])); - transform_int4x8_to_int8( - intermediate[3], - reinterpret_cast(src[3 * col_in_type + j])); - transform_int4x8_to_int8( - intermediate[4], - reinterpret_cast(src[4 * col_in_type + j])); - transform_int4x8_to_int8( - intermediate[5], - reinterpret_cast(src[5 * col_in_type + j])); - transform_int4x8_to_int8( - intermediate[6], - reinterpret_cast(src[6 * col_in_type + j])); - transform_int4x8_to_int8( - intermediate[7], - reinterpret_cast(src[7 * col_in_type + j])); - dst_frag[(j / inc_col_in_type) * 8 + 0] = pack(0); - dst_frag[(j / inc_col_in_type) * 8 + 1] = pack(1); - dst_frag[(j / inc_col_in_type) * 8 + 2] = pack(2); - dst_frag[(j / inc_col_in_type) * 8 + 3] = pack(3); - dst_frag[(j / inc_col_in_type) * 8 + 4] = pack(4); - dst_frag[(j / inc_col_in_type) * 8 + 5] = pack(5); - dst_frag[(j / inc_col_in_type) * 8 + 6] = pack(6); - dst_frag[(j / inc_col_in_type) * 8 + 7] = pack(7); - } - } -}; -#undef pack - -#define pack(_idx) \ - ((post_process(intermediate[0][_idx]) & 0xf) | \ - (post_process(intermediate[1][_idx]) << 4)) -template -struct Translayout<64, 2, SrcType, dtype::QuantizedS4, dtype::QuantizedS4, - same_scale> { - using DnnSrcType = dtype::QuantizedS4; - using DnnDstType = dtype::QuantizedS4; - static constexpr int row = 2; - static constexpr int col = 64; - static constexpr int size_nbits = 4; - static constexpr int col_in_type = col * size_nbits / (8 * sizeof(SrcType)); - static constexpr int elements_in_type = row * col_in_type; - static constexpr int inc_col = 8; - static constexpr int inc_col_in_type = - inc_col * size_nbits / (8 * sizeof(SrcType)); - using Fragment = array_wrapper; - static MEGDNN_DEVICE __forceinline__ void trans( - Fragment& dst, const Fragment& src, - CudaPostProcess& post_process) { - int intermediate[2][8]; - uint8_t* dst_frag = reinterpret_cast(&dst); -#pragma unroll - for (int j = 0; j < col_in_type; j += inc_col_in_type) { - transform_int4x8_to_int8( - intermediate[0], - reinterpret_cast(src[0 * col_in_type + j])); - transform_int4x8_to_int8( - intermediate[1], - reinterpret_cast(src[1 * col_in_type + j])); - dst_frag[(j / inc_col_in_type) * 8 + 0] = pack(0); - dst_frag[(j / inc_col_in_type) * 8 + 1] = pack(1); - dst_frag[(j / inc_col_in_type) * 8 + 2] = pack(2); - dst_frag[(j / inc_col_in_type) * 8 + 3] = pack(3); - dst_frag[(j / inc_col_in_type) * 8 + 4] = pack(4); - dst_frag[(j / inc_col_in_type) * 8 + 5] = pack(5); - dst_frag[(j / inc_col_in_type) * 8 + 6] = pack(6); - dst_frag[(j / inc_col_in_type) * 8 + 7] = pack(7); - } - } -}; -#undef pack - -#define pack(_idx) \ - transform_int8_to_uint4x8(post_process(intermediate[0][_idx]), \ - post_process(intermediate[1][_idx]), \ - post_process(intermediate[2][_idx]), \ - post_process(intermediate[3][_idx]), \ - post_process(intermediate[4][_idx]), \ - post_process(intermediate[5][_idx]), \ - post_process(intermediate[6][_idx]), \ - post_process(intermediate[7][_idx])); -template -struct Translayout<64, 8, SrcType, dtype::Quantized4Asymm, - dtype::Quantized4Asymm, same_scale> { - using DnnSrcType = dtype::Quantized4Asymm; - using DnnDstType = dtype::Quantized4Asymm; - static constexpr int row = 8; - static constexpr int col = 64; - static constexpr int size_nbits = 4; - static constexpr int col_in_type = col * size_nbits / (8 * sizeof(SrcType)); - static constexpr int elements_in_type = row * col_in_type; - static constexpr int inc_col = 8; - static constexpr int inc_col_in_type = - inc_col * size_nbits / (8 * sizeof(SrcType)); - using Fragment = array_wrapper; - static MEGDNN_DEVICE __forceinline__ void trans( - Fragment& dst, const Fragment& src, - CudaPostProcess& post_process) { - int intermediate[8][8]; - int* dst_frag = reinterpret_cast(&dst); -#pragma unroll - for (int j = 0; j < col_in_type; j += inc_col_in_type) { - transform_uint4x8_to_int8( - intermediate[0], - reinterpret_cast(src[0 * col_in_type + j])); - transform_uint4x8_to_int8( - intermediate[1], - reinterpret_cast(src[1 * col_in_type + j])); - transform_uint4x8_to_int8( - intermediate[2], - reinterpret_cast(src[2 * col_in_type + j])); - transform_uint4x8_to_int8( - intermediate[3], - reinterpret_cast(src[3 * col_in_type + j])); - transform_uint4x8_to_int8( - intermediate[4], - reinterpret_cast(src[4 * col_in_type + j])); - transform_uint4x8_to_int8( - intermediate[5], - reinterpret_cast(src[5 * col_in_type + j])); - transform_uint4x8_to_int8( - intermediate[6], - reinterpret_cast(src[6 * col_in_type + j])); - transform_uint4x8_to_int8( - intermediate[7], - reinterpret_cast(src[7 * col_in_type + j])); - dst_frag[(j / inc_col_in_type) * 8 + 0] = pack(0); - dst_frag[(j / inc_col_in_type) * 8 + 1] = pack(1); - dst_frag[(j / inc_col_in_type) * 8 + 2] = pack(2); - dst_frag[(j / inc_col_in_type) * 8 + 3] = pack(3); - dst_frag[(j / inc_col_in_type) * 8 + 4] = pack(4); - dst_frag[(j / inc_col_in_type) * 8 + 5] = pack(5); - dst_frag[(j / inc_col_in_type) * 8 + 6] = pack(6); - dst_frag[(j / inc_col_in_type) * 8 + 7] = pack(7); - } - } -}; -#undef pack - -#define pack(_idx) \ - (post_process(intermediate[0][_idx]) | \ - (post_process(intermediate[1][_idx]) << 4)) -template -struct Translayout<64, 2, SrcType, dtype::Quantized4Asymm, - dtype::Quantized4Asymm, same_scale> { - using DnnSrcType = dtype::Quantized4Asymm; - using DnnDstType = dtype::Quantized4Asymm; - static constexpr int row = 2; - static constexpr int col = 64; - static constexpr int size_nbits = 4; - static constexpr int col_in_type = col * size_nbits / (8 * sizeof(SrcType)); - static constexpr int elements_in_type = row * col_in_type; - static constexpr int inc_col = 8; - static constexpr int inc_col_in_type = - inc_col * size_nbits / (8 * sizeof(SrcType)); - using Fragment = array_wrapper; - static MEGDNN_DEVICE __forceinline__ void trans( - Fragment& dst, const Fragment& src, - CudaPostProcess& post_process) { - int intermediate[2][8]; - uint8_t* dst_frag = reinterpret_cast(&dst); -#pragma unroll - for (int j = 0; j < col_in_type; j += inc_col_in_type) { - transform_uint4x8_to_int8( - intermediate[0], - reinterpret_cast(src[0 * col_in_type + j])); - transform_uint4x8_to_int8( - intermediate[1], - reinterpret_cast(src[1 * col_in_type + j])); - dst_frag[(j / inc_col_in_type) * 8 + 0] = pack(0); - dst_frag[(j / inc_col_in_type) * 8 + 1] = pack(1); - dst_frag[(j / inc_col_in_type) * 8 + 2] = pack(2); - dst_frag[(j / inc_col_in_type) * 8 + 3] = pack(3); - dst_frag[(j / inc_col_in_type) * 8 + 4] = pack(4); - dst_frag[(j / inc_col_in_type) * 8 + 5] = pack(5); - dst_frag[(j / inc_col_in_type) * 8 + 6] = pack(6); - dst_frag[(j / inc_col_in_type) * 8 + 7] = pack(7); - } - } -}; -#undef pack - -template -inline __device__ DstType make_zero_pad(const uint8_t zero_point) { - return zero_point; -} - -template <> -inline __device__ char4 make_zero_pad(const uint8_t zero_point) { - char izp = reinterpret_cast(zero_point); - return {izp, izp, izp, izp}; -} - -template <> -inline __device__ int4 make_zero_pad(const uint8_t zero_point) { - return {zero_point, zero_point, zero_point, zero_point}; -} - -template -inline __device__ int make_zero(int zero_point); - -template <> -inline __device__ int make_zero<4>(int zero_point) { - return transform_int8_to_uint4x8(zero_point, zero_point, zero_point, - zero_point, zero_point, zero_point, - zero_point, zero_point); -} - -template -inline __device__ void write_helper(DstDtype* ptr, DstDtype val) { - *ptr = val; -} - -template <> -inline __device__ void write_helper(char4* ptr, char4 val) { - int32_t* rel_ptr = (int32_t*)ptr; - *rel_ptr = *(int32_t*)(&val); -} - -template <> -inline __device__ void write_helper>( - array_wrapper* ptr, array_wrapper val) { - uint4 const* data = reinterpret_cast(&val); - void* ptr_ = reinterpret_cast(ptr); - asm volatile( - "{\n" - " st.global.v4.u32 [%0], {%1, %2, %3, %4};\n" - " st.global.v4.u32 [%5], {%6, %7, %8, %9};\n" - "}\n" - : - : "l"(ptr_), "r"(data[0].x), "r"(data[0].y), "r"(data[0].z), - "r"(data[0].w), "l"(((uint8_t*)ptr_) + 16), "r"(data[1].x), - "r"(data[1].y), "r"(data[1].z), "r"(data[1].w)); -} - template @@ -1063,323 +288,6 @@ __global__ void kern_nchw_nchw4_weight( } } } - -template -class TensorIteratorOverChannel { -public: - using Type = Type_; - static constexpr int pack_size = pack_size_; - static constexpr int chan_blk = chan_blk_; - static constexpr int width = width_; - static constexpr int size_nbits = size_nbits_; - static constexpr int elements_in_type = - chan_blk * width * size_nbits / (8 * sizeof(Type)); - static constexpr int lane_size_in_type = - (width * pack_size * size_nbits) / (8 * sizeof(Type)); - static constexpr int pack_size_in_type = - (pack_size * size_nbits) >= (8 * sizeof(Type)) - ? (pack_size * size_nbits / (8 * sizeof(Type))) - : (width * pack_size * size_nbits / (8 * sizeof(Type))); - static constexpr int pack_size_in_byte = pack_size_in_type * sizeof(Type); - using AccessType = array_wrapper; - using Fragment = array_wrapper; - - MEGDNN_HOST TensorIteratorOverChannel() - : pointer{nullptr}, chan_stride_in_elements{0}, channel{0} {} - MEGDNN_HOST TensorIteratorOverChannel(Type* pointer_, - int chan_stride_in_elements_, - int channel_, int, int) - : pointer{pointer_}, - chan_stride_in_elements{chan_stride_in_elements_}, - channel{channel_} {} - - MEGDNN_DEVICE __forceinline__ void initialize(int c_idx, int hw_idx) { - pointer += (c_idx / pack_size) * chan_stride_in_elements + - hw_idx * pack_size * size_nbits / (8 * sizeof(Type)); - channel -= c_idx; - } - - MEGDNN_DEVICE __forceinline__ void add_pointer_offset( - size_t offset_in_type) { - pointer += offset_in_type; - } - - MEGDNN_DEVICE __forceinline__ void load(Fragment& frag, int zero_point) { - AccessType* frag_ptr = reinterpret_cast(&frag); - Type* pointer_ = pointer; -#pragma unroll - for (int i = 0; i < chan_blk; i += pack_size) { -#pragma unroll - for (int j = 0; j < lane_size_in_type / pack_size_in_type; j++) { - int frag_idx = i / pack_size * - (lane_size_in_type / pack_size_in_type) + - j; - bool guard = i < channel; - memory::global_load( - frag_ptr[frag_idx], - reinterpret_cast(pointer_ + - j * pack_size_in_type), - guard, zero_point); - } - pointer_ += chan_stride_in_elements; - } - } - - MEGDNN_DEVICE __forceinline__ void store(const Fragment& frag) { - const AccessType* frag_ptr = reinterpret_cast(&frag); - Type* pointer_ = pointer; -#pragma unroll - for (int i = 0; i < chan_blk; i += pack_size) { -#pragma unroll - for (int j = 0; j < lane_size_in_type / pack_size_in_type; j++) { - int frag_idx = i / pack_size * - (lane_size_in_type / pack_size_in_type) + - j; - bool guard = i < channel; - memory::global_store( - frag_ptr[frag_idx], - reinterpret_cast(pointer_ + - j * pack_size_in_type), - guard); - } - pointer_ += chan_stride_in_elements; - } - } - - - MEGDNN_DEVICE __forceinline__ void advance() { - pointer += (chan_blk / pack_size) * chan_stride_in_elements; - channel -= chan_blk; - } - -private: - Type* pointer; - int chan_stride_in_elements; - int channel; -}; - -template -class MaskedTensorIteratorOverChannel { -public: - using Type = Type_; - static constexpr int pack_size = pack_size_; - static constexpr int chan_blk = chan_blk_; - static constexpr int width = width_; - static constexpr int size_nbits = size_nbits_; - static constexpr int elements_in_type = - chan_blk * width * size_nbits / (8 * sizeof(Type)); - static constexpr int lane_size_in_type = - (width * pack_size * size_nbits) / (8 * sizeof(Type)); - static constexpr int pack_size_in_type = - (pack_size * size_nbits) >= (8 * sizeof(Type)) - ? (pack_size * size_nbits / (8 * sizeof(Type))) - : (width * pack_size * size_nbits / (8 * sizeof(Type))); - static constexpr int pack_size_in_byte = pack_size_in_type * sizeof(Type); - static constexpr int accesses = elements_in_type / pack_size_in_type; - static constexpr int mask_size = (accesses + 32 - 1) / 32; - using AccessType = array_wrapper; - using Fragment = array_wrapper; - - MEGDNN_HOST MaskedTensorIteratorOverChannel() - : pointer{nullptr}, - chan_stride_in_elements{0}, - channel{0} {} - MEGDNN_HOST MaskedTensorIteratorOverChannel( - Type* pointer_, int chan_stride_in_elements_, int channel_, - int bound_, int div_) - : pointer{pointer_}, - chan_stride_in_elements{chan_stride_in_elements_}, - channel{channel_}, - bound{bound_}, - div{uint32_t(div_)} {} - - MEGDNN_DEVICE __forceinline__ void initialize(int c_idx, int hw_idx) { - pointer += (c_idx / pack_size) * chan_stride_in_elements; - channel -= c_idx; -#pragma unroll - for (int i = 0; i < mask_size; ++i) { - mask[i] = 0; - } -#pragma unroll - for (int i = 0; i < chan_blk; i += pack_size) { -#pragma unroll - for (int j = 0; j < lane_size_in_type / pack_size_in_type; j++) { - int offset = hw_idx + j; - int h = (int)((uint32_t)(offset) / div); - int w = (int)((uint32_t)(offset) % div); - bool guard = (i < channel) && (w < bound); - int index = (i / pack_size) * - (lane_size_in_type / pack_size_in_type) + - j; - int mask_index = (index >> 5); - int mask_shift = (index & 0x1f); - mask[mask_index] |= (guard << mask_shift); - stride[j] = (h * bound + w) * pack_size * size_nbits / - (8 * sizeof(Type)); - } - } - } - - MEGDNN_DEVICE __forceinline__ void add_pointer_offset(size_t offset_in_type) { - pointer += offset_in_type; - } - - MEGDNN_DEVICE __forceinline__ void load(Fragment& frag, int zero_point) { - AccessType* frag_ptr = reinterpret_cast(&frag); - Type* pointer_ = pointer; -#pragma unroll - for (int i = 0; i < chan_blk; i += pack_size) { -#pragma unroll - for (int j = 0; j < lane_size_in_type / pack_size_in_type; j++) { - int frag_idx = i / pack_size * - (lane_size_in_type / pack_size_in_type) + - j; - int mask_index = (frag_idx >> 5); - int mask_shift = (frag_idx & 0x1f); - bool guard = (mask[mask_index] & (1 << mask_shift)); - memory::global_load( - frag_ptr[frag_idx], - reinterpret_cast(pointer_ + stride[j]), guard, - zero_point); - } - pointer_ += chan_stride_in_elements; - } - } - - MEGDNN_DEVICE __forceinline__ void store(const Fragment& frag) { - const AccessType* frag_ptr = reinterpret_cast(&frag); - Type* pointer_ = pointer; -#pragma unroll - for (int i = 0; i < chan_blk; i += pack_size) { -#pragma unroll - for (int j = 0; j < lane_size_in_type / pack_size_in_type; j++) { - int frag_idx = i / pack_size * - (lane_size_in_type / pack_size_in_type) + - j; - int mask_index = (frag_idx >> 5); - int mask_shift = (frag_idx & 0x1f); - bool guard = (mask[mask_index] & (1 << mask_shift)); - memory::global_store( - frag_ptr[frag_idx], - reinterpret_cast(pointer_ + stride[j]), guard); - } - pointer_ += chan_stride_in_elements; - } - } - - MEGDNN_DEVICE __forceinline__ void advance() { - pointer += (chan_blk / pack_size) * chan_stride_in_elements; - channel -= chan_blk; - } - -private: - Type* pointer; - int chan_stride_in_elements; - int channel; - int bound; - Uint32Fastdiv div; - uint32_t mask[mask_size]; - size_t stride[lane_size_in_type / pack_size_in_type]; -}; - -template -struct TensorIteratorPolicy; -template -struct TensorIteratorPolicy { - using TensorIterator = - MaskedTensorIteratorOverChannel; -}; -template -struct TensorIteratorPolicy { - using TensorIterator = - TensorIteratorOverChannel; -}; - -template -struct RelayoutProblem { - using SrcIterator = SrcIterator_; - using DstIterator = DstIterator_; - using Transpose = Transpose_; - using CudaPostProcess = CudaPostProcess_; - MEGDNN_STATIC_ASSERT(SrcIterator::chan_blk == DstIterator::chan_blk, - "channel block mismatch"); - MEGDNN_STATIC_ASSERT(SrcIterator::width == DstIterator::width, - "width block mismatch"); - MEGDNN_STATIC_ASSERT(SrcIterator::size_nbits == DstIterator::size_nbits, - "size in bits of elements mismatch"); - static constexpr int pack_chan = SrcIterator::chan_blk; - static constexpr int pack_width = SrcIterator::width; - using DnnSrcType = typename CudaPostProcess::SrcType; - using DnnDstType = typename CudaPostProcess::DstType; - struct Param { - SrcIterator src_iterator; - DstIterator dst_iterator; - CudaPostProcess post_process; - int n_stride_src; - int n_stride_dst; - int batch_size; - int channels; - int hw; - int zero_point; - MEGDNN_HOST MEGDNN_DEVICE Param(SrcIterator src_iterator_, - DstIterator dst_iterator_, - CudaPostProcess post_process_, - int n_stride_src_, int n_stride_dst_, - int batch_size_, int channels_, int hw_, - int zero_point_) - : src_iterator{src_iterator_}, - dst_iterator{dst_iterator_}, - post_process{post_process_}, - n_stride_src{n_stride_src_}, - n_stride_dst{n_stride_dst_}, - batch_size{batch_size_}, - channels{channels_}, - hw{hw_}, - zero_point{zero_point_} {} - }; -}; - -template -__global__ void relayout_kern(typename RelayoutProblem_::Param param) { - using SrcIterator = typename RelayoutProblem_::SrcIterator; - using DstIterator = typename RelayoutProblem_::DstIterator; - static constexpr int pack_chan = RelayoutProblem_::pack_chan; - static constexpr int pack_width = RelayoutProblem_::pack_width; - const int thread_idx = blockIdx.x * blockDim.x + threadIdx.x; - const int thread_offset = thread_idx * pack_width; - const int hw_idx = (thread_offset % param.hw); - const int nc_blks = thread_offset / param.hw; - const int c_blks = (param.channels + pack_chan - 1) / pack_chan; - const int n_idx = nc_blks / c_blks; - const int c_blk_idx = nc_blks % c_blks; - const int c_idx = c_blk_idx * pack_chan; - if (n_idx < param.batch_size) { - const int src_offset = n_idx * param.n_stride_src; - const int dst_offset = n_idx * param.n_stride_dst; - param.src_iterator.add_pointer_offset(src_offset); - param.dst_iterator.add_pointer_offset(dst_offset); - param.src_iterator.initialize(c_idx, hw_idx); - param.dst_iterator.initialize(c_idx, hw_idx); - typename SrcIterator::Fragment src_frag; - typename DstIterator::Fragment dst_frag; - int zp = make_zero(param.zero_point); - param.src_iterator.load(src_frag, zp); - RelayoutProblem_::Transpose::trans( - reinterpret_cast(dst_frag), - src_frag, param.post_process); - param.dst_iterator.store(dst_frag); - } -} } // namespace void relayout_format::relayout_format_cuda_nchw_nchwx( diff --git a/dnn/src/cuda/relayout_format/relayout_format.cuh b/dnn/src/cuda/relayout_format/relayout_format.cuh index 2d621940..fffae53f 100644 --- a/dnn/src/cuda/relayout_format/relayout_format.cuh +++ b/dnn/src/cuda/relayout_format/relayout_format.cuh @@ -39,6 +39,20 @@ void relayout_format_cuda_nchwx_nchw(const TensorND& src, const TensorND& dst, const uint8_t src_zero_point = 0, const uint8_t dst_zero_point = 0); +void relayout_format_cuda_nchw_nhwc(const TensorND& src, const TensorND& dst, + const cudaStream_t& stream, + const float src_scale = 1.f, + const float dst_scale = 1.f, + const uint8_t src_zero_point = 0, + const uint8_t dst_zero_point = 0); + +void relayout_format_cuda_nhwc_nchw(const TensorND& src, const TensorND& dst, + const cudaStream_t& stream, + const float src_scale = 1.f, + const float dst_scale = 1.f, + const uint8_t src_zero_point = 0, + const uint8_t dst_zero_point = 0); + void relayout_format_cuda_nchw_nchw4_weight(const TensorND& src, const TensorND& dst, const cudaStream_t& stream); diff --git a/dnn/src/cuda/relayout_format/relayout_format_kern.cuh b/dnn/src/cuda/relayout_format/relayout_format_kern.cuh new file mode 100644 index 00000000..b2ce2eb6 --- /dev/null +++ b/dnn/src/cuda/relayout_format/relayout_format_kern.cuh @@ -0,0 +1,346 @@ +/** + * \file dnn/src/cuda/relayout_format/relayout_format_kern.cuh + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ +#pragma once +#include "src/cuda/int_fastdiv.cuh" +#include "src/cuda/memory_utils.cuh" +#include "src/cuda/relayout_format/translayout.cuh" + +namespace megdnn { +namespace cuda { +namespace relayout_format { +namespace internal { +using namespace memory; + +template +class TensorIteratorOverChannel { +public: + using Type = Type_; + static constexpr int pack_size = pack_size_; + static constexpr int chan_blk = chan_blk_; + static constexpr int width = width_; + static constexpr int size_nbits = size_nbits_; + static constexpr int elements_in_type = + chan_blk * width * size_nbits / (8 * sizeof(Type)); + static constexpr int lane_size_in_type = + (width * pack_size * size_nbits) / (8 * sizeof(Type)); + static constexpr int pack_size_in_type = + (pack_size * size_nbits) >= (8 * sizeof(Type)) + ? (pack_size * size_nbits / (8 * sizeof(Type))) + : (width * pack_size * size_nbits / (8 * sizeof(Type))); + static constexpr int pack_size_in_byte = pack_size_in_type * sizeof(Type); + using AccessType = array_wrapper; + using Fragment = array_wrapper; + + MEGDNN_HOST TensorIteratorOverChannel() + : pointer{nullptr}, chan_stride_in_elements{0}, channel{0} {} + MEGDNN_HOST TensorIteratorOverChannel(Type* pointer_, + int chan_stride_in_elements_, + int channel_, int, int) + : pointer{pointer_}, + chan_stride_in_elements{chan_stride_in_elements_}, + channel{channel_} {} + + MEGDNN_DEVICE __forceinline__ void initialize(int c_idx, int hw_idx) { + pointer += (c_idx / pack_size) * chan_stride_in_elements + + hw_idx * pack_size * size_nbits / (8 * sizeof(Type)); + channel -= c_idx; + } + + MEGDNN_DEVICE __forceinline__ void add_pointer_offset( + size_t offset_in_type) { + pointer += offset_in_type; + } + + MEGDNN_DEVICE __forceinline__ void load(Fragment& frag, int zero_point) { + AccessType* frag_ptr = reinterpret_cast(&frag); + Type* pointer_ = pointer; +#pragma unroll + for (int i = 0; i < chan_blk; i += pack_size) { +#pragma unroll + for (int j = 0; j < lane_size_in_type / pack_size_in_type; j++) { + int frag_idx = i / pack_size * + (lane_size_in_type / pack_size_in_type) + + j; + bool guard = i < channel; + global_load( + frag_ptr[frag_idx], + reinterpret_cast(pointer_ + + j * pack_size_in_type), + guard, zero_point); + } + pointer_ += chan_stride_in_elements; + } + } + + MEGDNN_DEVICE __forceinline__ void store(const Fragment& frag) { + const AccessType* frag_ptr = reinterpret_cast(&frag); + Type* pointer_ = pointer; +#pragma unroll + for (int i = 0; i < chan_blk; i += pack_size) { +#pragma unroll + for (int j = 0; j < lane_size_in_type / pack_size_in_type; j++) { + int frag_idx = i / pack_size * + (lane_size_in_type / pack_size_in_type) + + j; + bool guard = i < channel; + global_store( + frag_ptr[frag_idx], + reinterpret_cast(pointer_ + + j * pack_size_in_type), + guard); + } + pointer_ += chan_stride_in_elements; + } + } + + MEGDNN_DEVICE __forceinline__ void advance() { + pointer += (chan_blk / pack_size) * chan_stride_in_elements; + channel -= chan_blk; + } + +private: + Type* pointer; + int chan_stride_in_elements; + int channel; +}; + +template +class MaskedTensorIteratorOverChannel { +public: + using Type = Type_; + static constexpr int pack_size = pack_size_; + static constexpr int chan_blk = chan_blk_; + static constexpr int width = width_; + static constexpr int size_nbits = size_nbits_; + static constexpr int elements_in_type = + chan_blk * width * size_nbits / (8 * sizeof(Type)); + static constexpr int lane_size_in_type = + (width * pack_size * size_nbits) / (8 * sizeof(Type)); + static constexpr int pack_size_in_type = + (pack_size * size_nbits) >= (8 * sizeof(Type)) + ? (pack_size * size_nbits / (8 * sizeof(Type))) + : (width * pack_size * size_nbits / (8 * sizeof(Type))); + static constexpr int pack_size_in_byte = pack_size_in_type * sizeof(Type); + static constexpr int accesses = elements_in_type / pack_size_in_type; + static constexpr int mask_size = (accesses + 32 - 1) / 32; + using AccessType = array_wrapper; + using Fragment = array_wrapper; + + MEGDNN_HOST MaskedTensorIteratorOverChannel() + : pointer{nullptr}, chan_stride_in_elements{0}, channel{0} {} + MEGDNN_HOST MaskedTensorIteratorOverChannel(Type* pointer_, + int chan_stride_in_elements_, + int channel_, int bound_, + int div_) + : pointer{pointer_}, + chan_stride_in_elements{chan_stride_in_elements_}, + channel{channel_}, + bound{bound_}, + div{uint32_t(div_)} {} + + MEGDNN_DEVICE __forceinline__ void initialize(int c_idx, int hw_idx) { + pointer += (c_idx / pack_size) * chan_stride_in_elements; + channel -= c_idx; + int w[lane_size_in_type / pack_size_in_type]; +#pragma unroll + for (int i = 0; i < mask_size; ++i) { + mask[i] = 0; + } +#pragma unroll + for (int j = 0; j < lane_size_in_type / pack_size_in_type; j++) { + int offset = hw_idx + j; + int h = (int)((uint32_t)(offset) / div); + w[j] = (int)((uint32_t)(offset) % div); + stride[j] = (h * bound + w[j]) * pack_size * size_nbits / + (8 * sizeof(Type)); + } +#pragma unroll + for (int i = 0; i < chan_blk; i += pack_size) { +#pragma unroll + for (int j = 0; j < lane_size_in_type / pack_size_in_type; j++) { + bool guard = (i < channel) && (w[j] < bound); + int index = (i / pack_size) * + (lane_size_in_type / pack_size_in_type) + + j; + int mask_index = (index >> 5); + int mask_shift = (index & 0x1f); + mask[mask_index] |= (guard << mask_shift); + } + } + } + + MEGDNN_DEVICE __forceinline__ void add_pointer_offset( + size_t offset_in_type) { + pointer += offset_in_type; + } + + MEGDNN_DEVICE __forceinline__ void load(Fragment& frag, int zero_point) { + AccessType* frag_ptr = reinterpret_cast(&frag); + Type* pointer_ = pointer; +#pragma unroll + for (int i = 0; i < chan_blk; i += pack_size) { +#pragma unroll + for (int j = 0; j < lane_size_in_type / pack_size_in_type; j++) { + int frag_idx = i / pack_size * + (lane_size_in_type / pack_size_in_type) + + j; + int mask_index = (frag_idx >> 5); + int mask_shift = (frag_idx & 0x1f); + bool guard = (mask[mask_index] & (1 << mask_shift)); + global_load( + frag_ptr[frag_idx], + reinterpret_cast(pointer_ + stride[j]), guard, + zero_point); + } + pointer_ += chan_stride_in_elements; + } + } + + MEGDNN_DEVICE __forceinline__ void store(const Fragment& frag) { + const AccessType* frag_ptr = reinterpret_cast(&frag); + Type* pointer_ = pointer; +#pragma unroll + for (int i = 0; i < chan_blk; i += pack_size) { +#pragma unroll + for (int j = 0; j < lane_size_in_type / pack_size_in_type; j++) { + int frag_idx = i / pack_size * + (lane_size_in_type / pack_size_in_type) + + j; + int mask_index = (frag_idx >> 5); + int mask_shift = (frag_idx & 0x1f); + bool guard = (mask[mask_index] & (1 << mask_shift)); + global_store( + frag_ptr[frag_idx], + reinterpret_cast(pointer_ + stride[j]), guard); + } + pointer_ += chan_stride_in_elements; + } + } + + MEGDNN_DEVICE __forceinline__ void advance() { + pointer += (chan_blk / pack_size) * chan_stride_in_elements; + channel -= chan_blk; + } + +private: + Type* pointer; + int chan_stride_in_elements; + int channel; + int bound; + Uint32Fastdiv div; + uint32_t mask[mask_size]; + size_t stride[lane_size_in_type / pack_size_in_type]; +}; + +template +struct TensorIteratorPolicy; +template +struct TensorIteratorPolicy { + using TensorIterator = + MaskedTensorIteratorOverChannel; +}; +template +struct TensorIteratorPolicy { + using TensorIterator = + TensorIteratorOverChannel; +}; + +template +struct RelayoutProblem { + using SrcIterator = SrcIterator_; + using DstIterator = DstIterator_; + using Transpose = Transpose_; + using CudaPostProcess = CudaPostProcess_; + MEGDNN_STATIC_ASSERT(SrcIterator::chan_blk == DstIterator::chan_blk, + "channel block mismatch"); + MEGDNN_STATIC_ASSERT(SrcIterator::width == DstIterator::width, + "width block mismatch"); + MEGDNN_STATIC_ASSERT(SrcIterator::size_nbits == DstIterator::size_nbits, + "size in bits of elements mismatch"); + static constexpr int pack_chan = SrcIterator::chan_blk; + static constexpr int pack_width = SrcIterator::width; + using DnnSrcType = typename CudaPostProcess::SrcType; + using DnnDstType = typename CudaPostProcess::DstType; + struct Param { + SrcIterator src_iterator; + DstIterator dst_iterator; + CudaPostProcess post_process; + int n_stride_src; + int n_stride_dst; + int batch_size; + int channels; + int hw; + int zero_point; + MEGDNN_HOST MEGDNN_DEVICE Param(SrcIterator src_iterator_, + DstIterator dst_iterator_, + CudaPostProcess post_process_, + int n_stride_src_, int n_stride_dst_, + int batch_size_, int channels_, int hw_, + int zero_point_) + : src_iterator{src_iterator_}, + dst_iterator{dst_iterator_}, + post_process{post_process_}, + n_stride_src{n_stride_src_}, + n_stride_dst{n_stride_dst_}, + batch_size{batch_size_}, + channels{channels_}, + hw{hw_}, + zero_point{zero_point_} {} + }; +}; + +template +__global__ void relayout_kern(typename RelayoutProblem_::Param param) { + using SrcIterator = typename RelayoutProblem_::SrcIterator; + using DstIterator = typename RelayoutProblem_::DstIterator; + static constexpr int pack_chan = RelayoutProblem_::pack_chan; + static constexpr int pack_width = RelayoutProblem_::pack_width; + const int thread_idx = blockIdx.x * blockDim.x + threadIdx.x; + const int thread_offset = thread_idx * pack_width; + const int hw_idx = (thread_offset % param.hw); + const int nc_blks = thread_offset / param.hw; + const int c_blks = (param.channels + pack_chan - 1) / pack_chan; + const int n_idx = nc_blks / c_blks; + const int c_blk_idx = nc_blks % c_blks; + const int c_idx = c_blk_idx * pack_chan; + if (n_idx < param.batch_size) { + const int src_offset = n_idx * param.n_stride_src; + const int dst_offset = n_idx * param.n_stride_dst; + param.src_iterator.add_pointer_offset(src_offset); + param.dst_iterator.add_pointer_offset(dst_offset); + param.src_iterator.initialize(c_idx, hw_idx); + param.dst_iterator.initialize(c_idx, hw_idx); + typename SrcIterator::Fragment src_frag; + typename DstIterator::Fragment dst_frag; + int zp = make_zero(param.zero_point); + param.src_iterator.load(src_frag, zp); + RelayoutProblem_::Transpose::trans( + reinterpret_cast(dst_frag), + src_frag, param.post_process); + param.dst_iterator.store(dst_frag); + } +} + +} // namespace internal +} // namespace relayout_format +} // namespace cuda +} // namespace megdnn diff --git a/dnn/src/cuda/relayout_format/relayout_format_utils.cuh b/dnn/src/cuda/relayout_format/relayout_format_utils.cuh new file mode 100644 index 00000000..0df7f9da --- /dev/null +++ b/dnn/src/cuda/relayout_format/relayout_format_utils.cuh @@ -0,0 +1,128 @@ +/** + * \file dnn/src/cuda/relayout_format/relayout_format_utils.cuh + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ +#pragma once +#include "src/cuda/integer_subbyte_utils.cuh" +#include "src/cuda/relayout_format/relayout_format.cuh" + +namespace megdnn { +namespace cuda { +namespace relayout_format { +namespace internal { +template +struct DTypeRWHelper; +template +struct DTypeRWHelper< + ctype, 1, + typename std::enable_if::value || + std::is_same::value || + std::is_same::value>::type> { + using InnerDtype = char; + using DstDtype = char4; +}; + +template +struct DTypeRWHelper< + ctype, 4, + typename std::enable_if::value || + std::is_same::value || + std::is_same::value>::type> { + using InnerDtype = char4; + using DstDtype = char4; +}; + +template <> +struct DTypeRWHelper { + using InnerDtype = int; + using DstDtype = int4; +}; + +template <> +struct DTypeRWHelper { + using InnerDtype = int4; + using DstDtype = int4; +}; + +template +struct DTypeRWHelper< + ctype, 2, + typename std::enable_if::value || + std::is_same::value>::type> { + using InnerDtype = char; + using DstDtype = array_wrapper; +}; + +template +struct DTypeRWHelper< + ctype, 8, + typename std::enable_if::value || + std::is_same::value>::type> { + using InnerDtype = unsigned; + using DstDtype = array_wrapper; +}; + +template +inline __device__ DstType make_zero_pad(const uint8_t zero_point) { + return zero_point; +} + +template <> +inline __device__ char4 make_zero_pad(const uint8_t zero_point) { + char izp = reinterpret_cast(zero_point); + return {izp, izp, izp, izp}; +} + +template <> +inline __device__ int4 make_zero_pad(const uint8_t zero_point) { + return {zero_point, zero_point, zero_point, zero_point}; +} + +template +inline __device__ int make_zero(int zero_point); + +template <> +inline __device__ int make_zero<4>(int zero_point) { + return integer_subbyte::transform_int8_to_uint4x8( + zero_point, zero_point, zero_point, zero_point, zero_point, + zero_point, zero_point, zero_point); +} + +template +inline __device__ void write_helper(DstDtype* ptr, DstDtype val) { + *ptr = val; +} + +template <> +inline __device__ void write_helper(char4* ptr, char4 val) { + int32_t* rel_ptr = (int32_t*)ptr; + *rel_ptr = *(int32_t*)(&val); +} + +template <> +inline __device__ void write_helper>( + array_wrapper* ptr, array_wrapper val) { + uint4 const* data = reinterpret_cast(&val); + void* ptr_ = reinterpret_cast(ptr); + asm volatile( + "{\n" + " st.global.v4.u32 [%0], {%1, %2, %3, %4};\n" + " st.global.v4.u32 [%5], {%6, %7, %8, %9};\n" + "}\n" + : + : "l"(ptr_), "r"(data[0].x), "r"(data[0].y), "r"(data[0].z), + "r"(data[0].w), "l"(((uint8_t*)ptr_) + 16), "r"(data[1].x), + "r"(data[1].y), "r"(data[1].z), "r"(data[1].w)); +} + +} // namespace internal +} // namespace relayout_format +} // namespace cuda +} // namespace megdnn diff --git a/dnn/src/cuda/relayout_format/translayout.cuh b/dnn/src/cuda/relayout_format/translayout.cuh new file mode 100644 index 00000000..d1cc8238 --- /dev/null +++ b/dnn/src/cuda/relayout_format/translayout.cuh @@ -0,0 +1,537 @@ +/** + * \file dnn/src/cuda/relayout_format/translayout.cuh + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ +#pragma once +#include "src/cuda/integer_subbyte_utils.cuh" +#include "src/cuda/relayout_format/cuda_post_process.cuh" +#include "src/cuda/relayout_format/relayout_format.cuh" +#include "src/cuda/relayout_format/relayout_format_utils.cuh" + +namespace megdnn { +namespace cuda { +namespace relayout_format { +namespace internal { +using namespace integer_subbyte; + +template +struct qtype_signedness; + +template <> +struct qtype_signedness { + static constexpr bool value = true; +}; + +template <> +struct qtype_signedness { + static constexpr bool value = false; +}; + +template +struct enable_qtype_b4 { + static constexpr bool val_src = + std::is_same::value || + std::is_same::value; + static constexpr bool val_dst = + std::is_same::value || + std::is_same::value; + using type = typename std::enable_if::value && + val_src && val_dst>::type; +}; + +// The input fragment is stored in RowMajor order. The translayout operator +// performs a transpose operation on the input fragment, and produces a +// reordered fragment, i.e. a fragment stored in ColumnMajor order. +template +struct Translayout; + +// partial specialization for translayout operator for qint8 and quint8 +template +struct Translayout<1, 4, SrcType, DnnSrcType, DnnDstType, same_scale> { + using InnerDtype = + typename DTypeRWHelper::ctype, + 1>::InnerDtype; + using DstDtype = + typename DTypeRWHelper::ctype, + 1>::DstDtype; + static inline __device__ void trans( + DstDtype (&dst_width)[1], InnerDtype (&read_channel)[4], + CudaPostProcess& post_process, + const char zero_point) { + dst_width[0].x = post_process(read_channel[0]); + dst_width[0].y = post_process(read_channel[1]); + dst_width[0].z = post_process(read_channel[2]); + dst_width[0].w = post_process(read_channel[3]); + } +}; + +template +struct Translayout<4, 4, SrcType, DnnSrcType, DnnDstType, same_scale> { + using InnerDtype = + typename DTypeRWHelper::ctype, + 4>::InnerDtype; + using DstDtype = + typename DTypeRWHelper::ctype, + 4>::DstDtype; + static inline __device__ void trans( + DstDtype (&dst_width)[4], InnerDtype (&read_channel)[4], + CudaPostProcess& post_process, + const char zero_point) { + dst_width[0].x = post_process(read_channel[0].x); + dst_width[0].y = post_process(read_channel[1].x); + dst_width[0].z = post_process(read_channel[2].x); + dst_width[0].w = post_process(read_channel[3].x); + + dst_width[1].x = post_process(read_channel[0].y); + dst_width[1].y = post_process(read_channel[1].y); + dst_width[1].z = post_process(read_channel[2].y); + dst_width[1].w = post_process(read_channel[3].y); + + dst_width[2].x = post_process(read_channel[0].z); + dst_width[2].y = post_process(read_channel[1].z); + dst_width[2].z = post_process(read_channel[2].z); + dst_width[2].w = post_process(read_channel[3].z); + + dst_width[3].x = post_process(read_channel[0].w); + dst_width[3].y = post_process(read_channel[1].w); + dst_width[3].z = post_process(read_channel[2].w); + dst_width[3].w = post_process(read_channel[3].w); + } +}; + +// ========================================================= + +// partial specialization for translayout operator for qint4 +// NCHW <-> NCHW64 +template +struct Translayout<2, 64, SrcType, DnnSrcType_, DnnDstType_, same_scale, + typename enable_qtype_b4::type> { + using DnnSrcType = DnnSrcType_; + using DnnDstType = DnnDstType_; + using InnerDtype = + typename DTypeRWHelper::ctype, + 2>::InnerDtype; + using DstDtype = + typename DTypeRWHelper::ctype, + 2>::DstDtype; + static constexpr bool signedness = qtype_signedness::value; + static inline __device__ void trans( + DstDtype (&dst_width)[2], InnerDtype (&read_channel)[64], + CudaPostProcess& post_process, + const char zero_point) { + int intermediate[8][2]; + int* dst_frag = reinterpret_cast(dst_width); + auto pack_channel = [&](int idx) -> int { + return transform_int8_to_b4x8( + post_process(intermediate[0][idx]), + post_process(intermediate[1][idx]), + post_process(intermediate[2][idx]), + post_process(intermediate[3][idx]), + post_process(intermediate[4][idx]), + post_process(intermediate[5][idx]), + post_process(intermediate[6][idx]), + post_process(intermediate[7][idx])); + }; +#pragma unroll + for (int i = 0; i < 64; i += 8) { + transform_b4x2_to_int8( + intermediate[0], + reinterpret_cast(read_channel[i + 0])); + transform_b4x2_to_int8( + intermediate[1], + reinterpret_cast(read_channel[i + 1])); + transform_b4x2_to_int8( + intermediate[2], + reinterpret_cast(read_channel[i + 2])); + transform_b4x2_to_int8( + intermediate[3], + reinterpret_cast(read_channel[i + 3])); + transform_b4x2_to_int8( + intermediate[4], + reinterpret_cast(read_channel[i + 4])); + transform_b4x2_to_int8( + intermediate[5], + reinterpret_cast(read_channel[i + 5])); + transform_b4x2_to_int8( + intermediate[6], + reinterpret_cast(read_channel[i + 6])); + transform_b4x2_to_int8( + intermediate[7], + reinterpret_cast(read_channel[i + 7])); + + int frag_idx = i / 8; + dst_frag[0 * 8 + frag_idx] = pack_channel(0); + dst_frag[1 * 8 + frag_idx] = pack_channel(1); + } + } + using Fragment = array_wrapper; + static inline __device__ void trans( + Fragment& dst, Fragment& src, + CudaPostProcess& post_process) { + trans(reinterpret_cast(dst), + reinterpret_cast(src), post_process, 0); + } +}; + +template +struct Translayout<8, 64, SrcType, DnnSrcType_, DnnDstType_, same_scale, + typename enable_qtype_b4::type> { + using DnnSrcType = DnnSrcType_; + using DnnDstType = DnnDstType_; + using InnerDtype = + typename DTypeRWHelper::ctype, + 8>::InnerDtype; + using DstDtype = + typename DTypeRWHelper::ctype, + 8>::DstDtype; + static constexpr bool signedness = qtype_signedness::value; + static inline __device__ void trans( + DstDtype (&dst_width)[8], InnerDtype (&read_channel)[64], + CudaPostProcess& post_process, + const char zero_point) { + int intermediate[8][8]; + int* dst_frag = reinterpret_cast(dst_width); + auto pack_channel = [&](int idx) -> int { + return transform_int8_to_b4x8( + post_process(intermediate[0][idx]), + post_process(intermediate[1][idx]), + post_process(intermediate[2][idx]), + post_process(intermediate[3][idx]), + post_process(intermediate[4][idx]), + post_process(intermediate[5][idx]), + post_process(intermediate[6][idx]), + post_process(intermediate[7][idx])); + }; +#pragma unroll + for (int i = 0; i < 64; i += 8) { + transform_b4x8_to_int8(intermediate[0], + read_channel[i + 0]); + transform_b4x8_to_int8(intermediate[1], + read_channel[i + 1]); + transform_b4x8_to_int8(intermediate[2], + read_channel[i + 2]); + transform_b4x8_to_int8(intermediate[3], + read_channel[i + 3]); + transform_b4x8_to_int8(intermediate[4], + read_channel[i + 4]); + transform_b4x8_to_int8(intermediate[5], + read_channel[i + 5]); + transform_b4x8_to_int8(intermediate[6], + read_channel[i + 6]); + transform_b4x8_to_int8(intermediate[7], + read_channel[i + 7]); + int frag_idx = i / 8; + dst_frag[0 * 8 + frag_idx] = pack_channel(0); + dst_frag[1 * 8 + frag_idx] = pack_channel(1); + dst_frag[2 * 8 + frag_idx] = pack_channel(2); + dst_frag[3 * 8 + frag_idx] = pack_channel(3); + dst_frag[4 * 8 + frag_idx] = pack_channel(4); + dst_frag[5 * 8 + frag_idx] = pack_channel(5); + dst_frag[6 * 8 + frag_idx] = pack_channel(6); + dst_frag[7 * 8 + frag_idx] = pack_channel(7); + } + } + using Fragment = array_wrapper; + static inline __device__ void trans( + Fragment& dst, Fragment& src, + CudaPostProcess& post_process) { + trans(reinterpret_cast(dst), + reinterpret_cast(src), post_process, 0); + } +}; + +template +struct Translayout<64, 8, SrcType, DnnSrcType_, DnnDstType_, same_scale, + typename enable_qtype_b4::type> { + using DnnSrcType = DnnSrcType_; + using DnnDstType = DnnDstType_; + static constexpr int row = 8; + static constexpr int col = 64; + static constexpr int size_nbits = 4; + static constexpr int col_in_type = col * size_nbits / (8 * sizeof(SrcType)); + static constexpr int elements_in_type = row * col_in_type; + static constexpr int inc_col = 8; + static constexpr int inc_col_in_type = + inc_col * size_nbits / (8 * sizeof(SrcType)); + static constexpr bool signedness = qtype_signedness::value; + using Fragment = array_wrapper; + static MEGDNN_DEVICE __forceinline__ void trans( + Fragment& dst, const Fragment& src, + CudaPostProcess& post_process) { + int intermediate[8][8]; + int* dst_frag = reinterpret_cast(&dst); + auto pack = [&](int idx) -> int { + return transform_int8_to_b4x8( + post_process(intermediate[0][idx]), + post_process(intermediate[1][idx]), + post_process(intermediate[2][idx]), + post_process(intermediate[3][idx]), + post_process(intermediate[4][idx]), + post_process(intermediate[5][idx]), + post_process(intermediate[6][idx]), + post_process(intermediate[7][idx])); + }; +#pragma unroll + for (int j = 0; j < col_in_type; j += inc_col_in_type) { + transform_b4x8_to_int8( + intermediate[0], + reinterpret_cast(src[0 * col_in_type + j])); + transform_b4x8_to_int8( + intermediate[1], + reinterpret_cast(src[1 * col_in_type + j])); + transform_b4x8_to_int8( + intermediate[2], + reinterpret_cast(src[2 * col_in_type + j])); + transform_b4x8_to_int8( + intermediate[3], + reinterpret_cast(src[3 * col_in_type + j])); + transform_b4x8_to_int8( + intermediate[4], + reinterpret_cast(src[4 * col_in_type + j])); + transform_b4x8_to_int8( + intermediate[5], + reinterpret_cast(src[5 * col_in_type + j])); + transform_b4x8_to_int8( + intermediate[6], + reinterpret_cast(src[6 * col_in_type + j])); + transform_b4x8_to_int8( + intermediate[7], + reinterpret_cast(src[7 * col_in_type + j])); + dst_frag[(j / inc_col_in_type) * 8 + 0] = pack(0); + dst_frag[(j / inc_col_in_type) * 8 + 1] = pack(1); + dst_frag[(j / inc_col_in_type) * 8 + 2] = pack(2); + dst_frag[(j / inc_col_in_type) * 8 + 3] = pack(3); + dst_frag[(j / inc_col_in_type) * 8 + 4] = pack(4); + dst_frag[(j / inc_col_in_type) * 8 + 5] = pack(5); + dst_frag[(j / inc_col_in_type) * 8 + 6] = pack(6); + dst_frag[(j / inc_col_in_type) * 8 + 7] = pack(7); + } + } +}; + +template +struct Translayout<64, 2, SrcType, DnnSrcType_, DnnDstType_, same_scale, + typename enable_qtype_b4::type> { + using DnnSrcType = DnnSrcType_; + using DnnDstType = DnnDstType_; + static constexpr int row = 2; + static constexpr int col = 64; + static constexpr int size_nbits = 4; + static constexpr int col_in_type = col * size_nbits / (8 * sizeof(SrcType)); + static constexpr int elements_in_type = row * col_in_type; + static constexpr int inc_col = 8; + static constexpr int inc_col_in_type = + inc_col * size_nbits / (8 * sizeof(SrcType)); + static constexpr bool signedness = qtype_signedness::value; + using Fragment = array_wrapper; + static MEGDNN_DEVICE __forceinline__ void trans( + Fragment& dst, const Fragment& src, + CudaPostProcess& post_process) { + int intermediate[2][8]; + int* dst_frag = reinterpret_cast(&dst); +#pragma unroll + for (int j = 0; j < col_in_type; j += inc_col_in_type) { + transform_b4x8_to_int8( + intermediate[0], + reinterpret_cast(src[0 * col_in_type + j])); + transform_b4x8_to_int8( + intermediate[1], + reinterpret_cast(src[1 * col_in_type + j])); + dst_frag[(j / inc_col_in_type) * 2 + 0] = + transform_int8_to_b4x8( + post_process(intermediate[0][0]), + post_process(intermediate[1][0]), + post_process(intermediate[0][1]), + post_process(intermediate[1][1]), + post_process(intermediate[0][2]), + post_process(intermediate[1][2]), + post_process(intermediate[0][3]), + post_process(intermediate[1][3])); + dst_frag[(j / inc_col_in_type) * 2 + 1] = + transform_int8_to_b4x8( + post_process(intermediate[0][4]), + post_process(intermediate[1][4]), + post_process(intermediate[0][5]), + post_process(intermediate[1][5]), + post_process(intermediate[0][6]), + post_process(intermediate[1][6]), + post_process(intermediate[0][7]), + post_process(intermediate[1][7])); + } + } +}; +// ========================================================= + +// partial specialization for translayout operator for qint4 +// NCHW <-> NHWC +template +struct Translayout<2, 8, SrcType, DnnSrcType_, DnnDstType_, same_scale, + typename enable_qtype_b4::type> { + using DnnSrcType = DnnSrcType_; + using DnnDstType = DnnDstType_; + static constexpr int row = 8; + static constexpr int col = 2; + static constexpr int size_nbits = 4; + static constexpr int col_in_type = col * size_nbits / (8 * sizeof(SrcType)); + static constexpr int elements_in_type = row * col_in_type; + static constexpr bool signedness = qtype_signedness::value; + using Fragment = array_wrapper; + static inline __device__ void trans( + Fragment& dst, const Fragment& src, + CudaPostProcess& post_process, + const char zero_point) { + int intermediate[8][2]; + transform_b4x2_to_int8(intermediate[0], + reinterpret_cast(src[0])); + transform_b4x2_to_int8(intermediate[1], + reinterpret_cast(src[1])); + transform_b4x2_to_int8(intermediate[2], + reinterpret_cast(src[2])); + transform_b4x2_to_int8(intermediate[3], + reinterpret_cast(src[3])); + transform_b4x2_to_int8(intermediate[4], + reinterpret_cast(src[4])); + transform_b4x2_to_int8(intermediate[5], + reinterpret_cast(src[5])); + transform_b4x2_to_int8(intermediate[6], + reinterpret_cast(src[6])); + transform_b4x2_to_int8(intermediate[7], + reinterpret_cast(src[7])); + + int* dst_frag = reinterpret_cast(&dst); + auto pack = [&](int idx) -> int { + return transform_int8_to_b4x8( + post_process(intermediate[0][idx]), + post_process(intermediate[1][idx]), + post_process(intermediate[2][idx]), + post_process(intermediate[3][idx]), + post_process(intermediate[4][idx]), + post_process(intermediate[5][idx]), + post_process(intermediate[6][idx]), + post_process(intermediate[7][idx])); + }; + dst_frag[0] = pack(0); + dst_frag[1] = pack(1); + } +}; + +template +struct Translayout<8, 8, SrcType, DnnSrcType_, DnnDstType_, same_scale, + typename enable_qtype_b4::type> { + using DnnSrcType = DnnSrcType_; + using DnnDstType = DnnDstType_; + static constexpr int row = 8; + static constexpr int col = 8; + static constexpr int size_nbits = 4; + static constexpr int col_in_type = col * size_nbits / (8 * sizeof(SrcType)); + static constexpr int elements_in_type = row * col_in_type; + static constexpr bool signedness = qtype_signedness::value; + using Fragment = array_wrapper; + static inline __device__ void trans( + Fragment& dst, const Fragment& src, + CudaPostProcess& post_process, + const char zero_point) { + int intermediate[8][8]; + transform_b4x8_to_int8( + intermediate[0], reinterpret_cast(src[0])); + transform_b4x8_to_int8( + intermediate[1], reinterpret_cast(src[1])); + transform_b4x8_to_int8( + intermediate[2], reinterpret_cast(src[2])); + transform_b4x8_to_int8( + intermediate[3], reinterpret_cast(src[3])); + transform_b4x8_to_int8( + intermediate[4], reinterpret_cast(src[4])); + transform_b4x8_to_int8( + intermediate[5], reinterpret_cast(src[5])); + transform_b4x8_to_int8( + intermediate[6], reinterpret_cast(src[6])); + transform_b4x8_to_int8( + intermediate[7], reinterpret_cast(src[7])); + int* dst_frag = reinterpret_cast(&dst); + auto pack = [&](int idx) { + return transform_int8_to_b4x8( + post_process(intermediate[0][idx]), + post_process(intermediate[1][idx]), + post_process(intermediate[2][idx]), + post_process(intermediate[3][idx]), + post_process(intermediate[4][idx]), + post_process(intermediate[5][idx]), + post_process(intermediate[6][idx]), + post_process(intermediate[7][idx])); + }; + dst_frag[0] = pack(0); + dst_frag[1] = pack(1); + dst_frag[2] = pack(2); + dst_frag[3] = pack(3); + dst_frag[4] = pack(4); + dst_frag[5] = pack(5); + dst_frag[6] = pack(6); + dst_frag[7] = pack(7); + } +}; + +template +struct Translayout<8, 2, SrcType, DnnSrcType_, DnnDstType_, same_scale, + typename enable_qtype_b4::type> { + using DnnSrcType = DnnSrcType_; + using DnnDstType = DnnDstType_; + static constexpr int row = 2; + static constexpr int col = 8; + static constexpr int size_nbits = 4; + static constexpr int col_in_type = col * size_nbits / (8 * sizeof(SrcType)); + static constexpr int elements_in_type = row * col_in_type; + static constexpr bool signedness = qtype_signedness::value; + using Fragment = array_wrapper; + static inline __device__ void trans( + Fragment& dst, const Fragment& src, + CudaPostProcess& post_process, + const char zero_point) { + int intermediate[2][8]; + transform_b4x8_to_int8( + intermediate[0], reinterpret_cast(src[0])); + transform_b4x8_to_int8( + intermediate[1], reinterpret_cast(src[1])); + int* dst_frag = reinterpret_cast(&dst); + dst_frag[0] = transform_int8_to_b4x8( + post_process(intermediate[0][0]), + post_process(intermediate[1][0]), + post_process(intermediate[0][1]), + post_process(intermediate[1][1]), + post_process(intermediate[0][2]), + post_process(intermediate[1][2]), + post_process(intermediate[0][3]), + post_process(intermediate[1][3])); + dst_frag[1] = transform_int8_to_b4x8( + post_process(intermediate[0][4]), + post_process(intermediate[1][4]), + post_process(intermediate[0][5]), + post_process(intermediate[1][5]), + post_process(intermediate[0][6]), + post_process(intermediate[1][6]), + post_process(intermediate[0][7]), + post_process(intermediate[1][7])); + } +}; + +} // namespace internal +} // namespace relayout_format +} // namespace cuda +} // namespace megdnn diff --git a/dnn/src/cuda/warp_perspective/forward.cu b/dnn/src/cuda/warp_perspective/forward.cu index aa43591c..88b28c13 100644 --- a/dnn/src/cuda/warp_perspective/forward.cu +++ b/dnn/src/cuda/warp_perspective/forward.cu @@ -176,60 +176,22 @@ __global__ void kern_general_nchw4(SrcVisitor src, const float* __restrict mat, } } - - - -template -MEGDNN_DEVICE __forceinline__ int transform_int8_to_bit4x8(int s0, int s1, - int s2, int s3, - int s4, int s5, - int s6, int s7); - -template <> -MEGDNN_DEVICE __forceinline__ int transform_int8_to_bit4x8( - int s0, int s1, int s2, int s3, int s4, int s5, int s6, int s7) { - return transform_int8_to_int4x8(s0, s1, s2, s3, s4, s5, s6, s7); -} - -template <> -MEGDNN_DEVICE __forceinline__ int transform_int8_to_bit4x8( - int s0, int s1, int s2, int s3, int s4, int s5, int s6, int s7) { - return transform_int8_to_uint4x8(s0, s1, s2, s3, s4, s5, s6, s7); -} -template -MEGDNN_DEVICE __forceinline__ void -transform_bit4x8_to_int8(int (&result)[8], const int& source); - -template <> -MEGDNN_DEVICE __forceinline__ void -transform_bit4x8_to_int8(int (&result)[8], const int& source){ - transform_int4x8_to_int8(result, source); -} - -template <> -MEGDNN_DEVICE __forceinline__ void -transform_bit4x8_to_int8(int (&result)[8], const int& source){ - transform_uint4x8_to_int8(result, source); -} - template MEGDNN_DEVICE __forceinline__ int pack_output_func( OutputConverter& output_converter, int (&s00)[8], int (&s01)[8], int (&s10)[8], int (&s11)[8], float w00, float w01, float w10, float w11) { - #define warp_perspective_transform(idx) \ - static_cast(output_converter(s00[idx] * w00 + \ - s01[idx] * w01 + \ - s10[idx] * w10 + \ - s11[idx] * w11) \ +#define warp_perspective_transform(idx) \ + static_cast(output_converter(s00[idx] * w00 + s01[idx] * w01 + \ + s10[idx] * w10 + s11[idx] * w11) \ .as_storage()) - return transform_int8_to_bit4x8( + return transform_int8_to_b4x8( warp_perspective_transform(0), warp_perspective_transform(1), warp_perspective_transform(2), warp_perspective_transform(3), warp_perspective_transform(4), warp_perspective_transform(5), warp_perspective_transform(6), warp_perspective_transform(7)); - #undef warp_perspective_transform +#undef warp_perspective_transform } template (s00, s[0].x); - transform_bit4x8_to_int8(s01, s[1].x); - transform_bit4x8_to_int8(s10, s[2].x); - transform_bit4x8_to_int8(s11, s[3].x); + transform_b4x8_to_int8(s00, s[0].x); + transform_b4x8_to_int8(s01, s[1].x); + transform_b4x8_to_int8(s10, s[2].x); + transform_b4x8_to_int8(s11, s[3].x); d.x = pack_output_func(output_converter, s00, s01, s10, s11, w00, w01, w10, w11); - transform_bit4x8_to_int8(s00, s[0].y); - transform_bit4x8_to_int8(s01, s[1].y); - transform_bit4x8_to_int8(s10, s[2].y); - transform_bit4x8_to_int8(s11, s[3].y); + transform_b4x8_to_int8(s00, s[0].y); + transform_b4x8_to_int8(s01, s[1].y); + transform_b4x8_to_int8(s10, s[2].y); + transform_b4x8_to_int8(s11, s[3].y); d.y = pack_output_func(output_converter, s00, s01, s10, s11, w00, w01, w10, w11); - transform_bit4x8_to_int8(s00, s[0].z); - transform_bit4x8_to_int8(s01, s[1].z); - transform_bit4x8_to_int8(s10, s[2].z); - transform_bit4x8_to_int8(s11, s[3].z); + transform_b4x8_to_int8(s00, s[0].z); + transform_b4x8_to_int8(s01, s[1].z); + transform_b4x8_to_int8(s10, s[2].z); + transform_b4x8_to_int8(s11, s[3].z); d.z = pack_output_func(output_converter, s00, s01, s10, s11, w00, w01, w10, w11); - transform_bit4x8_to_int8(s00, s[0].w); - transform_bit4x8_to_int8(s01, s[1].w); - transform_bit4x8_to_int8(s10, s[2].w); - transform_bit4x8_to_int8(s11, s[3].w); + transform_b4x8_to_int8(s00, s[0].w); + transform_b4x8_to_int8(s01, s[1].w); + transform_b4x8_to_int8(s10, s[2].w); + transform_b4x8_to_int8(s11, s[3].w); d.w = pack_output_func(output_converter, s00, s01, s10, s11, w00, w01, w10, w11); @@ -403,15 +365,7 @@ __global__ void kern_const_border_nchw4(SrcVisitor src, } } } -template -MEGDNN_DEVICE __forceinline__ static void transform_bit4x8_to_int8( - int (&result)[8], const int& source) { -#pragma unroll - for (int i = 0; i < 8; i++) { - result[i] = unpack_integer_4bits( - reinterpret_cast(source), (i << 2)); - } -} + template __global__ void kern_const_border_nchw64(SrcVisitor src, @@ -457,7 +411,7 @@ __global__ void kern_const_border_nchw64(SrcVisitor src, bool flag00 = okh0 && okw0, flag01 = okh0 && okw1, flag10 = okh1 && okw0, flag11 = okh1 && okw1; int8_t bval_4 = bval.as_storage() & 0xF; - int bval_8 = transform_int8_to_bit4x8( + int bval_8 = transform_int8_to_b4x8( bval_4, bval_4, bval_4, bval_4, bval_4, bval_4, bval_4, bval_4); int4 bval_int4; bval_int4.x = bval_8; @@ -488,31 +442,31 @@ __global__ void kern_const_border_nchw64(SrcVisitor src, s[3] = bval_int4; } - transform_bit4x8_to_int8(s00, s[0].x); - transform_bit4x8_to_int8(s01, s[1].x); - transform_bit4x8_to_int8(s10, s[2].x); - transform_bit4x8_to_int8(s11, s[3].x); + transform_b4x8_to_int8(s00, s[0].x); + transform_b4x8_to_int8(s01, s[1].x); + transform_b4x8_to_int8(s10, s[2].x); + transform_b4x8_to_int8(s11, s[3].x); d.x = pack_output_func(output_converter, s00, s01, s10, s11, w00, w01, w10, w11); - transform_bit4x8_to_int8(s00, s[0].y); - transform_bit4x8_to_int8(s01, s[1].y); - transform_bit4x8_to_int8(s10, s[2].y); - transform_bit4x8_to_int8(s11, s[3].y); + transform_b4x8_to_int8(s00, s[0].y); + transform_b4x8_to_int8(s01, s[1].y); + transform_b4x8_to_int8(s10, s[2].y); + transform_b4x8_to_int8(s11, s[3].y); d.y = pack_output_func(output_converter, s00, s01, s10, s11, w00, w01, w10, w11); - transform_bit4x8_to_int8(s00, s[0].z); - transform_bit4x8_to_int8(s01, s[1].z); - transform_bit4x8_to_int8(s10, s[2].z); - transform_bit4x8_to_int8(s11, s[3].z); + transform_b4x8_to_int8(s00, s[0].z); + transform_b4x8_to_int8(s01, s[1].z); + transform_b4x8_to_int8(s10, s[2].z); + transform_b4x8_to_int8(s11, s[3].z); d.z = pack_output_func(output_converter, s00, s01, s10, s11, w00, w01, w10, w11); - transform_bit4x8_to_int8(s00, s[0].w); - transform_bit4x8_to_int8(s01, s[1].w); - transform_bit4x8_to_int8(s10, s[2].w); - transform_bit4x8_to_int8(s11, s[3].w); + transform_b4x8_to_int8(s00, s[0].w); + transform_b4x8_to_int8(s01, s[1].w); + transform_b4x8_to_int8(s10, s[2].w); + transform_b4x8_to_int8(s11, s[3].w); d.w = pack_output_func(output_converter, s00, s01, s10, s11, w00, w01, w10, w11);