From 96d90be1c6e534b3bc50dd84c04735dde3be0818 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Wed, 1 Jun 2022 19:40:01 +0800 Subject: [PATCH] feat(dnn): fallback support int4 relayout GitOrigin-RevId: 3625f5847055940e646358654f296922f05afa93 --- dnn/include/megdnn/dtype.h | 24 +- dnn/src/fallback/relayout/opr_impl.cpp | 405 +++++++++++++++++++++++++++------ dnn/test/common/checker.cpp | 2 +- dnn/test/fallback/relayout.cpp | 54 +++++ 4 files changed, 409 insertions(+), 76 deletions(-) diff --git a/dnn/include/megdnn/dtype.h b/dnn/include/megdnn/dtype.h index 35a27163..18387cf9 100644 --- a/dnn/include/megdnn/dtype.h +++ b/dnn/include/megdnn/dtype.h @@ -615,16 +615,17 @@ struct log<1> { // begin define DTypeTrait impls { #if MEGDNN_CC_HOST -#define MEGDNN_DEF_DT_BASIC_FIELDS(_name, _ctype, _cat, _sign, _bits, _has_param) \ - static MEGDNN_CONSTEXPR const char* name = #_name; \ - using ctype = _ctype; \ - using dtype = ::megdnn::dtype::_name; \ - static MEGDNN_CONSTEXPR DTypeCategory category = DTypeCategory::_cat; \ - static MEGDNN_CONSTEXPR DTypeSignedness signedness = DTypeSignedness::_sign; \ - static MEGDNN_CONSTEXPR uint16_t size_log = \ - ::megdnn::dtype::log::value; \ - static MEGDNN_CONSTEXPR DTypeEnum enumv = DTypeEnum::_name; \ - static MEGDNN_CONSTEXPR uint16_t low_bit = _bits; \ +#define MEGDNN_DEF_DT_BASIC_FIELDS(_name, _ctype, _cat, _sign, _bits, _has_param) \ + static MEGDNN_CONSTEXPR const char* name = #_name; \ + using ctype = _ctype; \ + using dtype = ::megdnn::dtype::_name; \ + static MEGDNN_CONSTEXPR DTypeCategory category = DTypeCategory::_cat; \ + static MEGDNN_CONSTEXPR DTypeSignedness signedness = DTypeSignedness::_sign; \ + static MEGDNN_CONSTEXPR uint16_t size_log = \ + ::megdnn::dtype::log::value; \ + static MEGDNN_CONSTEXPR DTypeEnum enumv = DTypeEnum::_name; \ + static MEGDNN_CONSTEXPR uint16_t low_bit = _bits; \ + static MEGDNN_CONSTEXPR uint16_t bits = _bits == 0 ? sizeof(_ctype) * 8 : _bits; \ static MEGDNN_CONSTEXPR bool has_param = _has_param #else #define MEGDNN_DEF_DT_BASIC_FIELDS(_name, _ctype, _cat, _sign, _bits, _has_param) \ @@ -632,7 +633,8 @@ struct log<1> { typedef ::megdnn::dtype::_name dtype; \ static const uint16_t size_log = ::megdnn::dtype::log::value; \ static MEGDNN_CONSTEXPR int enumv = DTypeEnum::_name; \ - static MEGDNN_CONSTEXPR uint16_t low_bit = _bits + static MEGDNN_CONSTEXPR uint16_t low_bit = _bits; \ + static MEGDNN_CONSTEXPR uint16_t bits = _bits == 0 ? sizeof(_ctype) * 8 : _bits; #endif // MEGDNN_CC_HOST #define MEGDNN_DEF_DT(_name, _ctype, _cat, _sign, _minval, _maxval) \ diff --git a/dnn/src/fallback/relayout/opr_impl.cpp b/dnn/src/fallback/relayout/opr_impl.cpp index 1c070009..01b5e945 100644 --- a/dnn/src/fallback/relayout/opr_impl.cpp +++ b/dnn/src/fallback/relayout/opr_impl.cpp @@ -8,12 +8,129 @@ using namespace megdnn; using namespace fallback; +namespace megdnn { +namespace relayout { +namespace transpose_fallback { +template <> +struct transpose_traits { + static constexpr size_t block_size = BLOCK_LINE_SIZE_BYTES; +}; + +template <> +void transpose_block_fallback( + const dt_qint4* src, dt_qint4* dst, const size_t src_stride, + const size_t dst_stride, size_t block_h, size_t block_w) { + constexpr size_t block_size = transpose_traits::block_size; + uint8_t block[block_size][block_size]; + uint8_t* src_ptr = (uint8_t*)src; + uint8_t* dst_ptr = (uint8_t*)dst; + for (size_t i = 0; i < block_h; ++i) { + size_t src_offset_base = i * src_stride; + for (size_t j = 0; j < block_w; ++j) { + size_t src_offset = src_offset_base + j; + size_t src_byte_offset = src_offset >> 1; + if (src_offset % 2 == 0) { + block[j][i] = src_ptr[src_byte_offset] & 0xf; + } else { + block[j][i] = ((src_ptr[src_byte_offset] & 0xf0) >> 4) & 0xf; + } + } + } + for (size_t i = 0; i < block_w; ++i) { + size_t dst_offset_base = i * dst_stride; + for (size_t j = 0; j < block_h; ++j) { + size_t dst_offset = dst_offset_base + j; + size_t dst_byte_offset = dst_offset >> 1; + uint8_t dst_temp = dst_ptr[dst_byte_offset]; + uint8_t src_temp = block[i][j]; + if (dst_offset % 2 == 0) { + dst_temp = (dst_temp & 0xf0) | src_temp; + } else { + dst_temp = (dst_temp & 0xf) | (src_temp << 4); + } + dst_ptr[dst_byte_offset] = dst_temp; + } + } +} + +template <> +void transpose( + size_t batch, size_t m, size_t n, dt_qint4* src, dt_qint4* dst, + size_t stride_m) { + if (stride_m == 0) { + stride_m = n; + } + uint8_t* batch_src = (uint8_t*)(src); + uint8_t* batch_dst = (uint8_t*)(dst); + constexpr size_t B = transpose_traits::block_size; + + auto work_block = [m, stride_m, &batch_src, &batch_dst]( + const size_t i, const size_t j, const size_t h, + const size_t w) { + size_t src_offset = i * stride_m + j; + size_t dst_offset = j * m + i; + megdnn_assert(src_offset % 2 == 0 && dst_offset % 2 == 0); + auto src = batch_src + (src_offset >> 1); + auto dst = batch_dst + (dst_offset >> 1); + MIDOUT_BEGIN(transpose_fallback, midout_iv(0)) { + if (h == B && w == B) { + transpose_block((dt_qint4*)src, (dt_qint4*)dst, stride_m, m); + } else { + transpose_block((dt_qint4*)src, (dt_qint4*)dst, stride_m, m, h, w); + } + } + MIDOUT_END(); + }; + auto work_row = [&work_block, n](size_t i, size_t h) { + size_t j = 0; + for (; j + B <= n; j += B) { + work_block(i, j, h, B); + } + if (j < n) { + work_block(i, j, h, n - j); + } + }; + + for (size_t b = 0; b < batch; ++b) { + size_t i = 0; + for (; i + B <= m; i += B) { + work_row(i, B); + } + if (i < m) { + work_row(i, m - i); + } + size_t src_offset = m * stride_m; + size_t dst_offset = m * n; + megdnn_assert(src_offset % 2 == 0 && dst_offset % 2 == 0); + batch_src += (src_offset >> 1); + batch_dst += (dst_offset >> 1); + } +} + +} // namespace transpose_fallback +} // namespace relayout +} // namespace megdnn + namespace { bool is_lastdim_contig(const TensorLayout& layout) { return layout.ndim <= 3 && layout.stride[layout.ndim - 1] == 1; } +bool is_int4(const TensorLayout& layout) { + return layout.dtype.enumv() == DTypeEnum::QuantizedS4 || + layout.dtype.enumv() == DTypeEnum::Quantized4Asymm; +} + +inline bool check_dtype_support_transparam( + bool trans, bool is_bit4, const relayout::TransposeParam& param) { + if (trans && is_bit4) { + auto c = param.c; + return c == 1 || c == 2 || c == 4 || c == 8; + } + return trans; +} + template struct equiv_ctype_storage { T0 _[sz]; @@ -26,16 +143,111 @@ struct equiv_ctype { alignof(typename DTypeTrait::ctype)>; }; -typedef void (*memcpy_policy_t)(void* cont, void* non_cont, size_t); +typedef void (*memcpy_policy_t)( + void* cont, void* non_cont, size_t src_offset, size_t dst_offset, size_t size); -void memcpy_cont2noncont(void* cont, void* non_cont, size_t size) { +void memcpy_cont2noncont(void* cont, void* non_cont, size_t, size_t, size_t size) { memcpy(non_cont, cont, size); } -void memcpy_noncont2cont(void* cont, void* non_cont, size_t size) { +void memcpy_noncont2cont(void* cont, void* non_cont, size_t, size_t, size_t size) { memcpy(cont, non_cont, size); } +void memcpy_4bit( + void* cont, void* nocont, size_t cont_offset, size_t nocont_offset, + size_t size) { + if (size == 0) + return; + uint8_t* cont_u8 = (uint8_t*)cont; + uint8_t* nocont_u8 = (uint8_t*)nocont; + size_t cont_bytes = cont_offset >> 1; + size_t nocont_bytes = nocont_offset >> 1; + size_t size_byte = size >> 1; + void* cont_ptr = cont_u8 + cont_bytes; + void* nocont_ptr = nocont_u8 + nocont_bytes; + bool size_align = size % 2 == 0; + bool cont_align = cont_offset % 2 == 0; + bool nocont_align = nocont_offset % 2 == 0; + if (cont_align && nocont_align) { + memcpy(cont_ptr, nocont_ptr, size_byte); + if (!size_align) { + uint8_t* dst_ptr = (uint8_t*)cont_ptr + size_byte; + uint8_t* src_ptr = (uint8_t*)nocont_ptr + size_byte; + *dst_ptr = (*src_ptr) & 0xf; + } + } else if (!cont_align && nocont_align) { + uint8_t* dst_ptr = (uint8_t*)cont_ptr; + uint8_t* src_ptr = (uint8_t*)nocont_ptr; + for (size_t i = 0; i < size_byte; ++i) { + uint8_t dst_low = *dst_ptr; + uint8_t src_all = *src_ptr; + uint8_t last = (dst_low & 0xf) | (src_all & 0xf) << 4; + uint8_t now = ((src_all & 0xf0) >> 4) & 0xf; + *dst_ptr = last; + ++dst_ptr; + *dst_ptr = now; + ++src_ptr; + } + if (!size_align) { + uint8_t dst_low = *dst_ptr; + uint8_t src_all = *src_ptr; + uint8_t last = (dst_low & 0xf) | (src_all & 0xf) << 4; + *dst_ptr = last; + } + } else if (cont_align && !nocont_align) { + uint8_t* dst_ptr = (uint8_t*)cont_ptr; + uint8_t* src_ptr = (uint8_t*)nocont_ptr; + for (size_t i = 0; i < size_byte; ++i) { + uint8_t src_last_high = *src_ptr; + ++src_ptr; + uint8_t src_low = *src_ptr; + uint8_t rst = (src_low & 0xf) << 4 | ((src_last_high >> 4) & 0xf); + *dst_ptr = rst; + ++dst_ptr; + } + if (!size_align) { + uint8_t src_last_high = *src_ptr; + *dst_ptr = ((src_last_high >> 4) & 0xf); + } + } else { + uint8_t* dst_ptr = (uint8_t*)cont_ptr; + uint8_t* src_ptr = (uint8_t*)nocont_ptr; + { + uint8_t src_last_high = *src_ptr; + uint8_t dst_last_low = *dst_ptr; + uint8_t rst = (dst_last_low & 0xf) | (src_last_high & 0xf0); + *dst_ptr = rst; + ++dst_ptr; + ++src_ptr; + } + if (!size_align) { + memcpy(dst_ptr, src_ptr, size_byte); + } else { + if (size_byte > 1) { + size_t align_size = size_byte - 1; + memcpy(dst_ptr, src_ptr, align_size); + dst_ptr += align_size; + src_ptr += align_size; + } + uint8_t src_last_low = *src_ptr; + *dst_ptr = src_last_low & 0xf; + } + } +} + +void memcpy_cont2noncont_4bit( + void* cont, void* non_cont, size_t cont_offset, size_t nocont_offset, + size_t size) { + memcpy_4bit(non_cont, cont, nocont_offset, cont_offset, size); +} + +void memcpy_noncont2cont_4bit( + void* cont, void* non_cont, size_t cont_offset, size_t nocont_offset, + size_t size) { + memcpy_4bit(cont, non_cont, cont_offset, nocont_offset, size); +} + template void call_transpose( size_t batch, size_t m, size_t n, size_t ch, void* src, void* dst, @@ -46,7 +258,7 @@ void call_transpose( } //! one operand contiguous, and the other non-contiguous -template +template void dispatch_on_dtype_cont( Handle* handle, const TensorND& cont, const TensorND& nonc, memcpy_policy_t mcp_pol) { @@ -54,13 +266,13 @@ void dispatch_on_dtype_cont( switch (nonc.layout.ndim) { case 2: { auto shp0 = nonc.layout.shape[0], shp1 = nonc.layout.shape[1]; - auto strd0_n = nonc.layout.stride[0] * sizeof(ctype); - auto strd0_c = shp1 * sizeof(ctype); + auto strd0_n = nonc.layout.stride[0] * bits / 8; + auto strd0_c = shp1 * bits / 8; kern = [=]() { auto cur_ctptr = static_cast(cont.raw_ptr()); auto cur_ncptr = static_cast(nonc.raw_ptr()); for (size_t i = 0; i < shp0; ++i) { - mcp_pol(cur_ctptr, cur_ncptr, strd0_c); + mcp_pol(cur_ctptr, cur_ncptr, 0, 0, strd0_c); cur_ctptr += strd0_c; cur_ncptr += strd0_n; } @@ -70,16 +282,16 @@ void dispatch_on_dtype_cont( case 3: { auto shp0 = nonc.layout.shape[0], shp1 = nonc.layout.shape[1], shp2 = nonc.layout.shape[2]; - auto strd0_n = nonc.layout.stride[0] * sizeof(ctype), - strd1_n = nonc.layout.stride[1] * sizeof(ctype); - auto strd1_c = shp2 * sizeof(ctype); + auto strd0_n = nonc.layout.stride[0] * bits / 8, + strd1_n = nonc.layout.stride[1] * bits / 8; + auto strd1_c = shp2 * bits / 8; kern = [=]() { auto cur_ctptr = static_cast(cont.raw_ptr()); auto ncptr_row = static_cast(nonc.raw_ptr()); for (size_t i = 0; i < shp0; ++i) { auto cur_ncptr = ncptr_row; for (size_t j = 0; j < shp1; ++j) { - mcp_pol(cur_ctptr, cur_ncptr, strd1_c); + mcp_pol(cur_ctptr, cur_ncptr, 0, 0, strd1_c); cur_ctptr += strd1_c; cur_ncptr += strd1_n; } @@ -95,13 +307,64 @@ void dispatch_on_dtype_cont( static_cast(handle)->dispatch_kern(std::move(kern)); } +template <> +void dispatch_on_dtype_cont<4>( + Handle* handle, const TensorND& cont, const TensorND& nonc, + memcpy_policy_t mcp_pol) { + thin_function kern; + switch (nonc.layout.ndim) { + case 2: { + auto shp0 = nonc.layout.shape[0], shp1 = nonc.layout.shape[1]; + auto strd0_n = nonc.layout.stride[0]; + auto strd0_c = shp1; + kern = [=]() { + auto cur_ctptr = static_cast(cont.raw_ptr()); + auto cur_ncptr = static_cast(nonc.raw_ptr()); + size_t c_cnt = 0; + size_t n_cnt = 0; + for (size_t i = 0; i < shp0; ++i) { + mcp_pol(cur_ctptr, cur_ncptr, c_cnt, n_cnt, strd0_c); + c_cnt += strd0_c; + n_cnt += strd0_n; + } + }; + break; + } + case 3: { + auto shp0 = nonc.layout.shape[0], shp1 = nonc.layout.shape[1], + shp2 = nonc.layout.shape[2]; + auto strd0_n = nonc.layout.stride[0], strd1_n = nonc.layout.stride[1]; + auto strd1_c = shp2; + kern = [=]() { + auto cur_ctptr = static_cast(cont.raw_ptr()); + auto ncptr_row = static_cast(nonc.raw_ptr()); + size_t c_cnt = 0; + size_t n_cnt = 0; + for (size_t i = 0; i < shp0; ++i) { + n_cnt = i * strd0_n; + for (size_t j = 0; j < shp1; ++j) { + mcp_pol(cur_ctptr, ncptr_row, c_cnt, n_cnt, strd1_c); + c_cnt += strd1_c; + n_cnt += strd1_n; + } + } + }; + break; + } + default: + megdnn_assert(0); + } + + static_cast(handle)->dispatch_kern(std::move(kern)); +} + void dispatch_cont( Handle* handle, const TensorND& cont, const TensorND& nonc, memcpy_policy_t mcp_pol) { switch (cont.layout.dtype.enumv()) { -#define cb(_dt) \ - case DTypeTrait::enumv: \ - return dispatch_on_dtype_cont::type>( \ +#define cb(_dt) \ + case DTypeTrait::enumv: \ + return dispatch_on_dtype_cont::bits>( \ handle, cont, nonc, mcp_pol); MEGDNN_FOREACH_DTYPE_NAME(cb) MEGDNN_FOREACH_PARAMETERIZED_DTYPE(cb) @@ -110,8 +373,8 @@ void dispatch_cont( } } -const size_t BLOCK_SIZE = 16, - TRANSPOSE_CV_MAX_C = relayout::transpose_fallback::BLOCK_LINE_SIZE_BYTES; +const size_t BLOCK_SIZE = 16; +const size_t TRANSPOSE_CV_MAX_C = relayout::transpose_fallback::BLOCK_LINE_SIZE_BYTES; /*! * \tparam ctype The type of the data @@ -221,69 +484,76 @@ void RelayoutForwardImpl::exec( return; } - // FIXME: optimize for lowbit cases - if (src.layout.dtype.enumv() == DTypeEnum::QuantizedS4 || - src.layout.dtype.enumv() == DTypeEnum::Quantized4Asymm) { - NaiveRelayoutForwardImpl::do_exec(src, dst); - return; - } - + bool is_bit4 = is_int4(src.layout); + bool allow_nocontig = !is_bit4; relayout::TransposeParam trans_param; - bool trans = relayout::is_transpose(src.layout, dst.layout, trans_param, true); + bool trans = + relayout::is_transpose(src.layout, dst.layout, trans_param, allow_nocontig); + trans = check_dtype_support_transparam(trans, is_bit4, trans_param); exec_after_preprocess(src, dst, trans ? &trans_param : nullptr); } void RelayoutForwardImpl::exec_after_preprocess( const TensorND& src, const TensorND& dst, relayout::TransposeParam* transpose) { if (transpose) { - auto kernel = [tparam = *transpose, src, dst]() { + bool is_bit4 = is_int4(src.layout); + auto kernel = [tparam = *transpose, src, dst, is_bit4]() { auto t = tparam; - auto dsize = src.layout.dtype.size() * t.c; void (*kptr)(size_t, size_t, size_t, size_t, void*, void*, size_t) = nullptr; auto src_addr = reinterpret_cast(src.raw_ptr()), dst_addr = reinterpret_cast(dst.raw_ptr()); - if (dsize == 1) { - megdnn_assert(t.c == 1); - kptr = call_transpose; - } else if (dsize == 2) { - t.c = 1; - if (!((src_addr | dst_addr) & (alignof(uint16_t) - 1))) { - kptr = call_transpose; - } else { - kptr = call_transpose>; - megdnn_log_error("unaligned addr in relayout"); - } - } else if (dsize == 3) { - t.c = 1; - kptr = call_transpose>; - } else if (dsize == 4) { - t.c = 1; - if (!((src_addr | dst_addr) & (alignof(uint32_t) - 1))) { - kptr = call_transpose; - } else { - kptr = call_transpose>; - megdnn_log_error("unaligned addr in relayout"); - } - } else if (dsize == 12) { - t.c = 1; - if (!((src_addr | dst_addr) & (alignof(uint32_t) - 1))) { - kptr = call_transpose>; - } else { - kptr = call_transpose>; - megdnn_log_error("unaligned addr in relayout"); - } - } else if (dsize <= TRANSPOSE_CV_MAX_C) { - switch (dst.layout.dtype.enumv()) { + size_t dsize = 0; + if (is_bit4) { + dsize = t.c >> 1; + } else { + dsize = src.layout.dtype.size() * t.c; + } + if (is_bit4 && dsize == 0) { + kptr = call_transpose; + } else { + if (dsize == 1) { + megdnn_assert(t.c == 1); + kptr = call_transpose; + } else if (dsize == 2) { + t.c = 1; + if (!((src_addr | dst_addr) & (alignof(uint16_t) - 1))) { + kptr = call_transpose; + } else { + kptr = call_transpose>; + megdnn_log_error("unaligned addr in relayout"); + } + } else if (dsize == 3) { + t.c = 1; + kptr = call_transpose>; + } else if (dsize == 4) { + t.c = 1; + if (!((src_addr | dst_addr) & (alignof(uint32_t) - 1))) { + kptr = call_transpose; + } else { + kptr = call_transpose>; + megdnn_log_error("unaligned addr in relayout"); + } + } else if (dsize == 12) { + t.c = 1; + if (!((src_addr | dst_addr) & (alignof(uint32_t) - 1))) { + kptr = call_transpose>; + } else { + kptr = call_transpose>; + megdnn_log_error("unaligned addr in relayout"); + } + } else if (dsize <= TRANSPOSE_CV_MAX_C) { + switch (dst.layout.dtype.enumv()) { #define cb(_dt) \ case DTypeTrait::enumv: \ kptr = transpose_cv::type>; \ break; - MEGDNN_FOREACH_DTYPE_NAME(cb) - MEGDNN_FOREACH_PARAMETERIZED_DTYPE(cb) + MEGDNN_FOREACH_DTYPE_NAME(cb) + MEGDNN_FOREACH_PARAMETERIZED_DTYPE(cb) #undef cb + } + megdnn_assert(kptr); } - megdnn_assert(kptr); } if (kptr) { @@ -305,13 +575,20 @@ void RelayoutForwardImpl::exec_after_preprocess( MEGDNN_DISPATCH_CPU_KERN_OPR(memcpy(dst.raw_ptr(), src.raw_ptr(), sz)); return; } - + memcpy_policy_t cpy_noncont2cont = memcpy_noncont2cont; + memcpy_policy_t cpy_cont2noncont = memcpy_cont2noncont; + bool is_bit4 = src.layout.dtype.enumv() == DTypeEnum::QuantizedS4 || + src.layout.dtype.enumv() == DTypeEnum::Quantized4Asymm; + if (is_bit4) { + cpy_noncont2cont = memcpy_noncont2cont_4bit; + cpy_cont2noncont = memcpy_cont2noncont_4bit; + } if (is_contig(dst.layout) && is_lastdim_contig(src.layout)) { - return dispatch_cont(handle(), dst, src, memcpy_noncont2cont); + return dispatch_cont(handle(), dst, src, cpy_noncont2cont); } if (is_contig(src.layout) && is_lastdim_contig(dst.layout)) { - return dispatch_cont(handle(), src, dst, memcpy_cont2noncont); + return dispatch_cont(handle(), src, dst, cpy_cont2noncont); } NaiveRelayoutForwardImpl::do_exec(src, dst); } diff --git a/dnn/test/common/checker.cpp b/dnn/test/common/checker.cpp index c73ff9e3..abfc37bc 100644 --- a/dnn/test/common/checker.cpp +++ b/dnn/test/common/checker.cpp @@ -98,7 +98,7 @@ template void copy_tensors( const CheckerHelper::TensorValueArray& dest, const CheckerHelper::TensorValueArray& src, const Impl& copy_impl) { - megdnn_assert(dest.size() == src.size()); + megdnn_assert(dest.size() == src.size(), "%zu != %zu", dest.size(), src.size()); for (size_t i = 0; i < src.size(); i++) { auto&& tensor = src[i]; if (tensor.layout.ndim == 0) diff --git a/dnn/test/fallback/relayout.cpp b/dnn/test/fallback/relayout.cpp index e70439ff..75ef5ec5 100644 --- a/dnn/test/fallback/relayout.cpp +++ b/dnn/test/fallback/relayout.cpp @@ -34,6 +34,60 @@ TEST_F(FALLBACK, RELAYOUT_RECORD) { checker.exec({{2, 2, 2}, {2, 2, 2}}); } +TEST_F(FALLBACK, RELAYOUT_Q4) { + Checker checker(handle()); + UniformIntRNG rng_int4{-7, 7}; + checker.set_rng(0, &rng_int4) + .set_rng(1, &rng_int4) + .set_dtype(0, dtype::QuantizedS4(1.f)) + .set_dtype(1, dtype::QuantizedS4(1.f)) + .execs({{2, 2, 1, 1}, {1, 1, 2, 2}}) + .execs({{1, 64, 15, 15}, {1, 15, 15, 64}}) + .execs({{1, 5, 9, 32}, {1, 5, 32, 9}}) + .execl(TensorLayoutArray{ + {{6400}, {1}, dtype::QuantizedS4{1.f}}, + {{20, 320}, {1024, 1}, dtype::QuantizedS4{1.f}}}) + .execl(TensorLayoutArray{ + {{156}, {1}, dtype::QuantizedS4{1.f}}, + {{13, 3, 4}, {16, 1, 4}, dtype::QuantizedS4{1.f}}}) + .execl(TensorLayoutArray{ + {{48}, {1}, dtype::QuantizedS4{1.f}}, + {{3, 4, 4}, {16, 1, 4}, dtype::QuantizedS4{1.f}}}) + .execl(TensorLayoutArray{ + {{84}, {1}, dtype::QuantizedS4{1.f}}, + {{3, 4, 7}, {28, 1, 4}, dtype::QuantizedS4{1.f}}}) + .execl(TensorLayoutArray{ + {{336}, {1}, dtype::QuantizedS4{1.f}}, + {{3, 4, 7, 4}, {112, 4, 16, 1}, dtype::QuantizedS4{1.f}}}) + .execl(TensorLayoutArray{ + {{54}, {1}, dtype::QuantizedS4{1.f}}, + {{6, 3, 3}, {16, 4, 1}, dtype::QuantizedS4{1.f}}}) + .execl(TensorLayoutArray{ + {{1200, 3}, {4, 1}, dtype::QuantizedS4{1.f}}, + {{20, 60, 3}, {256, 4, 1}, dtype::QuantizedS4{1.f}}}) + .execl(TensorLayoutArray{ + {{20, 20, 3, 3}, {256, 12, 4, 1}, dtype::QuantizedS4{1.f}}, + {{1200, 3}, {4, 1}, dtype::QuantizedS4{1.f}}}) + .execl(TensorLayoutArray{ + {{5, 16, 7, 7, 4}, {3136, 196, 28, 4, 1}, dtype::QuantizedS4{1.f}}, + {{5, 16, 7, 7, 4}, {3136, 4, 448, 64, 1}, dtype::QuantizedS4{1.f}}}) + .execl(TensorLayoutArray{ + {{5, 7, 7, 16, 4}, {3136, 448, 64, 4, 1}, dtype::QuantizedS4{1.f}}, + {{5, 7, 7, 16, 4}, {3136, 28, 4, 196, 1}, dtype::QuantizedS4{1.f}}}) + .execl(TensorLayoutArray{ + {{5, 2, 7, 7, 32}, + {3136, 1568, 224, 32, 1}, + dtype::QuantizedS4{1.f}}, + {{5, 2, 7, 7, 32}, + {3136, 32, 448, 64, 1}, + dtype::QuantizedS4{1.f}}}) + .execl(TensorLayoutArray{ + {{5, 7, 7, 2, 32}, {3136, 448, 64, 32, 1}, dtype::QuantizedS4{1.f}}, + {{5, 7, 7, 2, 32}, + {3136, 224, 32, 1568, 1}, + dtype::QuantizedS4{1.f}}}); +} + #if MEGDNN_WITH_BENCHMARK TEST_F(FALLBACK, BENCHMARK_RELAYOUT_CV) { relayout::run_cv_benchmark(handle());