GitOrigin-RevId: ab86e66533
release-1.5
@@ -110,35 +110,33 @@ MEGDNN_DEVICE __forceinline__ static int unpack_integer_4bits(T storage, | |||||
return (result << (shift - bits)) >> shift; | return (result << (shift - bits)) >> shift; | ||||
} | } | ||||
MEGDNN_DEVICE __forceinline__ static void transform_int4x8_to_int8( | |||||
int (&result)[8], const int& source) { | |||||
#pragma unroll | |||||
for (int i = 0; i < 8; i++) { | |||||
result[i] = unpack_integer_4bits<true>( | |||||
reinterpret_cast<unsigned const&>(source), (i << 2)); | |||||
} | |||||
} | |||||
MEGDNN_DEVICE __forceinline__ static void transform_uint4x8_to_int8( | |||||
template <bool signedness> | |||||
MEGDNN_DEVICE __forceinline__ static void transform_b4x8_to_int8( | |||||
int (&result)[8], const int& source) { | int (&result)[8], const int& source) { | ||||
#pragma unroll | #pragma unroll | ||||
for (int i = 0; i < 8; i++) { | for (int i = 0; i < 8; i++) { | ||||
result[i] = unpack_integer_4bits<false>( | |||||
result[i] = unpack_integer_4bits<signedness>( | |||||
reinterpret_cast<unsigned const&>(source), (i << 2)); | reinterpret_cast<unsigned const&>(source), (i << 2)); | ||||
} | } | ||||
} | } | ||||
MEGDNN_DEVICE __forceinline__ static void transform_int4x2_to_int8( | |||||
template <bool signedness> | |||||
MEGDNN_DEVICE __forceinline__ static void transform_b4x2_to_int8( | |||||
int (&result)[2], const uint8_t& source) { | int (&result)[2], const uint8_t& source) { | ||||
result[0] = unpack_integer_4bits<true>(source, 0); | |||||
result[1] = unpack_integer_4bits<true>(source, 4); | |||||
result[0] = unpack_integer_4bits<signedness>(source, 0); | |||||
result[1] = unpack_integer_4bits<signedness>(source, 4); | |||||
} | } | ||||
MEGDNN_DEVICE __forceinline__ static void transform_uint4x2_to_int8( | |||||
int (&result)[2], const uint8_t& source) { | |||||
result[0] = unpack_integer_4bits<false>(source, 0); | |||||
result[1] = unpack_integer_4bits<false>(source, 4); | |||||
template <bool signedness> | |||||
MEGDNN_DEVICE __forceinline__ static int transform_int8_to_b4x8( | |||||
int s0, int s1, int s2, int s3, int s4, int s5, int s6, int s7) { | |||||
if (signedness) { | |||||
return transform_int8_to_int4x8(s0, s1, s2, s3, s4, s5, s6, s7); | |||||
} else { | |||||
return transform_int8_to_uint4x8(s0, s1, s2, s3, s4, s5, s6, s7); | |||||
} | |||||
} | } | ||||
} // namespace integer_subbyte | } // namespace integer_subbyte | ||||
} // namespace cuda | } // namespace cuda | ||||
} // namespace megdnn | } // namespace megdnn | ||||
@@ -0,0 +1,171 @@ | |||||
/** | |||||
* \file dnn/src/cuda/relayout_format/cuda_post_process.cuh | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
* implied. | |||||
*/ | |||||
#pragma once | |||||
#include "src/cuda/relayout_format/relayout_format.cuh" | |||||
namespace megdnn { | |||||
namespace cuda { | |||||
namespace relayout_format { | |||||
namespace internal { | |||||
template <typename SrcType, typename DstType, bool same_scale> | |||||
struct CudaPostProcess; | |||||
template <> | |||||
struct CudaPostProcess<dtype::Uint8, dtype::QuantizedS8, true> { | |||||
CudaPostProcess(float, uint8_t, float, uint8_t){}; | |||||
inline __device__ int8_t operator()(uint8_t val) { return val - 128; } | |||||
}; | |||||
template <> | |||||
struct CudaPostProcess<dtype::Uint8, dtype::QuantizedS8, false> { | |||||
CudaDTypeParamImpl<dt_qint8> m_dst_type_cvt; | |||||
CudaPostProcess(float, uint8_t, float dst_scale, uint8_t) { | |||||
m_dst_type_cvt = CudaDTypeParamImpl<dt_qint8>(dst_scale); | |||||
}; | |||||
inline __device__ int8_t operator()(uint8_t val) { | |||||
return m_dst_type_cvt.quantize((float)val - 128.f).as_int8(); | |||||
} | |||||
}; | |||||
template <> | |||||
struct CudaPostProcess<dtype::Quantized8Asymm, dtype::QuantizedS8, false> { | |||||
CudaDTypeParamImpl<dt_qint8> m_dst_type_cvt; | |||||
CudaDTypeParamImpl<dt_quint8> m_src_type_cvt; | |||||
CudaPostProcess(float src_scale, uint8_t src_zero_point, float dst_scale, | |||||
uint8_t) { | |||||
m_dst_type_cvt = CudaDTypeParamImpl<dt_qint8>(dst_scale); | |||||
m_src_type_cvt = | |||||
CudaDTypeParamImpl<dt_quint8>(src_scale, src_zero_point); | |||||
}; | |||||
inline __device__ int8_t operator()(uint8_t val) { | |||||
float med_var = m_src_type_cvt.dequantize(dt_quint8(val)); | |||||
return m_dst_type_cvt.quantize(med_var).as_int8(); | |||||
} | |||||
}; | |||||
template <> | |||||
struct CudaPostProcess<dtype::Quantized8Asymm, dtype::QuantizedS8, true> { | |||||
uint8_t m_src_zero_point = 0; | |||||
CudaPostProcess(float, uint8_t src_zero_point, float, uint8_t) { | |||||
m_src_zero_point = src_zero_point; | |||||
}; | |||||
inline __device__ int8_t operator()(uint8_t val) { | |||||
return val - m_src_zero_point; | |||||
} | |||||
}; | |||||
template <> | |||||
struct CudaPostProcess<dtype::QuantizedS8, dtype::QuantizedS8, false> { | |||||
CudaDTypeParamImpl<dt_qint8> m_dst_type_cvt; | |||||
CudaDTypeParamImpl<dt_qint8> m_src_type_cvt; | |||||
CudaPostProcess(float src_scale, uint8_t, float dst_scale, uint8_t) { | |||||
m_dst_type_cvt = CudaDTypeParamImpl<dt_qint8>(dst_scale); | |||||
m_src_type_cvt = CudaDTypeParamImpl<dt_qint8>(src_scale); | |||||
}; | |||||
inline __device__ int8_t operator()(int8_t val) { | |||||
float med_var = m_src_type_cvt.dequantize(dt_qint8(val)); | |||||
return m_dst_type_cvt.quantize(med_var).as_int8(); | |||||
} | |||||
}; | |||||
template <> | |||||
struct CudaPostProcess<dtype::QuantizedS8, dtype::QuantizedS8, true> { | |||||
CudaPostProcess(){}; | |||||
CudaPostProcess(float, uint8_t, float, uint8_t){}; | |||||
inline __device__ int8_t operator()(int8_t val) { return val; } | |||||
}; | |||||
template <> | |||||
struct CudaPostProcess<dtype::QuantizedS32, dtype::QuantizedS32, false> { | |||||
CudaDTypeParamImpl<dt_qint32> m_dst_type_cvt; | |||||
CudaDTypeParamImpl<dt_qint32> m_src_type_cvt; | |||||
CudaPostProcess(float src_scale, int, float dst_scale, int) { | |||||
m_dst_type_cvt = CudaDTypeParamImpl<dt_qint32>(dst_scale); | |||||
m_src_type_cvt = CudaDTypeParamImpl<dt_qint32>(src_scale); | |||||
}; | |||||
inline __device__ int operator()(int val) { | |||||
float med_var = m_src_type_cvt.dequantize(dt_qint32(val)); | |||||
return m_dst_type_cvt.quantize(med_var).as_int32(); | |||||
} | |||||
}; | |||||
template <> | |||||
struct CudaPostProcess<dtype::QuantizedS32, dtype::QuantizedS32, true> { | |||||
CudaPostProcess(float, int, float, int){}; | |||||
inline __device__ int operator()(int val) { return val; } | |||||
}; | |||||
template <> | |||||
struct CudaPostProcess<dtype::QuantizedS4, dtype::QuantizedS4, false> { | |||||
using SrcType = dtype::QuantizedS4; | |||||
using DstType = dtype::QuantizedS4; | |||||
CudaDTypeParamImpl<dt_qint4> m_dst_type_cvt; | |||||
CudaDTypeParamImpl<dt_qint4> m_src_type_cvt; | |||||
CudaPostProcess(float src_scale, uint8_t, float dst_scale, uint8_t) { | |||||
m_dst_type_cvt = CudaDTypeParamImpl<dt_qint4>(dst_scale); | |||||
m_src_type_cvt = CudaDTypeParamImpl<dt_qint4>(src_scale); | |||||
} | |||||
inline __device__ int8_t operator()(int8_t val) { | |||||
float intermediate = m_src_type_cvt.dequantize(dt_qint4(val)); | |||||
return m_dst_type_cvt.quantize(intermediate).as_int8(); | |||||
} | |||||
}; | |||||
template <> | |||||
struct CudaPostProcess<dtype::QuantizedS4, dtype::QuantizedS4, true> { | |||||
using SrcType = dtype::QuantizedS4; | |||||
using DstType = dtype::QuantizedS4; | |||||
CudaPostProcess(float, uint8_t, float, uint8_t){}; | |||||
inline __device__ int8_t operator()(int8_t val) { return val; } | |||||
}; | |||||
template <> | |||||
struct CudaPostProcess<dtype::Quantized4Asymm, dtype::Quantized4Asymm, false> { | |||||
using SrcType = dtype::Quantized4Asymm; | |||||
using DstType = dtype::Quantized4Asymm; | |||||
CudaDTypeParamImpl<dt_quint4> m_dst_type_cvt; | |||||
CudaDTypeParamImpl<dt_quint4> m_src_type_cvt; | |||||
CudaPostProcess(float src_scale, uint8_t src_zero_point, float dst_scale, | |||||
uint8_t dst_zero_point) { | |||||
m_dst_type_cvt = | |||||
CudaDTypeParamImpl<dt_quint4>(dst_scale, dst_zero_point); | |||||
m_src_type_cvt = | |||||
CudaDTypeParamImpl<dt_quint4>(src_scale, src_zero_point); | |||||
}; | |||||
inline __device__ uint8_t operator()(uint8_t val) { | |||||
float intermediate = m_src_type_cvt.dequantize(dt_quint4(val)); | |||||
return m_dst_type_cvt.quantize(intermediate).as_uint8(); | |||||
} | |||||
}; | |||||
template <> | |||||
struct CudaPostProcess<dtype::Quantized4Asymm, dtype::Quantized4Asymm, true> { | |||||
using SrcType = dtype::Quantized4Asymm; | |||||
using DstType = dtype::Quantized4Asymm; | |||||
uint8_t m_src_zero_point = 0; | |||||
uint8_t m_dst_zero_point = 0; | |||||
CudaPostProcess(float, uint8_t src_zero_point, float, | |||||
uint8_t dst_zero_point) { | |||||
m_src_zero_point = src_zero_point; | |||||
m_dst_zero_point = dst_zero_point; | |||||
}; | |||||
inline __device__ uint8_t operator()(uint8_t val) { | |||||
int result = val - m_src_zero_point + m_dst_zero_point; | |||||
result = result >= 0 ? result : 0; | |||||
result = result < 16 ? result : 15; | |||||
return static_cast<uint8_t>(result); | |||||
} | |||||
}; | |||||
} // namespace internal | |||||
} // namespace relayout_format | |||||
} // namespace cuda | |||||
} // namespace megdnn |
@@ -1,252 +0,0 @@ | |||||
/** | |||||
* \file dnn/src/cuda/relayout_format/helper.cuh | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
* implied. | |||||
*/ | |||||
namespace megdnn { | |||||
namespace cuda { | |||||
namespace relayout_format { | |||||
#define devfunc __forceinline__ __device__ | |||||
template <int size_nbits> | |||||
devfunc int make_zero(int zero_point); | |||||
template <> | |||||
devfunc int make_zero<4>(int zero_point) { | |||||
return transform_int8_to_uint4x8(zero_point, zero_point, zero_point, | |||||
zero_point, zero_point, zero_point, | |||||
zero_point, zero_point); | |||||
} | |||||
template <typename AccessType, int LoadBytes> | |||||
struct global_load_with_zero_point; | |||||
///////////////////////////////////////////////////////////////////////////////////////////////// | |||||
// | |||||
// Specializations | |||||
// | |||||
///////////////////////////////////////////////////////////////////////////////////////////////// | |||||
///////////////////////////////////////////////////////////////////////////////////////////////// | |||||
// The redundant mov PTX instruction is used to enforce the compiler to | |||||
// initialize data to zero before ld.global | |||||
template <typename AccessType> | |||||
struct global_load_with_zero_point<AccessType, 32> { | |||||
devfunc global_load_with_zero_point(AccessType& D, void const* ptr, | |||||
bool pred_guard, int zero_point) { | |||||
uint4* data = reinterpret_cast<uint4*>(&D); | |||||
asm volatile( | |||||
"{\n" | |||||
" .reg .pred p;\n" | |||||
" setp.ne.b32 p, %9, 0;\n" | |||||
" mov.b32 %0, %10;\n" | |||||
" mov.b32 %1, %10;\n" | |||||
" mov.b32 %2, %10;\n" | |||||
" mov.b32 %3, %10;\n" | |||||
" mov.b32 %4, %10;\n" | |||||
" mov.b32 %5, %10;\n" | |||||
" mov.b32 %6, %10;\n" | |||||
" mov.b32 %7, %10;\n" | |||||
" @p ld.global.v4.u32 {%0, %1, %2, %3}, [%8];\n" | |||||
" @p ld.global.v4.u32 {%4, %5, %6, %7}, [%11];\n" | |||||
"}\n" | |||||
: "=r"(data[0].x), "=r"(data[0].y), "=r"(data[0].z), | |||||
"=r"(data[0].w), "=r"(data[1].x), "=r"(data[1].y), | |||||
"=r"(data[1].z), "=r"(data[1].w) | |||||
: "l"(ptr), "r"((int)pred_guard), | |||||
"r"(reinterpret_cast<unsigned&>(zero_point)), | |||||
"l"(((uint8_t*)ptr) + 16)); | |||||
} | |||||
}; | |||||
template <typename AccessType> | |||||
struct global_load_with_zero_point<AccessType, 16> { | |||||
devfunc global_load_with_zero_point(AccessType& D, void const* ptr, | |||||
bool pred_guard, int zero_point) { | |||||
uint4& data = reinterpret_cast<uint4&>(D); | |||||
asm volatile( | |||||
"{\n" | |||||
" .reg .pred p;\n" | |||||
" setp.ne.b32 p, %5, 0;\n" | |||||
" mov.b32 %0, %6;\n" | |||||
" mov.b32 %1, %6;\n" | |||||
" mov.b32 %2, %6;\n" | |||||
" mov.b32 %3, %6;\n" | |||||
" @p ld.global.v4.u32 {%0, %1, %2, %3}, [%4];\n" | |||||
"}\n" | |||||
: "=r"(data.x), "=r"(data.y), "=r"(data.z), "=r"(data.w) | |||||
: "l"(ptr), "r"((int)pred_guard), | |||||
"r"(reinterpret_cast<unsigned&>(zero_point))); | |||||
} | |||||
}; | |||||
template <typename AccessType> | |||||
struct global_load_with_zero_point<AccessType, 8> { | |||||
devfunc global_load_with_zero_point(AccessType& D, void const* ptr, | |||||
bool pred_guard, int zero_point) { | |||||
uint2& data = reinterpret_cast<uint2&>(D); | |||||
asm volatile( | |||||
"{\n" | |||||
" .reg .pred p;\n" | |||||
" setp.ne.b32 p, %3, 0;\n" | |||||
" mov.b32 %0, %4;\n" | |||||
" mov.b32 %1, %4;\n" | |||||
" @p ld.global.v2.u32 {%0, %1}, [%2];\n" | |||||
"}\n" | |||||
: "=r"(data.x), "=r"(data.y) | |||||
: "l"(ptr), "r"((int)pred_guard), | |||||
"r"(reinterpret_cast<unsigned&>(zero_point))); | |||||
} | |||||
}; | |||||
template <typename AccessType> | |||||
struct global_load_with_zero_point<AccessType, 4> { | |||||
devfunc global_load_with_zero_point(AccessType& D, void const* ptr, | |||||
bool pred_guard, int zero_point) { | |||||
unsigned& data = reinterpret_cast<unsigned&>(D); | |||||
asm volatile( | |||||
"{\n" | |||||
" .reg .pred p;\n" | |||||
" setp.ne.b32 p, %2, 0;\n" | |||||
" mov.b32 %0, %3;\n" | |||||
" @p ld.global.u32 %0, [%1];\n" | |||||
"}\n" | |||||
: "=r"(data) | |||||
: "l"(ptr), "r"((int)pred_guard), | |||||
"r"(reinterpret_cast<unsigned&>(zero_point))); | |||||
} | |||||
}; | |||||
template <typename AccessType> | |||||
struct global_load_with_zero_point<AccessType, 1> { | |||||
devfunc global_load_with_zero_point(AccessType& D, void const* ptr, | |||||
bool pred_guard, int zero_point) { | |||||
if (pred_guard) | |||||
D = *(reinterpret_cast<AccessType const*>(ptr)); | |||||
else { | |||||
unsigned uv = reinterpret_cast<unsigned&>(zero_point); | |||||
uint8_t& data = reinterpret_cast<uint8_t&>(D); | |||||
data = uv & 0xff; | |||||
} | |||||
} | |||||
}; | |||||
///////////////////////////////////////////////////////////////////////////////////////////////// | |||||
template < | |||||
/// Fragment type to store loaded data | |||||
typename AccessType, | |||||
/// The bytes of loading | |||||
int LoadBytes> | |||||
struct global_store; | |||||
///////////////////////////////////////////////////////////////////////////////////////////////// | |||||
// | |||||
// Specializations | |||||
// | |||||
///////////////////////////////////////////////////////////////////////////////////////////////// | |||||
template <typename AccessType> | |||||
struct global_store<AccessType, 32> { | |||||
devfunc global_store(AccessType const& D, void* ptr, bool pred_guard) { | |||||
uint4 const* data = reinterpret_cast<uint4 const*>(&D); | |||||
asm volatile( | |||||
"{\n" | |||||
" .reg .pred p;\n" | |||||
" setp.ne.b32 p, %5, 0;\n" | |||||
" @p st.global.v4.u32 [%0], {%1, %2, %3, %4};\n" | |||||
" @p st.global.v4.u32 [%6], {%7, %8, %9, %10};\n" | |||||
"}\n" | |||||
: | |||||
: "l"(ptr), "r"(data[0].x), "r"(data[0].y), "r"(data[0].z), | |||||
"r"(data[0].w), "r"((int)pred_guard), | |||||
"l"(((uint8_t*)ptr) + 16), "r"(data[1].x), "r"(data[1].y), | |||||
"r"(data[1].z), "r"(data[1].w)); | |||||
} | |||||
}; | |||||
template <typename AccessType> | |||||
struct global_store<AccessType, 16> { | |||||
devfunc global_store(AccessType const& D, void* ptr, bool pred_guard) { | |||||
uint4 const& data = reinterpret_cast<uint4 const&>(D); | |||||
asm volatile( | |||||
"{\n" | |||||
" .reg .pred p;\n" | |||||
" setp.ne.b32 p, %5, 0;\n" | |||||
" @p st.global.v4.u32 [%0], {%1, %2, %3, %4};\n" | |||||
"}\n" | |||||
: | |||||
: "l"(ptr), "r"(data.x), "r"(data.y), "r"(data.z), "r"(data.w), | |||||
"r"((int)pred_guard)); | |||||
} | |||||
}; | |||||
template <typename AccessType> | |||||
struct global_store<AccessType, 8> { | |||||
devfunc global_store(AccessType const& D, void* ptr, bool pred_guard) { | |||||
uint2 const& data = reinterpret_cast<uint2 const&>(D); | |||||
asm volatile( | |||||
"{\n" | |||||
" .reg .pred p;\n" | |||||
" setp.ne.b32 p, %3, 0;\n" | |||||
" @p st.global.v2.u32 [%0], {%1, %2};\n" | |||||
"}\n" | |||||
: | |||||
: "l"(ptr), "r"(data.x), "r"(data.y), "r"((int)pred_guard)); | |||||
} | |||||
}; | |||||
template <typename AccessType> | |||||
struct global_store<AccessType, 4> { | |||||
devfunc global_store(AccessType const& D, void* ptr, bool pred_guard) { | |||||
uint32_t const& data = reinterpret_cast<uint32_t const&>(D); | |||||
asm volatile( | |||||
"{\n" | |||||
" .reg .pred p;\n" | |||||
" setp.ne.b32 p, %2, 0;\n" | |||||
" @p st.global.u32 [%0], %1;\n" | |||||
"}\n" | |||||
: | |||||
: "l"(ptr), "r"(data), "r"((int)pred_guard)); | |||||
} | |||||
}; | |||||
template <typename AccessType> | |||||
struct global_store<AccessType, 2> { | |||||
devfunc global_store(AccessType const& D, void* ptr, bool pred_guard) { | |||||
uint16_t const& data = reinterpret_cast<uint16_t const&>(D); | |||||
asm volatile( | |||||
"{\n" | |||||
" .reg .pred p;\n" | |||||
" setp.ne.b32 p, %2, 0;\n" | |||||
" @p st.global.u16 [%0], %1;\n" | |||||
"}\n" | |||||
: | |||||
: "l"(ptr), "h"(data), "r"((int)pred_guard)); | |||||
} | |||||
}; | |||||
template <typename AccessType> | |||||
struct global_store<AccessType, 1> { | |||||
devfunc global_store(AccessType const& D, void* ptr, bool pred_guard) { | |||||
if (pred_guard) | |||||
*(reinterpret_cast<AccessType*>(ptr)) = D; | |||||
} | |||||
}; | |||||
#undef devfunc | |||||
} // namespace relayout_format | |||||
} // namespace cuda | |||||
} // namespace megdnn |
@@ -39,6 +39,20 @@ void relayout_format_cuda_nchwx_nchw(const TensorND& src, const TensorND& dst, | |||||
const uint8_t src_zero_point = 0, | const uint8_t src_zero_point = 0, | ||||
const uint8_t dst_zero_point = 0); | const uint8_t dst_zero_point = 0); | ||||
void relayout_format_cuda_nchw_nhwc(const TensorND& src, const TensorND& dst, | |||||
const cudaStream_t& stream, | |||||
const float src_scale = 1.f, | |||||
const float dst_scale = 1.f, | |||||
const uint8_t src_zero_point = 0, | |||||
const uint8_t dst_zero_point = 0); | |||||
void relayout_format_cuda_nhwc_nchw(const TensorND& src, const TensorND& dst, | |||||
const cudaStream_t& stream, | |||||
const float src_scale = 1.f, | |||||
const float dst_scale = 1.f, | |||||
const uint8_t src_zero_point = 0, | |||||
const uint8_t dst_zero_point = 0); | |||||
void relayout_format_cuda_nchw_nchw4_weight(const TensorND& src, | void relayout_format_cuda_nchw_nchw4_weight(const TensorND& src, | ||||
const TensorND& dst, | const TensorND& dst, | ||||
const cudaStream_t& stream); | const cudaStream_t& stream); | ||||
@@ -0,0 +1,346 @@ | |||||
/** | |||||
* \file dnn/src/cuda/relayout_format/relayout_format_kern.cuh | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
* implied. | |||||
*/ | |||||
#pragma once | |||||
#include "src/cuda/int_fastdiv.cuh" | |||||
#include "src/cuda/memory_utils.cuh" | |||||
#include "src/cuda/relayout_format/translayout.cuh" | |||||
namespace megdnn { | |||||
namespace cuda { | |||||
namespace relayout_format { | |||||
namespace internal { | |||||
using namespace memory; | |||||
template <typename Type_, int pack_size_, int chan_blk_, int width_, | |||||
int size_nbits_> | |||||
class TensorIteratorOverChannel { | |||||
public: | |||||
using Type = Type_; | |||||
static constexpr int pack_size = pack_size_; | |||||
static constexpr int chan_blk = chan_blk_; | |||||
static constexpr int width = width_; | |||||
static constexpr int size_nbits = size_nbits_; | |||||
static constexpr int elements_in_type = | |||||
chan_blk * width * size_nbits / (8 * sizeof(Type)); | |||||
static constexpr int lane_size_in_type = | |||||
(width * pack_size * size_nbits) / (8 * sizeof(Type)); | |||||
static constexpr int pack_size_in_type = | |||||
(pack_size * size_nbits) >= (8 * sizeof(Type)) | |||||
? (pack_size * size_nbits / (8 * sizeof(Type))) | |||||
: (width * pack_size * size_nbits / (8 * sizeof(Type))); | |||||
static constexpr int pack_size_in_byte = pack_size_in_type * sizeof(Type); | |||||
using AccessType = array_wrapper<Type, pack_size_in_type>; | |||||
using Fragment = array_wrapper<Type, elements_in_type>; | |||||
MEGDNN_HOST TensorIteratorOverChannel() | |||||
: pointer{nullptr}, chan_stride_in_elements{0}, channel{0} {} | |||||
MEGDNN_HOST TensorIteratorOverChannel(Type* pointer_, | |||||
int chan_stride_in_elements_, | |||||
int channel_, int, int) | |||||
: pointer{pointer_}, | |||||
chan_stride_in_elements{chan_stride_in_elements_}, | |||||
channel{channel_} {} | |||||
MEGDNN_DEVICE __forceinline__ void initialize(int c_idx, int hw_idx) { | |||||
pointer += (c_idx / pack_size) * chan_stride_in_elements + | |||||
hw_idx * pack_size * size_nbits / (8 * sizeof(Type)); | |||||
channel -= c_idx; | |||||
} | |||||
MEGDNN_DEVICE __forceinline__ void add_pointer_offset( | |||||
size_t offset_in_type) { | |||||
pointer += offset_in_type; | |||||
} | |||||
MEGDNN_DEVICE __forceinline__ void load(Fragment& frag, int zero_point) { | |||||
AccessType* frag_ptr = reinterpret_cast<AccessType*>(&frag); | |||||
Type* pointer_ = pointer; | |||||
#pragma unroll | |||||
for (int i = 0; i < chan_blk; i += pack_size) { | |||||
#pragma unroll | |||||
for (int j = 0; j < lane_size_in_type / pack_size_in_type; j++) { | |||||
int frag_idx = i / pack_size * | |||||
(lane_size_in_type / pack_size_in_type) + | |||||
j; | |||||
bool guard = i < channel; | |||||
global_load<AccessType, pack_size_in_byte>( | |||||
frag_ptr[frag_idx], | |||||
reinterpret_cast<void*>(pointer_ + | |||||
j * pack_size_in_type), | |||||
guard, zero_point); | |||||
} | |||||
pointer_ += chan_stride_in_elements; | |||||
} | |||||
} | |||||
MEGDNN_DEVICE __forceinline__ void store(const Fragment& frag) { | |||||
const AccessType* frag_ptr = reinterpret_cast<const AccessType*>(&frag); | |||||
Type* pointer_ = pointer; | |||||
#pragma unroll | |||||
for (int i = 0; i < chan_blk; i += pack_size) { | |||||
#pragma unroll | |||||
for (int j = 0; j < lane_size_in_type / pack_size_in_type; j++) { | |||||
int frag_idx = i / pack_size * | |||||
(lane_size_in_type / pack_size_in_type) + | |||||
j; | |||||
bool guard = i < channel; | |||||
global_store<AccessType, pack_size_in_byte>( | |||||
frag_ptr[frag_idx], | |||||
reinterpret_cast<void*>(pointer_ + | |||||
j * pack_size_in_type), | |||||
guard); | |||||
} | |||||
pointer_ += chan_stride_in_elements; | |||||
} | |||||
} | |||||
MEGDNN_DEVICE __forceinline__ void advance() { | |||||
pointer += (chan_blk / pack_size) * chan_stride_in_elements; | |||||
channel -= chan_blk; | |||||
} | |||||
private: | |||||
Type* pointer; | |||||
int chan_stride_in_elements; | |||||
int channel; | |||||
}; | |||||
template <typename Type_, int pack_size_, int chan_blk_, int width_, | |||||
int size_nbits_> | |||||
class MaskedTensorIteratorOverChannel { | |||||
public: | |||||
using Type = Type_; | |||||
static constexpr int pack_size = pack_size_; | |||||
static constexpr int chan_blk = chan_blk_; | |||||
static constexpr int width = width_; | |||||
static constexpr int size_nbits = size_nbits_; | |||||
static constexpr int elements_in_type = | |||||
chan_blk * width * size_nbits / (8 * sizeof(Type)); | |||||
static constexpr int lane_size_in_type = | |||||
(width * pack_size * size_nbits) / (8 * sizeof(Type)); | |||||
static constexpr int pack_size_in_type = | |||||
(pack_size * size_nbits) >= (8 * sizeof(Type)) | |||||
? (pack_size * size_nbits / (8 * sizeof(Type))) | |||||
: (width * pack_size * size_nbits / (8 * sizeof(Type))); | |||||
static constexpr int pack_size_in_byte = pack_size_in_type * sizeof(Type); | |||||
static constexpr int accesses = elements_in_type / pack_size_in_type; | |||||
static constexpr int mask_size = (accesses + 32 - 1) / 32; | |||||
using AccessType = array_wrapper<Type, pack_size_in_type>; | |||||
using Fragment = array_wrapper<Type, elements_in_type>; | |||||
MEGDNN_HOST MaskedTensorIteratorOverChannel() | |||||
: pointer{nullptr}, chan_stride_in_elements{0}, channel{0} {} | |||||
MEGDNN_HOST MaskedTensorIteratorOverChannel(Type* pointer_, | |||||
int chan_stride_in_elements_, | |||||
int channel_, int bound_, | |||||
int div_) | |||||
: pointer{pointer_}, | |||||
chan_stride_in_elements{chan_stride_in_elements_}, | |||||
channel{channel_}, | |||||
bound{bound_}, | |||||
div{uint32_t(div_)} {} | |||||
MEGDNN_DEVICE __forceinline__ void initialize(int c_idx, int hw_idx) { | |||||
pointer += (c_idx / pack_size) * chan_stride_in_elements; | |||||
channel -= c_idx; | |||||
int w[lane_size_in_type / pack_size_in_type]; | |||||
#pragma unroll | |||||
for (int i = 0; i < mask_size; ++i) { | |||||
mask[i] = 0; | |||||
} | |||||
#pragma unroll | |||||
for (int j = 0; j < lane_size_in_type / pack_size_in_type; j++) { | |||||
int offset = hw_idx + j; | |||||
int h = (int)((uint32_t)(offset) / div); | |||||
w[j] = (int)((uint32_t)(offset) % div); | |||||
stride[j] = (h * bound + w[j]) * pack_size * size_nbits / | |||||
(8 * sizeof(Type)); | |||||
} | |||||
#pragma unroll | |||||
for (int i = 0; i < chan_blk; i += pack_size) { | |||||
#pragma unroll | |||||
for (int j = 0; j < lane_size_in_type / pack_size_in_type; j++) { | |||||
bool guard = (i < channel) && (w[j] < bound); | |||||
int index = (i / pack_size) * | |||||
(lane_size_in_type / pack_size_in_type) + | |||||
j; | |||||
int mask_index = (index >> 5); | |||||
int mask_shift = (index & 0x1f); | |||||
mask[mask_index] |= (guard << mask_shift); | |||||
} | |||||
} | |||||
} | |||||
MEGDNN_DEVICE __forceinline__ void add_pointer_offset( | |||||
size_t offset_in_type) { | |||||
pointer += offset_in_type; | |||||
} | |||||
MEGDNN_DEVICE __forceinline__ void load(Fragment& frag, int zero_point) { | |||||
AccessType* frag_ptr = reinterpret_cast<AccessType*>(&frag); | |||||
Type* pointer_ = pointer; | |||||
#pragma unroll | |||||
for (int i = 0; i < chan_blk; i += pack_size) { | |||||
#pragma unroll | |||||
for (int j = 0; j < lane_size_in_type / pack_size_in_type; j++) { | |||||
int frag_idx = i / pack_size * | |||||
(lane_size_in_type / pack_size_in_type) + | |||||
j; | |||||
int mask_index = (frag_idx >> 5); | |||||
int mask_shift = (frag_idx & 0x1f); | |||||
bool guard = (mask[mask_index] & (1 << mask_shift)); | |||||
global_load<AccessType, pack_size_in_byte>( | |||||
frag_ptr[frag_idx], | |||||
reinterpret_cast<void*>(pointer_ + stride[j]), guard, | |||||
zero_point); | |||||
} | |||||
pointer_ += chan_stride_in_elements; | |||||
} | |||||
} | |||||
MEGDNN_DEVICE __forceinline__ void store(const Fragment& frag) { | |||||
const AccessType* frag_ptr = reinterpret_cast<const AccessType*>(&frag); | |||||
Type* pointer_ = pointer; | |||||
#pragma unroll | |||||
for (int i = 0; i < chan_blk; i += pack_size) { | |||||
#pragma unroll | |||||
for (int j = 0; j < lane_size_in_type / pack_size_in_type; j++) { | |||||
int frag_idx = i / pack_size * | |||||
(lane_size_in_type / pack_size_in_type) + | |||||
j; | |||||
int mask_index = (frag_idx >> 5); | |||||
int mask_shift = (frag_idx & 0x1f); | |||||
bool guard = (mask[mask_index] & (1 << mask_shift)); | |||||
global_store<AccessType, pack_size_in_byte>( | |||||
frag_ptr[frag_idx], | |||||
reinterpret_cast<void*>(pointer_ + stride[j]), guard); | |||||
} | |||||
pointer_ += chan_stride_in_elements; | |||||
} | |||||
} | |||||
MEGDNN_DEVICE __forceinline__ void advance() { | |||||
pointer += (chan_blk / pack_size) * chan_stride_in_elements; | |||||
channel -= chan_blk; | |||||
} | |||||
private: | |||||
Type* pointer; | |||||
int chan_stride_in_elements; | |||||
int channel; | |||||
int bound; | |||||
Uint32Fastdiv div; | |||||
uint32_t mask[mask_size]; | |||||
size_t stride[lane_size_in_type / pack_size_in_type]; | |||||
}; | |||||
template <bool padding_, typename Type_, int pack_size_, int chan_blk_, | |||||
int width_, int size_nbits_> | |||||
struct TensorIteratorPolicy; | |||||
template <typename Type_, int pack_size_, int chan_blk_, int width_, | |||||
int size_nbits_> | |||||
struct TensorIteratorPolicy<true, Type_, pack_size_, chan_blk_, width_, | |||||
size_nbits_> { | |||||
using TensorIterator = | |||||
MaskedTensorIteratorOverChannel<Type_, pack_size_, chan_blk_, | |||||
width_, size_nbits_>; | |||||
}; | |||||
template <typename Type_, int pack_size_, int chan_blk_, int width_, | |||||
int size_nbits_> | |||||
struct TensorIteratorPolicy<false, Type_, pack_size_, chan_blk_, width_, | |||||
size_nbits_> { | |||||
using TensorIterator = | |||||
TensorIteratorOverChannel<Type_, pack_size_, chan_blk_, width_, | |||||
size_nbits_>; | |||||
}; | |||||
template <typename SrcIterator_, typename DstIterator_, typename Transpose_, | |||||
typename CudaPostProcess_> | |||||
struct RelayoutProblem { | |||||
using SrcIterator = SrcIterator_; | |||||
using DstIterator = DstIterator_; | |||||
using Transpose = Transpose_; | |||||
using CudaPostProcess = CudaPostProcess_; | |||||
MEGDNN_STATIC_ASSERT(SrcIterator::chan_blk == DstIterator::chan_blk, | |||||
"channel block mismatch"); | |||||
MEGDNN_STATIC_ASSERT(SrcIterator::width == DstIterator::width, | |||||
"width block mismatch"); | |||||
MEGDNN_STATIC_ASSERT(SrcIterator::size_nbits == DstIterator::size_nbits, | |||||
"size in bits of elements mismatch"); | |||||
static constexpr int pack_chan = SrcIterator::chan_blk; | |||||
static constexpr int pack_width = SrcIterator::width; | |||||
using DnnSrcType = typename CudaPostProcess::SrcType; | |||||
using DnnDstType = typename CudaPostProcess::DstType; | |||||
struct Param { | |||||
SrcIterator src_iterator; | |||||
DstIterator dst_iterator; | |||||
CudaPostProcess post_process; | |||||
int n_stride_src; | |||||
int n_stride_dst; | |||||
int batch_size; | |||||
int channels; | |||||
int hw; | |||||
int zero_point; | |||||
MEGDNN_HOST MEGDNN_DEVICE Param(SrcIterator src_iterator_, | |||||
DstIterator dst_iterator_, | |||||
CudaPostProcess post_process_, | |||||
int n_stride_src_, int n_stride_dst_, | |||||
int batch_size_, int channels_, int hw_, | |||||
int zero_point_) | |||||
: src_iterator{src_iterator_}, | |||||
dst_iterator{dst_iterator_}, | |||||
post_process{post_process_}, | |||||
n_stride_src{n_stride_src_}, | |||||
n_stride_dst{n_stride_dst_}, | |||||
batch_size{batch_size_}, | |||||
channels{channels_}, | |||||
hw{hw_}, | |||||
zero_point{zero_point_} {} | |||||
}; | |||||
}; | |||||
template <typename RelayoutProblem_> | |||||
__global__ void relayout_kern(typename RelayoutProblem_::Param param) { | |||||
using SrcIterator = typename RelayoutProblem_::SrcIterator; | |||||
using DstIterator = typename RelayoutProblem_::DstIterator; | |||||
static constexpr int pack_chan = RelayoutProblem_::pack_chan; | |||||
static constexpr int pack_width = RelayoutProblem_::pack_width; | |||||
const int thread_idx = blockIdx.x * blockDim.x + threadIdx.x; | |||||
const int thread_offset = thread_idx * pack_width; | |||||
const int hw_idx = (thread_offset % param.hw); | |||||
const int nc_blks = thread_offset / param.hw; | |||||
const int c_blks = (param.channels + pack_chan - 1) / pack_chan; | |||||
const int n_idx = nc_blks / c_blks; | |||||
const int c_blk_idx = nc_blks % c_blks; | |||||
const int c_idx = c_blk_idx * pack_chan; | |||||
if (n_idx < param.batch_size) { | |||||
const int src_offset = n_idx * param.n_stride_src; | |||||
const int dst_offset = n_idx * param.n_stride_dst; | |||||
param.src_iterator.add_pointer_offset(src_offset); | |||||
param.dst_iterator.add_pointer_offset(dst_offset); | |||||
param.src_iterator.initialize(c_idx, hw_idx); | |||||
param.dst_iterator.initialize(c_idx, hw_idx); | |||||
typename SrcIterator::Fragment src_frag; | |||||
typename DstIterator::Fragment dst_frag; | |||||
int zp = make_zero<SrcIterator::size_nbits>(param.zero_point); | |||||
param.src_iterator.load(src_frag, zp); | |||||
RelayoutProblem_::Transpose::trans( | |||||
reinterpret_cast<typename SrcIterator::Fragment&>(dst_frag), | |||||
src_frag, param.post_process); | |||||
param.dst_iterator.store(dst_frag); | |||||
} | |||||
} | |||||
} // namespace internal | |||||
} // namespace relayout_format | |||||
} // namespace cuda | |||||
} // namespace megdnn |
@@ -0,0 +1,128 @@ | |||||
/** | |||||
* \file dnn/src/cuda/relayout_format/relayout_format_utils.cuh | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
* implied. | |||||
*/ | |||||
#pragma once | |||||
#include "src/cuda/integer_subbyte_utils.cuh" | |||||
#include "src/cuda/relayout_format/relayout_format.cuh" | |||||
namespace megdnn { | |||||
namespace cuda { | |||||
namespace relayout_format { | |||||
namespace internal { | |||||
template <typename cype, int pack_w, typename enable = void> | |||||
struct DTypeRWHelper; | |||||
template <typename ctype> | |||||
struct DTypeRWHelper< | |||||
ctype, 1, | |||||
typename std::enable_if<std::is_same<ctype, dt_qint8>::value || | |||||
std::is_same<ctype, dt_quint8>::value || | |||||
std::is_same<ctype, dt_uint8>::value>::type> { | |||||
using InnerDtype = char; | |||||
using DstDtype = char4; | |||||
}; | |||||
template <typename ctype> | |||||
struct DTypeRWHelper< | |||||
ctype, 4, | |||||
typename std::enable_if<std::is_same<ctype, dt_qint8>::value || | |||||
std::is_same<ctype, dt_quint8>::value || | |||||
std::is_same<ctype, dt_uint8>::value>::type> { | |||||
using InnerDtype = char4; | |||||
using DstDtype = char4; | |||||
}; | |||||
template <> | |||||
struct DTypeRWHelper<dt_qint32, 1> { | |||||
using InnerDtype = int; | |||||
using DstDtype = int4; | |||||
}; | |||||
template <> | |||||
struct DTypeRWHelper<dt_qint32, 4> { | |||||
using InnerDtype = int4; | |||||
using DstDtype = int4; | |||||
}; | |||||
template <typename ctype> | |||||
struct DTypeRWHelper< | |||||
ctype, 2, | |||||
typename std::enable_if<std::is_same<ctype, dt_qint4>::value || | |||||
std::is_same<ctype, dt_quint4>::value>::type> { | |||||
using InnerDtype = char; | |||||
using DstDtype = array_wrapper<uint8_t, 32>; | |||||
}; | |||||
template <typename ctype> | |||||
struct DTypeRWHelper< | |||||
ctype, 8, | |||||
typename std::enable_if<std::is_same<ctype, dt_qint4>::value || | |||||
std::is_same<ctype, dt_quint4>::value>::type> { | |||||
using InnerDtype = unsigned; | |||||
using DstDtype = array_wrapper<uint8_t, 32>; | |||||
}; | |||||
template <typename DstType> | |||||
inline __device__ DstType make_zero_pad(const uint8_t zero_point) { | |||||
return zero_point; | |||||
} | |||||
template <> | |||||
inline __device__ char4 make_zero_pad<char4>(const uint8_t zero_point) { | |||||
char izp = reinterpret_cast<const char&>(zero_point); | |||||
return {izp, izp, izp, izp}; | |||||
} | |||||
template <> | |||||
inline __device__ int4 make_zero_pad<int4>(const uint8_t zero_point) { | |||||
return {zero_point, zero_point, zero_point, zero_point}; | |||||
} | |||||
template <int size_nbits> | |||||
inline __device__ int make_zero(int zero_point); | |||||
template <> | |||||
inline __device__ int make_zero<4>(int zero_point) { | |||||
return integer_subbyte::transform_int8_to_uint4x8( | |||||
zero_point, zero_point, zero_point, zero_point, zero_point, | |||||
zero_point, zero_point, zero_point); | |||||
} | |||||
template <typename DstDtype> | |||||
inline __device__ void write_helper(DstDtype* ptr, DstDtype val) { | |||||
*ptr = val; | |||||
} | |||||
template <> | |||||
inline __device__ void write_helper<char4>(char4* ptr, char4 val) { | |||||
int32_t* rel_ptr = (int32_t*)ptr; | |||||
*rel_ptr = *(int32_t*)(&val); | |||||
} | |||||
template <> | |||||
inline __device__ void write_helper<array_wrapper<uint8_t, 32>>( | |||||
array_wrapper<uint8_t, 32>* ptr, array_wrapper<uint8_t, 32> val) { | |||||
uint4 const* data = reinterpret_cast<uint4 const*>(&val); | |||||
void* ptr_ = reinterpret_cast<void*>(ptr); | |||||
asm volatile( | |||||
"{\n" | |||||
" st.global.v4.u32 [%0], {%1, %2, %3, %4};\n" | |||||
" st.global.v4.u32 [%5], {%6, %7, %8, %9};\n" | |||||
"}\n" | |||||
: | |||||
: "l"(ptr_), "r"(data[0].x), "r"(data[0].y), "r"(data[0].z), | |||||
"r"(data[0].w), "l"(((uint8_t*)ptr_) + 16), "r"(data[1].x), | |||||
"r"(data[1].y), "r"(data[1].z), "r"(data[1].w)); | |||||
} | |||||
} // namespace internal | |||||
} // namespace relayout_format | |||||
} // namespace cuda | |||||
} // namespace megdnn |
@@ -0,0 +1,537 @@ | |||||
/** | |||||
* \file dnn/src/cuda/relayout_format/translayout.cuh | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
* implied. | |||||
*/ | |||||
#pragma once | |||||
#include "src/cuda/integer_subbyte_utils.cuh" | |||||
#include "src/cuda/relayout_format/cuda_post_process.cuh" | |||||
#include "src/cuda/relayout_format/relayout_format.cuh" | |||||
#include "src/cuda/relayout_format/relayout_format_utils.cuh" | |||||
namespace megdnn { | |||||
namespace cuda { | |||||
namespace relayout_format { | |||||
namespace internal { | |||||
using namespace integer_subbyte; | |||||
template <typename dt> | |||||
struct qtype_signedness; | |||||
template <> | |||||
struct qtype_signedness<dtype::QuantizedS4> { | |||||
static constexpr bool value = true; | |||||
}; | |||||
template <> | |||||
struct qtype_signedness<dtype::Quantized4Asymm> { | |||||
static constexpr bool value = false; | |||||
}; | |||||
template <typename dt_src, typename dt_dst> | |||||
struct enable_qtype_b4 { | |||||
static constexpr bool val_src = | |||||
std::is_same<dt_src, dtype::QuantizedS4>::value || | |||||
std::is_same<dt_src, dtype::Quantized4Asymm>::value; | |||||
static constexpr bool val_dst = | |||||
std::is_same<dt_dst, dtype::QuantizedS4>::value || | |||||
std::is_same<dt_dst, dtype::Quantized4Asymm>::value; | |||||
using type = typename std::enable_if<std::is_same<dt_src, dt_dst>::value && | |||||
val_src && val_dst>::type; | |||||
}; | |||||
// The input fragment is stored in RowMajor order. The translayout operator | |||||
// performs a transpose operation on the input fragment, and produces a | |||||
// reordered fragment, i.e. a fragment stored in ColumnMajor order. | |||||
template <int col, int row, typename SrcType, typename DnnSrcType, | |||||
typename DnnDstType, bool same_scale, typename enable = void> | |||||
struct Translayout; | |||||
// partial specialization for translayout operator for qint8 and quint8 | |||||
template <typename SrcType, typename DnnSrcType, typename DnnDstType, | |||||
bool same_scale> | |||||
struct Translayout<1, 4, SrcType, DnnSrcType, DnnDstType, same_scale> { | |||||
using InnerDtype = | |||||
typename DTypeRWHelper<typename DTypeTrait<DnnSrcType>::ctype, | |||||
1>::InnerDtype; | |||||
using DstDtype = | |||||
typename DTypeRWHelper<typename DTypeTrait<DnnSrcType>::ctype, | |||||
1>::DstDtype; | |||||
static inline __device__ void trans( | |||||
DstDtype (&dst_width)[1], InnerDtype (&read_channel)[4], | |||||
CudaPostProcess<DnnSrcType, DnnDstType, same_scale>& post_process, | |||||
const char zero_point) { | |||||
dst_width[0].x = post_process(read_channel[0]); | |||||
dst_width[0].y = post_process(read_channel[1]); | |||||
dst_width[0].z = post_process(read_channel[2]); | |||||
dst_width[0].w = post_process(read_channel[3]); | |||||
} | |||||
}; | |||||
template <typename SrcType, typename DnnSrcType, typename DnnDstType, | |||||
bool same_scale> | |||||
struct Translayout<4, 4, SrcType, DnnSrcType, DnnDstType, same_scale> { | |||||
using InnerDtype = | |||||
typename DTypeRWHelper<typename DTypeTrait<DnnSrcType>::ctype, | |||||
4>::InnerDtype; | |||||
using DstDtype = | |||||
typename DTypeRWHelper<typename DTypeTrait<DnnSrcType>::ctype, | |||||
4>::DstDtype; | |||||
static inline __device__ void trans( | |||||
DstDtype (&dst_width)[4], InnerDtype (&read_channel)[4], | |||||
CudaPostProcess<DnnSrcType, DnnDstType, same_scale>& post_process, | |||||
const char zero_point) { | |||||
dst_width[0].x = post_process(read_channel[0].x); | |||||
dst_width[0].y = post_process(read_channel[1].x); | |||||
dst_width[0].z = post_process(read_channel[2].x); | |||||
dst_width[0].w = post_process(read_channel[3].x); | |||||
dst_width[1].x = post_process(read_channel[0].y); | |||||
dst_width[1].y = post_process(read_channel[1].y); | |||||
dst_width[1].z = post_process(read_channel[2].y); | |||||
dst_width[1].w = post_process(read_channel[3].y); | |||||
dst_width[2].x = post_process(read_channel[0].z); | |||||
dst_width[2].y = post_process(read_channel[1].z); | |||||
dst_width[2].z = post_process(read_channel[2].z); | |||||
dst_width[2].w = post_process(read_channel[3].z); | |||||
dst_width[3].x = post_process(read_channel[0].w); | |||||
dst_width[3].y = post_process(read_channel[1].w); | |||||
dst_width[3].z = post_process(read_channel[2].w); | |||||
dst_width[3].w = post_process(read_channel[3].w); | |||||
} | |||||
}; | |||||
// ========================================================= | |||||
// partial specialization for translayout operator for qint4 | |||||
// NCHW <-> NCHW64 | |||||
template <typename SrcType, typename DnnSrcType_, typename DnnDstType_, | |||||
bool same_scale> | |||||
struct Translayout<2, 64, SrcType, DnnSrcType_, DnnDstType_, same_scale, | |||||
typename enable_qtype_b4<DnnSrcType_, DnnDstType_>::type> { | |||||
using DnnSrcType = DnnSrcType_; | |||||
using DnnDstType = DnnDstType_; | |||||
using InnerDtype = | |||||
typename DTypeRWHelper<typename DTypeTrait<DnnSrcType>::ctype, | |||||
2>::InnerDtype; | |||||
using DstDtype = | |||||
typename DTypeRWHelper<typename DTypeTrait<DnnSrcType>::ctype, | |||||
2>::DstDtype; | |||||
static constexpr bool signedness = qtype_signedness<DnnSrcType>::value; | |||||
static inline __device__ void trans( | |||||
DstDtype (&dst_width)[2], InnerDtype (&read_channel)[64], | |||||
CudaPostProcess<DnnSrcType, DnnDstType, same_scale>& post_process, | |||||
const char zero_point) { | |||||
int intermediate[8][2]; | |||||
int* dst_frag = reinterpret_cast<int*>(dst_width); | |||||
auto pack_channel = [&](int idx) -> int { | |||||
return transform_int8_to_b4x8<signedness>( | |||||
post_process(intermediate[0][idx]), | |||||
post_process(intermediate[1][idx]), | |||||
post_process(intermediate[2][idx]), | |||||
post_process(intermediate[3][idx]), | |||||
post_process(intermediate[4][idx]), | |||||
post_process(intermediate[5][idx]), | |||||
post_process(intermediate[6][idx]), | |||||
post_process(intermediate[7][idx])); | |||||
}; | |||||
#pragma unroll | |||||
for (int i = 0; i < 64; i += 8) { | |||||
transform_b4x2_to_int8<signedness>( | |||||
intermediate[0], | |||||
reinterpret_cast<uint8_t&>(read_channel[i + 0])); | |||||
transform_b4x2_to_int8<signedness>( | |||||
intermediate[1], | |||||
reinterpret_cast<uint8_t&>(read_channel[i + 1])); | |||||
transform_b4x2_to_int8<signedness>( | |||||
intermediate[2], | |||||
reinterpret_cast<uint8_t&>(read_channel[i + 2])); | |||||
transform_b4x2_to_int8<signedness>( | |||||
intermediate[3], | |||||
reinterpret_cast<uint8_t&>(read_channel[i + 3])); | |||||
transform_b4x2_to_int8<signedness>( | |||||
intermediate[4], | |||||
reinterpret_cast<uint8_t&>(read_channel[i + 4])); | |||||
transform_b4x2_to_int8<signedness>( | |||||
intermediate[5], | |||||
reinterpret_cast<uint8_t&>(read_channel[i + 5])); | |||||
transform_b4x2_to_int8<signedness>( | |||||
intermediate[6], | |||||
reinterpret_cast<uint8_t&>(read_channel[i + 6])); | |||||
transform_b4x2_to_int8<signedness>( | |||||
intermediate[7], | |||||
reinterpret_cast<uint8_t&>(read_channel[i + 7])); | |||||
int frag_idx = i / 8; | |||||
dst_frag[0 * 8 + frag_idx] = pack_channel(0); | |||||
dst_frag[1 * 8 + frag_idx] = pack_channel(1); | |||||
} | |||||
} | |||||
using Fragment = array_wrapper<SrcType, 64>; | |||||
static inline __device__ void trans( | |||||
Fragment& dst, Fragment& src, | |||||
CudaPostProcess<DnnSrcType, DnnDstType, same_scale>& post_process) { | |||||
trans(reinterpret_cast<DstDtype(&)[2]>(dst), | |||||
reinterpret_cast<InnerDtype(&)[64]>(src), post_process, 0); | |||||
} | |||||
}; | |||||
template <typename SrcType, typename DnnSrcType_, typename DnnDstType_, | |||||
bool same_scale> | |||||
struct Translayout<8, 64, SrcType, DnnSrcType_, DnnDstType_, same_scale, | |||||
typename enable_qtype_b4<DnnSrcType_, DnnDstType_>::type> { | |||||
using DnnSrcType = DnnSrcType_; | |||||
using DnnDstType = DnnDstType_; | |||||
using InnerDtype = | |||||
typename DTypeRWHelper<typename DTypeTrait<DnnSrcType>::ctype, | |||||
8>::InnerDtype; | |||||
using DstDtype = | |||||
typename DTypeRWHelper<typename DTypeTrait<DnnSrcType>::ctype, | |||||
8>::DstDtype; | |||||
static constexpr bool signedness = qtype_signedness<DnnSrcType>::value; | |||||
static inline __device__ void trans( | |||||
DstDtype (&dst_width)[8], InnerDtype (&read_channel)[64], | |||||
CudaPostProcess<DnnSrcType, DnnDstType, same_scale>& post_process, | |||||
const char zero_point) { | |||||
int intermediate[8][8]; | |||||
int* dst_frag = reinterpret_cast<int*>(dst_width); | |||||
auto pack_channel = [&](int idx) -> int { | |||||
return transform_int8_to_b4x8<signedness>( | |||||
post_process(intermediate[0][idx]), | |||||
post_process(intermediate[1][idx]), | |||||
post_process(intermediate[2][idx]), | |||||
post_process(intermediate[3][idx]), | |||||
post_process(intermediate[4][idx]), | |||||
post_process(intermediate[5][idx]), | |||||
post_process(intermediate[6][idx]), | |||||
post_process(intermediate[7][idx])); | |||||
}; | |||||
#pragma unroll | |||||
for (int i = 0; i < 64; i += 8) { | |||||
transform_b4x8_to_int8<signedness>(intermediate[0], | |||||
read_channel[i + 0]); | |||||
transform_b4x8_to_int8<signedness>(intermediate[1], | |||||
read_channel[i + 1]); | |||||
transform_b4x8_to_int8<signedness>(intermediate[2], | |||||
read_channel[i + 2]); | |||||
transform_b4x8_to_int8<signedness>(intermediate[3], | |||||
read_channel[i + 3]); | |||||
transform_b4x8_to_int8<signedness>(intermediate[4], | |||||
read_channel[i + 4]); | |||||
transform_b4x8_to_int8<signedness>(intermediate[5], | |||||
read_channel[i + 5]); | |||||
transform_b4x8_to_int8<signedness>(intermediate[6], | |||||
read_channel[i + 6]); | |||||
transform_b4x8_to_int8<signedness>(intermediate[7], | |||||
read_channel[i + 7]); | |||||
int frag_idx = i / 8; | |||||
dst_frag[0 * 8 + frag_idx] = pack_channel(0); | |||||
dst_frag[1 * 8 + frag_idx] = pack_channel(1); | |||||
dst_frag[2 * 8 + frag_idx] = pack_channel(2); | |||||
dst_frag[3 * 8 + frag_idx] = pack_channel(3); | |||||
dst_frag[4 * 8 + frag_idx] = pack_channel(4); | |||||
dst_frag[5 * 8 + frag_idx] = pack_channel(5); | |||||
dst_frag[6 * 8 + frag_idx] = pack_channel(6); | |||||
dst_frag[7 * 8 + frag_idx] = pack_channel(7); | |||||
} | |||||
} | |||||
using Fragment = array_wrapper<unsigned, 64>; | |||||
static inline __device__ void trans( | |||||
Fragment& dst, Fragment& src, | |||||
CudaPostProcess<DnnSrcType, DnnDstType, same_scale>& post_process) { | |||||
trans(reinterpret_cast<DstDtype(&)[8]>(dst), | |||||
reinterpret_cast<InnerDtype(&)[64]>(src), post_process, 0); | |||||
} | |||||
}; | |||||
template <typename SrcType, typename DnnSrcType_, typename DnnDstType_, | |||||
bool same_scale> | |||||
struct Translayout<64, 8, SrcType, DnnSrcType_, DnnDstType_, same_scale, | |||||
typename enable_qtype_b4<DnnSrcType_, DnnDstType_>::type> { | |||||
using DnnSrcType = DnnSrcType_; | |||||
using DnnDstType = DnnDstType_; | |||||
static constexpr int row = 8; | |||||
static constexpr int col = 64; | |||||
static constexpr int size_nbits = 4; | |||||
static constexpr int col_in_type = col * size_nbits / (8 * sizeof(SrcType)); | |||||
static constexpr int elements_in_type = row * col_in_type; | |||||
static constexpr int inc_col = 8; | |||||
static constexpr int inc_col_in_type = | |||||
inc_col * size_nbits / (8 * sizeof(SrcType)); | |||||
static constexpr bool signedness = qtype_signedness<DnnSrcType>::value; | |||||
using Fragment = array_wrapper<SrcType, elements_in_type>; | |||||
static MEGDNN_DEVICE __forceinline__ void trans( | |||||
Fragment& dst, const Fragment& src, | |||||
CudaPostProcess<DnnSrcType, DnnDstType, same_scale>& post_process) { | |||||
int intermediate[8][8]; | |||||
int* dst_frag = reinterpret_cast<int*>(&dst); | |||||
auto pack = [&](int idx) -> int { | |||||
return transform_int8_to_b4x8<signedness>( | |||||
post_process(intermediate[0][idx]), | |||||
post_process(intermediate[1][idx]), | |||||
post_process(intermediate[2][idx]), | |||||
post_process(intermediate[3][idx]), | |||||
post_process(intermediate[4][idx]), | |||||
post_process(intermediate[5][idx]), | |||||
post_process(intermediate[6][idx]), | |||||
post_process(intermediate[7][idx])); | |||||
}; | |||||
#pragma unroll | |||||
for (int j = 0; j < col_in_type; j += inc_col_in_type) { | |||||
transform_b4x8_to_int8<signedness>( | |||||
intermediate[0], | |||||
reinterpret_cast<const int&>(src[0 * col_in_type + j])); | |||||
transform_b4x8_to_int8<signedness>( | |||||
intermediate[1], | |||||
reinterpret_cast<const int&>(src[1 * col_in_type + j])); | |||||
transform_b4x8_to_int8<signedness>( | |||||
intermediate[2], | |||||
reinterpret_cast<const int&>(src[2 * col_in_type + j])); | |||||
transform_b4x8_to_int8<signedness>( | |||||
intermediate[3], | |||||
reinterpret_cast<const int&>(src[3 * col_in_type + j])); | |||||
transform_b4x8_to_int8<signedness>( | |||||
intermediate[4], | |||||
reinterpret_cast<const int&>(src[4 * col_in_type + j])); | |||||
transform_b4x8_to_int8<signedness>( | |||||
intermediate[5], | |||||
reinterpret_cast<const int&>(src[5 * col_in_type + j])); | |||||
transform_b4x8_to_int8<signedness>( | |||||
intermediate[6], | |||||
reinterpret_cast<const int&>(src[6 * col_in_type + j])); | |||||
transform_b4x8_to_int8<signedness>( | |||||
intermediate[7], | |||||
reinterpret_cast<const int&>(src[7 * col_in_type + j])); | |||||
dst_frag[(j / inc_col_in_type) * 8 + 0] = pack(0); | |||||
dst_frag[(j / inc_col_in_type) * 8 + 1] = pack(1); | |||||
dst_frag[(j / inc_col_in_type) * 8 + 2] = pack(2); | |||||
dst_frag[(j / inc_col_in_type) * 8 + 3] = pack(3); | |||||
dst_frag[(j / inc_col_in_type) * 8 + 4] = pack(4); | |||||
dst_frag[(j / inc_col_in_type) * 8 + 5] = pack(5); | |||||
dst_frag[(j / inc_col_in_type) * 8 + 6] = pack(6); | |||||
dst_frag[(j / inc_col_in_type) * 8 + 7] = pack(7); | |||||
} | |||||
} | |||||
}; | |||||
template <typename SrcType, typename DnnSrcType_, typename DnnDstType_, | |||||
bool same_scale> | |||||
struct Translayout<64, 2, SrcType, DnnSrcType_, DnnDstType_, same_scale, | |||||
typename enable_qtype_b4<DnnSrcType_, DnnDstType_>::type> { | |||||
using DnnSrcType = DnnSrcType_; | |||||
using DnnDstType = DnnDstType_; | |||||
static constexpr int row = 2; | |||||
static constexpr int col = 64; | |||||
static constexpr int size_nbits = 4; | |||||
static constexpr int col_in_type = col * size_nbits / (8 * sizeof(SrcType)); | |||||
static constexpr int elements_in_type = row * col_in_type; | |||||
static constexpr int inc_col = 8; | |||||
static constexpr int inc_col_in_type = | |||||
inc_col * size_nbits / (8 * sizeof(SrcType)); | |||||
static constexpr bool signedness = qtype_signedness<DnnSrcType>::value; | |||||
using Fragment = array_wrapper<SrcType, elements_in_type>; | |||||
static MEGDNN_DEVICE __forceinline__ void trans( | |||||
Fragment& dst, const Fragment& src, | |||||
CudaPostProcess<DnnSrcType, DnnDstType, same_scale>& post_process) { | |||||
int intermediate[2][8]; | |||||
int* dst_frag = reinterpret_cast<int*>(&dst); | |||||
#pragma unroll | |||||
for (int j = 0; j < col_in_type; j += inc_col_in_type) { | |||||
transform_b4x8_to_int8<signedness>( | |||||
intermediate[0], | |||||
reinterpret_cast<const int&>(src[0 * col_in_type + j])); | |||||
transform_b4x8_to_int8<signedness>( | |||||
intermediate[1], | |||||
reinterpret_cast<const int&>(src[1 * col_in_type + j])); | |||||
dst_frag[(j / inc_col_in_type) * 2 + 0] = | |||||
transform_int8_to_b4x8<signedness>( | |||||
post_process(intermediate[0][0]), | |||||
post_process(intermediate[1][0]), | |||||
post_process(intermediate[0][1]), | |||||
post_process(intermediate[1][1]), | |||||
post_process(intermediate[0][2]), | |||||
post_process(intermediate[1][2]), | |||||
post_process(intermediate[0][3]), | |||||
post_process(intermediate[1][3])); | |||||
dst_frag[(j / inc_col_in_type) * 2 + 1] = | |||||
transform_int8_to_b4x8<signedness>( | |||||
post_process(intermediate[0][4]), | |||||
post_process(intermediate[1][4]), | |||||
post_process(intermediate[0][5]), | |||||
post_process(intermediate[1][5]), | |||||
post_process(intermediate[0][6]), | |||||
post_process(intermediate[1][6]), | |||||
post_process(intermediate[0][7]), | |||||
post_process(intermediate[1][7])); | |||||
} | |||||
} | |||||
}; | |||||
// ========================================================= | |||||
// partial specialization for translayout operator for qint4 | |||||
// NCHW <-> NHWC | |||||
template <typename SrcType, typename DnnSrcType_, typename DnnDstType_, | |||||
bool same_scale> | |||||
struct Translayout<2, 8, SrcType, DnnSrcType_, DnnDstType_, same_scale, | |||||
typename enable_qtype_b4<DnnSrcType_, DnnDstType_>::type> { | |||||
using DnnSrcType = DnnSrcType_; | |||||
using DnnDstType = DnnDstType_; | |||||
static constexpr int row = 8; | |||||
static constexpr int col = 2; | |||||
static constexpr int size_nbits = 4; | |||||
static constexpr int col_in_type = col * size_nbits / (8 * sizeof(SrcType)); | |||||
static constexpr int elements_in_type = row * col_in_type; | |||||
static constexpr bool signedness = qtype_signedness<DnnSrcType>::value; | |||||
using Fragment = array_wrapper<SrcType, elements_in_type>; | |||||
static inline __device__ void trans( | |||||
Fragment& dst, const Fragment& src, | |||||
CudaPostProcess<DnnSrcType, DnnDstType, same_scale>& post_process, | |||||
const char zero_point) { | |||||
int intermediate[8][2]; | |||||
transform_b4x2_to_int8<signedness>(intermediate[0], | |||||
reinterpret_cast<uint8_t&>(src[0])); | |||||
transform_b4x2_to_int8<signedness>(intermediate[1], | |||||
reinterpret_cast<uint8_t&>(src[1])); | |||||
transform_b4x2_to_int8<signedness>(intermediate[2], | |||||
reinterpret_cast<uint8_t&>(src[2])); | |||||
transform_b4x2_to_int8<signedness>(intermediate[3], | |||||
reinterpret_cast<uint8_t&>(src[3])); | |||||
transform_b4x2_to_int8<signedness>(intermediate[4], | |||||
reinterpret_cast<uint8_t&>(src[4])); | |||||
transform_b4x2_to_int8<signedness>(intermediate[5], | |||||
reinterpret_cast<uint8_t&>(src[5])); | |||||
transform_b4x2_to_int8<signedness>(intermediate[6], | |||||
reinterpret_cast<uint8_t&>(src[6])); | |||||
transform_b4x2_to_int8<signedness>(intermediate[7], | |||||
reinterpret_cast<uint8_t&>(src[7])); | |||||
int* dst_frag = reinterpret_cast<int*>(&dst); | |||||
auto pack = [&](int idx) -> int { | |||||
return transform_int8_to_b4x8<signedness>( | |||||
post_process(intermediate[0][idx]), | |||||
post_process(intermediate[1][idx]), | |||||
post_process(intermediate[2][idx]), | |||||
post_process(intermediate[3][idx]), | |||||
post_process(intermediate[4][idx]), | |||||
post_process(intermediate[5][idx]), | |||||
post_process(intermediate[6][idx]), | |||||
post_process(intermediate[7][idx])); | |||||
}; | |||||
dst_frag[0] = pack(0); | |||||
dst_frag[1] = pack(1); | |||||
} | |||||
}; | |||||
template <typename SrcType, typename DnnSrcType_, typename DnnDstType_, | |||||
bool same_scale> | |||||
struct Translayout<8, 8, SrcType, DnnSrcType_, DnnDstType_, same_scale, | |||||
typename enable_qtype_b4<DnnSrcType_, DnnDstType_>::type> { | |||||
using DnnSrcType = DnnSrcType_; | |||||
using DnnDstType = DnnDstType_; | |||||
static constexpr int row = 8; | |||||
static constexpr int col = 8; | |||||
static constexpr int size_nbits = 4; | |||||
static constexpr int col_in_type = col * size_nbits / (8 * sizeof(SrcType)); | |||||
static constexpr int elements_in_type = row * col_in_type; | |||||
static constexpr bool signedness = qtype_signedness<DnnSrcType>::value; | |||||
using Fragment = array_wrapper<SrcType, elements_in_type>; | |||||
static inline __device__ void trans( | |||||
Fragment& dst, const Fragment& src, | |||||
CudaPostProcess<DnnSrcType, DnnDstType, same_scale>& post_process, | |||||
const char zero_point) { | |||||
int intermediate[8][8]; | |||||
transform_b4x8_to_int8<signedness>( | |||||
intermediate[0], reinterpret_cast<const int&>(src[0])); | |||||
transform_b4x8_to_int8<signedness>( | |||||
intermediate[1], reinterpret_cast<const int&>(src[1])); | |||||
transform_b4x8_to_int8<signedness>( | |||||
intermediate[2], reinterpret_cast<const int&>(src[2])); | |||||
transform_b4x8_to_int8<signedness>( | |||||
intermediate[3], reinterpret_cast<const int&>(src[3])); | |||||
transform_b4x8_to_int8<signedness>( | |||||
intermediate[4], reinterpret_cast<const int&>(src[4])); | |||||
transform_b4x8_to_int8<signedness>( | |||||
intermediate[5], reinterpret_cast<const int&>(src[5])); | |||||
transform_b4x8_to_int8<signedness>( | |||||
intermediate[6], reinterpret_cast<const int&>(src[6])); | |||||
transform_b4x8_to_int8<signedness>( | |||||
intermediate[7], reinterpret_cast<const int&>(src[7])); | |||||
int* dst_frag = reinterpret_cast<int*>(&dst); | |||||
auto pack = [&](int idx) { | |||||
return transform_int8_to_b4x8<signedness>( | |||||
post_process(intermediate[0][idx]), | |||||
post_process(intermediate[1][idx]), | |||||
post_process(intermediate[2][idx]), | |||||
post_process(intermediate[3][idx]), | |||||
post_process(intermediate[4][idx]), | |||||
post_process(intermediate[5][idx]), | |||||
post_process(intermediate[6][idx]), | |||||
post_process(intermediate[7][idx])); | |||||
}; | |||||
dst_frag[0] = pack(0); | |||||
dst_frag[1] = pack(1); | |||||
dst_frag[2] = pack(2); | |||||
dst_frag[3] = pack(3); | |||||
dst_frag[4] = pack(4); | |||||
dst_frag[5] = pack(5); | |||||
dst_frag[6] = pack(6); | |||||
dst_frag[7] = pack(7); | |||||
} | |||||
}; | |||||
template <typename SrcType, typename DnnSrcType_, typename DnnDstType_, | |||||
bool same_scale> | |||||
struct Translayout<8, 2, SrcType, DnnSrcType_, DnnDstType_, same_scale, | |||||
typename enable_qtype_b4<DnnSrcType_, DnnDstType_>::type> { | |||||
using DnnSrcType = DnnSrcType_; | |||||
using DnnDstType = DnnDstType_; | |||||
static constexpr int row = 2; | |||||
static constexpr int col = 8; | |||||
static constexpr int size_nbits = 4; | |||||
static constexpr int col_in_type = col * size_nbits / (8 * sizeof(SrcType)); | |||||
static constexpr int elements_in_type = row * col_in_type; | |||||
static constexpr bool signedness = qtype_signedness<DnnSrcType>::value; | |||||
using Fragment = array_wrapper<SrcType, elements_in_type>; | |||||
static inline __device__ void trans( | |||||
Fragment& dst, const Fragment& src, | |||||
CudaPostProcess<DnnSrcType, DnnDstType, same_scale>& post_process, | |||||
const char zero_point) { | |||||
int intermediate[2][8]; | |||||
transform_b4x8_to_int8<signedness>( | |||||
intermediate[0], reinterpret_cast<const int&>(src[0])); | |||||
transform_b4x8_to_int8<signedness>( | |||||
intermediate[1], reinterpret_cast<const int&>(src[1])); | |||||
int* dst_frag = reinterpret_cast<int*>(&dst); | |||||
dst_frag[0] = transform_int8_to_b4x8<signedness>( | |||||
post_process(intermediate[0][0]), | |||||
post_process(intermediate[1][0]), | |||||
post_process(intermediate[0][1]), | |||||
post_process(intermediate[1][1]), | |||||
post_process(intermediate[0][2]), | |||||
post_process(intermediate[1][2]), | |||||
post_process(intermediate[0][3]), | |||||
post_process(intermediate[1][3])); | |||||
dst_frag[1] = transform_int8_to_b4x8<signedness>( | |||||
post_process(intermediate[0][4]), | |||||
post_process(intermediate[1][4]), | |||||
post_process(intermediate[0][5]), | |||||
post_process(intermediate[1][5]), | |||||
post_process(intermediate[0][6]), | |||||
post_process(intermediate[1][6]), | |||||
post_process(intermediate[0][7]), | |||||
post_process(intermediate[1][7])); | |||||
} | |||||
}; | |||||
} // namespace internal | |||||
} // namespace relayout_format | |||||
} // namespace cuda | |||||
} // namespace megdnn |
@@ -176,60 +176,22 @@ __global__ void kern_general_nchw4(SrcVisitor src, const float* __restrict mat, | |||||
} | } | ||||
} | } | ||||
template <bool signedness> | |||||
MEGDNN_DEVICE __forceinline__ int transform_int8_to_bit4x8(int s0, int s1, | |||||
int s2, int s3, | |||||
int s4, int s5, | |||||
int s6, int s7); | |||||
template <> | |||||
MEGDNN_DEVICE __forceinline__ int transform_int8_to_bit4x8<true>( | |||||
int s0, int s1, int s2, int s3, int s4, int s5, int s6, int s7) { | |||||
return transform_int8_to_int4x8(s0, s1, s2, s3, s4, s5, s6, s7); | |||||
} | |||||
template <> | |||||
MEGDNN_DEVICE __forceinline__ int transform_int8_to_bit4x8<false>( | |||||
int s0, int s1, int s2, int s3, int s4, int s5, int s6, int s7) { | |||||
return transform_int8_to_uint4x8(s0, s1, s2, s3, s4, s5, s6, s7); | |||||
} | |||||
template <bool signedness> | |||||
MEGDNN_DEVICE __forceinline__ void | |||||
transform_bit4x8_to_int8(int (&result)[8], const int& source); | |||||
template <> | |||||
MEGDNN_DEVICE __forceinline__ void | |||||
transform_bit4x8_to_int8<true>(int (&result)[8], const int& source){ | |||||
transform_int4x8_to_int8(result, source); | |||||
} | |||||
template <> | |||||
MEGDNN_DEVICE __forceinline__ void | |||||
transform_bit4x8_to_int8<false>(int (&result)[8], const int& source){ | |||||
transform_uint4x8_to_int8(result, source); | |||||
} | |||||
template <bool signedness, typename OutputConverter> | template <bool signedness, typename OutputConverter> | ||||
MEGDNN_DEVICE __forceinline__ int pack_output_func( | MEGDNN_DEVICE __forceinline__ int pack_output_func( | ||||
OutputConverter& output_converter, int (&s00)[8], int (&s01)[8], | OutputConverter& output_converter, int (&s00)[8], int (&s01)[8], | ||||
int (&s10)[8], int (&s11)[8], float w00, float w01, float w10, | int (&s10)[8], int (&s11)[8], float w00, float w01, float w10, | ||||
float w11) { | float w11) { | ||||
#define warp_perspective_transform(idx) \ | |||||
static_cast<int>(output_converter(s00[idx] * w00 + \ | |||||
s01[idx] * w01 + \ | |||||
s10[idx] * w10 + \ | |||||
s11[idx] * w11) \ | |||||
#define warp_perspective_transform(idx) \ | |||||
static_cast<int>(output_converter(s00[idx] * w00 + s01[idx] * w01 + \ | |||||
s10[idx] * w10 + s11[idx] * w11) \ | |||||
.as_storage()) | .as_storage()) | ||||
return transform_int8_to_bit4x8<signedness>( | |||||
return transform_int8_to_b4x8<signedness>( | |||||
warp_perspective_transform(0), warp_perspective_transform(1), | warp_perspective_transform(0), warp_perspective_transform(1), | ||||
warp_perspective_transform(2), warp_perspective_transform(3), | warp_perspective_transform(2), warp_perspective_transform(3), | ||||
warp_perspective_transform(4), warp_perspective_transform(5), | warp_perspective_transform(4), warp_perspective_transform(5), | ||||
warp_perspective_transform(6), warp_perspective_transform(7)); | warp_perspective_transform(6), warp_perspective_transform(7)); | ||||
#undef warp_perspective_transform | |||||
#undef warp_perspective_transform | |||||
} | } | ||||
template <typename ctype, typename Getter, typename SrcVisitor, | template <typename ctype, typename Getter, typename SrcVisitor, | ||||
@@ -278,31 +240,31 @@ __global__ void kern_general_nchw64(SrcVisitor src, const float* __restrict mat, | |||||
s[2] = __ldg(sptr_int4 + i_coor_10 + c1); | s[2] = __ldg(sptr_int4 + i_coor_10 + c1); | ||||
s[3] = __ldg(sptr_int4 + i_coor_11 + c1); | s[3] = __ldg(sptr_int4 + i_coor_11 + c1); | ||||
transform_bit4x8_to_int8<signedness>(s00, s[0].x); | |||||
transform_bit4x8_to_int8<signedness>(s01, s[1].x); | |||||
transform_bit4x8_to_int8<signedness>(s10, s[2].x); | |||||
transform_bit4x8_to_int8<signedness>(s11, s[3].x); | |||||
transform_b4x8_to_int8<signedness>(s00, s[0].x); | |||||
transform_b4x8_to_int8<signedness>(s01, s[1].x); | |||||
transform_b4x8_to_int8<signedness>(s10, s[2].x); | |||||
transform_b4x8_to_int8<signedness>(s11, s[3].x); | |||||
d.x = pack_output_func<signedness>(output_converter, s00, s01, s10, | d.x = pack_output_func<signedness>(output_converter, s00, s01, s10, | ||||
s11, w00, w01, w10, w11); | s11, w00, w01, w10, w11); | ||||
transform_bit4x8_to_int8<signedness>(s00, s[0].y); | |||||
transform_bit4x8_to_int8<signedness>(s01, s[1].y); | |||||
transform_bit4x8_to_int8<signedness>(s10, s[2].y); | |||||
transform_bit4x8_to_int8<signedness>(s11, s[3].y); | |||||
transform_b4x8_to_int8<signedness>(s00, s[0].y); | |||||
transform_b4x8_to_int8<signedness>(s01, s[1].y); | |||||
transform_b4x8_to_int8<signedness>(s10, s[2].y); | |||||
transform_b4x8_to_int8<signedness>(s11, s[3].y); | |||||
d.y = pack_output_func<signedness>(output_converter, s00, s01, s10, | d.y = pack_output_func<signedness>(output_converter, s00, s01, s10, | ||||
s11, w00, w01, w10, w11); | s11, w00, w01, w10, w11); | ||||
transform_bit4x8_to_int8<signedness>(s00, s[0].z); | |||||
transform_bit4x8_to_int8<signedness>(s01, s[1].z); | |||||
transform_bit4x8_to_int8<signedness>(s10, s[2].z); | |||||
transform_bit4x8_to_int8<signedness>(s11, s[3].z); | |||||
transform_b4x8_to_int8<signedness>(s00, s[0].z); | |||||
transform_b4x8_to_int8<signedness>(s01, s[1].z); | |||||
transform_b4x8_to_int8<signedness>(s10, s[2].z); | |||||
transform_b4x8_to_int8<signedness>(s11, s[3].z); | |||||
d.z = pack_output_func<signedness>(output_converter, s00, s01, s10, | d.z = pack_output_func<signedness>(output_converter, s00, s01, s10, | ||||
s11, w00, w01, w10, w11); | s11, w00, w01, w10, w11); | ||||
transform_bit4x8_to_int8<signedness>(s00, s[0].w); | |||||
transform_bit4x8_to_int8<signedness>(s01, s[1].w); | |||||
transform_bit4x8_to_int8<signedness>(s10, s[2].w); | |||||
transform_bit4x8_to_int8<signedness>(s11, s[3].w); | |||||
transform_b4x8_to_int8<signedness>(s00, s[0].w); | |||||
transform_b4x8_to_int8<signedness>(s01, s[1].w); | |||||
transform_b4x8_to_int8<signedness>(s10, s[2].w); | |||||
transform_b4x8_to_int8<signedness>(s11, s[3].w); | |||||
d.w = pack_output_func<signedness>(output_converter, s00, s01, s10, | d.w = pack_output_func<signedness>(output_converter, s00, s01, s10, | ||||
s11, w00, w01, w10, w11); | s11, w00, w01, w10, w11); | ||||
@@ -403,15 +365,7 @@ __global__ void kern_const_border_nchw4(SrcVisitor src, | |||||
} | } | ||||
} | } | ||||
} | } | ||||
template <bool signedness> | |||||
MEGDNN_DEVICE __forceinline__ static void transform_bit4x8_to_int8( | |||||
int (&result)[8], const int& source) { | |||||
#pragma unroll | |||||
for (int i = 0; i < 8; i++) { | |||||
result[i] = unpack_integer_4bits<signedness>( | |||||
reinterpret_cast<unsigned const&>(source), (i << 2)); | |||||
} | |||||
} | |||||
template <typename ctype, typename SrcVisitor, typename OutputConverter> | template <typename ctype, typename SrcVisitor, typename OutputConverter> | ||||
__global__ void kern_const_border_nchw64(SrcVisitor src, | __global__ void kern_const_border_nchw64(SrcVisitor src, | ||||
@@ -457,7 +411,7 @@ __global__ void kern_const_border_nchw64(SrcVisitor src, | |||||
bool flag00 = okh0 && okw0, flag01 = okh0 && okw1, | bool flag00 = okh0 && okw0, flag01 = okh0 && okw1, | ||||
flag10 = okh1 && okw0, flag11 = okh1 && okw1; | flag10 = okh1 && okw0, flag11 = okh1 && okw1; | ||||
int8_t bval_4 = bval.as_storage() & 0xF; | int8_t bval_4 = bval.as_storage() & 0xF; | ||||
int bval_8 = transform_int8_to_bit4x8<signedness>( | |||||
int bval_8 = transform_int8_to_b4x8<signedness>( | |||||
bval_4, bval_4, bval_4, bval_4, bval_4, bval_4, bval_4, bval_4); | bval_4, bval_4, bval_4, bval_4, bval_4, bval_4, bval_4, bval_4); | ||||
int4 bval_int4; | int4 bval_int4; | ||||
bval_int4.x = bval_8; | bval_int4.x = bval_8; | ||||
@@ -488,31 +442,31 @@ __global__ void kern_const_border_nchw64(SrcVisitor src, | |||||
s[3] = bval_int4; | s[3] = bval_int4; | ||||
} | } | ||||
transform_bit4x8_to_int8<signedness>(s00, s[0].x); | |||||
transform_bit4x8_to_int8<signedness>(s01, s[1].x); | |||||
transform_bit4x8_to_int8<signedness>(s10, s[2].x); | |||||
transform_bit4x8_to_int8<signedness>(s11, s[3].x); | |||||
transform_b4x8_to_int8<signedness>(s00, s[0].x); | |||||
transform_b4x8_to_int8<signedness>(s01, s[1].x); | |||||
transform_b4x8_to_int8<signedness>(s10, s[2].x); | |||||
transform_b4x8_to_int8<signedness>(s11, s[3].x); | |||||
d.x = pack_output_func<signedness>(output_converter, s00, s01, s10, | d.x = pack_output_func<signedness>(output_converter, s00, s01, s10, | ||||
s11, w00, w01, w10, w11); | s11, w00, w01, w10, w11); | ||||
transform_bit4x8_to_int8<signedness>(s00, s[0].y); | |||||
transform_bit4x8_to_int8<signedness>(s01, s[1].y); | |||||
transform_bit4x8_to_int8<signedness>(s10, s[2].y); | |||||
transform_bit4x8_to_int8<signedness>(s11, s[3].y); | |||||
transform_b4x8_to_int8<signedness>(s00, s[0].y); | |||||
transform_b4x8_to_int8<signedness>(s01, s[1].y); | |||||
transform_b4x8_to_int8<signedness>(s10, s[2].y); | |||||
transform_b4x8_to_int8<signedness>(s11, s[3].y); | |||||
d.y = pack_output_func<signedness>(output_converter, s00, s01, s10, | d.y = pack_output_func<signedness>(output_converter, s00, s01, s10, | ||||
s11, w00, w01, w10, w11); | s11, w00, w01, w10, w11); | ||||
transform_bit4x8_to_int8<signedness>(s00, s[0].z); | |||||
transform_bit4x8_to_int8<signedness>(s01, s[1].z); | |||||
transform_bit4x8_to_int8<signedness>(s10, s[2].z); | |||||
transform_bit4x8_to_int8<signedness>(s11, s[3].z); | |||||
transform_b4x8_to_int8<signedness>(s00, s[0].z); | |||||
transform_b4x8_to_int8<signedness>(s01, s[1].z); | |||||
transform_b4x8_to_int8<signedness>(s10, s[2].z); | |||||
transform_b4x8_to_int8<signedness>(s11, s[3].z); | |||||
d.z = pack_output_func<signedness>(output_converter, s00, s01, s10, | d.z = pack_output_func<signedness>(output_converter, s00, s01, s10, | ||||
s11, w00, w01, w10, w11); | s11, w00, w01, w10, w11); | ||||
transform_bit4x8_to_int8<signedness>(s00, s[0].w); | |||||
transform_bit4x8_to_int8<signedness>(s01, s[1].w); | |||||
transform_bit4x8_to_int8<signedness>(s10, s[2].w); | |||||
transform_bit4x8_to_int8<signedness>(s11, s[3].w); | |||||
transform_b4x8_to_int8<signedness>(s00, s[0].w); | |||||
transform_b4x8_to_int8<signedness>(s01, s[1].w); | |||||
transform_b4x8_to_int8<signedness>(s10, s[2].w); | |||||
transform_b4x8_to_int8<signedness>(s11, s[3].w); | |||||
d.w = pack_output_func<signedness>(output_converter, s00, s01, s10, | d.w = pack_output_func<signedness>(output_converter, s00, s01, s10, | ||||
s11, w00, w01, w10, w11); | s11, w00, w01, w10, w11); | ||||