diff --git a/dnn/src/fallback/conv_bias/im2col/factory.h b/dnn/src/fallback/conv_bias/im2col/factory.h index 528917e8..9aba640b 100644 --- a/dnn/src/fallback/conv_bias/im2col/factory.h +++ b/dnn/src/fallback/conv_bias/im2col/factory.h @@ -41,48 +41,56 @@ enum class StrategyType : uint32_t { }; struct StrategyHashParam { - fallback::ConvBiasImpl::NCBKernSizeParam param; - param::ConvBias::Format format; - fallback::MatrixMulImpl::AlgoBase::PackMode packmode; + bool is_xcorr; + bool is_square; //! kernel_h == kernel_w, stride_h = stride_w size_t block_m; size_t block_n; size_t block_k; + size_t kernel; + size_t stride; + + fallback::ConvBiasImpl::NCBKernSizeParam param; + param::ConvBias::Format format; + fallback::MatrixMulImpl::AlgoBase::PackMode packmode; }; struct StrategyHashParamHash { - std::size_t operator()(const StrategyHashParam& sparam) const { - constexpr size_t base = 1; //! avoid hashkey is zero - std::size_t result = - static_cast(sparam.param.src_type.enumv()) + base; + uint64_t operator()(const StrategyHashParam& sparam) const { + constexpr uint64_t base = 1; //! avoid hashkey is zero + uint64_t result = + static_cast(sparam.param.src_type.enumv()) + base; result = result ^ - ((static_cast(sparam.param.dst_type.enumv()) + - base) + ((static_cast(sparam.param.dst_type.enumv()) + base) << 3); result = result ^ - ((static_cast(sparam.param.filter_type.enumv()) + + ((static_cast(sparam.param.filter_type.enumv()) + base) << 6); result = result ^ - ((static_cast(sparam.param.bias_type.enumv()) + - base) + ((static_cast(sparam.param.bias_type.enumv()) + base) << 9); + result = result ^ ((static_cast(sparam.format) + base) << 12); result = result ^ - ((static_cast(sparam.format) + base) << 12); - result = result ^ - ((static_cast(sparam.packmode) + base) << 15); - result = result ^ - ((static_cast(sparam.block_m) + base) << 18); + ((static_cast(sparam.packmode) + base) << 15); + result = + result ^ ((static_cast(sparam.block_m) + base) << 18); + result = + result ^ ((static_cast(sparam.block_n) + base) << 22); + result = + result ^ ((static_cast(sparam.block_k) + base) << 26); + result = result ^ ((static_cast(sparam.kernel) + base) << 30); + result = result ^ ((static_cast(sparam.stride) + base) << 34); result = result ^ - ((static_cast(sparam.block_n) + base) << 22); + ((static_cast(sparam.is_square) + base) << 35); result = result ^ - ((static_cast(sparam.block_k) + base) << 26); + ((static_cast(sparam.is_xcorr) + base) << 36); return result; }; }; struct StrategyHashParamEqual { - std::size_t operator()(const StrategyHashParam& param1, - const StrategyHashParam& param2) const { + bool operator()(const StrategyHashParam& param1, + const StrategyHashParam& param2) const { bool flags = true; flags = param1.param.src_type == param2.param.src_type && flags; flags = param1.param.filter_type == param2.param.filter_type && flags; @@ -93,6 +101,10 @@ struct StrategyHashParamEqual { flags = param1.block_m == param2.block_m && flags; flags = param1.block_n == param2.block_n && flags; flags = param1.block_k == param2.block_k && flags; + flags = param1.kernel == param2.kernel && flags; + flags = param1.stride == param2.stride && flags; + flags = param1.is_square == param2.is_square && flags; + flags = param1.is_xcorr == param2.is_xcorr && flags; return flags; }; }; @@ -484,10 +496,15 @@ Strategy* StrategyDelegationStorage::get( sparam.block_m = block_m; sparam.block_n = block_n; sparam.block_k = block_k; + sparam.kernel = param.filter_meta.spatial[0]; + sparam.stride = param.filter_meta.stride[0]; + sparam.is_square = + param.filter_meta.spatial[0] == param.filter_meta.spatial[0]; + sparam.is_xcorr = param.filter_meta.should_flip; + MEGDNN_LOCK_GUARD(m_mtx); if (map_strategys.find(sparam) == map_strategys.end()) { - MEGDNN_LOCK_GUARD(m_mtx); - auto strategy = Factory::make_strategy(matmul_algo, packmode, - param, stype); + auto strategy = + Factory::make_strategy(matmul_algo, packmode, param, stype); map_strategys[sparam] = std::move(strategy); } return static_cast(map_strategys[sparam].get()); diff --git a/dnn/src/fallback/conv_bias/im2col/strategy_base.h b/dnn/src/fallback/conv_bias/im2col/strategy_base.h index ab372680..1efbdff9 100644 --- a/dnn/src/fallback/conv_bias/im2col/strategy_base.h +++ b/dnn/src/fallback/conv_bias/im2col/strategy_base.h @@ -293,3 +293,5 @@ public: WorkspaceBundle bundle_thread, const StrategyParam& sparam); }; } // namespace megdnn + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/fallback/conv_bias/im2col/strategy_default_nchw44.cpp b/dnn/src/fallback/conv_bias/im2col/strategy_default_nchw44.cpp old mode 100755 new mode 100644 diff --git a/dnn/src/fallback/conv_bias/im2col/strategy_fuse_nchw44.cpp b/dnn/src/fallback/conv_bias/im2col/strategy_fuse_nchw44.cpp new file mode 100644 index 00000000..88980d4c --- /dev/null +++ b/dnn/src/fallback/conv_bias/im2col/strategy_fuse_nchw44.cpp @@ -0,0 +1,15 @@ +/** + * \file dnn/src/fallback/conv_bias/im2col/strategy_fuse_nchw44.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ + +#include "src/fallback/conv_bias/im2col/strategy_base.h" + + +// vim: syntax=cpp.doxygen diff --git a/dnn/test/arm_common/conv_bias_multi_thread.cpp b/dnn/test/arm_common/conv_bias_multi_thread.cpp index 6c0553f6..28bc67d7 100644 --- a/dnn/test/arm_common/conv_bias_multi_thread.cpp +++ b/dnn/test/arm_common/conv_bias_multi_thread.cpp @@ -1209,6 +1209,22 @@ TEST_F(ARM_COMMON_MULTI_THREADS, #undef cb } +#if MEGDNN_AARCH64 +TEST_F(ARM_COMMON_MULTI_THREADS, + CONV_BIAS_IM2COLMATMUL_QUANTIZEDSYM_NCHW44_FUSE) { + UniformIntRNG rng{-50, 50}; + +#define cb(name) \ + checker_conv_bias(get_nchw44_conv_bias_args({3}, 1), handle(), &rng, \ + epsilon, dtype::QuantizedS8(2.5f), \ + dtype::QuantizedS8(2.5f), dtype::QuantizedS32(6.25f), \ + dtype::QuantizedS8(60.25f), name); + float epsilon = 0.001; + cb("IM2COLMATMUL:AARCH64_INT8X8X32_MK4_4X4X16:96"); +#undef cb +} +#endif + #endif #endif