Browse Source

feat(dnn/fallback): fuse im2col and packb 4x4x16

GitOrigin-RevId: 123920899d
tags/v0.5.0
Megvii Engine Team Xinran Xu 5 years ago
parent
commit
ca855d8d21
5 changed files with 74 additions and 24 deletions
  1. +41
    -24
      dnn/src/fallback/conv_bias/im2col/factory.h
  2. +2
    -0
      dnn/src/fallback/conv_bias/im2col/strategy_base.h
  3. +0
    -0
      dnn/src/fallback/conv_bias/im2col/strategy_default_nchw44.cpp
  4. +15
    -0
      dnn/src/fallback/conv_bias/im2col/strategy_fuse_nchw44.cpp
  5. +16
    -0
      dnn/test/arm_common/conv_bias_multi_thread.cpp

+ 41
- 24
dnn/src/fallback/conv_bias/im2col/factory.h View File

@@ -41,48 +41,56 @@ enum class StrategyType : uint32_t {
};

struct StrategyHashParam {
fallback::ConvBiasImpl::NCBKernSizeParam param;
param::ConvBias::Format format;
fallback::MatrixMulImpl::AlgoBase::PackMode packmode;
bool is_xcorr;
bool is_square; //! kernel_h == kernel_w, stride_h = stride_w
size_t block_m;
size_t block_n;
size_t block_k;
size_t kernel;
size_t stride;

fallback::ConvBiasImpl::NCBKernSizeParam param;
param::ConvBias::Format format;
fallback::MatrixMulImpl::AlgoBase::PackMode packmode;
};

struct StrategyHashParamHash {
std::size_t operator()(const StrategyHashParam& sparam) const {
constexpr size_t base = 1; //! avoid hashkey is zero
std::size_t result =
static_cast<std::size_t>(sparam.param.src_type.enumv()) + base;
uint64_t operator()(const StrategyHashParam& sparam) const {
constexpr uint64_t base = 1; //! avoid hashkey is zero
uint64_t result =
static_cast<uint64_t>(sparam.param.src_type.enumv()) + base;
result = result ^
((static_cast<std::size_t>(sparam.param.dst_type.enumv()) +
base)
((static_cast<uint64_t>(sparam.param.dst_type.enumv()) + base)
<< 3);
result = result ^
((static_cast<std::size_t>(sparam.param.filter_type.enumv()) +
((static_cast<uint64_t>(sparam.param.filter_type.enumv()) +
base)
<< 6);
result = result ^
((static_cast<std::size_t>(sparam.param.bias_type.enumv()) +
base)
((static_cast<uint64_t>(sparam.param.bias_type.enumv()) + base)
<< 9);
result = result ^ ((static_cast<uint64_t>(sparam.format) + base) << 12);
result = result ^
((static_cast<std::size_t>(sparam.format) + base) << 12);
result = result ^
((static_cast<std::size_t>(sparam.packmode) + base) << 15);
result = result ^
((static_cast<std::size_t>(sparam.block_m) + base) << 18);
((static_cast<uint64_t>(sparam.packmode) + base) << 15);
result =
result ^ ((static_cast<uint64_t>(sparam.block_m) + base) << 18);
result =
result ^ ((static_cast<uint64_t>(sparam.block_n) + base) << 22);
result =
result ^ ((static_cast<uint64_t>(sparam.block_k) + base) << 26);
result = result ^ ((static_cast<uint64_t>(sparam.kernel) + base) << 30);
result = result ^ ((static_cast<uint64_t>(sparam.stride) + base) << 34);
result = result ^
((static_cast<std::size_t>(sparam.block_n) + base) << 22);
((static_cast<uint64_t>(sparam.is_square) + base) << 35);
result = result ^
((static_cast<std::size_t>(sparam.block_k) + base) << 26);
((static_cast<uint64_t>(sparam.is_xcorr) + base) << 36);
return result;
};
};

struct StrategyHashParamEqual {
std::size_t operator()(const StrategyHashParam& param1,
const StrategyHashParam& param2) const {
bool operator()(const StrategyHashParam& param1,
const StrategyHashParam& param2) const {
bool flags = true;
flags = param1.param.src_type == param2.param.src_type && flags;
flags = param1.param.filter_type == param2.param.filter_type && flags;
@@ -93,6 +101,10 @@ struct StrategyHashParamEqual {
flags = param1.block_m == param2.block_m && flags;
flags = param1.block_n == param2.block_n && flags;
flags = param1.block_k == param2.block_k && flags;
flags = param1.kernel == param2.kernel && flags;
flags = param1.stride == param2.stride && flags;
flags = param1.is_square == param2.is_square && flags;
flags = param1.is_xcorr == param2.is_xcorr && flags;
return flags;
};
};
@@ -484,10 +496,15 @@ Strategy* StrategyDelegationStorage::get(
sparam.block_m = block_m;
sparam.block_n = block_n;
sparam.block_k = block_k;
sparam.kernel = param.filter_meta.spatial[0];
sparam.stride = param.filter_meta.stride[0];
sparam.is_square =
param.filter_meta.spatial[0] == param.filter_meta.spatial[0];
sparam.is_xcorr = param.filter_meta.should_flip;
MEGDNN_LOCK_GUARD(m_mtx);
if (map_strategys.find(sparam) == map_strategys.end()) {
MEGDNN_LOCK_GUARD(m_mtx);
auto strategy = Factory::make_strategy(matmul_algo, packmode,
param, stype);
auto strategy =
Factory::make_strategy(matmul_algo, packmode, param, stype);
map_strategys[sparam] = std::move(strategy);
}
return static_cast<Strategy*>(map_strategys[sparam].get());


+ 2
- 0
dnn/src/fallback/conv_bias/im2col/strategy_base.h View File

@@ -293,3 +293,5 @@ public:
WorkspaceBundle bundle_thread, const StrategyParam& sparam);
};
} // namespace megdnn

// vim: syntax=cpp.doxygen

+ 0
- 0
dnn/src/fallback/conv_bias/im2col/strategy_default_nchw44.cpp View File


+ 15
- 0
dnn/src/fallback/conv_bias/im2col/strategy_fuse_nchw44.cpp View File

@@ -0,0 +1,15 @@
/**
* \file dnn/src/fallback/conv_bias/im2col/strategy_fuse_nchw44.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/

#include "src/fallback/conv_bias/im2col/strategy_base.h"


// vim: syntax=cpp.doxygen

+ 16
- 0
dnn/test/arm_common/conv_bias_multi_thread.cpp View File

@@ -1209,6 +1209,22 @@ TEST_F(ARM_COMMON_MULTI_THREADS,
#undef cb
}

#if MEGDNN_AARCH64
TEST_F(ARM_COMMON_MULTI_THREADS,
CONV_BIAS_IM2COLMATMUL_QUANTIZEDSYM_NCHW44_FUSE) {
UniformIntRNG rng{-50, 50};

#define cb(name) \
checker_conv_bias(get_nchw44_conv_bias_args({3}, 1), handle(), &rng, \
epsilon, dtype::QuantizedS8(2.5f), \
dtype::QuantizedS8(2.5f), dtype::QuantizedS32(6.25f), \
dtype::QuantizedS8(60.25f), name);
float epsilon = 0.001;
cb("IM2COLMATMUL:AARCH64_INT8X8X32_MK4_4X4X16:96");
#undef cb
}
#endif

#endif
#endif



Loading…
Cancel
Save