|
@@ -538,9 +538,9 @@ struct Translayout<64, 8, SrcType, dtype::QuantizedS4, dtype::QuantizedS4, |
|
|
}; |
|
|
}; |
|
|
#undef pack |
|
|
#undef pack |
|
|
|
|
|
|
|
|
#define pack(_idx) \ |
|
|
|
|
|
((uint8_t)(post_process(intermediate[0][_idx])) | \ |
|
|
|
|
|
((uint8_t)(post_process(intermediate[1][_idx])) << 4)) |
|
|
|
|
|
|
|
|
#define pack(_idx) \ |
|
|
|
|
|
((post_process(intermediate[0][_idx]) & 0xf) | \ |
|
|
|
|
|
(post_process(intermediate[1][_idx]) << 4)) |
|
|
template <typename SrcType, bool same_scale> |
|
|
template <typename SrcType, bool same_scale> |
|
|
struct Translayout<64, 2, SrcType, dtype::QuantizedS4, dtype::QuantizedS4, |
|
|
struct Translayout<64, 2, SrcType, dtype::QuantizedS4, dtype::QuantizedS4, |
|
|
same_scale> { |
|
|
same_scale> { |
|
@@ -648,9 +648,9 @@ struct Translayout<64, 8, SrcType, dtype::Quantized4Asymm, |
|
|
}; |
|
|
}; |
|
|
#undef pack |
|
|
#undef pack |
|
|
|
|
|
|
|
|
#define pack(_idx) \ |
|
|
|
|
|
((uint8_t)(post_process(intermediate[0][_idx])) | \ |
|
|
|
|
|
((uint8_t)(post_process(intermediate[1][_idx])) << 4)) |
|
|
|
|
|
|
|
|
#define pack(_idx) \ |
|
|
|
|
|
(post_process(intermediate[0][_idx]) | \ |
|
|
|
|
|
(post_process(intermediate[1][_idx]) << 4)) |
|
|
template <typename SrcType, bool same_scale> |
|
|
template <typename SrcType, bool same_scale> |
|
|
struct Translayout<64, 2, SrcType, dtype::Quantized4Asymm, |
|
|
struct Translayout<64, 2, SrcType, dtype::Quantized4Asymm, |
|
|
dtype::Quantized4Asymm, same_scale> { |
|
|
dtype::Quantized4Asymm, same_scale> { |
|
@@ -820,13 +820,25 @@ __global__ void kern_nchw_nchwx( |
|
|
int n_stride_src, int ic_stride, int n_stride_dst, int oc_stride, |
|
|
int n_stride_src, int ic_stride, int n_stride_dst, int oc_stride, |
|
|
CudaPostProcess<DnnSrcType, DnnDstType, same_scale> post_process, |
|
|
CudaPostProcess<DnnSrcType, DnnDstType, same_scale> post_process, |
|
|
const char zero_point, const int group, const int ocpg) { |
|
|
const char zero_point, const int group, const int ocpg) { |
|
|
|
|
|
static constexpr int size_src_type = sizeof(SrcType); |
|
|
|
|
|
static constexpr int size_dst_type = sizeof(DstType); |
|
|
|
|
|
#ifndef MEGDNN_COMMA |
|
|
|
|
|
#define MEGDNN_COMMA , |
|
|
|
|
|
#endif |
|
|
|
|
|
MEGDNN_STATIC_ASSERT(std::is_same<SrcType MEGDNN_COMMA DstType>::value, |
|
|
|
|
|
"Currently this kernel only support accessing tensor " |
|
|
|
|
|
"src and dst in same data type."); |
|
|
|
|
|
n_stride_src /= size_src_type; |
|
|
|
|
|
ic_stride /= size_src_type; |
|
|
|
|
|
n_stride_dst /= size_dst_type; |
|
|
|
|
|
oc_stride /= size_dst_type; |
|
|
|
|
|
|
|
|
const int n_idx = blockIdx.y; |
|
|
const int n_idx = blockIdx.y; |
|
|
const int ihw_block_idx = blockIdx.x * blockDim.x + threadIdx.x; |
|
|
const int ihw_block_idx = blockIdx.x * blockDim.x + threadIdx.x; |
|
|
const int ihw_offset = |
|
|
const int ihw_offset = |
|
|
ihw_block_idx * pack_w; |
|
|
ihw_block_idx * pack_w; |
|
|
const int ihw_offset_in_type = |
|
|
const int ihw_offset_in_type = |
|
|
ihw_offset * size_nbits / (8 * sizeof(SrcType)); |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
ihw_offset * size_nbits / (8 * size_src_type); |
|
|
if (ihw_offset < ihw) { |
|
|
if (ihw_offset < ihw) { |
|
|
const int src_offset_base = n_idx * n_stride_src + ihw_offset_in_type; |
|
|
const int src_offset_base = n_idx * n_stride_src + ihw_offset_in_type; |
|
|
const int dst_offset_base = |
|
|
const int dst_offset_base = |
|
@@ -836,7 +848,7 @@ __global__ void kern_nchw_nchwx( |
|
|
const int ic_block = icpg / pack_c; |
|
|
const int ic_block = icpg / pack_c; |
|
|
const int remain_ic = icpg % pack_c; |
|
|
const int remain_ic = icpg % pack_c; |
|
|
const int src_group_stride = icpg * ic_stride; |
|
|
const int src_group_stride = icpg * ic_stride; |
|
|
const int dst_group_stride = ocpg * oc_stride; |
|
|
|
|
|
|
|
|
const int dst_group_stride = (ocpg / pack_c) * oc_stride; |
|
|
for (int g_idx = 0; g_idx < group; ++g_idx) { |
|
|
for (int g_idx = 0; g_idx < group; ++g_idx) { |
|
|
const int src_offset = |
|
|
const int src_offset = |
|
|
src_offset_base + g_idx * src_group_stride; |
|
|
src_offset_base + g_idx * src_group_stride; |
|
@@ -1018,7 +1030,7 @@ public: |
|
|
int chan_stride_in_elements_, |
|
|
int chan_stride_in_elements_, |
|
|
int channel_) |
|
|
int channel_) |
|
|
: pointer{pointer_}, |
|
|
: pointer{pointer_}, |
|
|
chan_stride_in_elements{chan_stride_in_elements}, |
|
|
|
|
|
|
|
|
chan_stride_in_elements{chan_stride_in_elements_}, |
|
|
channel{channel_} {} |
|
|
channel{channel_} {} |
|
|
|
|
|
|
|
|
MEGDNN_DEVICE __forceinline__ void load(Fragment& frag) { |
|
|
MEGDNN_DEVICE __forceinline__ void load(Fragment& frag) { |
|
@@ -1031,7 +1043,7 @@ public: |
|
|
int frag_idx = i / pack_size * |
|
|
int frag_idx = i / pack_size * |
|
|
(lane_size_in_type / pack_size_in_type) + |
|
|
(lane_size_in_type / pack_size_in_type) + |
|
|
j; |
|
|
j; |
|
|
bool guard = i >= channel; |
|
|
|
|
|
|
|
|
bool guard = i < channel; |
|
|
cutlass::arch::global_load<AccessType, pack_size_in_byte>( |
|
|
cutlass::arch::global_load<AccessType, pack_size_in_byte>( |
|
|
frag_ptr[frag_idx], |
|
|
frag_ptr[frag_idx], |
|
|
reinterpret_cast<void*>(pointer_ + |
|
|
reinterpret_cast<void*>(pointer_ + |
|
@@ -1052,7 +1064,7 @@ public: |
|
|
int frag_idx = i / pack_size * |
|
|
int frag_idx = i / pack_size * |
|
|
(lane_size_in_type / pack_size_in_type) + |
|
|
(lane_size_in_type / pack_size_in_type) + |
|
|
j; |
|
|
j; |
|
|
bool guard = i >= channel; |
|
|
|
|
|
|
|
|
bool guard = i < channel; |
|
|
cutlass::arch::global_store<AccessType, pack_size_in_byte>( |
|
|
cutlass::arch::global_store<AccessType, pack_size_in_byte>( |
|
|
frag_ptr[frag_idx], |
|
|
frag_ptr[frag_idx], |
|
|
reinterpret_cast<void*>(pointer_ + |
|
|
reinterpret_cast<void*>(pointer_ + |
|
@@ -1092,11 +1104,24 @@ __global__ void kern_nchwx_nchw( |
|
|
size_nbits>; |
|
|
size_nbits>; |
|
|
using Transpose = Translayout<pack_c, pack_w, SrcType, DnnSrcType, |
|
|
using Transpose = Translayout<pack_c, pack_w, SrcType, DnnSrcType, |
|
|
DnnDstType, same_scale>; |
|
|
DnnDstType, same_scale>; |
|
|
|
|
|
static constexpr int size_src_type = sizeof(SrcType); |
|
|
|
|
|
static constexpr int size_dst_type = sizeof(DstType); |
|
|
|
|
|
MEGDNN_STATIC_ASSERT(std::is_same<SrcType MEGDNN_COMMA DstType>::value, |
|
|
|
|
|
"Currently this kernel only support accessing tensor " |
|
|
|
|
|
"src and dst in same data type."); |
|
|
|
|
|
n_stride_src /= size_src_type; |
|
|
|
|
|
ic_stride /= size_src_type; |
|
|
|
|
|
n_stride_dst /= size_dst_type; |
|
|
|
|
|
oc_stride /= size_dst_type; |
|
|
|
|
|
#undef MEGDNN_COMMA |
|
|
|
|
|
|
|
|
const int n_idx = blockIdx.y; |
|
|
const int n_idx = blockIdx.y; |
|
|
const int ihw_block_idx = blockIdx.x * blockDim.x + threadIdx.x; |
|
|
const int ihw_block_idx = blockIdx.x * blockDim.x + threadIdx.x; |
|
|
const int ihw_offset = ihw_block_idx * pack_w; |
|
|
const int ihw_offset = ihw_block_idx * pack_w; |
|
|
const int ihw_offset_in_type = |
|
|
const int ihw_offset_in_type = |
|
|
ihw_offset * size_nbits / (8 * sizeof(SrcType)); |
|
|
|
|
|
|
|
|
ihw_offset * size_nbits / (8 * size_src_type); |
|
|
|
|
|
const int oc_stride_inner_dtype = |
|
|
|
|
|
oc_stride * size_dst_type / sizeof(InnerDtype); |
|
|
if (ihw_offset < ihw) { |
|
|
if (ihw_offset < ihw) { |
|
|
const int ic_block = (ic + pack_c - 1) / pack_c; |
|
|
const int ic_block = (ic + pack_c - 1) / pack_c; |
|
|
const int src_offset_base = |
|
|
const int src_offset_base = |
|
@@ -1105,8 +1130,8 @@ __global__ void kern_nchwx_nchw( |
|
|
SrcIterator src_iterator{const_cast<SrcType*>(src + src_offset_base), |
|
|
SrcIterator src_iterator{const_cast<SrcType*>(src + src_offset_base), |
|
|
ic_stride, ic}; |
|
|
ic_stride, ic}; |
|
|
DstIteraotr dst_iterator{ |
|
|
DstIteraotr dst_iterator{ |
|
|
reinterpret_cast<InnerDtype*>(dst + dst_offset_base), oc_stride, |
|
|
|
|
|
ic}; |
|
|
|
|
|
|
|
|
reinterpret_cast<InnerDtype*>(dst + dst_offset_base), |
|
|
|
|
|
oc_stride_inner_dtype, ic}; |
|
|
|
|
|
|
|
|
for (int ic_blk_idx = 0; ic_blk_idx < ic_block; ++ic_blk_idx) { |
|
|
for (int ic_blk_idx = 0; ic_blk_idx < ic_block; ++ic_blk_idx) { |
|
|
typename SrcIterator::Fragment src_frag; |
|
|
typename SrcIterator::Fragment src_frag; |
|
@@ -1143,12 +1168,13 @@ void relayout_format::relayout_format_cuda_nchw_nchwx( |
|
|
DEF(64, Quantized4Asymm, Quantized4Asymm) |
|
|
DEF(64, Quantized4Asymm, Quantized4Asymm) |
|
|
DEF(4, QuantizedS8, QuantizedS8) |
|
|
DEF(4, QuantizedS8, QuantizedS8) |
|
|
DEF(4, Uint8, QuantizedS8) |
|
|
DEF(4, Uint8, QuantizedS8) |
|
|
DEF(4, Quantized8Asymm, Quantized8Asymm) |
|
|
|
|
|
DEF(4, QuantizedS32, QuantizedS32); |
|
|
|
|
|
|
|
|
DEF(4, Quantized8Asymm, QuantizedS8) |
|
|
|
|
|
DEF(4, QuantizedS32, QuantizedS32) |
|
|
// clang-format on |
|
|
// clang-format on |
|
|
megdnn_assert(pack_oc == 4 || pack_oc == 64, |
|
|
megdnn_assert(pack_oc == 4 || pack_oc == 64, |
|
|
"Unsupport pack size(pack_oc:%d)", pack_oc); |
|
|
|
|
|
#undef DEF |
|
|
|
|
|
|
|
|
"Unsupport pack size(pack_oc:%d, src:%s, dst:%s)", pack_oc, |
|
|
|
|
|
stype.name(), dtype.name()); |
|
|
|
|
|
#undef DEF |
|
|
const int in_n = src.layout[0]; |
|
|
const int in_n = src.layout[0]; |
|
|
const int out_n = dst.layout[0]; |
|
|
const int out_n = dst.layout[0]; |
|
|
const int ic = src.layout[1]; |
|
|
const int ic = src.layout[1]; |
|
@@ -1157,6 +1183,7 @@ void relayout_format::relayout_format_cuda_nchw_nchwx( |
|
|
const int oc = dst.layout[1] * pack_oc; |
|
|
const int oc = dst.layout[1] * pack_oc; |
|
|
const int hw = h * w; |
|
|
const int hw = h * w; |
|
|
const int ocpg = oc / group; |
|
|
const int ocpg = oc / group; |
|
|
|
|
|
// stride in byte |
|
|
const int n_stride_src = src_layout.dtype.size(src_layout.stride[0]); |
|
|
const int n_stride_src = src_layout.dtype.size(src_layout.stride[0]); |
|
|
const int ic_stride = src_layout.dtype.size(src_layout.stride[1]); |
|
|
const int ic_stride = src_layout.dtype.size(src_layout.stride[1]); |
|
|
const int n_stride_dst = dst_layout.dtype.size(dst_layout.stride[0]); |
|
|
const int n_stride_dst = dst_layout.dtype.size(dst_layout.stride[0]); |
|
@@ -1244,20 +1271,20 @@ void relayout_format::relayout_format_cuda_nchwx_nchw( |
|
|
auto& src_layout = src.layout; |
|
|
auto& src_layout = src.layout; |
|
|
auto& dst_layout = dst.layout; |
|
|
auto& dst_layout = dst.layout; |
|
|
// check pack size |
|
|
// check pack size |
|
|
int pack_oc = std::numeric_limits<int>::min(); |
|
|
|
|
|
#define DEF(_pack_oc, _src_type, _dst_type) \ |
|
|
|
|
|
|
|
|
int pack_ic = std::numeric_limits<int>::min(); |
|
|
|
|
|
#define DEF(_pack_ic, _src_type, _dst_type) \ |
|
|
if (stype.enumv().ev == DTypeEnum::Ev::_src_type && \ |
|
|
if (stype.enumv().ev == DTypeEnum::Ev::_src_type && \ |
|
|
dtype.enumv().ev == DTypeEnum::Ev::_dst_type) { \ |
|
|
dtype.enumv().ev == DTypeEnum::Ev::_dst_type) { \ |
|
|
pack_oc = _pack_oc; \ |
|
|
|
|
|
|
|
|
pack_ic = _pack_ic; \ |
|
|
} |
|
|
} |
|
|
// clang-format off |
|
|
// clang-format off |
|
|
DEF(64, QuantizedS4, QuantizedS4) |
|
|
DEF(64, QuantizedS4, QuantizedS4) |
|
|
DEF(64, Quantized4Asymm, Quantized4Asymm) |
|
|
DEF(64, Quantized4Asymm, Quantized4Asymm) |
|
|
// clang-format on |
|
|
// clang-format on |
|
|
megdnn_assert(pack_oc == 64, "Unsupport pack size(pack_oc:%d)", pack_oc); |
|
|
|
|
|
|
|
|
megdnn_assert(pack_ic == 64, "Unsupport pack size(pack_ic:%d)", pack_ic); |
|
|
#undef DEF |
|
|
#undef DEF |
|
|
const int n = src.layout[0]; |
|
|
const int n = src.layout[0]; |
|
|
const int c = src.layout[1]; |
|
|
|
|
|
|
|
|
const int c = src.layout[1] * pack_ic; |
|
|
const int h = src.layout[2]; |
|
|
const int h = src.layout[2]; |
|
|
// align to byte |
|
|
// align to byte |
|
|
const int w = src.layout[3]; |
|
|
const int w = src.layout[3]; |
|
@@ -1266,7 +1293,7 @@ void relayout_format::relayout_format_cuda_nchwx_nchw( |
|
|
const int ic_stride = src_layout.dtype.size(src_layout.stride[1]); |
|
|
const int ic_stride = src_layout.dtype.size(src_layout.stride[1]); |
|
|
const int n_stride_dst = dst_layout.dtype.size(dst_layout.stride[0]); |
|
|
const int n_stride_dst = dst_layout.dtype.size(dst_layout.stride[0]); |
|
|
const int oc_stride = dst_layout.dtype.size(dst_layout.stride[1]); |
|
|
const int oc_stride = dst_layout.dtype.size(dst_layout.stride[1]); |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
bool same_scale = src_scale == dst_scale; |
|
|
bool same_scale = src_scale == dst_scale; |
|
|
#define DISPATCH_RAW(_same_scale, _pack_w, _pack_oc, _src_type, _dst_type, \ |
|
|
#define DISPATCH_RAW(_same_scale, _pack_w, _pack_oc, _src_type, _dst_type, \ |
|
|
_src_c_type, _dst_c_type, _size_nbits) \ |
|
|
_src_c_type, _dst_c_type, _size_nbits) \ |
|
|