GitOrigin-RevId: ab86e66533
release-1.5
@@ -110,35 +110,33 @@ MEGDNN_DEVICE __forceinline__ static int unpack_integer_4bits(T storage, | |||
return (result << (shift - bits)) >> shift; | |||
} | |||
MEGDNN_DEVICE __forceinline__ static void transform_int4x8_to_int8( | |||
int (&result)[8], const int& source) { | |||
#pragma unroll | |||
for (int i = 0; i < 8; i++) { | |||
result[i] = unpack_integer_4bits<true>( | |||
reinterpret_cast<unsigned const&>(source), (i << 2)); | |||
} | |||
} | |||
MEGDNN_DEVICE __forceinline__ static void transform_uint4x8_to_int8( | |||
template <bool signedness> | |||
MEGDNN_DEVICE __forceinline__ static void transform_b4x8_to_int8( | |||
int (&result)[8], const int& source) { | |||
#pragma unroll | |||
for (int i = 0; i < 8; i++) { | |||
result[i] = unpack_integer_4bits<false>( | |||
result[i] = unpack_integer_4bits<signedness>( | |||
reinterpret_cast<unsigned const&>(source), (i << 2)); | |||
} | |||
} | |||
MEGDNN_DEVICE __forceinline__ static void transform_int4x2_to_int8( | |||
template <bool signedness> | |||
MEGDNN_DEVICE __forceinline__ static void transform_b4x2_to_int8( | |||
int (&result)[2], const uint8_t& source) { | |||
result[0] = unpack_integer_4bits<true>(source, 0); | |||
result[1] = unpack_integer_4bits<true>(source, 4); | |||
result[0] = unpack_integer_4bits<signedness>(source, 0); | |||
result[1] = unpack_integer_4bits<signedness>(source, 4); | |||
} | |||
MEGDNN_DEVICE __forceinline__ static void transform_uint4x2_to_int8( | |||
int (&result)[2], const uint8_t& source) { | |||
result[0] = unpack_integer_4bits<false>(source, 0); | |||
result[1] = unpack_integer_4bits<false>(source, 4); | |||
template <bool signedness> | |||
MEGDNN_DEVICE __forceinline__ static int transform_int8_to_b4x8( | |||
int s0, int s1, int s2, int s3, int s4, int s5, int s6, int s7) { | |||
if (signedness) { | |||
return transform_int8_to_int4x8(s0, s1, s2, s3, s4, s5, s6, s7); | |||
} else { | |||
return transform_int8_to_uint4x8(s0, s1, s2, s3, s4, s5, s6, s7); | |||
} | |||
} | |||
} // namespace integer_subbyte | |||
} // namespace cuda | |||
} // namespace megdnn | |||
@@ -0,0 +1,171 @@ | |||
/** | |||
* \file dnn/src/cuda/relayout_format/cuda_post_process.cuh | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
* implied. | |||
*/ | |||
#pragma once | |||
#include "src/cuda/relayout_format/relayout_format.cuh" | |||
namespace megdnn { | |||
namespace cuda { | |||
namespace relayout_format { | |||
namespace internal { | |||
template <typename SrcType, typename DstType, bool same_scale> | |||
struct CudaPostProcess; | |||
template <> | |||
struct CudaPostProcess<dtype::Uint8, dtype::QuantizedS8, true> { | |||
CudaPostProcess(float, uint8_t, float, uint8_t){}; | |||
inline __device__ int8_t operator()(uint8_t val) { return val - 128; } | |||
}; | |||
template <> | |||
struct CudaPostProcess<dtype::Uint8, dtype::QuantizedS8, false> { | |||
CudaDTypeParamImpl<dt_qint8> m_dst_type_cvt; | |||
CudaPostProcess(float, uint8_t, float dst_scale, uint8_t) { | |||
m_dst_type_cvt = CudaDTypeParamImpl<dt_qint8>(dst_scale); | |||
}; | |||
inline __device__ int8_t operator()(uint8_t val) { | |||
return m_dst_type_cvt.quantize((float)val - 128.f).as_int8(); | |||
} | |||
}; | |||
template <> | |||
struct CudaPostProcess<dtype::Quantized8Asymm, dtype::QuantizedS8, false> { | |||
CudaDTypeParamImpl<dt_qint8> m_dst_type_cvt; | |||
CudaDTypeParamImpl<dt_quint8> m_src_type_cvt; | |||
CudaPostProcess(float src_scale, uint8_t src_zero_point, float dst_scale, | |||
uint8_t) { | |||
m_dst_type_cvt = CudaDTypeParamImpl<dt_qint8>(dst_scale); | |||
m_src_type_cvt = | |||
CudaDTypeParamImpl<dt_quint8>(src_scale, src_zero_point); | |||
}; | |||
inline __device__ int8_t operator()(uint8_t val) { | |||
float med_var = m_src_type_cvt.dequantize(dt_quint8(val)); | |||
return m_dst_type_cvt.quantize(med_var).as_int8(); | |||
} | |||
}; | |||
template <> | |||
struct CudaPostProcess<dtype::Quantized8Asymm, dtype::QuantizedS8, true> { | |||
uint8_t m_src_zero_point = 0; | |||
CudaPostProcess(float, uint8_t src_zero_point, float, uint8_t) { | |||
m_src_zero_point = src_zero_point; | |||
}; | |||
inline __device__ int8_t operator()(uint8_t val) { | |||
return val - m_src_zero_point; | |||
} | |||
}; | |||
template <> | |||
struct CudaPostProcess<dtype::QuantizedS8, dtype::QuantizedS8, false> { | |||
CudaDTypeParamImpl<dt_qint8> m_dst_type_cvt; | |||
CudaDTypeParamImpl<dt_qint8> m_src_type_cvt; | |||
CudaPostProcess(float src_scale, uint8_t, float dst_scale, uint8_t) { | |||
m_dst_type_cvt = CudaDTypeParamImpl<dt_qint8>(dst_scale); | |||
m_src_type_cvt = CudaDTypeParamImpl<dt_qint8>(src_scale); | |||
}; | |||
inline __device__ int8_t operator()(int8_t val) { | |||
float med_var = m_src_type_cvt.dequantize(dt_qint8(val)); | |||
return m_dst_type_cvt.quantize(med_var).as_int8(); | |||
} | |||
}; | |||
template <> | |||
struct CudaPostProcess<dtype::QuantizedS8, dtype::QuantizedS8, true> { | |||
CudaPostProcess(){}; | |||
CudaPostProcess(float, uint8_t, float, uint8_t){}; | |||
inline __device__ int8_t operator()(int8_t val) { return val; } | |||
}; | |||
template <> | |||
struct CudaPostProcess<dtype::QuantizedS32, dtype::QuantizedS32, false> { | |||
CudaDTypeParamImpl<dt_qint32> m_dst_type_cvt; | |||
CudaDTypeParamImpl<dt_qint32> m_src_type_cvt; | |||
CudaPostProcess(float src_scale, int, float dst_scale, int) { | |||
m_dst_type_cvt = CudaDTypeParamImpl<dt_qint32>(dst_scale); | |||
m_src_type_cvt = CudaDTypeParamImpl<dt_qint32>(src_scale); | |||
}; | |||
inline __device__ int operator()(int val) { | |||
float med_var = m_src_type_cvt.dequantize(dt_qint32(val)); | |||
return m_dst_type_cvt.quantize(med_var).as_int32(); | |||
} | |||
}; | |||
template <> | |||
struct CudaPostProcess<dtype::QuantizedS32, dtype::QuantizedS32, true> { | |||
CudaPostProcess(float, int, float, int){}; | |||
inline __device__ int operator()(int val) { return val; } | |||
}; | |||
template <> | |||
struct CudaPostProcess<dtype::QuantizedS4, dtype::QuantizedS4, false> { | |||
using SrcType = dtype::QuantizedS4; | |||
using DstType = dtype::QuantizedS4; | |||
CudaDTypeParamImpl<dt_qint4> m_dst_type_cvt; | |||
CudaDTypeParamImpl<dt_qint4> m_src_type_cvt; | |||
CudaPostProcess(float src_scale, uint8_t, float dst_scale, uint8_t) { | |||
m_dst_type_cvt = CudaDTypeParamImpl<dt_qint4>(dst_scale); | |||
m_src_type_cvt = CudaDTypeParamImpl<dt_qint4>(src_scale); | |||
} | |||
inline __device__ int8_t operator()(int8_t val) { | |||
float intermediate = m_src_type_cvt.dequantize(dt_qint4(val)); | |||
return m_dst_type_cvt.quantize(intermediate).as_int8(); | |||
} | |||
}; | |||
template <> | |||
struct CudaPostProcess<dtype::QuantizedS4, dtype::QuantizedS4, true> { | |||
using SrcType = dtype::QuantizedS4; | |||
using DstType = dtype::QuantizedS4; | |||
CudaPostProcess(float, uint8_t, float, uint8_t){}; | |||
inline __device__ int8_t operator()(int8_t val) { return val; } | |||
}; | |||
template <> | |||
struct CudaPostProcess<dtype::Quantized4Asymm, dtype::Quantized4Asymm, false> { | |||
using SrcType = dtype::Quantized4Asymm; | |||
using DstType = dtype::Quantized4Asymm; | |||
CudaDTypeParamImpl<dt_quint4> m_dst_type_cvt; | |||
CudaDTypeParamImpl<dt_quint4> m_src_type_cvt; | |||
CudaPostProcess(float src_scale, uint8_t src_zero_point, float dst_scale, | |||
uint8_t dst_zero_point) { | |||
m_dst_type_cvt = | |||
CudaDTypeParamImpl<dt_quint4>(dst_scale, dst_zero_point); | |||
m_src_type_cvt = | |||
CudaDTypeParamImpl<dt_quint4>(src_scale, src_zero_point); | |||
}; | |||
inline __device__ uint8_t operator()(uint8_t val) { | |||
float intermediate = m_src_type_cvt.dequantize(dt_quint4(val)); | |||
return m_dst_type_cvt.quantize(intermediate).as_uint8(); | |||
} | |||
}; | |||
template <> | |||
struct CudaPostProcess<dtype::Quantized4Asymm, dtype::Quantized4Asymm, true> { | |||
using SrcType = dtype::Quantized4Asymm; | |||
using DstType = dtype::Quantized4Asymm; | |||
uint8_t m_src_zero_point = 0; | |||
uint8_t m_dst_zero_point = 0; | |||
CudaPostProcess(float, uint8_t src_zero_point, float, | |||
uint8_t dst_zero_point) { | |||
m_src_zero_point = src_zero_point; | |||
m_dst_zero_point = dst_zero_point; | |||
}; | |||
inline __device__ uint8_t operator()(uint8_t val) { | |||
int result = val - m_src_zero_point + m_dst_zero_point; | |||
result = result >= 0 ? result : 0; | |||
result = result < 16 ? result : 15; | |||
return static_cast<uint8_t>(result); | |||
} | |||
}; | |||
} // namespace internal | |||
} // namespace relayout_format | |||
} // namespace cuda | |||
} // namespace megdnn |
@@ -1,252 +0,0 @@ | |||
/** | |||
* \file dnn/src/cuda/relayout_format/helper.cuh | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
* implied. | |||
*/ | |||
namespace megdnn { | |||
namespace cuda { | |||
namespace relayout_format { | |||
#define devfunc __forceinline__ __device__ | |||
template <int size_nbits> | |||
devfunc int make_zero(int zero_point); | |||
template <> | |||
devfunc int make_zero<4>(int zero_point) { | |||
return transform_int8_to_uint4x8(zero_point, zero_point, zero_point, | |||
zero_point, zero_point, zero_point, | |||
zero_point, zero_point); | |||
} | |||
template <typename AccessType, int LoadBytes> | |||
struct global_load_with_zero_point; | |||
///////////////////////////////////////////////////////////////////////////////////////////////// | |||
// | |||
// Specializations | |||
// | |||
///////////////////////////////////////////////////////////////////////////////////////////////// | |||
///////////////////////////////////////////////////////////////////////////////////////////////// | |||
// The redundant mov PTX instruction is used to enforce the compiler to | |||
// initialize data to zero before ld.global | |||
template <typename AccessType> | |||
struct global_load_with_zero_point<AccessType, 32> { | |||
devfunc global_load_with_zero_point(AccessType& D, void const* ptr, | |||
bool pred_guard, int zero_point) { | |||
uint4* data = reinterpret_cast<uint4*>(&D); | |||
asm volatile( | |||
"{\n" | |||
" .reg .pred p;\n" | |||
" setp.ne.b32 p, %9, 0;\n" | |||
" mov.b32 %0, %10;\n" | |||
" mov.b32 %1, %10;\n" | |||
" mov.b32 %2, %10;\n" | |||
" mov.b32 %3, %10;\n" | |||
" mov.b32 %4, %10;\n" | |||
" mov.b32 %5, %10;\n" | |||
" mov.b32 %6, %10;\n" | |||
" mov.b32 %7, %10;\n" | |||
" @p ld.global.v4.u32 {%0, %1, %2, %3}, [%8];\n" | |||
" @p ld.global.v4.u32 {%4, %5, %6, %7}, [%11];\n" | |||
"}\n" | |||
: "=r"(data[0].x), "=r"(data[0].y), "=r"(data[0].z), | |||
"=r"(data[0].w), "=r"(data[1].x), "=r"(data[1].y), | |||
"=r"(data[1].z), "=r"(data[1].w) | |||
: "l"(ptr), "r"((int)pred_guard), | |||
"r"(reinterpret_cast<unsigned&>(zero_point)), | |||
"l"(((uint8_t*)ptr) + 16)); | |||
} | |||
}; | |||
template <typename AccessType> | |||
struct global_load_with_zero_point<AccessType, 16> { | |||
devfunc global_load_with_zero_point(AccessType& D, void const* ptr, | |||
bool pred_guard, int zero_point) { | |||
uint4& data = reinterpret_cast<uint4&>(D); | |||
asm volatile( | |||
"{\n" | |||
" .reg .pred p;\n" | |||
" setp.ne.b32 p, %5, 0;\n" | |||
" mov.b32 %0, %6;\n" | |||
" mov.b32 %1, %6;\n" | |||
" mov.b32 %2, %6;\n" | |||
" mov.b32 %3, %6;\n" | |||
" @p ld.global.v4.u32 {%0, %1, %2, %3}, [%4];\n" | |||
"}\n" | |||
: "=r"(data.x), "=r"(data.y), "=r"(data.z), "=r"(data.w) | |||
: "l"(ptr), "r"((int)pred_guard), | |||
"r"(reinterpret_cast<unsigned&>(zero_point))); | |||
} | |||
}; | |||
template <typename AccessType> | |||
struct global_load_with_zero_point<AccessType, 8> { | |||
devfunc global_load_with_zero_point(AccessType& D, void const* ptr, | |||
bool pred_guard, int zero_point) { | |||
uint2& data = reinterpret_cast<uint2&>(D); | |||
asm volatile( | |||
"{\n" | |||
" .reg .pred p;\n" | |||
" setp.ne.b32 p, %3, 0;\n" | |||
" mov.b32 %0, %4;\n" | |||
" mov.b32 %1, %4;\n" | |||
" @p ld.global.v2.u32 {%0, %1}, [%2];\n" | |||
"}\n" | |||
: "=r"(data.x), "=r"(data.y) | |||
: "l"(ptr), "r"((int)pred_guard), | |||
"r"(reinterpret_cast<unsigned&>(zero_point))); | |||
} | |||
}; | |||
template <typename AccessType> | |||
struct global_load_with_zero_point<AccessType, 4> { | |||
devfunc global_load_with_zero_point(AccessType& D, void const* ptr, | |||
bool pred_guard, int zero_point) { | |||
unsigned& data = reinterpret_cast<unsigned&>(D); | |||
asm volatile( | |||
"{\n" | |||
" .reg .pred p;\n" | |||
" setp.ne.b32 p, %2, 0;\n" | |||
" mov.b32 %0, %3;\n" | |||
" @p ld.global.u32 %0, [%1];\n" | |||
"}\n" | |||
: "=r"(data) | |||
: "l"(ptr), "r"((int)pred_guard), | |||
"r"(reinterpret_cast<unsigned&>(zero_point))); | |||
} | |||
}; | |||
template <typename AccessType> | |||
struct global_load_with_zero_point<AccessType, 1> { | |||
devfunc global_load_with_zero_point(AccessType& D, void const* ptr, | |||
bool pred_guard, int zero_point) { | |||
if (pred_guard) | |||
D = *(reinterpret_cast<AccessType const*>(ptr)); | |||
else { | |||
unsigned uv = reinterpret_cast<unsigned&>(zero_point); | |||
uint8_t& data = reinterpret_cast<uint8_t&>(D); | |||
data = uv & 0xff; | |||
} | |||
} | |||
}; | |||
///////////////////////////////////////////////////////////////////////////////////////////////// | |||
template < | |||
/// Fragment type to store loaded data | |||
typename AccessType, | |||
/// The bytes of loading | |||
int LoadBytes> | |||
struct global_store; | |||
///////////////////////////////////////////////////////////////////////////////////////////////// | |||
// | |||
// Specializations | |||
// | |||
///////////////////////////////////////////////////////////////////////////////////////////////// | |||
template <typename AccessType> | |||
struct global_store<AccessType, 32> { | |||
devfunc global_store(AccessType const& D, void* ptr, bool pred_guard) { | |||
uint4 const* data = reinterpret_cast<uint4 const*>(&D); | |||
asm volatile( | |||
"{\n" | |||
" .reg .pred p;\n" | |||
" setp.ne.b32 p, %5, 0;\n" | |||
" @p st.global.v4.u32 [%0], {%1, %2, %3, %4};\n" | |||
" @p st.global.v4.u32 [%6], {%7, %8, %9, %10};\n" | |||
"}\n" | |||
: | |||
: "l"(ptr), "r"(data[0].x), "r"(data[0].y), "r"(data[0].z), | |||
"r"(data[0].w), "r"((int)pred_guard), | |||
"l"(((uint8_t*)ptr) + 16), "r"(data[1].x), "r"(data[1].y), | |||
"r"(data[1].z), "r"(data[1].w)); | |||
} | |||
}; | |||
template <typename AccessType> | |||
struct global_store<AccessType, 16> { | |||
devfunc global_store(AccessType const& D, void* ptr, bool pred_guard) { | |||
uint4 const& data = reinterpret_cast<uint4 const&>(D); | |||
asm volatile( | |||
"{\n" | |||
" .reg .pred p;\n" | |||
" setp.ne.b32 p, %5, 0;\n" | |||
" @p st.global.v4.u32 [%0], {%1, %2, %3, %4};\n" | |||
"}\n" | |||
: | |||
: "l"(ptr), "r"(data.x), "r"(data.y), "r"(data.z), "r"(data.w), | |||
"r"((int)pred_guard)); | |||
} | |||
}; | |||
template <typename AccessType> | |||
struct global_store<AccessType, 8> { | |||
devfunc global_store(AccessType const& D, void* ptr, bool pred_guard) { | |||
uint2 const& data = reinterpret_cast<uint2 const&>(D); | |||
asm volatile( | |||
"{\n" | |||
" .reg .pred p;\n" | |||
" setp.ne.b32 p, %3, 0;\n" | |||
" @p st.global.v2.u32 [%0], {%1, %2};\n" | |||
"}\n" | |||
: | |||
: "l"(ptr), "r"(data.x), "r"(data.y), "r"((int)pred_guard)); | |||
} | |||
}; | |||
template <typename AccessType> | |||
struct global_store<AccessType, 4> { | |||
devfunc global_store(AccessType const& D, void* ptr, bool pred_guard) { | |||
uint32_t const& data = reinterpret_cast<uint32_t const&>(D); | |||
asm volatile( | |||
"{\n" | |||
" .reg .pred p;\n" | |||
" setp.ne.b32 p, %2, 0;\n" | |||
" @p st.global.u32 [%0], %1;\n" | |||
"}\n" | |||
: | |||
: "l"(ptr), "r"(data), "r"((int)pred_guard)); | |||
} | |||
}; | |||
template <typename AccessType> | |||
struct global_store<AccessType, 2> { | |||
devfunc global_store(AccessType const& D, void* ptr, bool pred_guard) { | |||
uint16_t const& data = reinterpret_cast<uint16_t const&>(D); | |||
asm volatile( | |||
"{\n" | |||
" .reg .pred p;\n" | |||
" setp.ne.b32 p, %2, 0;\n" | |||
" @p st.global.u16 [%0], %1;\n" | |||
"}\n" | |||
: | |||
: "l"(ptr), "h"(data), "r"((int)pred_guard)); | |||
} | |||
}; | |||
template <typename AccessType> | |||
struct global_store<AccessType, 1> { | |||
devfunc global_store(AccessType const& D, void* ptr, bool pred_guard) { | |||
if (pred_guard) | |||
*(reinterpret_cast<AccessType*>(ptr)) = D; | |||
} | |||
}; | |||
#undef devfunc | |||
} // namespace relayout_format | |||
} // namespace cuda | |||
} // namespace megdnn |
@@ -39,6 +39,20 @@ void relayout_format_cuda_nchwx_nchw(const TensorND& src, const TensorND& dst, | |||
const uint8_t src_zero_point = 0, | |||
const uint8_t dst_zero_point = 0); | |||
void relayout_format_cuda_nchw_nhwc(const TensorND& src, const TensorND& dst, | |||
const cudaStream_t& stream, | |||
const float src_scale = 1.f, | |||
const float dst_scale = 1.f, | |||
const uint8_t src_zero_point = 0, | |||
const uint8_t dst_zero_point = 0); | |||
void relayout_format_cuda_nhwc_nchw(const TensorND& src, const TensorND& dst, | |||
const cudaStream_t& stream, | |||
const float src_scale = 1.f, | |||
const float dst_scale = 1.f, | |||
const uint8_t src_zero_point = 0, | |||
const uint8_t dst_zero_point = 0); | |||
void relayout_format_cuda_nchw_nchw4_weight(const TensorND& src, | |||
const TensorND& dst, | |||
const cudaStream_t& stream); | |||
@@ -0,0 +1,346 @@ | |||
/** | |||
* \file dnn/src/cuda/relayout_format/relayout_format_kern.cuh | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
* implied. | |||
*/ | |||
#pragma once | |||
#include "src/cuda/int_fastdiv.cuh" | |||
#include "src/cuda/memory_utils.cuh" | |||
#include "src/cuda/relayout_format/translayout.cuh" | |||
namespace megdnn { | |||
namespace cuda { | |||
namespace relayout_format { | |||
namespace internal { | |||
using namespace memory; | |||
template <typename Type_, int pack_size_, int chan_blk_, int width_, | |||
int size_nbits_> | |||
class TensorIteratorOverChannel { | |||
public: | |||
using Type = Type_; | |||
static constexpr int pack_size = pack_size_; | |||
static constexpr int chan_blk = chan_blk_; | |||
static constexpr int width = width_; | |||
static constexpr int size_nbits = size_nbits_; | |||
static constexpr int elements_in_type = | |||
chan_blk * width * size_nbits / (8 * sizeof(Type)); | |||
static constexpr int lane_size_in_type = | |||
(width * pack_size * size_nbits) / (8 * sizeof(Type)); | |||
static constexpr int pack_size_in_type = | |||
(pack_size * size_nbits) >= (8 * sizeof(Type)) | |||
? (pack_size * size_nbits / (8 * sizeof(Type))) | |||
: (width * pack_size * size_nbits / (8 * sizeof(Type))); | |||
static constexpr int pack_size_in_byte = pack_size_in_type * sizeof(Type); | |||
using AccessType = array_wrapper<Type, pack_size_in_type>; | |||
using Fragment = array_wrapper<Type, elements_in_type>; | |||
MEGDNN_HOST TensorIteratorOverChannel() | |||
: pointer{nullptr}, chan_stride_in_elements{0}, channel{0} {} | |||
MEGDNN_HOST TensorIteratorOverChannel(Type* pointer_, | |||
int chan_stride_in_elements_, | |||
int channel_, int, int) | |||
: pointer{pointer_}, | |||
chan_stride_in_elements{chan_stride_in_elements_}, | |||
channel{channel_} {} | |||
MEGDNN_DEVICE __forceinline__ void initialize(int c_idx, int hw_idx) { | |||
pointer += (c_idx / pack_size) * chan_stride_in_elements + | |||
hw_idx * pack_size * size_nbits / (8 * sizeof(Type)); | |||
channel -= c_idx; | |||
} | |||
MEGDNN_DEVICE __forceinline__ void add_pointer_offset( | |||
size_t offset_in_type) { | |||
pointer += offset_in_type; | |||
} | |||
MEGDNN_DEVICE __forceinline__ void load(Fragment& frag, int zero_point) { | |||
AccessType* frag_ptr = reinterpret_cast<AccessType*>(&frag); | |||
Type* pointer_ = pointer; | |||
#pragma unroll | |||
for (int i = 0; i < chan_blk; i += pack_size) { | |||
#pragma unroll | |||
for (int j = 0; j < lane_size_in_type / pack_size_in_type; j++) { | |||
int frag_idx = i / pack_size * | |||
(lane_size_in_type / pack_size_in_type) + | |||
j; | |||
bool guard = i < channel; | |||
global_load<AccessType, pack_size_in_byte>( | |||
frag_ptr[frag_idx], | |||
reinterpret_cast<void*>(pointer_ + | |||
j * pack_size_in_type), | |||
guard, zero_point); | |||
} | |||
pointer_ += chan_stride_in_elements; | |||
} | |||
} | |||
MEGDNN_DEVICE __forceinline__ void store(const Fragment& frag) { | |||
const AccessType* frag_ptr = reinterpret_cast<const AccessType*>(&frag); | |||
Type* pointer_ = pointer; | |||
#pragma unroll | |||
for (int i = 0; i < chan_blk; i += pack_size) { | |||
#pragma unroll | |||
for (int j = 0; j < lane_size_in_type / pack_size_in_type; j++) { | |||
int frag_idx = i / pack_size * | |||
(lane_size_in_type / pack_size_in_type) + | |||
j; | |||
bool guard = i < channel; | |||
global_store<AccessType, pack_size_in_byte>( | |||
frag_ptr[frag_idx], | |||
reinterpret_cast<void*>(pointer_ + | |||
j * pack_size_in_type), | |||
guard); | |||
} | |||
pointer_ += chan_stride_in_elements; | |||
} | |||
} | |||
MEGDNN_DEVICE __forceinline__ void advance() { | |||
pointer += (chan_blk / pack_size) * chan_stride_in_elements; | |||
channel -= chan_blk; | |||
} | |||
private: | |||
Type* pointer; | |||
int chan_stride_in_elements; | |||
int channel; | |||
}; | |||
template <typename Type_, int pack_size_, int chan_blk_, int width_, | |||
int size_nbits_> | |||
class MaskedTensorIteratorOverChannel { | |||
public: | |||
using Type = Type_; | |||
static constexpr int pack_size = pack_size_; | |||
static constexpr int chan_blk = chan_blk_; | |||
static constexpr int width = width_; | |||
static constexpr int size_nbits = size_nbits_; | |||
static constexpr int elements_in_type = | |||
chan_blk * width * size_nbits / (8 * sizeof(Type)); | |||
static constexpr int lane_size_in_type = | |||
(width * pack_size * size_nbits) / (8 * sizeof(Type)); | |||
static constexpr int pack_size_in_type = | |||
(pack_size * size_nbits) >= (8 * sizeof(Type)) | |||
? (pack_size * size_nbits / (8 * sizeof(Type))) | |||
: (width * pack_size * size_nbits / (8 * sizeof(Type))); | |||
static constexpr int pack_size_in_byte = pack_size_in_type * sizeof(Type); | |||
static constexpr int accesses = elements_in_type / pack_size_in_type; | |||
static constexpr int mask_size = (accesses + 32 - 1) / 32; | |||
using AccessType = array_wrapper<Type, pack_size_in_type>; | |||
using Fragment = array_wrapper<Type, elements_in_type>; | |||
MEGDNN_HOST MaskedTensorIteratorOverChannel() | |||
: pointer{nullptr}, chan_stride_in_elements{0}, channel{0} {} | |||
MEGDNN_HOST MaskedTensorIteratorOverChannel(Type* pointer_, | |||
int chan_stride_in_elements_, | |||
int channel_, int bound_, | |||
int div_) | |||
: pointer{pointer_}, | |||
chan_stride_in_elements{chan_stride_in_elements_}, | |||
channel{channel_}, | |||
bound{bound_}, | |||
div{uint32_t(div_)} {} | |||
MEGDNN_DEVICE __forceinline__ void initialize(int c_idx, int hw_idx) { | |||
pointer += (c_idx / pack_size) * chan_stride_in_elements; | |||
channel -= c_idx; | |||
int w[lane_size_in_type / pack_size_in_type]; | |||
#pragma unroll | |||
for (int i = 0; i < mask_size; ++i) { | |||
mask[i] = 0; | |||
} | |||
#pragma unroll | |||
for (int j = 0; j < lane_size_in_type / pack_size_in_type; j++) { | |||
int offset = hw_idx + j; | |||
int h = (int)((uint32_t)(offset) / div); | |||
w[j] = (int)((uint32_t)(offset) % div); | |||
stride[j] = (h * bound + w[j]) * pack_size * size_nbits / | |||
(8 * sizeof(Type)); | |||
} | |||
#pragma unroll | |||
for (int i = 0; i < chan_blk; i += pack_size) { | |||
#pragma unroll | |||
for (int j = 0; j < lane_size_in_type / pack_size_in_type; j++) { | |||
bool guard = (i < channel) && (w[j] < bound); | |||
int index = (i / pack_size) * | |||
(lane_size_in_type / pack_size_in_type) + | |||
j; | |||
int mask_index = (index >> 5); | |||
int mask_shift = (index & 0x1f); | |||
mask[mask_index] |= (guard << mask_shift); | |||
} | |||
} | |||
} | |||
MEGDNN_DEVICE __forceinline__ void add_pointer_offset( | |||
size_t offset_in_type) { | |||
pointer += offset_in_type; | |||
} | |||
MEGDNN_DEVICE __forceinline__ void load(Fragment& frag, int zero_point) { | |||
AccessType* frag_ptr = reinterpret_cast<AccessType*>(&frag); | |||
Type* pointer_ = pointer; | |||
#pragma unroll | |||
for (int i = 0; i < chan_blk; i += pack_size) { | |||
#pragma unroll | |||
for (int j = 0; j < lane_size_in_type / pack_size_in_type; j++) { | |||
int frag_idx = i / pack_size * | |||
(lane_size_in_type / pack_size_in_type) + | |||
j; | |||
int mask_index = (frag_idx >> 5); | |||
int mask_shift = (frag_idx & 0x1f); | |||
bool guard = (mask[mask_index] & (1 << mask_shift)); | |||
global_load<AccessType, pack_size_in_byte>( | |||
frag_ptr[frag_idx], | |||
reinterpret_cast<void*>(pointer_ + stride[j]), guard, | |||
zero_point); | |||
} | |||
pointer_ += chan_stride_in_elements; | |||
} | |||
} | |||
MEGDNN_DEVICE __forceinline__ void store(const Fragment& frag) { | |||
const AccessType* frag_ptr = reinterpret_cast<const AccessType*>(&frag); | |||
Type* pointer_ = pointer; | |||
#pragma unroll | |||
for (int i = 0; i < chan_blk; i += pack_size) { | |||
#pragma unroll | |||
for (int j = 0; j < lane_size_in_type / pack_size_in_type; j++) { | |||
int frag_idx = i / pack_size * | |||
(lane_size_in_type / pack_size_in_type) + | |||
j; | |||
int mask_index = (frag_idx >> 5); | |||
int mask_shift = (frag_idx & 0x1f); | |||
bool guard = (mask[mask_index] & (1 << mask_shift)); | |||
global_store<AccessType, pack_size_in_byte>( | |||
frag_ptr[frag_idx], | |||
reinterpret_cast<void*>(pointer_ + stride[j]), guard); | |||
} | |||
pointer_ += chan_stride_in_elements; | |||
} | |||
} | |||
MEGDNN_DEVICE __forceinline__ void advance() { | |||
pointer += (chan_blk / pack_size) * chan_stride_in_elements; | |||
channel -= chan_blk; | |||
} | |||
private: | |||
Type* pointer; | |||
int chan_stride_in_elements; | |||
int channel; | |||
int bound; | |||
Uint32Fastdiv div; | |||
uint32_t mask[mask_size]; | |||
size_t stride[lane_size_in_type / pack_size_in_type]; | |||
}; | |||
template <bool padding_, typename Type_, int pack_size_, int chan_blk_, | |||
int width_, int size_nbits_> | |||
struct TensorIteratorPolicy; | |||
template <typename Type_, int pack_size_, int chan_blk_, int width_, | |||
int size_nbits_> | |||
struct TensorIteratorPolicy<true, Type_, pack_size_, chan_blk_, width_, | |||
size_nbits_> { | |||
using TensorIterator = | |||
MaskedTensorIteratorOverChannel<Type_, pack_size_, chan_blk_, | |||
width_, size_nbits_>; | |||
}; | |||
template <typename Type_, int pack_size_, int chan_blk_, int width_, | |||
int size_nbits_> | |||
struct TensorIteratorPolicy<false, Type_, pack_size_, chan_blk_, width_, | |||
size_nbits_> { | |||
using TensorIterator = | |||
TensorIteratorOverChannel<Type_, pack_size_, chan_blk_, width_, | |||
size_nbits_>; | |||
}; | |||
template <typename SrcIterator_, typename DstIterator_, typename Transpose_, | |||
typename CudaPostProcess_> | |||
struct RelayoutProblem { | |||
using SrcIterator = SrcIterator_; | |||
using DstIterator = DstIterator_; | |||
using Transpose = Transpose_; | |||
using CudaPostProcess = CudaPostProcess_; | |||
MEGDNN_STATIC_ASSERT(SrcIterator::chan_blk == DstIterator::chan_blk, | |||
"channel block mismatch"); | |||
MEGDNN_STATIC_ASSERT(SrcIterator::width == DstIterator::width, | |||
"width block mismatch"); | |||
MEGDNN_STATIC_ASSERT(SrcIterator::size_nbits == DstIterator::size_nbits, | |||
"size in bits of elements mismatch"); | |||
static constexpr int pack_chan = SrcIterator::chan_blk; | |||
static constexpr int pack_width = SrcIterator::width; | |||
using DnnSrcType = typename CudaPostProcess::SrcType; | |||
using DnnDstType = typename CudaPostProcess::DstType; | |||
struct Param { | |||
SrcIterator src_iterator; | |||
DstIterator dst_iterator; | |||
CudaPostProcess post_process; | |||
int n_stride_src; | |||
int n_stride_dst; | |||
int batch_size; | |||
int channels; | |||
int hw; | |||
int zero_point; | |||
MEGDNN_HOST MEGDNN_DEVICE Param(SrcIterator src_iterator_, | |||
DstIterator dst_iterator_, | |||
CudaPostProcess post_process_, | |||
int n_stride_src_, int n_stride_dst_, | |||
int batch_size_, int channels_, int hw_, | |||
int zero_point_) | |||
: src_iterator{src_iterator_}, | |||
dst_iterator{dst_iterator_}, | |||
post_process{post_process_}, | |||
n_stride_src{n_stride_src_}, | |||
n_stride_dst{n_stride_dst_}, | |||
batch_size{batch_size_}, | |||
channels{channels_}, | |||
hw{hw_}, | |||
zero_point{zero_point_} {} | |||
}; | |||
}; | |||
template <typename RelayoutProblem_> | |||
__global__ void relayout_kern(typename RelayoutProblem_::Param param) { | |||
using SrcIterator = typename RelayoutProblem_::SrcIterator; | |||
using DstIterator = typename RelayoutProblem_::DstIterator; | |||
static constexpr int pack_chan = RelayoutProblem_::pack_chan; | |||
static constexpr int pack_width = RelayoutProblem_::pack_width; | |||
const int thread_idx = blockIdx.x * blockDim.x + threadIdx.x; | |||
const int thread_offset = thread_idx * pack_width; | |||
const int hw_idx = (thread_offset % param.hw); | |||
const int nc_blks = thread_offset / param.hw; | |||
const int c_blks = (param.channels + pack_chan - 1) / pack_chan; | |||
const int n_idx = nc_blks / c_blks; | |||
const int c_blk_idx = nc_blks % c_blks; | |||
const int c_idx = c_blk_idx * pack_chan; | |||
if (n_idx < param.batch_size) { | |||
const int src_offset = n_idx * param.n_stride_src; | |||
const int dst_offset = n_idx * param.n_stride_dst; | |||
param.src_iterator.add_pointer_offset(src_offset); | |||
param.dst_iterator.add_pointer_offset(dst_offset); | |||
param.src_iterator.initialize(c_idx, hw_idx); | |||
param.dst_iterator.initialize(c_idx, hw_idx); | |||
typename SrcIterator::Fragment src_frag; | |||
typename DstIterator::Fragment dst_frag; | |||
int zp = make_zero<SrcIterator::size_nbits>(param.zero_point); | |||
param.src_iterator.load(src_frag, zp); | |||
RelayoutProblem_::Transpose::trans( | |||
reinterpret_cast<typename SrcIterator::Fragment&>(dst_frag), | |||
src_frag, param.post_process); | |||
param.dst_iterator.store(dst_frag); | |||
} | |||
} | |||
} // namespace internal | |||
} // namespace relayout_format | |||
} // namespace cuda | |||
} // namespace megdnn |
@@ -0,0 +1,128 @@ | |||
/** | |||
* \file dnn/src/cuda/relayout_format/relayout_format_utils.cuh | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
* implied. | |||
*/ | |||
#pragma once | |||
#include "src/cuda/integer_subbyte_utils.cuh" | |||
#include "src/cuda/relayout_format/relayout_format.cuh" | |||
namespace megdnn { | |||
namespace cuda { | |||
namespace relayout_format { | |||
namespace internal { | |||
template <typename cype, int pack_w, typename enable = void> | |||
struct DTypeRWHelper; | |||
template <typename ctype> | |||
struct DTypeRWHelper< | |||
ctype, 1, | |||
typename std::enable_if<std::is_same<ctype, dt_qint8>::value || | |||
std::is_same<ctype, dt_quint8>::value || | |||
std::is_same<ctype, dt_uint8>::value>::type> { | |||
using InnerDtype = char; | |||
using DstDtype = char4; | |||
}; | |||
template <typename ctype> | |||
struct DTypeRWHelper< | |||
ctype, 4, | |||
typename std::enable_if<std::is_same<ctype, dt_qint8>::value || | |||
std::is_same<ctype, dt_quint8>::value || | |||
std::is_same<ctype, dt_uint8>::value>::type> { | |||
using InnerDtype = char4; | |||
using DstDtype = char4; | |||
}; | |||
template <> | |||
struct DTypeRWHelper<dt_qint32, 1> { | |||
using InnerDtype = int; | |||
using DstDtype = int4; | |||
}; | |||
template <> | |||
struct DTypeRWHelper<dt_qint32, 4> { | |||
using InnerDtype = int4; | |||
using DstDtype = int4; | |||
}; | |||
template <typename ctype> | |||
struct DTypeRWHelper< | |||
ctype, 2, | |||
typename std::enable_if<std::is_same<ctype, dt_qint4>::value || | |||
std::is_same<ctype, dt_quint4>::value>::type> { | |||
using InnerDtype = char; | |||
using DstDtype = array_wrapper<uint8_t, 32>; | |||
}; | |||
template <typename ctype> | |||
struct DTypeRWHelper< | |||
ctype, 8, | |||
typename std::enable_if<std::is_same<ctype, dt_qint4>::value || | |||
std::is_same<ctype, dt_quint4>::value>::type> { | |||
using InnerDtype = unsigned; | |||
using DstDtype = array_wrapper<uint8_t, 32>; | |||
}; | |||
template <typename DstType> | |||
inline __device__ DstType make_zero_pad(const uint8_t zero_point) { | |||
return zero_point; | |||
} | |||
template <> | |||
inline __device__ char4 make_zero_pad<char4>(const uint8_t zero_point) { | |||
char izp = reinterpret_cast<const char&>(zero_point); | |||
return {izp, izp, izp, izp}; | |||
} | |||
template <> | |||
inline __device__ int4 make_zero_pad<int4>(const uint8_t zero_point) { | |||
return {zero_point, zero_point, zero_point, zero_point}; | |||
} | |||
template <int size_nbits> | |||
inline __device__ int make_zero(int zero_point); | |||
template <> | |||
inline __device__ int make_zero<4>(int zero_point) { | |||
return integer_subbyte::transform_int8_to_uint4x8( | |||
zero_point, zero_point, zero_point, zero_point, zero_point, | |||
zero_point, zero_point, zero_point); | |||
} | |||
template <typename DstDtype> | |||
inline __device__ void write_helper(DstDtype* ptr, DstDtype val) { | |||
*ptr = val; | |||
} | |||
template <> | |||
inline __device__ void write_helper<char4>(char4* ptr, char4 val) { | |||
int32_t* rel_ptr = (int32_t*)ptr; | |||
*rel_ptr = *(int32_t*)(&val); | |||
} | |||
template <> | |||
inline __device__ void write_helper<array_wrapper<uint8_t, 32>>( | |||
array_wrapper<uint8_t, 32>* ptr, array_wrapper<uint8_t, 32> val) { | |||
uint4 const* data = reinterpret_cast<uint4 const*>(&val); | |||
void* ptr_ = reinterpret_cast<void*>(ptr); | |||
asm volatile( | |||
"{\n" | |||
" st.global.v4.u32 [%0], {%1, %2, %3, %4};\n" | |||
" st.global.v4.u32 [%5], {%6, %7, %8, %9};\n" | |||
"}\n" | |||
: | |||
: "l"(ptr_), "r"(data[0].x), "r"(data[0].y), "r"(data[0].z), | |||
"r"(data[0].w), "l"(((uint8_t*)ptr_) + 16), "r"(data[1].x), | |||
"r"(data[1].y), "r"(data[1].z), "r"(data[1].w)); | |||
} | |||
} // namespace internal | |||
} // namespace relayout_format | |||
} // namespace cuda | |||
} // namespace megdnn |
@@ -0,0 +1,537 @@ | |||
/** | |||
* \file dnn/src/cuda/relayout_format/translayout.cuh | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
* implied. | |||
*/ | |||
#pragma once | |||
#include "src/cuda/integer_subbyte_utils.cuh" | |||
#include "src/cuda/relayout_format/cuda_post_process.cuh" | |||
#include "src/cuda/relayout_format/relayout_format.cuh" | |||
#include "src/cuda/relayout_format/relayout_format_utils.cuh" | |||
namespace megdnn { | |||
namespace cuda { | |||
namespace relayout_format { | |||
namespace internal { | |||
using namespace integer_subbyte; | |||
template <typename dt> | |||
struct qtype_signedness; | |||
template <> | |||
struct qtype_signedness<dtype::QuantizedS4> { | |||
static constexpr bool value = true; | |||
}; | |||
template <> | |||
struct qtype_signedness<dtype::Quantized4Asymm> { | |||
static constexpr bool value = false; | |||
}; | |||
template <typename dt_src, typename dt_dst> | |||
struct enable_qtype_b4 { | |||
static constexpr bool val_src = | |||
std::is_same<dt_src, dtype::QuantizedS4>::value || | |||
std::is_same<dt_src, dtype::Quantized4Asymm>::value; | |||
static constexpr bool val_dst = | |||
std::is_same<dt_dst, dtype::QuantizedS4>::value || | |||
std::is_same<dt_dst, dtype::Quantized4Asymm>::value; | |||
using type = typename std::enable_if<std::is_same<dt_src, dt_dst>::value && | |||
val_src && val_dst>::type; | |||
}; | |||
// The input fragment is stored in RowMajor order. The translayout operator | |||
// performs a transpose operation on the input fragment, and produces a | |||
// reordered fragment, i.e. a fragment stored in ColumnMajor order. | |||
template <int col, int row, typename SrcType, typename DnnSrcType, | |||
typename DnnDstType, bool same_scale, typename enable = void> | |||
struct Translayout; | |||
// partial specialization for translayout operator for qint8 and quint8 | |||
template <typename SrcType, typename DnnSrcType, typename DnnDstType, | |||
bool same_scale> | |||
struct Translayout<1, 4, SrcType, DnnSrcType, DnnDstType, same_scale> { | |||
using InnerDtype = | |||
typename DTypeRWHelper<typename DTypeTrait<DnnSrcType>::ctype, | |||
1>::InnerDtype; | |||
using DstDtype = | |||
typename DTypeRWHelper<typename DTypeTrait<DnnSrcType>::ctype, | |||
1>::DstDtype; | |||
static inline __device__ void trans( | |||
DstDtype (&dst_width)[1], InnerDtype (&read_channel)[4], | |||
CudaPostProcess<DnnSrcType, DnnDstType, same_scale>& post_process, | |||
const char zero_point) { | |||
dst_width[0].x = post_process(read_channel[0]); | |||
dst_width[0].y = post_process(read_channel[1]); | |||
dst_width[0].z = post_process(read_channel[2]); | |||
dst_width[0].w = post_process(read_channel[3]); | |||
} | |||
}; | |||
template <typename SrcType, typename DnnSrcType, typename DnnDstType, | |||
bool same_scale> | |||
struct Translayout<4, 4, SrcType, DnnSrcType, DnnDstType, same_scale> { | |||
using InnerDtype = | |||
typename DTypeRWHelper<typename DTypeTrait<DnnSrcType>::ctype, | |||
4>::InnerDtype; | |||
using DstDtype = | |||
typename DTypeRWHelper<typename DTypeTrait<DnnSrcType>::ctype, | |||
4>::DstDtype; | |||
static inline __device__ void trans( | |||
DstDtype (&dst_width)[4], InnerDtype (&read_channel)[4], | |||
CudaPostProcess<DnnSrcType, DnnDstType, same_scale>& post_process, | |||
const char zero_point) { | |||
dst_width[0].x = post_process(read_channel[0].x); | |||
dst_width[0].y = post_process(read_channel[1].x); | |||
dst_width[0].z = post_process(read_channel[2].x); | |||
dst_width[0].w = post_process(read_channel[3].x); | |||
dst_width[1].x = post_process(read_channel[0].y); | |||
dst_width[1].y = post_process(read_channel[1].y); | |||
dst_width[1].z = post_process(read_channel[2].y); | |||
dst_width[1].w = post_process(read_channel[3].y); | |||
dst_width[2].x = post_process(read_channel[0].z); | |||
dst_width[2].y = post_process(read_channel[1].z); | |||
dst_width[2].z = post_process(read_channel[2].z); | |||
dst_width[2].w = post_process(read_channel[3].z); | |||
dst_width[3].x = post_process(read_channel[0].w); | |||
dst_width[3].y = post_process(read_channel[1].w); | |||
dst_width[3].z = post_process(read_channel[2].w); | |||
dst_width[3].w = post_process(read_channel[3].w); | |||
} | |||
}; | |||
// ========================================================= | |||
// partial specialization for translayout operator for qint4 | |||
// NCHW <-> NCHW64 | |||
template <typename SrcType, typename DnnSrcType_, typename DnnDstType_, | |||
bool same_scale> | |||
struct Translayout<2, 64, SrcType, DnnSrcType_, DnnDstType_, same_scale, | |||
typename enable_qtype_b4<DnnSrcType_, DnnDstType_>::type> { | |||
using DnnSrcType = DnnSrcType_; | |||
using DnnDstType = DnnDstType_; | |||
using InnerDtype = | |||
typename DTypeRWHelper<typename DTypeTrait<DnnSrcType>::ctype, | |||
2>::InnerDtype; | |||
using DstDtype = | |||
typename DTypeRWHelper<typename DTypeTrait<DnnSrcType>::ctype, | |||
2>::DstDtype; | |||
static constexpr bool signedness = qtype_signedness<DnnSrcType>::value; | |||
static inline __device__ void trans( | |||
DstDtype (&dst_width)[2], InnerDtype (&read_channel)[64], | |||
CudaPostProcess<DnnSrcType, DnnDstType, same_scale>& post_process, | |||
const char zero_point) { | |||
int intermediate[8][2]; | |||
int* dst_frag = reinterpret_cast<int*>(dst_width); | |||
auto pack_channel = [&](int idx) -> int { | |||
return transform_int8_to_b4x8<signedness>( | |||
post_process(intermediate[0][idx]), | |||
post_process(intermediate[1][idx]), | |||
post_process(intermediate[2][idx]), | |||
post_process(intermediate[3][idx]), | |||
post_process(intermediate[4][idx]), | |||
post_process(intermediate[5][idx]), | |||
post_process(intermediate[6][idx]), | |||
post_process(intermediate[7][idx])); | |||
}; | |||
#pragma unroll | |||
for (int i = 0; i < 64; i += 8) { | |||
transform_b4x2_to_int8<signedness>( | |||
intermediate[0], | |||
reinterpret_cast<uint8_t&>(read_channel[i + 0])); | |||
transform_b4x2_to_int8<signedness>( | |||
intermediate[1], | |||
reinterpret_cast<uint8_t&>(read_channel[i + 1])); | |||
transform_b4x2_to_int8<signedness>( | |||
intermediate[2], | |||
reinterpret_cast<uint8_t&>(read_channel[i + 2])); | |||
transform_b4x2_to_int8<signedness>( | |||
intermediate[3], | |||
reinterpret_cast<uint8_t&>(read_channel[i + 3])); | |||
transform_b4x2_to_int8<signedness>( | |||
intermediate[4], | |||
reinterpret_cast<uint8_t&>(read_channel[i + 4])); | |||
transform_b4x2_to_int8<signedness>( | |||
intermediate[5], | |||
reinterpret_cast<uint8_t&>(read_channel[i + 5])); | |||
transform_b4x2_to_int8<signedness>( | |||
intermediate[6], | |||
reinterpret_cast<uint8_t&>(read_channel[i + 6])); | |||
transform_b4x2_to_int8<signedness>( | |||
intermediate[7], | |||
reinterpret_cast<uint8_t&>(read_channel[i + 7])); | |||
int frag_idx = i / 8; | |||
dst_frag[0 * 8 + frag_idx] = pack_channel(0); | |||
dst_frag[1 * 8 + frag_idx] = pack_channel(1); | |||
} | |||
} | |||
using Fragment = array_wrapper<SrcType, 64>; | |||
static inline __device__ void trans( | |||
Fragment& dst, Fragment& src, | |||
CudaPostProcess<DnnSrcType, DnnDstType, same_scale>& post_process) { | |||
trans(reinterpret_cast<DstDtype(&)[2]>(dst), | |||
reinterpret_cast<InnerDtype(&)[64]>(src), post_process, 0); | |||
} | |||
}; | |||
template <typename SrcType, typename DnnSrcType_, typename DnnDstType_, | |||
bool same_scale> | |||
struct Translayout<8, 64, SrcType, DnnSrcType_, DnnDstType_, same_scale, | |||
typename enable_qtype_b4<DnnSrcType_, DnnDstType_>::type> { | |||
using DnnSrcType = DnnSrcType_; | |||
using DnnDstType = DnnDstType_; | |||
using InnerDtype = | |||
typename DTypeRWHelper<typename DTypeTrait<DnnSrcType>::ctype, | |||
8>::InnerDtype; | |||
using DstDtype = | |||
typename DTypeRWHelper<typename DTypeTrait<DnnSrcType>::ctype, | |||
8>::DstDtype; | |||
static constexpr bool signedness = qtype_signedness<DnnSrcType>::value; | |||
static inline __device__ void trans( | |||
DstDtype (&dst_width)[8], InnerDtype (&read_channel)[64], | |||
CudaPostProcess<DnnSrcType, DnnDstType, same_scale>& post_process, | |||
const char zero_point) { | |||
int intermediate[8][8]; | |||
int* dst_frag = reinterpret_cast<int*>(dst_width); | |||
auto pack_channel = [&](int idx) -> int { | |||
return transform_int8_to_b4x8<signedness>( | |||
post_process(intermediate[0][idx]), | |||
post_process(intermediate[1][idx]), | |||
post_process(intermediate[2][idx]), | |||
post_process(intermediate[3][idx]), | |||
post_process(intermediate[4][idx]), | |||
post_process(intermediate[5][idx]), | |||
post_process(intermediate[6][idx]), | |||
post_process(intermediate[7][idx])); | |||
}; | |||
#pragma unroll | |||
for (int i = 0; i < 64; i += 8) { | |||
transform_b4x8_to_int8<signedness>(intermediate[0], | |||
read_channel[i + 0]); | |||
transform_b4x8_to_int8<signedness>(intermediate[1], | |||
read_channel[i + 1]); | |||
transform_b4x8_to_int8<signedness>(intermediate[2], | |||
read_channel[i + 2]); | |||
transform_b4x8_to_int8<signedness>(intermediate[3], | |||
read_channel[i + 3]); | |||
transform_b4x8_to_int8<signedness>(intermediate[4], | |||
read_channel[i + 4]); | |||
transform_b4x8_to_int8<signedness>(intermediate[5], | |||
read_channel[i + 5]); | |||
transform_b4x8_to_int8<signedness>(intermediate[6], | |||
read_channel[i + 6]); | |||
transform_b4x8_to_int8<signedness>(intermediate[7], | |||
read_channel[i + 7]); | |||
int frag_idx = i / 8; | |||
dst_frag[0 * 8 + frag_idx] = pack_channel(0); | |||
dst_frag[1 * 8 + frag_idx] = pack_channel(1); | |||
dst_frag[2 * 8 + frag_idx] = pack_channel(2); | |||
dst_frag[3 * 8 + frag_idx] = pack_channel(3); | |||
dst_frag[4 * 8 + frag_idx] = pack_channel(4); | |||
dst_frag[5 * 8 + frag_idx] = pack_channel(5); | |||
dst_frag[6 * 8 + frag_idx] = pack_channel(6); | |||
dst_frag[7 * 8 + frag_idx] = pack_channel(7); | |||
} | |||
} | |||
using Fragment = array_wrapper<unsigned, 64>; | |||
static inline __device__ void trans( | |||
Fragment& dst, Fragment& src, | |||
CudaPostProcess<DnnSrcType, DnnDstType, same_scale>& post_process) { | |||
trans(reinterpret_cast<DstDtype(&)[8]>(dst), | |||
reinterpret_cast<InnerDtype(&)[64]>(src), post_process, 0); | |||
} | |||
}; | |||
template <typename SrcType, typename DnnSrcType_, typename DnnDstType_, | |||
bool same_scale> | |||
struct Translayout<64, 8, SrcType, DnnSrcType_, DnnDstType_, same_scale, | |||
typename enable_qtype_b4<DnnSrcType_, DnnDstType_>::type> { | |||
using DnnSrcType = DnnSrcType_; | |||
using DnnDstType = DnnDstType_; | |||
static constexpr int row = 8; | |||
static constexpr int col = 64; | |||
static constexpr int size_nbits = 4; | |||
static constexpr int col_in_type = col * size_nbits / (8 * sizeof(SrcType)); | |||
static constexpr int elements_in_type = row * col_in_type; | |||
static constexpr int inc_col = 8; | |||
static constexpr int inc_col_in_type = | |||
inc_col * size_nbits / (8 * sizeof(SrcType)); | |||
static constexpr bool signedness = qtype_signedness<DnnSrcType>::value; | |||
using Fragment = array_wrapper<SrcType, elements_in_type>; | |||
static MEGDNN_DEVICE __forceinline__ void trans( | |||
Fragment& dst, const Fragment& src, | |||
CudaPostProcess<DnnSrcType, DnnDstType, same_scale>& post_process) { | |||
int intermediate[8][8]; | |||
int* dst_frag = reinterpret_cast<int*>(&dst); | |||
auto pack = [&](int idx) -> int { | |||
return transform_int8_to_b4x8<signedness>( | |||
post_process(intermediate[0][idx]), | |||
post_process(intermediate[1][idx]), | |||
post_process(intermediate[2][idx]), | |||
post_process(intermediate[3][idx]), | |||
post_process(intermediate[4][idx]), | |||
post_process(intermediate[5][idx]), | |||
post_process(intermediate[6][idx]), | |||
post_process(intermediate[7][idx])); | |||
}; | |||
#pragma unroll | |||
for (int j = 0; j < col_in_type; j += inc_col_in_type) { | |||
transform_b4x8_to_int8<signedness>( | |||
intermediate[0], | |||
reinterpret_cast<const int&>(src[0 * col_in_type + j])); | |||
transform_b4x8_to_int8<signedness>( | |||
intermediate[1], | |||
reinterpret_cast<const int&>(src[1 * col_in_type + j])); | |||
transform_b4x8_to_int8<signedness>( | |||
intermediate[2], | |||
reinterpret_cast<const int&>(src[2 * col_in_type + j])); | |||
transform_b4x8_to_int8<signedness>( | |||
intermediate[3], | |||
reinterpret_cast<const int&>(src[3 * col_in_type + j])); | |||
transform_b4x8_to_int8<signedness>( | |||
intermediate[4], | |||
reinterpret_cast<const int&>(src[4 * col_in_type + j])); | |||
transform_b4x8_to_int8<signedness>( | |||
intermediate[5], | |||
reinterpret_cast<const int&>(src[5 * col_in_type + j])); | |||
transform_b4x8_to_int8<signedness>( | |||
intermediate[6], | |||
reinterpret_cast<const int&>(src[6 * col_in_type + j])); | |||
transform_b4x8_to_int8<signedness>( | |||
intermediate[7], | |||
reinterpret_cast<const int&>(src[7 * col_in_type + j])); | |||
dst_frag[(j / inc_col_in_type) * 8 + 0] = pack(0); | |||
dst_frag[(j / inc_col_in_type) * 8 + 1] = pack(1); | |||
dst_frag[(j / inc_col_in_type) * 8 + 2] = pack(2); | |||
dst_frag[(j / inc_col_in_type) * 8 + 3] = pack(3); | |||
dst_frag[(j / inc_col_in_type) * 8 + 4] = pack(4); | |||
dst_frag[(j / inc_col_in_type) * 8 + 5] = pack(5); | |||
dst_frag[(j / inc_col_in_type) * 8 + 6] = pack(6); | |||
dst_frag[(j / inc_col_in_type) * 8 + 7] = pack(7); | |||
} | |||
} | |||
}; | |||
template <typename SrcType, typename DnnSrcType_, typename DnnDstType_, | |||
bool same_scale> | |||
struct Translayout<64, 2, SrcType, DnnSrcType_, DnnDstType_, same_scale, | |||
typename enable_qtype_b4<DnnSrcType_, DnnDstType_>::type> { | |||
using DnnSrcType = DnnSrcType_; | |||
using DnnDstType = DnnDstType_; | |||
static constexpr int row = 2; | |||
static constexpr int col = 64; | |||
static constexpr int size_nbits = 4; | |||
static constexpr int col_in_type = col * size_nbits / (8 * sizeof(SrcType)); | |||
static constexpr int elements_in_type = row * col_in_type; | |||
static constexpr int inc_col = 8; | |||
static constexpr int inc_col_in_type = | |||
inc_col * size_nbits / (8 * sizeof(SrcType)); | |||
static constexpr bool signedness = qtype_signedness<DnnSrcType>::value; | |||
using Fragment = array_wrapper<SrcType, elements_in_type>; | |||
static MEGDNN_DEVICE __forceinline__ void trans( | |||
Fragment& dst, const Fragment& src, | |||
CudaPostProcess<DnnSrcType, DnnDstType, same_scale>& post_process) { | |||
int intermediate[2][8]; | |||
int* dst_frag = reinterpret_cast<int*>(&dst); | |||
#pragma unroll | |||
for (int j = 0; j < col_in_type; j += inc_col_in_type) { | |||
transform_b4x8_to_int8<signedness>( | |||
intermediate[0], | |||
reinterpret_cast<const int&>(src[0 * col_in_type + j])); | |||
transform_b4x8_to_int8<signedness>( | |||
intermediate[1], | |||
reinterpret_cast<const int&>(src[1 * col_in_type + j])); | |||
dst_frag[(j / inc_col_in_type) * 2 + 0] = | |||
transform_int8_to_b4x8<signedness>( | |||
post_process(intermediate[0][0]), | |||
post_process(intermediate[1][0]), | |||
post_process(intermediate[0][1]), | |||
post_process(intermediate[1][1]), | |||
post_process(intermediate[0][2]), | |||
post_process(intermediate[1][2]), | |||
post_process(intermediate[0][3]), | |||
post_process(intermediate[1][3])); | |||
dst_frag[(j / inc_col_in_type) * 2 + 1] = | |||
transform_int8_to_b4x8<signedness>( | |||
post_process(intermediate[0][4]), | |||
post_process(intermediate[1][4]), | |||
post_process(intermediate[0][5]), | |||
post_process(intermediate[1][5]), | |||
post_process(intermediate[0][6]), | |||
post_process(intermediate[1][6]), | |||
post_process(intermediate[0][7]), | |||
post_process(intermediate[1][7])); | |||
} | |||
} | |||
}; | |||
// ========================================================= | |||
// partial specialization for translayout operator for qint4 | |||
// NCHW <-> NHWC | |||
template <typename SrcType, typename DnnSrcType_, typename DnnDstType_, | |||
bool same_scale> | |||
struct Translayout<2, 8, SrcType, DnnSrcType_, DnnDstType_, same_scale, | |||
typename enable_qtype_b4<DnnSrcType_, DnnDstType_>::type> { | |||
using DnnSrcType = DnnSrcType_; | |||
using DnnDstType = DnnDstType_; | |||
static constexpr int row = 8; | |||
static constexpr int col = 2; | |||
static constexpr int size_nbits = 4; | |||
static constexpr int col_in_type = col * size_nbits / (8 * sizeof(SrcType)); | |||
static constexpr int elements_in_type = row * col_in_type; | |||
static constexpr bool signedness = qtype_signedness<DnnSrcType>::value; | |||
using Fragment = array_wrapper<SrcType, elements_in_type>; | |||
static inline __device__ void trans( | |||
Fragment& dst, const Fragment& src, | |||
CudaPostProcess<DnnSrcType, DnnDstType, same_scale>& post_process, | |||
const char zero_point) { | |||
int intermediate[8][2]; | |||
transform_b4x2_to_int8<signedness>(intermediate[0], | |||
reinterpret_cast<uint8_t&>(src[0])); | |||
transform_b4x2_to_int8<signedness>(intermediate[1], | |||
reinterpret_cast<uint8_t&>(src[1])); | |||
transform_b4x2_to_int8<signedness>(intermediate[2], | |||
reinterpret_cast<uint8_t&>(src[2])); | |||
transform_b4x2_to_int8<signedness>(intermediate[3], | |||
reinterpret_cast<uint8_t&>(src[3])); | |||
transform_b4x2_to_int8<signedness>(intermediate[4], | |||
reinterpret_cast<uint8_t&>(src[4])); | |||
transform_b4x2_to_int8<signedness>(intermediate[5], | |||
reinterpret_cast<uint8_t&>(src[5])); | |||
transform_b4x2_to_int8<signedness>(intermediate[6], | |||
reinterpret_cast<uint8_t&>(src[6])); | |||
transform_b4x2_to_int8<signedness>(intermediate[7], | |||
reinterpret_cast<uint8_t&>(src[7])); | |||
int* dst_frag = reinterpret_cast<int*>(&dst); | |||
auto pack = [&](int idx) -> int { | |||
return transform_int8_to_b4x8<signedness>( | |||
post_process(intermediate[0][idx]), | |||
post_process(intermediate[1][idx]), | |||
post_process(intermediate[2][idx]), | |||
post_process(intermediate[3][idx]), | |||
post_process(intermediate[4][idx]), | |||
post_process(intermediate[5][idx]), | |||
post_process(intermediate[6][idx]), | |||
post_process(intermediate[7][idx])); | |||
}; | |||
dst_frag[0] = pack(0); | |||
dst_frag[1] = pack(1); | |||
} | |||
}; | |||
template <typename SrcType, typename DnnSrcType_, typename DnnDstType_, | |||
bool same_scale> | |||
struct Translayout<8, 8, SrcType, DnnSrcType_, DnnDstType_, same_scale, | |||
typename enable_qtype_b4<DnnSrcType_, DnnDstType_>::type> { | |||
using DnnSrcType = DnnSrcType_; | |||
using DnnDstType = DnnDstType_; | |||
static constexpr int row = 8; | |||
static constexpr int col = 8; | |||
static constexpr int size_nbits = 4; | |||
static constexpr int col_in_type = col * size_nbits / (8 * sizeof(SrcType)); | |||
static constexpr int elements_in_type = row * col_in_type; | |||
static constexpr bool signedness = qtype_signedness<DnnSrcType>::value; | |||
using Fragment = array_wrapper<SrcType, elements_in_type>; | |||
static inline __device__ void trans( | |||
Fragment& dst, const Fragment& src, | |||
CudaPostProcess<DnnSrcType, DnnDstType, same_scale>& post_process, | |||
const char zero_point) { | |||
int intermediate[8][8]; | |||
transform_b4x8_to_int8<signedness>( | |||
intermediate[0], reinterpret_cast<const int&>(src[0])); | |||
transform_b4x8_to_int8<signedness>( | |||
intermediate[1], reinterpret_cast<const int&>(src[1])); | |||
transform_b4x8_to_int8<signedness>( | |||
intermediate[2], reinterpret_cast<const int&>(src[2])); | |||
transform_b4x8_to_int8<signedness>( | |||
intermediate[3], reinterpret_cast<const int&>(src[3])); | |||
transform_b4x8_to_int8<signedness>( | |||
intermediate[4], reinterpret_cast<const int&>(src[4])); | |||
transform_b4x8_to_int8<signedness>( | |||
intermediate[5], reinterpret_cast<const int&>(src[5])); | |||
transform_b4x8_to_int8<signedness>( | |||
intermediate[6], reinterpret_cast<const int&>(src[6])); | |||
transform_b4x8_to_int8<signedness>( | |||
intermediate[7], reinterpret_cast<const int&>(src[7])); | |||
int* dst_frag = reinterpret_cast<int*>(&dst); | |||
auto pack = [&](int idx) { | |||
return transform_int8_to_b4x8<signedness>( | |||
post_process(intermediate[0][idx]), | |||
post_process(intermediate[1][idx]), | |||
post_process(intermediate[2][idx]), | |||
post_process(intermediate[3][idx]), | |||
post_process(intermediate[4][idx]), | |||
post_process(intermediate[5][idx]), | |||
post_process(intermediate[6][idx]), | |||
post_process(intermediate[7][idx])); | |||
}; | |||
dst_frag[0] = pack(0); | |||
dst_frag[1] = pack(1); | |||
dst_frag[2] = pack(2); | |||
dst_frag[3] = pack(3); | |||
dst_frag[4] = pack(4); | |||
dst_frag[5] = pack(5); | |||
dst_frag[6] = pack(6); | |||
dst_frag[7] = pack(7); | |||
} | |||
}; | |||
template <typename SrcType, typename DnnSrcType_, typename DnnDstType_, | |||
bool same_scale> | |||
struct Translayout<8, 2, SrcType, DnnSrcType_, DnnDstType_, same_scale, | |||
typename enable_qtype_b4<DnnSrcType_, DnnDstType_>::type> { | |||
using DnnSrcType = DnnSrcType_; | |||
using DnnDstType = DnnDstType_; | |||
static constexpr int row = 2; | |||
static constexpr int col = 8; | |||
static constexpr int size_nbits = 4; | |||
static constexpr int col_in_type = col * size_nbits / (8 * sizeof(SrcType)); | |||
static constexpr int elements_in_type = row * col_in_type; | |||
static constexpr bool signedness = qtype_signedness<DnnSrcType>::value; | |||
using Fragment = array_wrapper<SrcType, elements_in_type>; | |||
static inline __device__ void trans( | |||
Fragment& dst, const Fragment& src, | |||
CudaPostProcess<DnnSrcType, DnnDstType, same_scale>& post_process, | |||
const char zero_point) { | |||
int intermediate[2][8]; | |||
transform_b4x8_to_int8<signedness>( | |||
intermediate[0], reinterpret_cast<const int&>(src[0])); | |||
transform_b4x8_to_int8<signedness>( | |||
intermediate[1], reinterpret_cast<const int&>(src[1])); | |||
int* dst_frag = reinterpret_cast<int*>(&dst); | |||
dst_frag[0] = transform_int8_to_b4x8<signedness>( | |||
post_process(intermediate[0][0]), | |||
post_process(intermediate[1][0]), | |||
post_process(intermediate[0][1]), | |||
post_process(intermediate[1][1]), | |||
post_process(intermediate[0][2]), | |||
post_process(intermediate[1][2]), | |||
post_process(intermediate[0][3]), | |||
post_process(intermediate[1][3])); | |||
dst_frag[1] = transform_int8_to_b4x8<signedness>( | |||
post_process(intermediate[0][4]), | |||
post_process(intermediate[1][4]), | |||
post_process(intermediate[0][5]), | |||
post_process(intermediate[1][5]), | |||
post_process(intermediate[0][6]), | |||
post_process(intermediate[1][6]), | |||
post_process(intermediate[0][7]), | |||
post_process(intermediate[1][7])); | |||
} | |||
}; | |||
} // namespace internal | |||
} // namespace relayout_format | |||
} // namespace cuda | |||
} // namespace megdnn |
@@ -176,60 +176,22 @@ __global__ void kern_general_nchw4(SrcVisitor src, const float* __restrict mat, | |||
} | |||
} | |||
template <bool signedness> | |||
MEGDNN_DEVICE __forceinline__ int transform_int8_to_bit4x8(int s0, int s1, | |||
int s2, int s3, | |||
int s4, int s5, | |||
int s6, int s7); | |||
template <> | |||
MEGDNN_DEVICE __forceinline__ int transform_int8_to_bit4x8<true>( | |||
int s0, int s1, int s2, int s3, int s4, int s5, int s6, int s7) { | |||
return transform_int8_to_int4x8(s0, s1, s2, s3, s4, s5, s6, s7); | |||
} | |||
template <> | |||
MEGDNN_DEVICE __forceinline__ int transform_int8_to_bit4x8<false>( | |||
int s0, int s1, int s2, int s3, int s4, int s5, int s6, int s7) { | |||
return transform_int8_to_uint4x8(s0, s1, s2, s3, s4, s5, s6, s7); | |||
} | |||
template <bool signedness> | |||
MEGDNN_DEVICE __forceinline__ void | |||
transform_bit4x8_to_int8(int (&result)[8], const int& source); | |||
template <> | |||
MEGDNN_DEVICE __forceinline__ void | |||
transform_bit4x8_to_int8<true>(int (&result)[8], const int& source){ | |||
transform_int4x8_to_int8(result, source); | |||
} | |||
template <> | |||
MEGDNN_DEVICE __forceinline__ void | |||
transform_bit4x8_to_int8<false>(int (&result)[8], const int& source){ | |||
transform_uint4x8_to_int8(result, source); | |||
} | |||
template <bool signedness, typename OutputConverter> | |||
MEGDNN_DEVICE __forceinline__ int pack_output_func( | |||
OutputConverter& output_converter, int (&s00)[8], int (&s01)[8], | |||
int (&s10)[8], int (&s11)[8], float w00, float w01, float w10, | |||
float w11) { | |||
#define warp_perspective_transform(idx) \ | |||
static_cast<int>(output_converter(s00[idx] * w00 + \ | |||
s01[idx] * w01 + \ | |||
s10[idx] * w10 + \ | |||
s11[idx] * w11) \ | |||
#define warp_perspective_transform(idx) \ | |||
static_cast<int>(output_converter(s00[idx] * w00 + s01[idx] * w01 + \ | |||
s10[idx] * w10 + s11[idx] * w11) \ | |||
.as_storage()) | |||
return transform_int8_to_bit4x8<signedness>( | |||
return transform_int8_to_b4x8<signedness>( | |||
warp_perspective_transform(0), warp_perspective_transform(1), | |||
warp_perspective_transform(2), warp_perspective_transform(3), | |||
warp_perspective_transform(4), warp_perspective_transform(5), | |||
warp_perspective_transform(6), warp_perspective_transform(7)); | |||
#undef warp_perspective_transform | |||
#undef warp_perspective_transform | |||
} | |||
template <typename ctype, typename Getter, typename SrcVisitor, | |||
@@ -278,31 +240,31 @@ __global__ void kern_general_nchw64(SrcVisitor src, const float* __restrict mat, | |||
s[2] = __ldg(sptr_int4 + i_coor_10 + c1); | |||
s[3] = __ldg(sptr_int4 + i_coor_11 + c1); | |||
transform_bit4x8_to_int8<signedness>(s00, s[0].x); | |||
transform_bit4x8_to_int8<signedness>(s01, s[1].x); | |||
transform_bit4x8_to_int8<signedness>(s10, s[2].x); | |||
transform_bit4x8_to_int8<signedness>(s11, s[3].x); | |||
transform_b4x8_to_int8<signedness>(s00, s[0].x); | |||
transform_b4x8_to_int8<signedness>(s01, s[1].x); | |||
transform_b4x8_to_int8<signedness>(s10, s[2].x); | |||
transform_b4x8_to_int8<signedness>(s11, s[3].x); | |||
d.x = pack_output_func<signedness>(output_converter, s00, s01, s10, | |||
s11, w00, w01, w10, w11); | |||
transform_bit4x8_to_int8<signedness>(s00, s[0].y); | |||
transform_bit4x8_to_int8<signedness>(s01, s[1].y); | |||
transform_bit4x8_to_int8<signedness>(s10, s[2].y); | |||
transform_bit4x8_to_int8<signedness>(s11, s[3].y); | |||
transform_b4x8_to_int8<signedness>(s00, s[0].y); | |||
transform_b4x8_to_int8<signedness>(s01, s[1].y); | |||
transform_b4x8_to_int8<signedness>(s10, s[2].y); | |||
transform_b4x8_to_int8<signedness>(s11, s[3].y); | |||
d.y = pack_output_func<signedness>(output_converter, s00, s01, s10, | |||
s11, w00, w01, w10, w11); | |||
transform_bit4x8_to_int8<signedness>(s00, s[0].z); | |||
transform_bit4x8_to_int8<signedness>(s01, s[1].z); | |||
transform_bit4x8_to_int8<signedness>(s10, s[2].z); | |||
transform_bit4x8_to_int8<signedness>(s11, s[3].z); | |||
transform_b4x8_to_int8<signedness>(s00, s[0].z); | |||
transform_b4x8_to_int8<signedness>(s01, s[1].z); | |||
transform_b4x8_to_int8<signedness>(s10, s[2].z); | |||
transform_b4x8_to_int8<signedness>(s11, s[3].z); | |||
d.z = pack_output_func<signedness>(output_converter, s00, s01, s10, | |||
s11, w00, w01, w10, w11); | |||
transform_bit4x8_to_int8<signedness>(s00, s[0].w); | |||
transform_bit4x8_to_int8<signedness>(s01, s[1].w); | |||
transform_bit4x8_to_int8<signedness>(s10, s[2].w); | |||
transform_bit4x8_to_int8<signedness>(s11, s[3].w); | |||
transform_b4x8_to_int8<signedness>(s00, s[0].w); | |||
transform_b4x8_to_int8<signedness>(s01, s[1].w); | |||
transform_b4x8_to_int8<signedness>(s10, s[2].w); | |||
transform_b4x8_to_int8<signedness>(s11, s[3].w); | |||
d.w = pack_output_func<signedness>(output_converter, s00, s01, s10, | |||
s11, w00, w01, w10, w11); | |||
@@ -403,15 +365,7 @@ __global__ void kern_const_border_nchw4(SrcVisitor src, | |||
} | |||
} | |||
} | |||
template <bool signedness> | |||
MEGDNN_DEVICE __forceinline__ static void transform_bit4x8_to_int8( | |||
int (&result)[8], const int& source) { | |||
#pragma unroll | |||
for (int i = 0; i < 8; i++) { | |||
result[i] = unpack_integer_4bits<signedness>( | |||
reinterpret_cast<unsigned const&>(source), (i << 2)); | |||
} | |||
} | |||
template <typename ctype, typename SrcVisitor, typename OutputConverter> | |||
__global__ void kern_const_border_nchw64(SrcVisitor src, | |||
@@ -457,7 +411,7 @@ __global__ void kern_const_border_nchw64(SrcVisitor src, | |||
bool flag00 = okh0 && okw0, flag01 = okh0 && okw1, | |||
flag10 = okh1 && okw0, flag11 = okh1 && okw1; | |||
int8_t bval_4 = bval.as_storage() & 0xF; | |||
int bval_8 = transform_int8_to_bit4x8<signedness>( | |||
int bval_8 = transform_int8_to_b4x8<signedness>( | |||
bval_4, bval_4, bval_4, bval_4, bval_4, bval_4, bval_4, bval_4); | |||
int4 bval_int4; | |||
bval_int4.x = bval_8; | |||
@@ -488,31 +442,31 @@ __global__ void kern_const_border_nchw64(SrcVisitor src, | |||
s[3] = bval_int4; | |||
} | |||
transform_bit4x8_to_int8<signedness>(s00, s[0].x); | |||
transform_bit4x8_to_int8<signedness>(s01, s[1].x); | |||
transform_bit4x8_to_int8<signedness>(s10, s[2].x); | |||
transform_bit4x8_to_int8<signedness>(s11, s[3].x); | |||
transform_b4x8_to_int8<signedness>(s00, s[0].x); | |||
transform_b4x8_to_int8<signedness>(s01, s[1].x); | |||
transform_b4x8_to_int8<signedness>(s10, s[2].x); | |||
transform_b4x8_to_int8<signedness>(s11, s[3].x); | |||
d.x = pack_output_func<signedness>(output_converter, s00, s01, s10, | |||
s11, w00, w01, w10, w11); | |||
transform_bit4x8_to_int8<signedness>(s00, s[0].y); | |||
transform_bit4x8_to_int8<signedness>(s01, s[1].y); | |||
transform_bit4x8_to_int8<signedness>(s10, s[2].y); | |||
transform_bit4x8_to_int8<signedness>(s11, s[3].y); | |||
transform_b4x8_to_int8<signedness>(s00, s[0].y); | |||
transform_b4x8_to_int8<signedness>(s01, s[1].y); | |||
transform_b4x8_to_int8<signedness>(s10, s[2].y); | |||
transform_b4x8_to_int8<signedness>(s11, s[3].y); | |||
d.y = pack_output_func<signedness>(output_converter, s00, s01, s10, | |||
s11, w00, w01, w10, w11); | |||
transform_bit4x8_to_int8<signedness>(s00, s[0].z); | |||
transform_bit4x8_to_int8<signedness>(s01, s[1].z); | |||
transform_bit4x8_to_int8<signedness>(s10, s[2].z); | |||
transform_bit4x8_to_int8<signedness>(s11, s[3].z); | |||
transform_b4x8_to_int8<signedness>(s00, s[0].z); | |||
transform_b4x8_to_int8<signedness>(s01, s[1].z); | |||
transform_b4x8_to_int8<signedness>(s10, s[2].z); | |||
transform_b4x8_to_int8<signedness>(s11, s[3].z); | |||
d.z = pack_output_func<signedness>(output_converter, s00, s01, s10, | |||
s11, w00, w01, w10, w11); | |||
transform_bit4x8_to_int8<signedness>(s00, s[0].w); | |||
transform_bit4x8_to_int8<signedness>(s01, s[1].w); | |||
transform_bit4x8_to_int8<signedness>(s10, s[2].w); | |||
transform_bit4x8_to_int8<signedness>(s11, s[3].w); | |||
transform_b4x8_to_int8<signedness>(s00, s[0].w); | |||
transform_b4x8_to_int8<signedness>(s01, s[1].w); | |||
transform_b4x8_to_int8<signedness>(s10, s[2].w); | |||
transform_b4x8_to_int8<signedness>(s11, s[3].w); | |||
d.w = pack_output_func<signedness>(output_converter, s00, s01, s10, | |||
s11, w00, w01, w10, w11); | |||