diff --git a/dnn/scripts/opr_param_defs.py b/dnn/scripts/opr_param_defs.py index f7aff5bc..72940d61 100755 --- a/dnn/scripts/opr_param_defs.py +++ b/dnn/scripts/opr_param_defs.py @@ -557,7 +557,10 @@ pdef('PowC', 'power with constant exponent').add_fields('float32', 'exp', 0) 'layout is (K/8, M/8, 8(k), 8(m)) x (K/8, N, 8(k))'), Doc('MK4_DOT = 3', 'Split 4 from M and K, better for neon dotprod:' 'M/4, K/4, 4(m), 4(k)) x (K/4, N, 4(k)). if transposeA the ' - 'layout is (K/4, M/4, 4(m), 4(k)) x (K/4, N, 4(k))')) + 'layout is (K/4, M/4, 4(m), 4(k)) x (K/4, N, 4(k))'), + Doc('N32K4_DOT = 4', 'Split 32 from N and 4 from K, better for neon gevm dotprod:' + 'N/32, K/4, 32(n), 4(k)') + ) ) (pdef('SVD'). diff --git a/dnn/src/arm_common/conv_bias/int8/algos.h b/dnn/src/arm_common/conv_bias/int8/algos.h index 9a88d723..5c505953 100644 --- a/dnn/src/arm_common/conv_bias/int8/algos.h +++ b/dnn/src/arm_common/conv_bias/int8/algos.h @@ -127,6 +127,37 @@ public: MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_DIRECT_NCHW_NCHW44_DOT_S8) }; +class ConvBiasImpl::AlgoDotS8DirectChanWiseLarge final : public AlgoBase { +public: + AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; } + const char* name() const override { return "ARMDOTS8_DIRECT_CHANWISE_LARGE"; } + bool usable(const NCBKernSizeParam&, AlgoSelectionStrategy algo_selection_strategy) + const override; + + size_t get_workspace(const NCBKernSizeParam&) const override; + virtual SmallVector dispatch_kerns( + const NCBKernSizeParam& param) const override; + ConvAlgoTypePack get_algo_type() const override { + return {AlgoDataType::QINT8X8X32, AlgoCategory::DIRECT}; + } + MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_DOT_DIRECT_CHANWISE_LARGE_S8) +}; + +class ConvBiasImpl::AlgoDotS8Im2colChanWiseLarge final : public AlgoBase { +public: + AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; } + const char* name() const override { return "ARMDOTS8_IM2COL_CHANWISE_LARGE"; } + bool usable(const NCBKernSizeParam&, AlgoSelectionStrategy algo_selection_strategy) + const override; + + size_t get_workspace(const NCBKernSizeParam&) const override; + virtual SmallVector dispatch_kerns( + const NCBKernSizeParam& param) const override; + ConvAlgoTypePack get_algo_type() const override { + return {AlgoDataType::QINT8X8X32, AlgoCategory::IM2COL}; + } + MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_DOT_IM2COL_CHANWISE_LARGE_S8) +}; class ConvBiasImpl::AlgoDotS8DirectStride1 final : public AlgoBase { public: AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; } diff --git a/dnn/src/arm_common/conv_bias/int8/chanwise_direct_dot.cpp b/dnn/src/arm_common/conv_bias/int8/chanwise_direct_dot.cpp new file mode 100644 index 00000000..d30dce77 --- /dev/null +++ b/dnn/src/arm_common/conv_bias/int8/chanwise_direct_dot.cpp @@ -0,0 +1,270 @@ +/** + * \file dnn/src/arm_common/conv_bias/int8/chanwise_direct_dot.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ +#include +#include "src/arm_common/conv_bias/int8/algos.h" +#include "src/arm_common/conv_bias/int8/direct_kernels/dot_direct_nchw_large.h" +#include "src/common/unroll_macro.h" +#if MGB_ENABLE_DOT +#include "midout.h" + +MIDOUT_DECL(megdnn_arm_common_conv_bias_int8_direct_dot_large_kernel) + +using namespace megdnn; +using namespace arm_common; + +using NCBKernSizeParam = fallback::ConvBiasImpl::NCBKernSizeParam; +using NCBKernParam = fallback::ConvBiasImpl::NCBKernParam; +using NCBKernIndex = fallback::ConvBiasImpl::NCBKernIndex; + +namespace { + +class DirectConvRunner { +public: + DirectConvRunner(size_t flt_size, size_t stride) { + if (flt_size == 9 && stride == 1) { + m_func = megdnn_dot_nchw_large_chanwise_direct_conv_9x9s1_oh4_ow16; + } else { + megdnn_assert(flt_size == 9 && stride == 2); + m_func = megdnn_dot_nchw_large_chanwise_direct_conv_9x9s2_oh4_ow16; + } + } + size_t get_round_fw(const ConvBiasImpl::NCBKernSizeParam& param) const { + auto&& fm = param.filter_meta; + auto FW = fm.spatial[1]; + return round_up((size_t)FW, m_block_k); + } + + size_t get_round_iw(const ConvBiasImpl::NCBKernSizeParam& param) const { + auto&& fm = param.filter_meta; + size_t SW = fm.stride[1]; + size_t OW = param.osz[1]; + size_t round_ow = round_up(OW, m_block_ow); + size_t round_fw = get_round_fw(param); + size_t pad_iw = round_ow * SW - SW + round_fw; + return round_up(pad_iw, m_align_iw); + } + + size_t get_round_ih(const ConvBiasImpl::NCBKernSizeParam& param) const { + auto&& fm = param.filter_meta; + size_t SH = fm.stride[0]; + size_t OH = param.osz[0]; + auto FH = fm.spatial[0]; + size_t round_oh = round_up(OH, m_block_oh); + return round_oh * SH - SH + FH; + } + + WorkspaceBundle get_sub_bundle(const ConvBiasImpl::NCBKernSizeParam& param) const { + auto&& fm = param.filter_meta; + auto FH = fm.spatial[0]; + + size_t round_filter = get_round_fw(param) * FH; + size_t round_ih = get_round_ih(param); + size_t round_iw = get_round_iw(param); + size_t pad_src = round_iw * round_ih; + return {nullptr, {pad_src, round_filter}}; + } + + WorkspaceBundle get_total_bundle( + const ConvBiasImpl::NCBKernSizeParam& param) const { + auto sub_bundle = get_sub_bundle(param); + auto sub_bundle_size = sub_bundle.total_size_in_bytes(); + size_t nr_threads = param.nr_threads; + SmallVector sizes_in_bytes; + for (size_t i = 0; i < nr_threads; ++i) { + sizes_in_bytes.push_back(sub_bundle_size); + } + WorkspaceBundle total_bundle(nullptr, sizes_in_bytes); + return total_bundle; + } + + void run( + const int8_t* pad_src_ptr, const int8_t* round_filter_ptr, int32_t bias, + int8_t* dst_ptr, size_t OH, size_t OW, size_t pad_iw, float scale, + int8_t relu_val) const { + const size_t ow_end = OW / m_block_ow * m_block_ow; + const size_t ow_remain = OW - ow_end; + const size_t oh_end = OH / m_block_oh * m_block_oh; + const size_t oh_remain = OH - oh_end; + int8_t cache[4 * 16]; + for (size_t oh = 0; oh < oh_end; oh += m_block_oh) { + for (size_t ow = 0; ow < ow_end; ow += m_block_ow) { + m_func(pad_src_ptr, round_filter_ptr, bias, dst_ptr, oh, ow, OH, OW, + pad_iw, scale, relu_val); + } + if (ow_remain > 0) { + m_func(pad_src_ptr, round_filter_ptr, bias, + &cache[0] - (oh * m_block_ow + ow_end), oh, ow_end, OH, + m_block_ow, pad_iw, scale, relu_val); + for (size_t i = 0; i < m_block_oh; ++i) { + for (size_t j = 0; j < ow_remain; ++j) { + dst_ptr[(i + oh) * OW + (j + ow_end)] = cache[i * 16 + j]; + } + } + } + } + if (oh_remain > 0) { + for (size_t ow = 0; ow < ow_end; ow += m_block_ow) { + m_func(pad_src_ptr, round_filter_ptr, bias, + &cache[0] - (oh_end * m_block_ow + ow), oh_end, ow, OH, + m_block_ow, pad_iw, scale, relu_val); + for (size_t i = 0; i < oh_remain; ++i) { + for (size_t j = 0; j < m_block_ow; ++j) { + dst_ptr[(i + oh_end) * OW + (j + ow)] = cache[i * 16 + j]; + } + } + } + if (ow_remain > 0) { + m_func(pad_src_ptr, round_filter_ptr, bias, + &cache[0] - (oh_end * m_block_ow + ow_end), oh_end, ow_end, OH, + m_block_ow, pad_iw, scale, relu_val); + for (size_t i = 0; i < oh_remain; ++i) { + for (size_t j = 0; j < ow_remain; ++j) { + dst_ptr[(i + oh_end) * OW + (j + ow_end)] = cache[i * 16 + j]; + } + } + } + } + } + +private: + std::function + m_func; + size_t m_block_oh{4}; + size_t m_block_ow{16}; + size_t m_block_k{4}; + size_t m_align_iw{16}; +}; + +void do_conv( + const WorkspaceBundle& bundle, const NCBKernParam& kern_param, + const NCBKernIndex& ncb_index, const DirectConvRunner& runner) { + auto&& fm = kern_param.filter_meta; + size_t PH = kern_param.filter_meta.padding[0]; + size_t PW = kern_param.filter_meta.padding[1]; + size_t OH = kern_param.osz[0]; + size_t OW = kern_param.osz[1]; + size_t IH = kern_param.isz[0]; + size_t IW = kern_param.isz[1]; + size_t FH = fm.spatial[0]; + size_t FW = fm.spatial[1]; + + float scale_bias = kern_param.bias_type.param().scale; + float scale_dst = kern_param.dst_type.param().scale; + float scale_dst_div = 1.f / scale_dst; + size_t batch_id = ncb_index.ndrange_id[0]; + size_t group_id = ncb_index.ndrange_id[1]; + int8_t* pad_src_ptr = static_cast(bundle.get(0)); + int8_t* round_filter_ptr = static_cast(bundle.get(1)); + const int8_t* sptr = kern_param.src(batch_id, group_id); + const int32_t* bptr = kern_param.bias(batch_id, group_id); + const int8_t* fptr = kern_param.filter(group_id); + void* dst = kern_param.dst(batch_id, group_id); + size_t pad_iw = runner.get_round_iw(kern_param); + + memset(pad_src_ptr, 0, bundle.get_size(0)); + rep(ih, IH) { + std::memcpy( + pad_src_ptr + (ih + PH) * pad_iw + PW, sptr + ih * IW, + sizeof(int8_t) * IW); + } + + memset(round_filter_ptr, 0, bundle.get_size(1)); + size_t round_fw = runner.get_round_fw(kern_param); + for (size_t fh = 0; fh < FH; ++fh) { + std::memcpy(round_filter_ptr + fh * round_fw, fptr + fh * FW, FW); + } + + int8_t relu_val = kern_param.nonlineMode == NonlineMode::RELU ? 0 : -128; + int32_t bias_val = kern_param.bias_mode == BiasMode::NO_BIAS ? 0 : *bptr; + + int8_t* dst_ptr = (int8_t*)dst; + runner.run( + pad_src_ptr, round_filter_ptr, bias_val, dst_ptr, OH, OW, pad_iw, + scale_bias * scale_dst_div, relu_val); +} + +} // namespace + +bool ConvBiasImpl::AlgoDotS8DirectChanWiseLarge::usable( + const NCBKernSizeParam& param, AlgoSelectionStrategy) const { + if (!cpuinfo_has_arm_neon_dot()) { + return false; + } + auto&& fm = param.filter_meta; + auto FH = fm.spatial[0]; + auto FW = fm.spatial[1]; + auto SH = fm.stride[0]; + auto SW = fm.stride[1]; + auto noline = param.nonlineMode; + auto bias_mode = param.bias_mode; + bool avaible = + //! src and filter are qint8, dst is qint8 or qint32 + (param.src_type.enumv() == DTypeEnum::QuantizedS8 && + param.filter_type.enumv() == DTypeEnum::QuantizedS8 && + (param.dst_type.enumv() == DTypeEnum::QuantizedS8)) && + fm.format == param::Convolution::Format::NCHW && !fm.should_flip && + (noline == NonlineMode::IDENTITY || noline == NonlineMode::RELU) && + (bias_mode == BiasMode::NO_BIAS || + bias_mode == BiasMode::BROADCAST_CHANNEL_BIAS) && + fm.spatial_ndim == 2 && fm.dilation[0] == 1 && fm.dilation[1] == 1 && + SH == SW && (SH == 1 || SH == 2) && FH == FW && (FH == 9) && fm.icpg == 1 && + fm.ocpg == 1; + return avaible; +} + +size_t ConvBiasImpl::AlgoDotS8DirectChanWiseLarge::get_workspace( + const NCBKernSizeParam& param) const { + MIDOUT_BEGIN( + megdnn_arm_common_conv_bias_int8_direct_dot_large_kernel, + midout_iv("AlgoDotS8DirectChanWiseLarge::get_workspace"_hash)) { + auto&& fm = param.filter_meta; + DirectConvRunner runner(fm.spatial[0], fm.stride[0]); + auto total_bundle = runner.get_total_bundle(param); + return total_bundle.total_size_in_bytes(); + } + MIDOUT_END(); + return 0; +} + +SmallVector ConvBiasImpl::AlgoDotS8DirectChanWiseLarge:: + dispatch_kerns(const NCBKernSizeParam& param) const { + MIDOUT_BEGIN( + megdnn_arm_common_conv_bias_int8_direct_dot_large_kernel, + midout_iv("AlgoDotS8DirectChanWiseLarge::dispatch_kerns"_hash)) { + SmallVector ret_kerns; + auto&& fm = param.filter_meta; + DirectConvRunner runner(fm.spatial[0], fm.stride[0]); + WorkspaceBundle wbundle = runner.get_sub_bundle(param); + WorkspaceBundle total_bundle = runner.get_total_bundle(param); + + auto exec_one_group = [wbundle, total_bundle, runner]( + const NCBKernParam& kern_param, + const NCBKernIndex& ncb_index) mutable { + WorkspaceBundle temp_total_bundle = total_bundle; + temp_total_bundle.set(kern_param.workspace_ptr); + WorkspaceBundle temp_bundle = wbundle; + temp_bundle.set(temp_total_bundle.get(ncb_index.thread_id)); + do_conv(temp_bundle, kern_param, ncb_index, runner); + }; + size_t N = param.n; + size_t group = fm.group; + ret_kerns.push_back({exec_one_group, {N, group}}); + return ret_kerns; + } + MIDOUT_END(); + return {}; +} + +#endif diff --git a/dnn/src/arm_common/conv_bias/int8/chanwise_im2col_dot.cpp b/dnn/src/arm_common/conv_bias/int8/chanwise_im2col_dot.cpp new file mode 100644 index 00000000..467ff51a --- /dev/null +++ b/dnn/src/arm_common/conv_bias/int8/chanwise_im2col_dot.cpp @@ -0,0 +1,425 @@ +/** + * \file dnn/src/arm_common/conv_bias/int8/chanwise_im2col_dot.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ + +#include "megdnn/arch.h" +#if MGB_ENABLE_DOT +#include "src/arm_common/simd_macro/marm_neon.h" + +#include "src/arm_common/conv_bias/int8/algos.h" +#include "src/arm_common/matrix_mul/int8/gemv.h" + +#include "midout.h" + +MIDOUT_DECL(megdnn_arm_common_conv_bias_int8_im2col_dot_large_kernel) + +using namespace megdnn; +using namespace arm_common; + +using NCBKernSizeParam = fallback::ConvBiasImpl::NCBKernSizeParam; +using NCBKernParam = fallback::ConvBiasImpl::NCBKernParam; +using NCBKernIndex = fallback::ConvBiasImpl::NCBKernIndex; + +namespace { +constexpr size_t block_n = 32; +constexpr size_t block_k = 4; + +WorkspaceBundle get_sub_bundle(const ConvBiasImpl::NCBKernSizeParam& param) { + auto&& fm = param.filter_meta; + auto OH = param.osz[0]; + auto OW = param.osz[1]; + size_t IH = param.isz[0]; + size_t IW = param.isz[1]; + auto FH = fm.spatial[0]; + auto FW = fm.spatial[1]; + size_t PH = param.filter_meta.padding[0]; + size_t PW = param.filter_meta.padding[1]; + + size_t round_ohw = round_up((size_t)OH * OW, block_n); + size_t round_filter = round_up((size_t)FW, block_k) * FH; + size_t pad_src = (IW + PW * 2) * (IH + PH * 2); + return {nullptr, + {pad_src, round_filter, round_ohw * round_filter, + round_ohw * sizeof(int32_t)}}; +} + +WorkspaceBundle get_total_bundle(const ConvBiasImpl::NCBKernSizeParam& param) { + auto sub_bundle = get_sub_bundle(param); + auto sub_bundle_size = sub_bundle.total_size_in_bytes(); + size_t nr_threads = param.nr_threads; + SmallVector sizes_in_bytes; + for (size_t i = 0; i < nr_threads; ++i) { + sizes_in_bytes.push_back(sub_bundle_size); + } + WorkspaceBundle total_bundle(nullptr, sizes_in_bytes); + return total_bundle; +} + +template +void im2col( + const int8_t* src, int8_t* dst, size_t OH, size_t OW, size_t pad_iw, + size_t round_filter) { + constexpr size_t FH = flt_size; + constexpr size_t FW = flt_size; + constexpr size_t SH = stride; + constexpr size_t SW = stride; + constexpr size_t FW_ROUND = (FW + 3) / 4 * 4; + int bn = 0; + int ni = 0; + for (size_t oh = 0; oh < OH; ++oh) + for (size_t ow = 0; ow < OW; ++ow) { + const int8_t* src_n = src + oh * SH * pad_iw + ow * SW; + int bk = 0; + int ki = 0; + for (size_t fh = 0; fh < FH; ++fh) + for (size_t fw = 0; fw < FW_ROUND; ++fw) { + dst[bn * block_n * round_filter + bk * block_n * block_k + + ni * block_k + ki] = src_n[fh * pad_iw + fw]; + ++ki; + if (ki == block_k) { + ki = 0; + ++bk; + } + } + ++ni; + if (ni == block_n) { + ni = 0; + ++bn; + } + } +} + +template <> +void im2col<9, 1>( + const int8_t* src, int8_t* dst, size_t OH, size_t OW, size_t pad_iw, + size_t round_filter) { + constexpr size_t FH = 9; + constexpr size_t SH = 1; + constexpr size_t SW = 1; + constexpr size_t k_block_stride = block_k * block_n; + + constexpr size_t ow_block = 16; + static const uint8_t tbl_array_0[16] = {0, 1, 2, 3, 1, 2, 3, 4, + 2, 3, 4, 5, 3, 4, 5, 6}; + static const uint8_t tbl_array_1[16] = {4, 5, 6, 7, 5, 6, 7, 8, + 6, 7, 8, 9, 7, 8, 9, 10}; + static const uint8_t tbl_array_2[16] = {8, 9, 10, 11, 9, 10, 11, 12, + 10, 11, 12, 13, 11, 12, 13, 14}; + + uint8x16_t tbl_reg_0 = vld1q_u8(&tbl_array_0[0]); + uint8x16_t tbl_reg_1 = vld1q_u8(&tbl_array_1[0]); + uint8x16_t tbl_reg_2 = vld1q_u8(&tbl_array_2[0]); + + int bn = 0; + int ni = 0; + for (size_t oh = 0; oh < OH; ++oh) + + for (size_t ow = 0; ow < OW;) { + if (ow + ow_block <= OW && ni + ow_block <= block_n) { + const int8_t* src_n = src + oh * SH * pad_iw + ow * SW; + int8_t* dst_n = dst + bn * block_n * round_filter + ni * block_k; + for (size_t fh = 0; fh < FH; ++fh) { + int8x16_t read_w[2]; + read_w[0] = vld1q_s8(src_n); + read_w[1] = vld1q_s8(src_n + 16); + + int8x16_t n0123_0 = vqtbl1q_s8(read_w[0], tbl_reg_0); + int8x16_t n4567_0 = vqtbl1q_s8(read_w[0], tbl_reg_1); + int8x16_t n89ab_0 = vqtbl1q_s8(read_w[0], tbl_reg_2); + int8x16_t ncdef_0 = + vqtbl1q_s8(vextq_s8(read_w[0], read_w[1], 12), tbl_reg_0); + + int8x16_t n0123_1 = n4567_0; + int8x16_t n4567_1 = n89ab_0; + int8x16_t n89ab_1 = ncdef_0; + int8x16_t ncdef_1 = vqtbl1q_s8(read_w[1], tbl_reg_0); + + int8x16_t n0123_2 = n89ab_0; + int8x16_t n4567_2 = ncdef_0; + int8x16_t n89ab_2 = ncdef_1; + int8x16_t ncdef_2 = vqtbl1q_s8(read_w[1], tbl_reg_1); + + vst1q_s8(dst_n + 0 * 16, n0123_0); + vst1q_s8(dst_n + 1 * 16, n4567_0); + vst1q_s8(dst_n + 2 * 16, n89ab_0); + vst1q_s8(dst_n + 3 * 16, ncdef_0); + + vst1q_s8(dst_n + 1 * k_block_stride + 0 * 16, n0123_1); + vst1q_s8(dst_n + 1 * k_block_stride + 1 * 16, n4567_1); + vst1q_s8(dst_n + 1 * k_block_stride + 2 * 16, n89ab_1); + vst1q_s8(dst_n + 1 * k_block_stride + 3 * 16, ncdef_1); + + vst1q_s8(dst_n + 2 * k_block_stride + 0 * 16, n0123_2); + vst1q_s8(dst_n + 2 * k_block_stride + 1 * 16, n4567_2); + vst1q_s8(dst_n + 2 * k_block_stride + 2 * 16, n89ab_2); + vst1q_s8(dst_n + 2 * k_block_stride + 3 * 16, ncdef_2); + + dst_n += 3 * k_block_stride; + src_n += pad_iw; + } + ni += ow_block; + ow += ow_block; + if (ni == block_n) { + ni = 0; + ++bn; + } + } else { + const int8_t* src_n = src + oh * SH * pad_iw + ow * SW; + int8_t* dst_n = dst + bn * block_n * round_filter + ni * block_k; + for (size_t fh = 0; fh < FH; ++fh) { + int8x16_t read_w[0]; + read_w[0] = vld1q_s8(src_n); + vst1q_lane_s32(dst_n, read_w[0], 0); + vst1q_lane_s32(dst_n + 1 * k_block_stride, read_w[0], 1); + vst1q_lane_s32(dst_n + 2 * k_block_stride, read_w[0], 2); + dst_n += 3 * k_block_stride; + src_n += pad_iw; + } + ++ni; + ++ow; + if (ni == block_n) { + ni = 0; + ++bn; + } + } + } +} + +template <> +void im2col<9, 2>( + const int8_t* src, int8_t* dst, size_t OH, size_t OW, size_t pad_iw, + size_t round_filter) { + constexpr size_t FH = 9; + constexpr size_t SH = 2; + constexpr size_t SW = 2; + constexpr size_t k_block_stride = block_k * block_n; + + constexpr size_t ow_block = 16; + static const uint8_t tbl_array_0[16] = {0, 1, 2, 3, 2, 3, 4, 5, + 4, 5, 6, 7, 6, 7, 8, 9}; + static const uint8_t tbl_array_1[16] = {4, 5, 6, 7, 6, 7, 8, 9, + 8, 9, 10, 11, 10, 11, 12, 13}; + + uint8x16_t tbl_reg_0 = vld1q_u8(&tbl_array_0[0]); + uint8x16_t tbl_reg_1 = vld1q_u8(&tbl_array_1[0]); + + int bn = 0; + int ni = 0; + for (size_t oh = 0; oh < OH; ++oh) + for (size_t ow = 0; ow < OW;) { + if (ow + ow_block <= OW && ni + ow_block <= block_n) { + const int8_t* src_n = src + oh * SH * pad_iw + ow * SW; + int8_t* dst_n = dst + bn * block_n * round_filter + ni * block_k; + for (size_t fh = 0; fh < FH; ++fh) { + int8x16_t read_w[3]; + read_w[0] = vld1q_s8(src_n); + read_w[1] = vld1q_s8(src_n + 16); + read_w[2] = vld1q_s8(src_n + 32); + int8x16_t ext_8 = vextq_s8(read_w[0], read_w[1], 8); + int8x16_t ext_24 = vextq_s8(read_w[1], read_w[2], 8); + + int8x16_t n0123_0 = vqtbl1q_s8(read_w[0], tbl_reg_0); + int8x16_t n4567_0 = vqtbl1q_s8(ext_8, tbl_reg_0); + int8x16_t n89ab_0 = vqtbl1q_s8(read_w[1], tbl_reg_0); + int8x16_t ncdef_0 = vqtbl1q_s8(ext_24, tbl_reg_0); + + int8x16_t n0123_1 = vqtbl1q_s8(read_w[0], tbl_reg_1); + int8x16_t n4567_1 = vqtbl1q_s8(ext_8, tbl_reg_1); + int8x16_t n89ab_1 = vqtbl1q_s8(read_w[1], tbl_reg_1); + int8x16_t ncdef_1 = vqtbl1q_s8(ext_24, tbl_reg_1); + + int8x16_t n0123_2 = n4567_0; + int8x16_t n4567_2 = n89ab_0; + int8x16_t n89ab_2 = ncdef_0; + int8x16_t ncdef_2 = vqtbl1q_s8(read_w[2], tbl_reg_0); + + vst1q_s8(dst_n + 0 * 16, n0123_0); + vst1q_s8(dst_n + 1 * 16, n4567_0); + vst1q_s8(dst_n + 2 * 16, n89ab_0); + vst1q_s8(dst_n + 3 * 16, ncdef_0); + + vst1q_s8(dst_n + 1 * k_block_stride + 0 * 16, n0123_1); + vst1q_s8(dst_n + 1 * k_block_stride + 1 * 16, n4567_1); + vst1q_s8(dst_n + 1 * k_block_stride + 2 * 16, n89ab_1); + vst1q_s8(dst_n + 1 * k_block_stride + 3 * 16, ncdef_1); + + vst1q_s8(dst_n + 2 * k_block_stride + 0 * 16, n0123_2); + vst1q_s8(dst_n + 2 * k_block_stride + 1 * 16, n4567_2); + vst1q_s8(dst_n + 2 * k_block_stride + 2 * 16, n89ab_2); + vst1q_s8(dst_n + 2 * k_block_stride + 3 * 16, ncdef_2); + + dst_n += 3 * k_block_stride; + src_n += pad_iw; + } + ni += ow_block; + ow += ow_block; + if (ni == block_n) { + ni = 0; + ++bn; + } + } else { + const int8_t* src_n = src + oh * SH * pad_iw + ow * SW; + int8_t* dst_n = dst + bn * block_n * round_filter + ni * block_k; + for (size_t fh = 0; fh < FH; ++fh) { + int8x16_t read_w[0]; + read_w[0] = vld1q_s8(src_n); + vst1q_lane_s32(dst_n, read_w[0], 0); + vst1q_lane_s32(dst_n + 1 * k_block_stride, read_w[0], 1); + vst1q_lane_s32(dst_n + 2 * k_block_stride, read_w[0], 2); + dst_n += 3 * k_block_stride; + src_n += pad_iw; + } + ++ni; + ++ow; + if (ni == block_n) { + ni = 0; + ++bn; + } + } + } +} + +void do_conv( + const WorkspaceBundle& bundle, const NCBKernParam& kern_param, + const NCBKernIndex& ncb_index) { + auto&& fm = kern_param.filter_meta; + size_t PH = kern_param.filter_meta.padding[0]; + size_t PW = kern_param.filter_meta.padding[1]; + size_t OH = kern_param.osz[0]; + size_t OW = kern_param.osz[1]; + size_t IH = kern_param.isz[0]; + size_t IW = kern_param.isz[1]; + size_t FH = fm.spatial[0]; + size_t FW = fm.spatial[1]; + size_t SH = fm.stride[0]; + + float scale_bias = kern_param.bias_type.param().scale; + float scale_dst = kern_param.dst_type.param().scale; + float scale_dst_div = 1.f / scale_dst; + + size_t batch_id = ncb_index.ndrange_id[0]; + size_t group_id = ncb_index.ndrange_id[1]; + int8_t* pad_src_ptr = static_cast(bundle.get(0)); + int8_t* round_filter_ptr = static_cast(bundle.get(1)); + int8_t* im2col_ptr = static_cast(bundle.get(2)); + int32_t* i32_ptr = static_cast(bundle.get(3)); + const int8_t* sptr = kern_param.src(batch_id, group_id); + const int32_t* bptr = kern_param.bias(batch_id, group_id); + const int8_t* fptr = kern_param.filter(group_id); + void* dst = kern_param.dst(batch_id, group_id); + + size_t round_filter = round_up(FW, block_k) * FH; + size_t pad_iw = IW + 2 * PW; + + memset(pad_src_ptr, 0, bundle.get_size(0)); + rep(ih, IH) { + std::memcpy( + pad_src_ptr + (ih + PH) * pad_iw + PW, sptr + ih * IW, + sizeof(int8_t) * IW); + } + + memset(round_filter_ptr, 0, bundle.get_size(1)); + size_t fh_stride = round_up(FW, block_k); + for (size_t fh = 0; fh < FH; ++fh) { + std::memcpy(round_filter_ptr + fh * fh_stride, fptr + fh * FW, FW); + } + + memset(im2col_ptr, 0, bundle.get_size(2)); + if (SH == 1) { + im2col<9, 1>(pad_src_ptr, im2col_ptr, OH, OW, pad_iw, round_filter); + } else { + im2col<9, 2>(pad_src_ptr, im2col_ptr, OH, OW, pad_iw, round_filter); + } + + gevm_naive_n32k4_dot( + round_filter_ptr, im2col_ptr, i32_ptr, 1, OH * OW, round_filter, 0, 0, 0); + + int32_t bias_val = kern_param.bias_mode == BiasMode::NO_BIAS ? 0 : *bptr; + int8_t relu_val = kern_param.nonlineMode == NonlineMode::RELU ? 0 : -128; + int8_t* dst_ptr = (int8_t*)dst; + for (size_t i = 0; i < OH * OW; ++i) { + //! optimize by tbl + int val = roundf(scale_bias * scale_dst_div * (i32_ptr[i] + bias_val)); + val = val < -128 ? -128 : val; + val = val > 127 ? 127 : val; + val = val > relu_val ? val : relu_val; + dst_ptr[i] = val; + } +} + +} // namespace + +bool ConvBiasImpl::AlgoDotS8Im2colChanWiseLarge::usable( + const NCBKernSizeParam& param, AlgoSelectionStrategy) const { + if (!cpuinfo_has_arm_neon_dot()) { + return false; + } + auto&& fm = param.filter_meta; + auto FH = fm.spatial[0]; + auto FW = fm.spatial[1]; + auto SH = fm.stride[0]; + auto SW = fm.stride[1]; + auto noline = param.nonlineMode; + auto bias_mode = param.bias_mode; + bool avaible = + //! src and filter are qint8, dst is qint8 or qint32 + (param.src_type.enumv() == DTypeEnum::QuantizedS8 && + param.filter_type.enumv() == DTypeEnum::QuantizedS8 && + (param.dst_type.enumv() == DTypeEnum::QuantizedS8)) && + fm.format == param::Convolution::Format::NCHW && !fm.should_flip && + (noline == NonlineMode::IDENTITY || noline == NonlineMode::RELU) && + (bias_mode == BiasMode::NO_BIAS || + bias_mode == BiasMode::BROADCAST_CHANNEL_BIAS) && + fm.spatial_ndim == 2 && fm.dilation[0] == 1 && fm.dilation[1] == 1 && + SH == SW && (SH == 1 || SH == 2) && FH == FW && (FH == 9) && fm.icpg == 1 && + fm.ocpg == 1; + return avaible; +} + +size_t ConvBiasImpl::AlgoDotS8Im2colChanWiseLarge::get_workspace( + const NCBKernSizeParam& param) const { + MIDOUT_BEGIN( + megdnn_arm_common_conv_bias_int8_im2col_dot_large_kernel, + midout_iv("AlgoDotS8Im2colChanWiseLarge::get_workspace"_hash)) { + auto bundle = get_total_bundle(param); + return bundle.total_size_in_bytes(); + } + MIDOUT_END(); + return 0; +} + +SmallVector ConvBiasImpl::AlgoDotS8Im2colChanWiseLarge:: + dispatch_kerns(const NCBKernSizeParam& param) const { + MIDOUT_BEGIN( + megdnn_arm_common_conv_bias_int8_im2col_dot_large_kernel, + midout_iv("AlgoDotS8Im2colChanWiseLarge::dispatch_kerns"_hash)) { + SmallVector ret_kerns; + auto fm = param.filter_meta; + size_t N = param.n; + size_t group = fm.group; + WorkspaceBundle wbundle = get_sub_bundle(param); + WorkspaceBundle total_bundle = get_total_bundle(param); + auto exec_one_group = [wbundle, total_bundle]( + const NCBKernParam& kern_param, + const NCBKernIndex& ncb_index) mutable { + WorkspaceBundle temp_total_bundle = total_bundle; + temp_total_bundle.set(kern_param.workspace_ptr); + WorkspaceBundle temp_bundle = wbundle; + temp_bundle.set(temp_total_bundle.get(ncb_index.thread_id)); + do_conv(temp_bundle, kern_param, ncb_index); + }; + ret_kerns.push_back({exec_one_group, {N, group}}); + return ret_kerns; + } + MIDOUT_END(); + return {}; +} +#endif \ No newline at end of file diff --git a/dnn/src/arm_common/conv_bias/int8/direct_kernels/dot_direct_nchw_large.h b/dnn/src/arm_common/conv_bias/int8/direct_kernels/dot_direct_nchw_large.h new file mode 100644 index 00000000..eee0490d --- /dev/null +++ b/dnn/src/arm_common/conv_bias/int8/direct_kernels/dot_direct_nchw_large.h @@ -0,0 +1,27 @@ +/** + * \file + * dnn/src/arm_common/conv_bias/int8/direct_kernels/dot_direct_nchw_large.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ +#pragma once +#include "megdnn/arch.h" +#if MGB_ENABLE_DOT + +void megdnn_dot_nchw_large_chanwise_direct_conv_9x9s1_oh4_ow16( + const int8_t* src, const int8_t* weight, int32_t bias, int8_t* dst, size_t oh, + size_t ow, size_t OH, size_t OW, size_t pad_iw, const float scale, + int8_t relu_val); + +void megdnn_dot_nchw_large_chanwise_direct_conv_9x9s2_oh4_ow16( + const int8_t* src, const int8_t* weight, int32_t bias, int8_t* dst, size_t oh, + size_t ow, size_t OH, size_t OW, size_t pad_iw, const float scale, + int8_t relu_val); + +#endif \ No newline at end of file diff --git a/dnn/src/arm_common/conv_bias/int8/direct_kernels/dot_direct_nchw_large_common.h b/dnn/src/arm_common/conv_bias/int8/direct_kernels/dot_direct_nchw_large_common.h new file mode 100644 index 00000000..bb1aae19 --- /dev/null +++ b/dnn/src/arm_common/conv_bias/int8/direct_kernels/dot_direct_nchw_large_common.h @@ -0,0 +1,40 @@ +/** + * \file + * dnn/src/arm_common/conv_bias/int8/direct_kernels/dot_direct_nchw44_common.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ +#pragma once +#include "megdnn/arch.h" +#if MGB_ENABLE_DOT +#include "src/arm_common/simd_macro/marm_neon.h" + +static inline void quant_store_s8( + float32x4_t v0, float32x4_t v1, float32x4_t v2, float32x4_t v3, int8_t* ptr, + int8x16_t relu_reg) { + int32x4_t i0 = vcvtaq_s32_f32(v0); + int32x4_t i1 = vcvtaq_s32_f32(v1); + int32x4_t i2 = vcvtaq_s32_f32(v2); + int32x4_t i3 = vcvtaq_s32_f32(v3); + + int16x4_t i16_0 = vqmovn_s32(i0); + int16x4_t i16_1 = vqmovn_s32(i1); + int16x4_t i16_2 = vqmovn_s32(i2); + int16x4_t i16_3 = vqmovn_s32(i3); + + int8x8_t i8_0 = vqmovn_s16(vcombine_s16(i16_0, i16_1)); + int8x8_t i8_1 = vqmovn_s16(vcombine_s16(i16_2, i16_3)); + int8x16_t rst = vcombine_s8(i8_0, i8_1); + + rst = vmaxq_s8(rst, relu_reg); + + vst1q_s8(ptr, rst); +} + +#endif \ No newline at end of file diff --git a/dnn/src/arm_common/conv_bias/int8/direct_kernels/dot_direct_nchw_larget_9x9s1.cpp b/dnn/src/arm_common/conv_bias/int8/direct_kernels/dot_direct_nchw_larget_9x9s1.cpp new file mode 100644 index 00000000..dfd6a2ba --- /dev/null +++ b/dnn/src/arm_common/conv_bias/int8/direct_kernels/dot_direct_nchw_larget_9x9s1.cpp @@ -0,0 +1,221 @@ +/** + * \file + * dnn/src/arm_common/conv_bias/int8/direct_kernels/dot_direct_nchw_larget_9x9s1.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ +#include "megdnn/arch.h" +#if MGB_ENABLE_DOT +#include "src/arm_common/simd_macro/marm_neon.h" + +#include "src/arm_common/conv_bias/int8/direct_kernels/dot_direct_nchw_large.h" +#include "src/arm_common/conv_bias/int8/direct_kernels/dot_direct_nchw_large_common.h" +#include "src/common/unroll_macro.h" + +MEGDNN_ATTRIBUTE_TARGET("dotprod") +void megdnn_dot_nchw_large_chanwise_direct_conv_9x9s1_oh4_ow16( + const int8_t* src, const int8_t* weight, int32_t bias, int8_t* dst, size_t oh, + size_t ow, size_t OH, size_t OW, size_t pad_iw, const float scale, + int8_t relu_val) { + //! 4x16 + const size_t SH = 1; + const size_t SW = 1; + + static const uint8_t tbl_array_0[16] = {0, 1, 2, 3, 1, 2, 3, 4, + 2, 3, 4, 5, 3, 4, 5, 6}; + static const uint8_t tbl_array_1[16] = {4, 5, 6, 7, 5, 6, 7, 8, + 6, 7, 8, 9, 7, 8, 9, 10}; + static const uint8_t tbl_array_2[16] = {8, 9, 10, 11, 9, 10, 11, 12, + 10, 11, 12, 13, 11, 12, 13, 14}; + + uint8x16_t tbl_reg_0 = vld1q_u8(&tbl_array_0[0]); + uint8x16_t tbl_reg_1 = vld1q_u8(&tbl_array_1[0]); + uint8x16_t tbl_reg_2 = vld1q_u8(&tbl_array_2[0]); + + const int8_t* src_n = src + oh * SH * pad_iw + ow * SW; + + //! init + int32x4_t c[4][4]; +#define cb(step) \ + c[step][0] = vdupq_n_s32(bias); \ + c[step][1] = vdupq_n_s32(bias); \ + c[step][2] = vdupq_n_s32(bias); \ + c[step][3] = vdupq_n_s32(bias); + + UNROLL_CALL_RAW(4, cb); +#undef cb + int8x16_t flt[4]; + flt[0] = vld1q_s8(weight + 0 * 16); + flt[1] = vld1q_s8(weight + 1 * 16); + flt[2] = vld1q_s8(weight + 2 * 16); + flt[3] = vld1q_s8(weight + 3 * 16); + //! row 0 + + int8x16_t read_w[2]; + read_w[0] = vld1q_s8(src_n + 0 * pad_iw); + read_w[1] = vld1q_s8(src_n + 0 * pad_iw + 16); + int8x16_t n0123_0 = vqtbl1q_s8(read_w[0], tbl_reg_0); + int8x16_t n4567_0 = vqtbl1q_s8(read_w[0], tbl_reg_1); + int8x16_t n89ab_0 = vqtbl1q_s8(read_w[0], tbl_reg_2); + int8x16_t ncdef_0 = vqtbl1q_s8(vextq_s8(read_w[0], read_w[1], 12), tbl_reg_0); + int8x16_t n0123_1 = n4567_0; + int8x16_t n4567_1 = n89ab_0; + int8x16_t n89ab_1 = ncdef_0; + int8x16_t ncdef_1 = vqtbl1q_s8(read_w[1], tbl_reg_0); + int8x16_t n0123_2 = n89ab_0; + int8x16_t n4567_2 = ncdef_0; + int8x16_t n89ab_2 = ncdef_1; + int8x16_t ncdef_2 = vqtbl1q_s8(read_w[1], tbl_reg_1); +#define CAL_C(oh, flt_start) \ + c[oh][0] = vdotq_laneq_s32( \ + c[oh][0], n0123_0, flt[(flt_start + 0) / 4 % 4], (flt_start + 0) % 4); \ + c[oh][1] = vdotq_laneq_s32( \ + c[oh][1], n4567_0, flt[(flt_start + 0) / 4 % 4], (flt_start + 0) % 4); \ + c[oh][2] = vdotq_laneq_s32( \ + c[oh][2], n89ab_0, flt[(flt_start + 0) / 4 % 4], (flt_start + 0) % 4); \ + c[oh][3] = vdotq_laneq_s32( \ + c[oh][3], ncdef_0, flt[(flt_start + 0) / 4 % 4], (flt_start + 0) % 4); \ + c[oh][0] = vdotq_laneq_s32( \ + c[oh][0], n0123_1, flt[(flt_start + 1) / 4 % 4], (flt_start + 1) % 4); \ + c[oh][1] = vdotq_laneq_s32( \ + c[oh][1], n4567_1, flt[(flt_start + 1) / 4 % 4], (flt_start + 1) % 4); \ + c[oh][2] = vdotq_laneq_s32( \ + c[oh][2], n89ab_1, flt[(flt_start + 1) / 4 % 4], (flt_start + 1) % 4); \ + c[oh][3] = vdotq_laneq_s32( \ + c[oh][3], ncdef_1, flt[(flt_start + 1) / 4 % 4], (flt_start + 1) % 4); \ + c[oh][0] = vdotq_laneq_s32( \ + c[oh][0], n0123_2, flt[(flt_start + 2) / 4 % 4], (flt_start + 2) % 4); \ + c[oh][1] = vdotq_laneq_s32( \ + c[oh][1], n4567_2, flt[(flt_start + 2) / 4 % 4], (flt_start + 2) % 4); \ + c[oh][2] = vdotq_laneq_s32( \ + c[oh][2], n89ab_2, flt[(flt_start + 2) / 4 % 4], (flt_start + 2) % 4); \ + c[oh][3] = vdotq_laneq_s32( \ + c[oh][3], ncdef_2, flt[(flt_start + 2) / 4 % 4], (flt_start + 2) % 4); + + CAL_C(0, 0); + + //! row 1 +#define LOAD_SRC(row_id) \ + read_w[0] = vld1q_s8(src_n + row_id * pad_iw); \ + read_w[1] = vld1q_s8(src_n + row_id * pad_iw + 16); \ + n0123_0 = vqtbl1q_s8(read_w[0], tbl_reg_0); \ + n4567_0 = vqtbl1q_s8(read_w[0], tbl_reg_1); \ + n89ab_0 = vqtbl1q_s8(read_w[0], tbl_reg_2); \ + ncdef_0 = vqtbl1q_s8(vextq_s8(read_w[0], read_w[1], 12), tbl_reg_0); \ + n0123_1 = n4567_0; \ + n4567_1 = n89ab_0; \ + n89ab_1 = ncdef_0; \ + ncdef_1 = vqtbl1q_s8(read_w[1], tbl_reg_0); \ + n0123_2 = n89ab_0; \ + n4567_2 = ncdef_0; \ + n89ab_2 = ncdef_1; \ + ncdef_2 = vqtbl1q_s8(read_w[1], tbl_reg_1); + + LOAD_SRC(1); + CAL_C(0, 3); + CAL_C(1, 0); + + //! row 2 + LOAD_SRC(2); + CAL_C(0, 3 * 2); + CAL_C(1, 3 * 1); + CAL_C(2, 3 * 0); + + //! row 3 + LOAD_SRC(3); + CAL_C(0, 3 * 3); + CAL_C(1, 3 * 2); + CAL_C(2, 3 * 1); + CAL_C(3, 3 * 0); + + //! row 4 + LOAD_SRC(4); + CAL_C(0, 3 * 4); + CAL_C(1, 3 * 3); + CAL_C(2, 3 * 2); + CAL_C(3, 3 * 1); + //! update flt 4 -> 0 + flt[0] = vld1q_s8(weight + 4 * 16); + + //! row 5 + LOAD_SRC(5); + CAL_C(0, 3 * 5); + CAL_C(1, 3 * 4); + CAL_C(2, 3 * 3); + CAL_C(3, 3 * 2); + //! update flt 5 -> 1 + flt[1] = vld1q_s8(weight + 5 * 16); + + //! row 6 + LOAD_SRC(6); + CAL_C(0, 3 * 6); + CAL_C(1, 3 * 5); + CAL_C(2, 3 * 4); + CAL_C(3, 3 * 3); + //! update flt 6 -> 2 + flt[2] = vld1q_s8(weight + 6 * 16); + + //! row 7 + LOAD_SRC(7); + CAL_C(0, 3 * 7); + CAL_C(1, 3 * 6); + CAL_C(2, 3 * 5); + CAL_C(3, 3 * 4); + + //! row 8 + LOAD_SRC(8); + CAL_C(0, 3 * 8); + CAL_C(1, 3 * 7); + CAL_C(2, 3 * 6); + CAL_C(3, 3 * 5); + + //! row 9 + LOAD_SRC(9); + CAL_C(1, 3 * 8); + CAL_C(2, 3 * 7); + CAL_C(3, 3 * 6); + + //! row 10 + LOAD_SRC(10); + CAL_C(2, 3 * 8); + CAL_C(3, 3 * 7); + + //! row 11 + LOAD_SRC(11); + CAL_C(3, 3 * 8); + + float32x4_t dst_reg[4][4]; +#define cb(step) \ + dst_reg[step][0] = vcvtq_f32_s32(c[step][0]); \ + dst_reg[step][1] = vcvtq_f32_s32(c[step][1]); \ + dst_reg[step][2] = vcvtq_f32_s32(c[step][2]); \ + dst_reg[step][3] = vcvtq_f32_s32(c[step][3]); + + UNROLL_CALL_RAW(4, cb); +#undef cb + +#define cb(step) \ + dst_reg[step][0] = vmulq_n_f32(dst_reg[step][0], scale); \ + dst_reg[step][1] = vmulq_n_f32(dst_reg[step][1], scale); \ + dst_reg[step][2] = vmulq_n_f32(dst_reg[step][2], scale); \ + dst_reg[step][3] = vmulq_n_f32(dst_reg[step][3], scale); + + UNROLL_CALL_RAW(4, cb); +#undef cb + int8_t* dst_store = dst + oh * OW + ow; + int8x16_t relu_reg = vdupq_n_s8(relu_val); +#define cb(step) \ + quant_store_s8( \ + dst_reg[step][0], dst_reg[step][1], dst_reg[step][2], dst_reg[step][3], \ + dst_store + step * OW, relu_reg); + + UNROLL_CALL_RAW(4, cb); +#undef cb +} +#endif \ No newline at end of file diff --git a/dnn/src/arm_common/conv_bias/int8/direct_kernels/dot_direct_nchw_larget_9x9s2.cpp b/dnn/src/arm_common/conv_bias/int8/direct_kernels/dot_direct_nchw_larget_9x9s2.cpp new file mode 100644 index 00000000..e9d92378 --- /dev/null +++ b/dnn/src/arm_common/conv_bias/int8/direct_kernels/dot_direct_nchw_larget_9x9s2.cpp @@ -0,0 +1,250 @@ +/** + * \file + * dnn/src/arm_common/conv_bias/int8/direct_kernels/dot_direct_nchw_larget_9x9s2.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ + +#include "megdnn/arch.h" +#if MGB_ENABLE_DOT +#include "src/arm_common/simd_macro/marm_neon.h" + +#include "src/arm_common/conv_bias/int8/direct_kernels/dot_direct_nchw_large.h" +#include "src/arm_common/conv_bias/int8/direct_kernels/dot_direct_nchw_large_common.h" +#include "src/common/unroll_macro.h" + +MEGDNN_ATTRIBUTE_TARGET("dotprod") +void megdnn_dot_nchw_large_chanwise_direct_conv_9x9s2_oh4_ow16( + const int8_t* src, const int8_t* weight, int32_t bias, int8_t* dst, size_t oh, + size_t ow, size_t OH, size_t OW, size_t pad_iw, const float scale, + int8_t relu_val) { + //! 4x16 + const size_t SH = 2; + const size_t SW = 2; + + static const uint8_t tbl_array_0[16] = {0, 1, 2, 3, 2, 3, 4, 5, + 4, 5, 6, 7, 6, 7, 8, 9}; + static const uint8_t tbl_array_1[16] = {4, 5, 6, 7, 6, 7, 8, 9, + 8, 9, 10, 11, 10, 11, 12, 13}; + + uint8x16_t tbl_reg_0 = vld1q_u8(&tbl_array_0[0]); + uint8x16_t tbl_reg_1 = vld1q_u8(&tbl_array_1[0]); + + const int8_t* src_n = src + oh * SH * pad_iw + ow * SW; + + //! init + int32x4_t c[4][4]; +#define cb(step) \ + c[step][0] = vdupq_n_s32(bias); \ + c[step][1] = vdupq_n_s32(bias); \ + c[step][2] = vdupq_n_s32(bias); \ + c[step][3] = vdupq_n_s32(bias); + + UNROLL_CALL_RAW(4, cb); +#undef cb + + constexpr int flt_reg = 7; + constexpr int flt_per_reg = 4; + int8x16_t flt[7]; + flt[0] = vld1q_s8(weight + 0 * 16); + flt[1] = vld1q_s8(weight + 1 * 16); + flt[2] = vld1q_s8(weight + 2 * 16); + flt[3] = vld1q_s8(weight + 3 * 16); + flt[4] = vld1q_s8(weight + 4 * 16); + flt[5] = vld1q_s8(weight + 5 * 16); + flt[6] = vld1q_s8(weight + 6 * 16); + +#define CAL_C(oh, flt_start) \ + c[oh][0] = vdotq_laneq_s32( \ + c[oh][0], n0123_0, flt[(flt_start + 0) / flt_per_reg % flt_reg], \ + (flt_start + 0) % flt_per_reg); \ + c[oh][1] = vdotq_laneq_s32( \ + c[oh][1], n4567_0, flt[(flt_start + 0) / flt_per_reg % flt_reg], \ + (flt_start + 0) % flt_per_reg); \ + c[oh][2] = vdotq_laneq_s32( \ + c[oh][2], n89ab_0, flt[(flt_start + 0) / flt_per_reg % flt_reg], \ + (flt_start + 0) % flt_per_reg); \ + c[oh][3] = vdotq_laneq_s32( \ + c[oh][3], ncdef_0, flt[(flt_start + 0) / flt_per_reg % flt_reg], \ + (flt_start + 0) % flt_per_reg); \ + c[oh][0] = vdotq_laneq_s32( \ + c[oh][0], n0123_1, flt[(flt_start + 1) / flt_per_reg % flt_reg], \ + (flt_start + 1) % flt_per_reg); \ + c[oh][1] = vdotq_laneq_s32( \ + c[oh][1], n4567_1, flt[(flt_start + 1) / flt_per_reg % flt_reg], \ + (flt_start + 1) % flt_per_reg); \ + c[oh][2] = vdotq_laneq_s32( \ + c[oh][2], n89ab_1, flt[(flt_start + 1) / flt_per_reg % flt_reg], \ + (flt_start + 1) % flt_per_reg); \ + c[oh][3] = vdotq_laneq_s32( \ + c[oh][3], ncdef_1, flt[(flt_start + 1) / flt_per_reg % flt_reg], \ + (flt_start + 1) % flt_per_reg); \ + c[oh][0] = vdotq_laneq_s32( \ + c[oh][0], n0123_2, flt[(flt_start + 2) / flt_per_reg % flt_reg], \ + (flt_start + 2) % flt_per_reg); \ + c[oh][1] = vdotq_laneq_s32( \ + c[oh][1], n4567_2, flt[(flt_start + 2) / flt_per_reg % flt_reg], \ + (flt_start + 2) % flt_per_reg); \ + c[oh][2] = vdotq_laneq_s32( \ + c[oh][2], n89ab_2, flt[(flt_start + 2) / flt_per_reg % flt_reg], \ + (flt_start + 2) % flt_per_reg); \ + c[oh][3] = vdotq_laneq_s32( \ + c[oh][3], ncdef_2, flt[(flt_start + 2) / flt_per_reg % flt_reg], \ + (flt_start + 2) % flt_per_reg); + +#define LOAD_SRC(row_id) \ + read_w[0] = vld1q_s8(src_n + row_id * pad_iw); \ + read_w[1] = vld1q_s8(src_n + row_id * pad_iw + 16); \ + read_w[2] = vld1q_s8(src_n + row_id * pad_iw + 32); \ + ext_8 = vextq_s8(read_w[0], read_w[1], 8); \ + ext_24 = vextq_s8(read_w[1], read_w[2], 8); \ + n0123_0 = vqtbl1q_s8(read_w[0], tbl_reg_0); \ + n4567_0 = vqtbl1q_s8(ext_8, tbl_reg_0); \ + n89ab_0 = vqtbl1q_s8(read_w[1], tbl_reg_0); \ + ncdef_0 = vqtbl1q_s8(ext_24, tbl_reg_0); \ + n0123_1 = vqtbl1q_s8(read_w[0], tbl_reg_1); \ + n4567_1 = vqtbl1q_s8(ext_8, tbl_reg_1); \ + n89ab_1 = vqtbl1q_s8(read_w[1], tbl_reg_1); \ + ncdef_1 = vqtbl1q_s8(ext_24, tbl_reg_1); \ + n0123_2 = n4567_0; \ + n4567_2 = n89ab_0; \ + n89ab_2 = ncdef_0; \ + ncdef_2 = vqtbl1q_s8(read_w[2], tbl_reg_0); + + //! row 0 + int8x16_t read_w[3]; + read_w[0] = vld1q_s8(src_n); + read_w[1] = vld1q_s8(src_n + 16); + read_w[2] = vld1q_s8(src_n + 32); + int8x16_t ext_8 = vextq_s8(read_w[0], read_w[1], 8); + int8x16_t ext_24 = vextq_s8(read_w[1], read_w[2], 8); + + int8x16_t n0123_0 = vqtbl1q_s8(read_w[0], tbl_reg_0); + int8x16_t n4567_0 = vqtbl1q_s8(ext_8, tbl_reg_0); + int8x16_t n89ab_0 = vqtbl1q_s8(read_w[1], tbl_reg_0); + int8x16_t ncdef_0 = vqtbl1q_s8(ext_24, tbl_reg_0); + + int8x16_t n0123_1 = vqtbl1q_s8(read_w[0], tbl_reg_1); + int8x16_t n4567_1 = vqtbl1q_s8(ext_8, tbl_reg_1); + int8x16_t n89ab_1 = vqtbl1q_s8(read_w[1], tbl_reg_1); + int8x16_t ncdef_1 = vqtbl1q_s8(ext_24, tbl_reg_1); + + int8x16_t n0123_2 = n4567_0; + int8x16_t n4567_2 = n89ab_0; + int8x16_t n89ab_2 = ncdef_0; + int8x16_t ncdef_2 = vqtbl1q_s8(read_w[2], tbl_reg_0); + + CAL_C(0, 0); + + //! row 1 + LOAD_SRC(1); + CAL_C(0, 3 * 1); + + //! row 2 + LOAD_SRC(2); + CAL_C(0, 3 * 2); + CAL_C(1, 3 * 0); + + //! row 3 + LOAD_SRC(3); + CAL_C(0, 3 * 3); + CAL_C(1, 3 * 1); + + //! row 4 + LOAD_SRC(4); + CAL_C(0, 3 * 4); + CAL_C(1, 3 * 2); + CAL_C(2, 3 * 0); + + //! row 5 + LOAD_SRC(5); + CAL_C(0, 3 * 5); + CAL_C(1, 3 * 3); + CAL_C(2, 3 * 1); + + //! row 6 + LOAD_SRC(6); + CAL_C(0, 3 * 6); + CAL_C(1, 3 * 4); + CAL_C(2, 3 * 2); + CAL_C(3, 3 * 0); + + //! row 7 + LOAD_SRC(7); + CAL_C(0, 3 * 7); + CAL_C(1, 3 * 5); + CAL_C(2, 3 * 3); + CAL_C(3, 3 * 1); + + //! row 8 + LOAD_SRC(8); + CAL_C(0, 3 * 8); + CAL_C(1, 3 * 6); + CAL_C(2, 3 * 4); + CAL_C(3, 3 * 2); + + //! row 9 + LOAD_SRC(9); + CAL_C(1, 3 * 7); + CAL_C(2, 3 * 5); + CAL_C(3, 3 * 3); + + //! row 10 + LOAD_SRC(10); + CAL_C(1, 3 * 8); + CAL_C(2, 3 * 6); + CAL_C(3, 3 * 4); + + //! row 11 + LOAD_SRC(11); + CAL_C(2, 3 * 7); + CAL_C(3, 3 * 5); + + //! row 12 + LOAD_SRC(12); + CAL_C(2, 3 * 8); + CAL_C(3, 3 * 6); + + //! row 13 + LOAD_SRC(13); + CAL_C(3, 3 * 7); + + //! row 14 + LOAD_SRC(14); + CAL_C(3, 3 * 8); + + float32x4_t dst_reg[4][4]; +#define cb(step) \ + dst_reg[step][0] = vcvtq_f32_s32(c[step][0]); \ + dst_reg[step][1] = vcvtq_f32_s32(c[step][1]); \ + dst_reg[step][2] = vcvtq_f32_s32(c[step][2]); \ + dst_reg[step][3] = vcvtq_f32_s32(c[step][3]); + + UNROLL_CALL_RAW(4, cb); +#undef cb + +#define cb(step) \ + dst_reg[step][0] = vmulq_n_f32(dst_reg[step][0], scale); \ + dst_reg[step][1] = vmulq_n_f32(dst_reg[step][1], scale); \ + dst_reg[step][2] = vmulq_n_f32(dst_reg[step][2], scale); \ + dst_reg[step][3] = vmulq_n_f32(dst_reg[step][3], scale); + + UNROLL_CALL_RAW(4, cb); +#undef cb + int8_t* dst_store = dst + oh * OW + ow; + int8x16_t relu_reg = vdupq_n_s8(relu_val); +#define cb(step) \ + quant_store_s8( \ + dst_reg[step][0], dst_reg[step][1], dst_reg[step][2], dst_reg[step][3], \ + dst_store + step * OW, relu_reg); + + UNROLL_CALL_RAW(4, cb); +#undef cb +} +#endif \ No newline at end of file diff --git a/dnn/src/arm_common/conv_bias/opr_impl.cpp b/dnn/src/arm_common/conv_bias/opr_impl.cpp index 6de11307..e43a8835 100644 --- a/dnn/src/arm_common/conv_bias/opr_impl.cpp +++ b/dnn/src/arm_common/conv_bias/opr_impl.cpp @@ -54,6 +54,8 @@ class ConvBiasImpl::AlgoPack : NonCopyableObj { AlgoDotS8Direct_NCHW44 ds8_direct_nchw44; AlgoDotS8DirectNCHWNCHW44 ds8_direct_nchw_nchw44; + AlgoDotS8Im2colChanWiseLarge ds8_im2col_large_chanwise; + AlgoDotS8DirectChanWiseLarge ds8_direct_large_chanwise; #endif AlgoI8x8x16Direct i8x8x16_direct; @@ -75,6 +77,8 @@ class ConvBiasImpl::AlgoPack : NonCopyableObj { public: AlgoPack() { #if MGB_ENABLE_DOT + m_direct_algos.emplace_back(&ds8_direct_large_chanwise); + m_direct_algos.emplace_back(&ds8_im2col_large_chanwise); m_direct_algos.emplace_back(&ds8_direct_stride1); m_direct_algos.emplace_back(&ds8_direct_stride2); m_direct_algos.emplace_back(&du8_direct_stride1); diff --git a/dnn/src/arm_common/conv_bias/opr_impl.h b/dnn/src/arm_common/conv_bias/opr_impl.h index 0a533683..ec8a4c15 100644 --- a/dnn/src/arm_common/conv_bias/opr_impl.h +++ b/dnn/src/arm_common/conv_bias/opr_impl.h @@ -51,6 +51,8 @@ private: #endif #if MGB_ENABLE_DOT class AlgoDotS8DirectNCHWNCHW44; + class AlgoDotS8DirectChanWiseLarge; + class AlgoDotS8Im2colChanWiseLarge; class AlgoDotS8DirectStride1; class AlgoDotS8DirectStride2; class AlgoDotU8DirectStride1; diff --git a/dnn/src/arm_common/matrix_mul/algos.cpp b/dnn/src/arm_common/matrix_mul/algos.cpp index 0702c5fa..bf1bd711 100644 --- a/dnn/src/arm_common/matrix_mul/algos.cpp +++ b/dnn/src/arm_common/matrix_mul/algos.cpp @@ -143,8 +143,89 @@ MatrixMulImpl::kern_t MatrixMulImpl::AlgoInt8x8x32GemvMK4::get_kern( const KernSizeParam&) const { return int8x8x32_gemv_mk4_kern; } - #if MGB_ENABLE_DOT +namespace { +void int8x8x32_gevm_dot_kern(const MatrixMulImpl::KernParam& kern_param) { + MIDOUT_BEGIN(megdnn_arm_exec_int8832, midout_iv("int8x8x32_gevm_dot_kern"_hash)) { + auto M = kern_param.M, N = kern_param.N, K = kern_param.K; + auto LDA = kern_param.LDA, LDB = kern_param.LDB, LDC = kern_param.LDC; + const auto Aptr = kern_param.A(), Bptr = kern_param.B(); + auto Cptr = kern_param.C(); + gevm_naive_dot(Aptr, Bptr, Cptr, M, N, K, LDA, LDB, LDC); + } + MIDOUT_END(); +} +} // anonymous namespace + +bool MatrixMulImpl::AlgoInt8x8x32GevmDot::usable( + const KernSizeParam& kern_size_param) const { + if (!cpuinfo_has_arm_neon_dot()) { + return false; + } + auto M = kern_size_param.M; + bool is_dtype_ok = + kern_size_param.A_type.enumv() == kern_size_param.B_type.enumv() && + (kern_size_param.A_type.enumv() == DTypeEnum::Int8 || + kern_size_param.A_type.enumv() == DTypeEnum::QuantizedS8) && + (kern_size_param.C_type.enumv() == DTypeEnum::Int32 || + kern_size_param.C_type.enumv() == DTypeEnum::QuantizedS32); + + return kern_size_param.compute_mode == Param::ComputeMode::DEFAULT && + kern_size_param.format == param::MatrixMul::Format::DEFAULT && is_dtype_ok && + !kern_size_param.trA && !kern_size_param.trB && M == 1; +} + +bool MatrixMulImpl::AlgoInt8x8x32GevmDot::preferred( + const KernSizeParam& kern_size_param) const { + return true; +} + +MatrixMulImpl::kern_t MatrixMulImpl::AlgoInt8x8x32GevmDot::get_kern( + const KernSizeParam&) const { + return int8x8x32_gevm_dot_kern; +} + +namespace { +void int8x8x32_gevm_n32k4_dot_kern(const MatrixMulImpl::KernParam& kern_param) { + MIDOUT_BEGIN(megdnn_arm_exec_int8832, midout_iv("int8x8x32_gevm_dot_kern"_hash)) { + auto M = kern_param.M, N = kern_param.N, K = kern_param.K; + auto LDA = kern_param.LDA, LDB = kern_param.LDB, LDC = kern_param.LDC; + const auto Aptr = kern_param.A(), Bptr = kern_param.B(); + auto Cptr = kern_param.C(); + gevm_naive_n32k4_dot(Aptr, Bptr, Cptr, M, N, K, LDA, LDB, LDC); + } + MIDOUT_END(); +} +} // anonymous namespace + +bool MatrixMulImpl::AlgoInt8x8x32GevmN32K4Dot::usable( + const KernSizeParam& kern_size_param) const { + if (!cpuinfo_has_arm_neon_dot()) { + return false; + } + auto M = kern_size_param.M; + + bool is_dtype_ok = + kern_size_param.A_type.enumv() == kern_size_param.B_type.enumv() && + (kern_size_param.A_type.enumv() == DTypeEnum::Int8 || + kern_size_param.A_type.enumv() == DTypeEnum::QuantizedS8) && + (kern_size_param.C_type.enumv() == DTypeEnum::Int32 || + kern_size_param.C_type.enumv() == DTypeEnum::QuantizedS32); + return kern_size_param.compute_mode == Param::ComputeMode::DEFAULT && + kern_size_param.format == param::MatrixMul::Format::N32K4_DOT && + is_dtype_ok && !kern_size_param.trA && !kern_size_param.trB && M == 1; +} + +bool MatrixMulImpl::AlgoInt8x8x32GevmN32K4Dot::preferred( + const KernSizeParam& kern_size_param) const { + return true; +} + +MatrixMulImpl::kern_t MatrixMulImpl::AlgoInt8x8x32GevmN32K4Dot::get_kern( + const KernSizeParam&) const { + return int8x8x32_gevm_n32k4_dot_kern; +} + /* =================== Int8x8x32 Gemv MK4_DOT algo ==================== */ namespace { void int8x8x32_gemv_mk4_dot_kern(const MatrixMulImpl::KernParam& kern_param) { diff --git a/dnn/src/arm_common/matrix_mul/algos.h b/dnn/src/arm_common/matrix_mul/algos.h index 5e39b043..9c1d7c50 100644 --- a/dnn/src/arm_common/matrix_mul/algos.h +++ b/dnn/src/arm_common/matrix_mul/algos.h @@ -49,8 +49,69 @@ public: MEGDNN_OVERRIDE_MATMUL_DESC(8, 16, 1, 2, AlgoDataType::QINT8X8X32, MK4) MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_INT8X8X32_GEMV_MK4) }; - #if MGB_ENABLE_DOT +class MatrixMulImpl::AlgoInt8x8x32GevmDot : public AlgoBase { +public: + AlgoAttribute attribute() const override { + return AlgoAttribute::REPRODUCIBLE | AlgoAttribute::USABLE_DEPEND_ON_SHAPE; + } + const char* name() const override { return "ARM_COMMON_INT8X8X32_GEVM_DOT"; } + bool usable(const KernSizeParam&) const override; + bool preferred(const KernSizeParam&) const override; + size_t get_workspace(const KernSizeParam&) const override { return 0; } + kern_t get_kern(const KernSizeParam&) const override; + AlgoSet algoset() const override { return AlgoSet::ALGO_TYPE_GEVM; } + PackMode packmode() const override { return PackMode::NO_PACK; } + MEGDNN_OVERRIDE_MATMUL_DESC(1, 32, 4, 2, AlgoDataType::QINT8X8X32, DEFAULT) + WorkspaceBundle get_bundle(const KernSizeParam&) const override { + return WorkspaceBundle{nullptr, {}}; + } + kern_naked_t get_kern_naked(const KernSizeParam&) const override { + megdnn_assert(0, "naked kern no impl"); + } + void pack_A(const KernParam& kern_param, void* out, size_t index, size_t stride) + const override { + megdnn_assert(0, "pack_A no impl"); + } + void pack_B(const KernParam& kern_param, void* out, size_t x0, size_t xmax) + const override { + megdnn_assert(0, "pack_B no impl"); + } + InnerBlockSize get_inner_block_size() const override { return {1, 32, 4}; }; + MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_INT8X8X32_GEVM_DOT) +}; + +class MatrixMulImpl::AlgoInt8x8x32GevmN32K4Dot : public AlgoBase { +public: + AlgoAttribute attribute() const override { + return AlgoAttribute::REPRODUCIBLE | AlgoAttribute::USABLE_DEPEND_ON_SHAPE; + } + const char* name() const override { return "ARM_COMMON_INT8X8X32_GEVM_N32K4_DOT"; } + bool usable(const KernSizeParam&) const override; + bool preferred(const KernSizeParam&) const override; + size_t get_workspace(const KernSizeParam&) const override { return 0; } + kern_t get_kern(const KernSizeParam&) const override; + AlgoSet algoset() const override { return AlgoSet::ALGO_TYPE_GEVM; } + PackMode packmode() const override { return PackMode::NO_PACK; } + MEGDNN_OVERRIDE_MATMUL_DESC(1, 32, 4, 2, AlgoDataType::QINT8X8X32, N32K4_DOT) + WorkspaceBundle get_bundle(const KernSizeParam&) const override { + return WorkspaceBundle{nullptr, {}}; + } + kern_naked_t get_kern_naked(const KernSizeParam&) const override { + megdnn_assert(0, "naked kern no impl"); + } + void pack_A(const KernParam& kern_param, void* out, size_t index, size_t stride) + const override { + megdnn_assert(0, "pack_A no impl"); + } + void pack_B(const KernParam& kern_param, void* out, size_t x0, size_t xmax) + const override { + megdnn_assert(0, "pack_B no impl"); + } + InnerBlockSize get_inner_block_size() const override { return {1, 32, 4}; }; + MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_INT8X8X32_GEVM_N32K4_DOT) +}; + class MatrixMulImpl::AlgoInt8x8x32GemvMK4Dot : public AlgoBase { public: AlgoAttribute attribute() const override { diff --git a/dnn/src/arm_common/matrix_mul/int8/gemv.cpp b/dnn/src/arm_common/matrix_mul/int8/gemv.cpp index 7588c6b4..50446818 100644 --- a/dnn/src/arm_common/matrix_mul/int8/gemv.cpp +++ b/dnn/src/arm_common/matrix_mul/int8/gemv.cpp @@ -2,6 +2,7 @@ #include "megdnn/oprs.h" #include "src/arm_common/matrix_mul/int8/gemv.h" +#include "src/common/unroll_macro.h" #include "src/common/utils.h" #include "midout.h" @@ -430,5 +431,398 @@ void arm_common::gemv_like_mk4_dot( MIDOUT_END(); } #endif +#if MGB_ENABLE_DOT +namespace { +MEGDNN_ATTRIBUTE_TARGET("dotprod") +void gevm_naive_dot_impl( + const int8_t* __restrict A, const int8_t* __restrict B, int32_t* __restrict C, + size_t M, size_t N, size_t K, size_t Astride, size_t Bstride, size_t Cstride, + bool load_c) { + constexpr size_t n_block = 32; + const size_t n_end = N / n_block * n_block; + const size_t n_remain = N - n_end; + + constexpr size_t k_block = 4; + constexpr size_t k_block_x2 = k_block * 2; + const size_t k_end = (K / k_block_x2) * k_block_x2; + const size_t k_remain = K - k_end; + for (size_t n = 0; n < n_end; n += n_block) { + if (K < k_block_x2) { + if (!load_c) { + for (size_t i = 0; i < n_block; ++i) { + C[n + i] = 0; + } + } + for (size_t k = 0; k < K; ++k) { + for (size_t i = 0; i < n_block; ++i) { + C[n + i] += A[k] * B[k * Bstride + n + i]; + } + } + continue; + } + int32x4_t c[8]; + if (load_c) { +#define cb(step) c[step] = vld1q_s32(C + n + step * 4); + UNROLL_CALL_RAW(8, cb); +#undef cb + } else { +#define cb(step) c[step] = vdupq_n_s32(0); + UNROLL_CALL_RAW(8, cb); +#undef cb + } + int8x16_t a[2]; + a[0] = vld1q_dup_s32(A); + int8x16_t b[2][8]; +#define cb(step) \ + b[0][step * 2 + 0] = vld1q_s8(B + (0 + step) * Bstride + n); \ + b[0][step * 2 + 1] = vld1q_s8(B + (0 + step) * Bstride + n + 16); + UNROLL_CALL_RAW(4, cb); +#undef cb + size_t k_buffer_end = k_end - k_block_x2; + for (size_t k = 0; k < k_buffer_end; k += k_block_x2) { + //! double buffer main +#define cb(step) \ + b[1][step * 2 + 0] = vld1q_s8(B + (k + step + k_block) * Bstride + n); \ + b[1][step * 2 + 1] = vld1q_s8(B + (k + step + k_block) * Bstride + n + 16); + UNROLL_CALL_RAW(4, cb); +#undef cb + a[1] = vld1q_dup_s32(A + k + k_block); + + int8x16x2_t ab0 = vzipq_s8(b[0][0], b[0][2]); + int8x16x2_t cd0 = vzipq_s8(b[0][4], b[0][6]); + int8x16x2_t ab1 = vzipq_s8(b[0][1], b[0][3]); + int8x16x2_t cd1 = vzipq_s8(b[0][5], b[0][7]); + int16x8x2_t abcd0 = vzipq_s16(ab0.val[0], cd0.val[0]); + int16x8x2_t abcd1 = vzipq_s16(ab0.val[1], cd0.val[1]); + int16x8x2_t abcd2 = vzipq_s16(ab1.val[0], cd1.val[0]); + int16x8x2_t abcd3 = vzipq_s16(ab1.val[1], cd1.val[1]); + c[0] = vdotq_s32(c[0], abcd0.val[0], a[0]); + c[1] = vdotq_s32(c[1], abcd0.val[1], a[0]); + c[2] = vdotq_s32(c[2], abcd1.val[0], a[0]); + c[3] = vdotq_s32(c[3], abcd1.val[1], a[0]); + c[4] = vdotq_s32(c[4], abcd2.val[0], a[0]); + c[5] = vdotq_s32(c[5], abcd2.val[1], a[0]); + c[6] = vdotq_s32(c[6], abcd3.val[0], a[0]); + c[7] = vdotq_s32(c[7], abcd3.val[1], a[0]); +#define cb(step) \ + b[0][step * 2 + 0] = vld1q_s8(B + (k + step + k_block_x2) * Bstride + n); \ + b[0][step * 2 + 1] = vld1q_s8(B + (k + step + k_block_x2) * Bstride + n + 16); + UNROLL_CALL_RAW(4, cb); +#undef cb + a[0] = vld1q_dup_s32(A + k + k_block_x2); + + ab0 = vzipq_s8(b[1][0], b[1][2]); + cd0 = vzipq_s8(b[1][4], b[1][6]); + ab1 = vzipq_s8(b[1][1], b[1][3]); + cd1 = vzipq_s8(b[1][5], b[1][7]); + + abcd0 = vzipq_s16(ab0.val[0], cd0.val[0]); + abcd1 = vzipq_s16(ab0.val[1], cd0.val[1]); + abcd2 = vzipq_s16(ab1.val[0], cd1.val[0]); + abcd3 = vzipq_s16(ab1.val[1], cd1.val[1]); + c[0] = vdotq_s32(c[0], abcd0.val[0], a[1]); + c[1] = vdotq_s32(c[1], abcd0.val[1], a[1]); + c[2] = vdotq_s32(c[2], abcd1.val[0], a[1]); + c[3] = vdotq_s32(c[3], abcd1.val[1], a[1]); + c[4] = vdotq_s32(c[4], abcd2.val[0], a[1]); + c[5] = vdotq_s32(c[5], abcd2.val[1], a[1]); + c[6] = vdotq_s32(c[6], abcd3.val[0], a[1]); + c[7] = vdotq_s32(c[7], abcd3.val[1], a[1]); + } + //! double buffer remain +#define cb(step) \ + b[1][step * 2 + 0] = vld1q_s8(B + (k_buffer_end + step + k_block) * Bstride + n); \ + b[1][step * 2 + 1] = \ + vld1q_s8(B + (k_buffer_end + step + k_block) * Bstride + n + 16); + UNROLL_CALL_RAW(4, cb); +#undef cb + a[1] = vld1q_dup_s32(A + k_buffer_end + k_block); + + int8x16x2_t ab0 = vzipq_s8(b[0][0], b[0][2]); + int8x16x2_t cd0 = vzipq_s8(b[0][4], b[0][6]); + int8x16x2_t ab1 = vzipq_s8(b[0][1], b[0][3]); + int8x16x2_t cd1 = vzipq_s8(b[0][5], b[0][7]); + int16x8x2_t abcd0 = vzipq_s16(ab0.val[0], cd0.val[0]); + int16x8x2_t abcd1 = vzipq_s16(ab0.val[1], cd0.val[1]); + int16x8x2_t abcd2 = vzipq_s16(ab1.val[0], cd1.val[0]); + int16x8x2_t abcd3 = vzipq_s16(ab1.val[1], cd1.val[1]); + c[0] = vdotq_s32(c[0], abcd0.val[0], a[0]); + c[1] = vdotq_s32(c[1], abcd0.val[1], a[0]); + c[2] = vdotq_s32(c[2], abcd1.val[0], a[0]); + c[3] = vdotq_s32(c[3], abcd1.val[1], a[0]); + c[4] = vdotq_s32(c[4], abcd2.val[0], a[0]); + c[5] = vdotq_s32(c[5], abcd2.val[1], a[0]); + c[6] = vdotq_s32(c[6], abcd3.val[0], a[0]); + c[7] = vdotq_s32(c[7], abcd3.val[1], a[0]); + + ab0 = vzipq_s8(b[1][0], b[1][2]); + cd0 = vzipq_s8(b[1][4], b[1][6]); + ab1 = vzipq_s8(b[1][1], b[1][3]); + cd1 = vzipq_s8(b[1][5], b[1][7]); + abcd0 = vzipq_s16(ab0.val[0], cd0.val[0]); + abcd1 = vzipq_s16(ab0.val[1], cd0.val[1]); + abcd2 = vzipq_s16(ab1.val[0], cd1.val[0]); + abcd3 = vzipq_s16(ab1.val[1], cd1.val[1]); + c[0] = vdotq_s32(c[0], abcd0.val[0], a[1]); + c[1] = vdotq_s32(c[1], abcd0.val[1], a[1]); + c[2] = vdotq_s32(c[2], abcd1.val[0], a[1]); + c[3] = vdotq_s32(c[3], abcd1.val[1], a[1]); + c[4] = vdotq_s32(c[4], abcd2.val[0], a[1]); + c[5] = vdotq_s32(c[5], abcd2.val[1], a[1]); + c[6] = vdotq_s32(c[6], abcd3.val[0], a[1]); + c[7] = vdotq_s32(c[7], abcd3.val[1], a[1]); + + vst1q_s32(C + n + 0 * 4, c[0]); + vst1q_s32(C + n + 1 * 4, c[1]); + vst1q_s32(C + n + 2 * 4, c[2]); + vst1q_s32(C + n + 3 * 4, c[3]); + vst1q_s32(C + n + 4 * 4, c[4]); + vst1q_s32(C + n + 5 * 4, c[5]); + vst1q_s32(C + n + 6 * 4, c[6]); + vst1q_s32(C + n + 7 * 4, c[7]); + if (k_remain > 0) { + for (size_t k = k_end; k < K; ++k) { + for (size_t i = 0; i < n_block; ++i) { + C[n + i] += A[k] * B[k * Bstride + n + i]; + } + } + } + } + if (n_remain > 0) { + for (size_t n = n_end; n < N; ++n) { + if (!load_c) { + C[n] = 0; + } + for (size_t k = 0; k < K; ++k) { + C[n] += A[k] * B[k * Bstride + n]; + } + } + } +} +#if MEGDNN_ARMV7 +MEGDNN_ATTRIBUTE_TARGET("dotprod") +void gevm_naive_dot_n32k4_impl( + const int8_t* __restrict A, const int8_t* __restrict B, int32_t* __restrict C, + size_t M, size_t N, size_t K, size_t Astride, size_t Bstride, size_t Cstride, + bool load_c) { + //! input must be N/32, k/4, 32, 4 + //! TODO: add prefetch + //! TODO: add double buffer + constexpr size_t n_block = 32; + constexpr size_t k_block = 4; + for (size_t n = 0; n < N; n += n_block) { + int32x4_t c[n_block / 4]; +#define cb(step) c[step] = vdupq_n_s32(0); + UNROLL_CALL_RAW(8, cb); +#undef cb + const int8_t* b_base = B + n * K; + for (size_t k = 0; k < K; k += k_block) { + int8x16_t a[1]; + int8x16_t b[1][8]; +#define cb(step) b[0][step] = vld1q_s8(b_base + k * 32 + 16 * step); + UNROLL_CALL_RAW(8, cb); +#undef cb + a[0] = vld1q_dup_s32(A + k); + + c[0] = vdotq_s32(c[0], b[0][0], a[0]); + c[1] = vdotq_s32(c[1], b[0][1], a[0]); + c[2] = vdotq_s32(c[2], b[0][2], a[0]); + c[3] = vdotq_s32(c[3], b[0][3], a[0]); + c[4] = vdotq_s32(c[4], b[0][4], a[0]); + c[5] = vdotq_s32(c[5], b[0][5], a[0]); + c[6] = vdotq_s32(c[6], b[0][6], a[0]); + c[7] = vdotq_s32(c[7], b[0][7], a[0]); + } + vst1q_s32(C + n + 0 * 4, c[0]); + vst1q_s32(C + n + 1 * 4, c[1]); + vst1q_s32(C + n + 2 * 4, c[2]); + vst1q_s32(C + n + 3 * 4, c[3]); + vst1q_s32(C + n + 4 * 4, c[4]); + vst1q_s32(C + n + 5 * 4, c[5]); + vst1q_s32(C + n + 6 * 4, c[6]); + vst1q_s32(C + n + 7 * 4, c[7]); + } +} +#else +MEGDNN_ATTRIBUTE_TARGET("dotprod") +inline void n32k4_dot( + const int8_t* __restrict A, const int8_t* __restrict B, int32_t* __restrict C, + size_t K) { + int oddk = (K & 1); + K = ((K + 1) / 2) - 1; + //! C q0-q7 + //! A q8-q9 + //! B q10-q25 + asm volatile( + // load accumulator C + "1:\n" + "eor v0.16b, v0.16b, v0.16b\n" + "eor v1.16b, v1.16b, v1.16b\n" + "eor v2.16b, v2.16b, v2.16b\n" + "eor v3.16b, v3.16b, v3.16b\n" + "eor v4.16b, v4.16b, v4.16b\n" + "eor v5.16b, v5.16b, v5.16b\n" + "eor v6.16b, v6.16b, v6.16b\n" + "eor v7.16b, v7.16b, v7.16b\n" + + "ld1r {v8.4s}, [%[a_ptr]]\n" + "ld1 {v10.4s, v11.4s, v12.4s, v13.4s}, [%[b_ptr]], 64\n" + "ld1 {v14.4s, v15.4s, v16.4s, v17.4s}, [%[b_ptr]], 64\n" + "add %[a_ptr], %[a_ptr], #4\n" + + "cmp %w[k], #0\n" + "beq 4f\n" + + "2: \n" + // Loop proper + "3:\n" + "ld1r {v9.4s}, [%[a_ptr]]\n" + "sdot v0.4s, v10.16b, v8.16b\n" + "ldr q18, [%[b_ptr], #0]\n" + "sdot v1.4s, v11.16b, v8.16b\n" + "ldr q19, [%[b_ptr], #16]\n" + "sdot v2.4s, v12.16b, v8.16b\n" + "ldr q20, [%[b_ptr], #32]\n" + "add %[a_ptr], %[a_ptr], #4\n" + "sdot v3.4s, v13.16b, v8.16b\n" + "ldr q21, [%[b_ptr], #48]\n" + "sdot v4.4s, v14.16b, v8.16b\n" + "ldr q22, [%[b_ptr], #64]\n" + "sdot v5.4s, v15.16b, v8.16b\n" + "ldr q23, [%[b_ptr], #80]\n" + "sdot v6.4s, v16.16b, v8.16b\n" + "ldr q24, [%[b_ptr], #96]\n" + "sdot v7.4s, v17.16b, v8.16b\n" + "ldr q25, [%[b_ptr], #112]\n" + + "ld1r {v8.4s}, [%[a_ptr]]\n" + "sdot v0.4s, v18.16b, v9.16b\n" + "ldr q10, [%[b_ptr], #128]\n" + "sdot v1.4s, v19.16b, v9.16b\n" + "ldr q11, [%[b_ptr], #144]\n" + "sdot v2.4s, v20.16b, v9.16b\n" + "ldr q12, [%[b_ptr], #160]\n" + "sdot v3.4s, v21.16b, v9.16b\n" + "ldr q13, [%[b_ptr], #176]\n" + "sdot v4.4s, v22.16b, v9.16b\n" + "ldr q14, [%[b_ptr], #192]\n" + "sdot v5.4s, v23.16b, v9.16b\n" + "ldr q15, [%[b_ptr], #208]\n" + "sdot v6.4s, v24.16b, v9.16b\n" + "ldr q16, [%[b_ptr], #224]\n" + "sdot v7.4s, v25.16b, v9.16b\n" + "ldr q17, [%[b_ptr], #240]\n" + + "add %[a_ptr], %[a_ptr], #4\n" + "add %[b_ptr], %[b_ptr], #256\n" + + "subs %w[k], %w[k], #1\n" + "bne 3b\n" + + "4:\n" + "cmp %w[oddk], #1\n" + "beq 5f\n" + // Even tail + + "ld1r {v9.4s}, [%[a_ptr]]\n" + "sdot v0.4s, v10.16b, v8.16b\n" + "ldr q18, [%[b_ptr], #0]\n" + "sdot v1.4s, v11.16b, v8.16b\n" + "ldr q19, [%[b_ptr], #16]\n" + "sdot v2.4s, v12.16b, v8.16b\n" + "ldr q20, [%[b_ptr], #32]\n" + "sdot v3.4s, v13.16b, v8.16b\n" + "ldr q21, [%[b_ptr], #48]\n" + "sdot v4.4s, v14.16b, v8.16b\n" + "ldr q22, [%[b_ptr], #64]\n" + "sdot v5.4s, v15.16b, v8.16b\n" + "ldr q23, [%[b_ptr], #80]\n" + "sdot v6.4s, v16.16b, v8.16b\n" + "ldr q24, [%[b_ptr], #96]\n" + "sdot v7.4s, v17.16b, v8.16b\n" + "ldr q25, [%[b_ptr], #112]\n" + + "sdot v0.4s, v18.16b, v9.16b\n" + "sdot v1.4s, v19.16b, v9.16b\n" + "sdot v2.4s, v20.16b, v9.16b\n" + "sdot v3.4s, v21.16b, v9.16b\n" + "sdot v4.4s, v22.16b, v9.16b\n" + "sdot v5.4s, v23.16b, v9.16b\n" + "sdot v6.4s, v24.16b, v9.16b\n" + "sdot v7.4s, v25.16b, v9.16b\n" + "st1 {v0.4s, v1.4s, v2.4s, v3.4s}, [%[c_ptr]], 64\n" + "st1 {v4.4s, v5.4s, v6.4s, v7.4s}, [%[c_ptr]], 64\n" + "b 6f\n" + + "5:\n" + // Odd tail + "sdot v0.4s, v10.16b, v8.16b\n" + "sdot v1.4s, v11.16b, v8.16b\n" + "sdot v2.4s, v12.16b, v8.16b\n" + "sdot v3.4s, v13.16b, v8.16b\n" + "sdot v4.4s, v14.16b, v8.16b\n" + "sdot v5.4s, v15.16b, v8.16b\n" + "sdot v6.4s, v16.16b, v8.16b\n" + "sdot v7.4s, v17.16b, v8.16b\n" + "st1 {v0.4s, v1.4s, v2.4s, v3.4s}, [%[c_ptr]], 64\n" + "st1 {v4.4s, v5.4s, v6.4s, v7.4s}, [%[c_ptr]], 64\n" + "6:\n" + + : [a_ptr] "+r"(A), [b_ptr] "+r"(B), [k] "+r"(K), [c_ptr] "+r"(C), + [oddk] "+r"(oddk) + : + : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", + "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", + "v22", "v23", "v24", "v25", "cc", "memory"); +} +MEGDNN_ATTRIBUTE_TARGET("dotprod") +void gevm_naive_dot_n32k4_impl( + const int8_t* __restrict A, const int8_t* __restrict B, int32_t* __restrict C, + size_t M, size_t N, size_t K, size_t Astride, size_t Bstride, size_t Cstride, + bool load_c) { + //! input must be N/32, k/4, 32, 4 + //! TODO: add prefetch + //! TODO: add double buffer + constexpr size_t n_block = 32; + for (size_t n = 0; n < N; n += n_block) { + n32k4_dot(A, B + n * K, C + n, K / 4); + } +} +#endif +} // namespace + +void arm_common::gevm_naive_dot( + const int8_t* __restrict A, const int8_t* __restrict B, int32_t* __restrict C, + size_t M, size_t N, size_t K, size_t Astride, size_t Bstride, size_t Cstride) { + megdnn_assert(M == 1); + MIDOUT_BEGIN(megdnn_arm_common_int8_gemv, midout_iv("INT8_gevm_dot"_hash)) { + size_t cache_size = 256 * 1024; + size_t k_group = N * K / cache_size; + constexpr size_t k_align = 8; + if (k_group >= 2) { + size_t k_per_group = ((K / k_group) + k_align - 1) / k_align * k_align; + for (size_t k = 0; k < K; k += k_per_group) { + size_t real_k = std::min(K - k, k_per_group); + gevm_naive_dot_impl( + A + k, B + k * Bstride, C, M, N, real_k, Astride, Bstride, + Cstride, k != 0); + } + } else { + gevm_naive_dot_impl(A, B, C, M, N, K, Astride, Bstride, Cstride, false); + } + } + MIDOUT_END(); +} + +void arm_common::gevm_naive_n32k4_dot( + const int8_t* __restrict A, const int8_t* __restrict B, int32_t* __restrict C, + size_t M, size_t N, size_t K, size_t Astride, size_t Bstride, size_t Cstride) { + megdnn_assert(M == 1); + MIDOUT_BEGIN(megdnn_arm_common_int8_gemv, midout_iv("INT8_gevm_dot_nk4"_hash)) { + gevm_naive_dot_n32k4_impl(A, B, C, M, N, K, Astride, Bstride, Cstride, false); + } + MIDOUT_END(); +} +#endif // vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/matrix_mul/int8/gemv.h b/dnn/src/arm_common/matrix_mul/int8/gemv.h index 9b8d3efc..cca596cc 100644 --- a/dnn/src/arm_common/matrix_mul/int8/gemv.h +++ b/dnn/src/arm_common/matrix_mul/int8/gemv.h @@ -22,6 +22,14 @@ void gemv_like_mk4( void gemv_like_mk4_dot( const int8_t* __restrict A, const int8_t* __restrict B, int32_t* __restrict C, size_t M, size_t N, size_t K, size_t Astride, size_t Bstride, size_t Cstride); + +void gevm_naive_dot( + const int8_t* __restrict A, const int8_t* __restrict B, int32_t* __restrict C, + size_t M, size_t N, size_t K, size_t Astride, size_t Bstride, size_t Cstride); + +void gevm_naive_n32k4_dot( + const int8_t* __restrict A, const int8_t* __restrict B, int32_t* __restrict C, + size_t M, size_t N, size_t K, size_t Astride, size_t Bstride, size_t Cstride); #endif } // namespace arm_common diff --git a/dnn/src/arm_common/matrix_mul/opr_impl.cpp b/dnn/src/arm_common/matrix_mul/opr_impl.cpp index ccd5f4fb..f5d44ab0 100644 --- a/dnn/src/arm_common/matrix_mul/opr_impl.cpp +++ b/dnn/src/arm_common/matrix_mul/opr_impl.cpp @@ -14,6 +14,8 @@ class MatrixMulImpl::AlgoPack : NonCopyableObj { AlgoInt8x8x32GemvMK4 int8x8x32_gemv_mk4; #if MGB_ENABLE_DOT AlgoInt8x8x32GemvMK4Dot int8x8x32_gemv_mk4_dot; + AlgoInt8x8x32GevmDot int8x8x32_gevm_dot; + AlgoInt8x8x32GevmN32K4Dot int8x8x32_gevm_n32k4_dot; #endif AlgoGevm gevm; @@ -28,6 +30,8 @@ public: #endif #if MGB_ENABLE_DOT m_all_algos.emplace_back(&int8x8x32_gemv_mk4_dot); + m_all_algos.emplace_back(&int8x8x32_gevm_dot); + m_all_algos.emplace_back(&int8x8x32_gevm_n32k4_dot); #endif m_all_algos.emplace_back(&int8x8x32_gemv); m_all_algos.emplace_back(&int8x8x32_gemv_mk4); diff --git a/dnn/src/arm_common/matrix_mul/opr_impl.h b/dnn/src/arm_common/matrix_mul/opr_impl.h index 50676703..1e9e2db8 100644 --- a/dnn/src/arm_common/matrix_mul/opr_impl.h +++ b/dnn/src/arm_common/matrix_mul/opr_impl.h @@ -31,7 +31,9 @@ protected: class AlgoF16Gemv; #endif #if MGB_ENABLE_DOT - class AlgoInt8x8x32GemvMK4Dot; // Arm_common Int8x8x32 Gemv NCHW44_DOT + class AlgoInt8x8x32GemvMK4Dot; // Arm_common Int8x8x32 Gemv NCHW44_DOT + class AlgoInt8x8x32GevmDot; // Arm_common Int8x8x32 Gevm NCHW DOT + class AlgoInt8x8x32GevmN32K4Dot; // Arm_common Int8x8x32 Gevm NK4 #endif class AlgoInt8x8x16; // Arm_common Int 8x8x16 class AlgoPack; diff --git a/dnn/src/arm_common/simd_macro/marm_neon.h b/dnn/src/arm_common/simd_macro/marm_neon.h index 2f90eebe..ce18c882 100644 --- a/dnn/src/arm_common/simd_macro/marm_neon.h +++ b/dnn/src/arm_common/simd_macro/marm_neon.h @@ -469,7 +469,7 @@ __ai float64x2_t vbitq_f64(float64x2_t dst, float64x2_t v1, uint64x2_t mask) { #endif #if MEGDNN_ARMV7 -__ai int8x16_t vqtbl1q_s8(int8x16_t& a, uint8x16_t& idx) { +__ai int8x16_t vqtbl1q_s8(int8x16_t a, uint8x16_t idx) { int8x8_t src_low = vget_low_s8(a); int8x8_t src_high = vget_high_s8(a); return vcombine_s8( @@ -726,6 +726,13 @@ __ai float32x4_t Vfmsq_f32(float32x4_t& a, float32x4_t& b, float32x4_t& v) { asm volatile("fmls %0.4s, %1.4s, %2.4s\n" : "+w"(a) : "w"(b), "w"(v) :); return a; } +#if __ARM_ARCH < 8 +__ai int32x4_t vcvtaq_s32_f32(float32x4_t val) { + float32x4_t vinc0 = vbslq_f32( + vcgeq_f32(val, vdupq_n_f32(0.f)), vdupq_n_f32(0.5f), vdupq_n_f32(-0.5f)); + return vcvtq_s32_f32(vaddq_f32(val, vinc0)); +} +#endif #if MGB_ENABLE_DOT #undef __ARM_FEATURE_DOTPROD #endif diff --git a/dnn/src/common/matrix_mul.cpp b/dnn/src/common/matrix_mul.cpp index 91599502..494040a9 100644 --- a/dnn/src/common/matrix_mul.cpp +++ b/dnn/src/common/matrix_mul.cpp @@ -61,6 +61,15 @@ void MatrixMulForward::deduce_layout( "(transposed) B is (%zu,%zu)", A0, A1, B0, B1); C = TensorLayout(TensorShape({A0, B1}), C.dtype); + } else if (param().format == param::MatrixMul::Format::N32K4_DOT) { + A0 = A.shape[0]; + A1 = A.shape[1]; + B0 = B.shape[0]; + B1 = B.shape[1]; + megdnn_assert(!m_param.transposeA && !m_param.transposeB); + megdnn_assert(A0 == 1 && A1 % 4 == 0); + megdnn_assert(B.ndim == 4); + C = TensorLayout(TensorShape({A0, B0 * 32}), C.dtype); } else { auto do_deduce = [&](size_t pack_size) { megdnn_assert( @@ -132,6 +141,18 @@ void MatrixMulForward::check_exec( megdnn_assert(A0 == C0, "%s", errmsg().c_str()); megdnn_assert(B1 == C1, "%s", errmsg().c_str()); megdnn_assert(A1 == B0, "%s", errmsg().c_str()); + } else if (param().format == param::MatrixMul::Format::N32K4_DOT) { + size_t A0 = A.shape[0]; + size_t A1 = A.shape[1]; + size_t B2 = B.shape[2]; + size_t B3 = B.shape[3]; + megdnn_assert(!m_param.transposeA && !m_param.transposeB); + megdnn_assert(A0 == 1 && A1 % 4 == 0); + megdnn_assert(B.ndim == 4); + megdnn_assert(B2 == 32 && B3 == 4); + megdnn_assert_contiguous(A); + megdnn_assert_contiguous(B); + megdnn_assert_contiguous(C); } else { megdnn_assert_eq_size_t(A.ndim, 4_z); megdnn_assert_eq_size_t(B.ndim, 3_z); diff --git a/dnn/src/fallback/conv_bias/opr_impl.cpp b/dnn/src/fallback/conv_bias/opr_impl.cpp index 19606065..836dfc61 100644 --- a/dnn/src/fallback/conv_bias/opr_impl.cpp +++ b/dnn/src/fallback/conv_bias/opr_impl.cpp @@ -442,9 +442,11 @@ ConvBiasImpl::NCBKernSizeParam ConvBiasImpl::make_ncb_kern_size_param( megdnn_assert(0, "invalid conv format %d", static_cast(param().format)); } BiasMode bias_mode; + //! dst only channel BIAS is viewed as BROADCAST_CHANNEL_BIAS + bool dst_only_c = dst[0] == 1 && dst[spatial_pos] == 1 && dst[spatial_pos + 1] == 1; if (bias.ndim == 0) { bias_mode = BiasMode::NO_BIAS; - } else if (bias.eq_shape(dst)) { + } else if (bias.eq_shape(dst) && !dst_only_c) { bias_mode = BiasMode::BIAS; } else { //! just check the ndim, the detail shape check is in check_exec diff --git a/dnn/src/fallback/conv_bias/opr_impl.h b/dnn/src/fallback/conv_bias/opr_impl.h index 5c8e4058..c07893ad 100644 --- a/dnn/src/fallback/conv_bias/opr_impl.h +++ b/dnn/src/fallback/conv_bias/opr_impl.h @@ -258,6 +258,9 @@ public: ARM_COMMON_CHANWISE_STRD1_NCHW44_S8, ARM_COMMON_CHANWISE_STRD2_NCHW44_S8, ARM_COMMON_DIRECT_NCHW_NCHW44_DOT_S8, + //! LARGE for large filter + ARM_COMMON_DOT_IM2COL_CHANWISE_LARGE_S8, + ARM_COMMON_DOT_DIRECT_CHANWISE_LARGE_S8, ARM_COMMON_DIRECT_STRD1_DOT_S8, ARM_COMMON_DIRECT_STRD2_DOT_S8, ARM_COMMON_DIRECT_NCHW44_DOT_S8, diff --git a/dnn/src/fallback/matrix_mul/opr_impl.cpp b/dnn/src/fallback/matrix_mul/opr_impl.cpp index ebc621b6..71385eca 100644 --- a/dnn/src/fallback/matrix_mul/opr_impl.cpp +++ b/dnn/src/fallback/matrix_mul/opr_impl.cpp @@ -195,11 +195,11 @@ MatrixMulImpl::KernSizeParam MatrixMulImpl::make_kern_size_param( kern_size_param.trB = param().transposeB; kern_size_param.compute_mode = param().compute_mode; kern_size_param.format = param().format; - - size_t pack_size = MatrixMulForward::pack_size(param().format); - kern_size_param.K *= pack_size; - kern_size_param.M *= pack_size; - + if (param().format != Param::Format::N32K4_DOT) { + size_t pack_size = MatrixMulForward::pack_size(param().format); + kern_size_param.K *= pack_size; + kern_size_param.M *= pack_size; + } return kern_size_param; } diff --git a/dnn/src/fallback/matrix_mul/opr_impl.h b/dnn/src/fallback/matrix_mul/opr_impl.h index 8165f57b..ea0b7064 100644 --- a/dnn/src/fallback/matrix_mul/opr_impl.h +++ b/dnn/src/fallback/matrix_mul/opr_impl.h @@ -122,6 +122,8 @@ public: ARM_COMMON_INT8X8X32_GEMV, ARM_COMMON_INT8X8X32_GEMV_MK4, ARM_COMMON_INT8X8X32_GEMV_MK4_DOT, + ARM_COMMON_INT8X8X32_GEVM_DOT, + ARM_COMMON_INT8X8X32_GEVM_N32K4_DOT, ARM_COMMON_F16_GEMV, ARM_COMMON_GEVM, #if MEGDNN_AARCH64 @@ -175,6 +177,7 @@ public: enum class AlgoSet : uint32_t { ALGO_TYPE_GEMM = 0, ALGO_TYPE_GEMV = 1, + ALGO_TYPE_GEVM = 2, }; enum class PackMode : uint32_t { diff --git a/dnn/src/naive/matrix_mul/matrix_mul_helper.h b/dnn/src/naive/matrix_mul/matrix_mul_helper.h index d0c8fbe0..a1d24f32 100644 --- a/dnn/src/naive/matrix_mul/matrix_mul_helper.h +++ b/dnn/src/naive/matrix_mul/matrix_mul_helper.h @@ -105,6 +105,34 @@ void run_matrix_mul_mk4_dot_tpl( template < typename itype, typename otype, bool transA, bool transB, typename comp_type = otype> +void run_matrix_mul_n32k4_dot_tpl( + const itype* A, const itype* B, otype* C, size_t M, size_t N, size_t K, + size_t LDA, size_t, size_t, const DType& A_type, const DType& B_type) { + Getter getterA(A_type), getterB(B_type); + megdnn_assert(!transA && !transB); + for (size_t m = 0; m < M; ++m) { + for (size_t n = 0; n < N; n += 32) { + comp_type res[32] = {comp_type(0)}; + for (size_t k = 0; k < K; k += 4) { + for (size_t i = 0; i < 32; i++) { + comp_type av, bv; + for (size_t j = 0; j < 4; j++) { + av = getterA(A[m * LDA + k + j]); + bv = getterA(B[n * K + k * 32 + i * 4 + j]); + res[i] += av * bv; + } + } + } + for (size_t i = 0; i < 32; i++) { + C[n + i] = res[i]; + } + } + } +} + +template < + typename itype, typename otype, bool transA, bool transB, + typename comp_type = otype> void run_matrix_mul_mk8_tpl( const itype* A, const itype* B, otype* C, size_t M, size_t N, size_t K, size_t LDA, size_t LDB, size_t LDC, const DType& A_type, const DType& B_type) { @@ -251,6 +279,10 @@ void dispatch_ta_tb( return run_matrix_mul_mk4_dot_tpl<_itype, _otype, TA, TB, _comp_type>( \ static_cast(A), static_cast(B), \ static_cast<_otype*>(C), M, N, K, LDA, LDB, LDC, A_type, B_type); \ + } else if (format == param::MatrixMul::Format::N32K4_DOT) { \ + return run_matrix_mul_n32k4_dot_tpl<_itype, _otype, TA, TB, _comp_type>( \ + static_cast(A), static_cast(B), \ + static_cast<_otype*>(C), M, N, K, LDA, LDB, LDC, A_type, B_type); \ } else if (format == param::MatrixMul::Format::MK8) { \ return run_matrix_mul_mk8_tpl<_itype, _otype, TA, TB, _comp_type>( \ static_cast(A), static_cast(B), \ diff --git a/dnn/test/arm_common/conv_bias.cpp b/dnn/test/arm_common/conv_bias.cpp index d77643a8..6a89a537 100644 --- a/dnn/test/arm_common/conv_bias.cpp +++ b/dnn/test/arm_common/conv_bias.cpp @@ -160,7 +160,7 @@ static void benchmark_convbias( .set_display(false); } auto nchw44_algo_regx = ".*(DIRECT|NCHW_NCHW44).*"; -#if MGB_ENBALE_DOT +#if MGB_ENABLE_DOT if (!is_fp32) { nchw44_algo_regx = ".*DOT.*"; } @@ -1626,7 +1626,7 @@ TEST_F(ARM_COMMON, BENCHMARK_CONV_BIAS_QINT8_STRIDE1_NCHW44) { #endif -#if MGB_ENBALE_DOT +#if MGB_ENABLE_DOT #if MEGDNN_WITH_BENCHMARK TEST_F(ARM_COMMON, BENCHMARK_CONV_BIAS_INT8_STRIDE1_WITHDOTPROD) { // have to remove preferred restrict in usable func before run the benchmark @@ -2011,6 +2011,80 @@ TEST_F(ARM_COMMON, BENCHMARK_CONV_BIAS_INT8_STRIDE1_WITHDOTPROD_NCHW44_DOT) { } } +TEST_F(ARM_COMMON, BENCHMARK_CONV_BIAS_INT8_LARGE_KERN_NCHW_DOT) { + using namespace conv_bias; + + std::vector args; + auto run = [&](size_t group, size_t w, size_t h, size_t kernel, size_t stride, + NonlineMode nonline_mode) { + size_t p = kernel / 2; + if (w + 2 * p < kernel || h + 2 * p < kernel) + return; + param::ConvBias param; + param.stride_h = stride; + param.stride_w = stride; + param.pad_h = p; + param.pad_w = p; + param.nonlineMode = nonline_mode; + param.format = param::ConvBias::Format::NCHW; + param.sparse = ConvBiasForward::Param::Sparse::GROUP; + + //! channel bias + args.emplace_back( + param, TensorShape{1, group, h, w}, + TensorShape{group, 1, 1, kernel, kernel}, TensorShape{1, group, 1, 1}); + }; + + run(64, 64, 64, 9, 1, NonlineMode::RELU); + run(64, 40, 40, 9, 2, NonlineMode::RELU); + run(64, 20, 20, 9, 1, NonlineMode::RELU); + + constexpr size_t RUN = 120; + Benchmarker benchmark0(handle()); + benchmark0.set_dtype(0, dtype::QuantizedS8(2.5f)) + .set_dtype(1, dtype::QuantizedS8(2.5f)) + .set_dtype(2, dtype::QuantizedS32(6.25f)) + .set_dtype(4, dtype::QuantizedS8(60.25f)); + benchmark0.set_display(false); + benchmark0.set_times(RUN); + benchmark0.set_before_exec_callback(conv_bias::ConvBiasAlgoChecker( + "ARMDOTS8_DIRECT_CHANWISE_LARGE")); + + Benchmarker benchmark1(handle()); + benchmark1.set_dtype(0, dtype::QuantizedS8(2.5f)) + .set_dtype(1, dtype::QuantizedS8(2.5f)) + .set_dtype(2, dtype::QuantizedS32(6.25f)) + .set_dtype(4, dtype::QuantizedS8(60.25f)); + benchmark1.set_before_exec_callback(conv_bias::ConvBiasAlgoChecker( + "ARMDOTS8_IM2COL_CHANWISE_LARGE")); + benchmark1.set_display(false); + benchmark1.set_times(RUN); + + for (auto&& arg : args) { + TensorLayout dst_layout; + auto opr = handle()->create_operator(); + opr->param() = arg.param; + opr->deduce_layout( + {arg.src, dtype::Int8()}, {arg.filter, dtype::Int8()}, + {arg.bias, dtype::Int32()}, {}, dst_layout); + //! dst.nr_elems * FH * FW * 2 + float computations = + dst_layout.total_nr_elems() * arg.filter[3] * arg.filter[4] * 2.0 / 1e6; + + auto used0 = benchmark0.set_param(arg.param).exec( + {arg.src, arg.filter, arg.bias, {}, {}}) / + RUN; + auto used1 = benchmark1.set_param(arg.param).exec( + {arg.src, arg.filter, arg.bias, {}, {}}) / + RUN; + + printf("%s %s: Direct use: %f ms %f Gflops im2col: %f ms %f GFlops " + "speedup: %f\n", + arg.src.to_string().c_str(), arg.filter.to_string().c_str(), used0, + computations / used0, used1, computations / used1, used1 / used0); + } +} + #endif #endif @@ -2194,7 +2268,7 @@ TEST_F(ARM_COMMON, BENCHMARK_CONV_BIAS_CONV1X1_S1_QUANTIZEDSYM) { dtype::QuantizedS8 stype(2.5f); dtype::QuantizedS32 dtype(6.25f); #if MEGDNN_AARCH64 -#if MGB_ENBALE_DOT +#if MGB_ENABLE_DOT benchmark_conv1x1( "AARCH64_INT8X8X32_K8X12X4_DOTPROD", handle(), stype, dtype, dtype, dtype); #else @@ -2212,7 +2286,7 @@ TEST_F(ARM_COMMON, BENCHMARK_CONV_BIAS_CONV1X1_S1_QUANTIZEDASYM) { dtype::QuantizedS32 dtype(1.2 * 1.2); #if MEGDNN_AARCH64 -#if MGB_ENBALE_DOT +#if MGB_ENABLE_DOT benchmark_conv1x1( "AARCH64_QUINT8_K8X8X4_DOTPROD", handle(), stype, dtype, dtype, dtype); #else diff --git a/dnn/test/arm_common/conv_bias_multi_thread.cpp b/dnn/test/arm_common/conv_bias_multi_thread.cpp index 0c78be63..f08bb174 100644 --- a/dnn/test/arm_common/conv_bias_multi_thread.cpp +++ b/dnn/test/arm_common/conv_bias_multi_thread.cpp @@ -136,6 +136,84 @@ std::vector get_nchw44_channel_wise_args( return args; } +std::vector get_channel_wise_args( + std::vector kernel, size_t stride, bool no_bias, bool no_nonlinemode, + bool no_full_bias, bool support_relu) { + using namespace conv_bias; + using Param = param::ConvBias; + using NLMode = param::ConvBias::NonlineMode; + std::vector args; + + auto pack = [&](size_t n, size_t group, size_t w, size_t h, size_t kernel, + size_t stride, NLMode nlmode, bool pad) { + Param param; + param.stride_h = stride; + param.stride_w = stride; + if (pad) { + param.pad_h = kernel / 2; + param.pad_w = kernel / 2; + } else { + param.pad_h = 0; + param.pad_w = 0; + } + param.nonlineMode = nlmode; + param.format = param::ConvBias::Format::NCHW; + param.sparse = param::ConvBias::Sparse::GROUP; + + args.emplace_back( + param, TensorShape{n, group, h, w}, + TensorShape{group, 1, 1, kernel, kernel}, TensorShape{}); + if (!no_bias) { + args.emplace_back( + param, TensorShape{n, group, h, w}, + TensorShape{group, 1, 1, kernel, kernel}, + TensorShape{1, group, 1, 1}); + } + if (!no_full_bias) { + args.emplace_back( + param, TensorShape{n, group, h, w}, + TensorShape{group, 1, 1, kernel, kernel}, + TensorShape{ + n, group, (h + 2 * param.pad_w - kernel) / stride + 1, + (w + 2 * param.pad_w - kernel) / stride + 1}); + } + }; + + std::vector nonlinemode = {NLMode::IDENTITY}; + if (!no_nonlinemode) { + nonlinemode.emplace_back(NLMode::RELU); + nonlinemode.emplace_back(NLMode::H_SWISH); + } else if (support_relu) { + nonlinemode.emplace_back(NLMode::RELU); + } + + for (size_t n : {1, 2}) { + for (auto nlmode : nonlinemode) { + for (bool pad : {true}) { + for (size_t group : {1, 3, 7}) { + for (size_t size : {4, 6, 7, 9, 16, 20, 32, 55}) { + for (size_t kern : kernel) { + pack(n, group, size, size, kern, stride, nlmode, pad); + } + } + } + } + for (bool pad : {false}) { + for (size_t group : {7}) { + for (size_t size : {37}) { + for (size_t kern : kernel) { + if (size < kern) + continue; + pack(n, group, size, size, kern, stride, nlmode, pad); + } + } + } + } + } + } + return args; +} + std::vector get_nchw88_channel_wise_args( std::vector kernel, size_t stride, bool no_bias, bool no_nonlinemode, bool no_full_bias) { @@ -226,7 +304,7 @@ void checker_conv_bias_qint8x8x8( .set_rng(1, &rng) .set_rng(2, &rng); for (auto&& arg : args) { - checker.set_param(arg.param).execs({arg.src, arg.filter, {}, {}, {}}); + checker.set_param(arg.param).execs({arg.src, arg.filter, arg.bias, {}, {}}); } } void checker_conv_bias_qint8x8x32( @@ -532,6 +610,30 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_QUINT8_STRIDE2) { /****************************dot qint8 direct*************************/ #if MGB_ENABLE_DOT +TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_DOT_DIRECT_LARGE_S1) { + checker_conv_bias_qint8x8x8( + get_channel_wise_args({9}, 1, false, true, true, true), handle(), + "ARMDOTS8_DIRECT_CHANWISE_LARGE"); +} + +TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_DOT_DIRECT_LARGE_S2) { + checker_conv_bias_qint8x8x8( + get_channel_wise_args({9}, 2, false, true, true, true), handle(), + "ARMDOTS8_DIRECT_CHANWISE_LARGE"); +} + +TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_DOT_IM2COL_LARGE_S1) { + checker_conv_bias_qint8x8x8( + get_channel_wise_args({9}, 1, false, true, true, true), handle(), + "ARMDOTS8_IM2COL_CHANWISE_LARGE"); +} + +TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_DOT_IM2COL_LARGE_S2) { + checker_conv_bias_qint8x8x8( + get_channel_wise_args({9}, 2, false, true, true, true), handle(), + "ARMDOTS8_IM2COL_CHANWISE_LARGE"); +} + TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_DOT_NCHW_NCHW44) { auto args = get_nchw44_conv_bias_args( {2, 3, 5, 7}, QUAN_NLMODE, BR_AND_NO_BIASMODE, 2, false, true); diff --git a/dnn/test/arm_common/matrix_mul.cpp b/dnn/test/arm_common/matrix_mul.cpp index e16831e3..fb934f41 100644 --- a/dnn/test/arm_common/matrix_mul.cpp +++ b/dnn/test/arm_common/matrix_mul.cpp @@ -219,6 +219,113 @@ TEST_F(ARM_COMMON, QINT8x8x32_GEMV_MK4_DOT) { for (size_t K : {4, 8, 12, 16, 20, 24, 256, 1024}) run(M, K, 1); } + +TEST_F(ARM_COMMON, QINT8x8x32_GEVM_DOT) { + Checker checker(handle()); + using Param = MatrixMul::Param; + auto algo_ck = AlgoChecker("ARM_COMMON_INT8X8X32_GEVM_DOT"); + + checker.set_before_exec_callback(algo_ck); + + std::unique_ptr rng = std::make_unique(-30, 30); + checker.set_rng(0, rng.get()).set_rng(1, rng.get()); + Param param; + param.format = Param::Format::DEFAULT; + param.transposeA = false; + param.transposeB = false; + + auto run = [&](size_t M, size_t N, size_t K) { + TensorShape A, B; + A = TensorShape{M, K}; + B = TensorShape{K, N}; + checker.set_param(param) + .set_dtype(0, dtype::Int8()) + .set_dtype(1, dtype::Int8()) + .set_dtype(2, dtype::Int32()) + .execs({A, B, {}}); + }; + run(1, 32, 4); + for (int n = 7; n < 43; n += 3) { + for (int k = 1; k < 33; k += 3) { + run(1, n, k); + } + } +} + +TEST_F(ARM_COMMON, QINT8x8x32_GEVM_N32K4_DOT) { + Checker checker(handle()); + using Param = MatrixMul::Param; + auto algo_ck = AlgoChecker("ARM_COMMON_INT8X8X32_GEVM_N32K4_DOT"); + checker.set_before_exec_callback(algo_ck); + + std::unique_ptr rng = std::make_unique(-30, 30); + checker.set_rng(0, rng.get()).set_rng(1, rng.get()); + Param param; + param.format = Param::Format::N32K4_DOT; + param.transposeA = false; + param.transposeB = false; + + auto run = [&](size_t M, size_t N, size_t K) { + TensorShape A, B; + A = TensorShape{M, K}; + B = TensorShape{N / 32, K / 4, 32, 4}; + checker.set_param(param) + .set_dtype(0, dtype::Int8()) + .set_dtype(1, dtype::Int8()) + .set_dtype(2, dtype::Int32()) + .execs({A, B, {}}); + }; + run(1, 32, 4); + for (int n = 32; n < 65; n += 32) { + for (int k = 4; k < 39; k += 4) { + run(1, n, k); + } + } +} + +#if MEGDNN_WITH_BENCHMARK +TEST_F(ARM_COMMON, BENCHMARK_QINT8x8x32_GEVM_N32K4_DOT) { + using Param = MatrixMul::Param; + auto algo_ck = AlgoChecker("ARM_COMMON_INT8X8X32_GEVM_N32K4_DOT"); + + std::unique_ptr rng = std::make_unique(-30, 30); + Param param; + param.format = Param::Format::N32K4_DOT; + param.transposeA = false; + param.transposeB = false; + + constexpr size_t RUNS = 2000; + + Benchmarker benchmarker_int(handle()); + benchmarker_int.set_times(RUNS) + .set_dtype(0, dtype::Int8{}) + .set_dtype(1, dtype::Int8{}) + .set_dtype(2, dtype::Int32{}) + .set_param(param) + .set_before_exec_callback(algo_ck) + .set_display(false); + Benchmarker benchmarker_float(handle()); + benchmarker_float.set_display(false).set_times(RUNS); + + auto bench = [&](size_t M, size_t N, size_t K) { + auto int_used = + benchmarker_int.exec({{M, K}, {N / 32, K / 4, 32, 4}, {}}) / RUNS; + auto float_used = benchmarker_float.exec({{M, K}, {K, N}, {}}) / RUNS; + float computations = 2.f * M * K * N * 1e-6; + float through_put = (M * K + N * K + M * N) * 1e-6; + printf("run: {%zu{M} %zu{K} %zu{N}} float: %f ms %f Gflops int: %f ms " + "%f Gflops speedup: %f, through put %f G\n", + M, K, N, float_used, computations / float_used, int_used, + computations / int_used, float_used / int_used, through_put / int_used); + }; + + bench(1, 256, 512); + bench(1, 256, 1024); + bench(1, 512, 512); + bench(1, 512, 1024); +} +#endif + #endif TEST_F(ARM_COMMON, QINT8x8x32_GEVM) {