Browse Source

fix(mbg/arm_common): fix nchw44-dot misc issue

GitOrigin-RevId: f870ad964c
tags/v0.5.0
Megvii Engine Team Xu Xinran 5 years ago
parent
commit
02abc36ea6
7 changed files with 25 additions and 14 deletions
  1. +1
    -1
      dnn/src/arm_common/conv_bias/int8/dot_direct_nchw_nchw44_algo.cpp
  2. +2
    -2
      dnn/src/arm_common/conv_bias/opr_impl.cpp
  3. +13
    -8
      dnn/test/arm_common/conv_bias_multi_thread.cpp
  4. +6
    -0
      sdk/load-and-run/src/mgblar.cpp
  5. +0
    -1
      src/core/include/megbrain/utils/persistent_cache.h
  6. +0
    -1
      src/opr/impl/dnn/convolution.cpp
  7. +3
    -1
      src/plugin/impl/opr_footprint.cpp

+ 1
- 1
dnn/src/arm_common/conv_bias/int8/dot_direct_nchw_nchw44_algo.cpp View File

@@ -182,7 +182,7 @@ bool ConvBiasImpl::AlgoDotS8DirectNCHWNCHW44::usable(
bool ok_type = ((param.src_type.enumv() == DTypeEnum::QuantizedS8 &&
param.filter_type.enumv() == DTypeEnum::QuantizedS8 &&
(param.dst_type.enumv() == DTypeEnum::QuantizedS8))) &&
(fm.format == param::Convolution::Format::NCHW44);
(fm.format == param::Convolution::Format::NCHW44_DOT);
bool ok_src_dst = (oc % 4 == 0 && oc >= 4 && ic < 4);
bool ok_filter = fm.spatial_ndim == 2 && fh == fm.spatial[1] &&
(fh == 2 || fh == 3 || fh == 5 || fh == 7);


+ 2
- 2
dnn/src/arm_common/conv_bias/opr_impl.cpp View File

@@ -55,7 +55,6 @@ class ConvBiasImpl::AlgoPack : NonCopyableObj {
AlgoS8ChanWiseStride2NCHW44 s8_channel_wise_stride2_nchw44;

#if __ARM_FEATURE_DOTPROD
AlgoDotS8DirectNCHWNCHW44 ds8_direct_stride2_nchw_nchw44;
AlgoDotS8DirectStride1 ds8_direct_stride1_large_group{true};
AlgoDotS8DirectStride1 ds8_direct_stride1_small_group{false};
AlgoDotS8DirectStride2 ds8_direct_stride2_large_group{true};
@@ -66,6 +65,7 @@ class ConvBiasImpl::AlgoPack : NonCopyableObj {
AlgoDotU8DirectStride2 du8_direct_stride2_small_group{false};

AlgoDotS8Direct_NCHW44 ds8_direct_nchw44;
AlgoDotS8DirectNCHWNCHW44 ds8_direct_nchw_nchw44;
#endif

AlgoF32DirectNCHWNCHW44 f32_direct_stride2_nchw_nchw44;
@@ -96,7 +96,6 @@ class ConvBiasImpl::AlgoPack : NonCopyableObj {
public:
AlgoPack() {
#if __ARM_FEATURE_DOTPROD
direct_algos.emplace_back(&ds8_direct_stride2_nchw_nchw44);
direct_algos.emplace_back(&ds8_direct_stride1_large_group);
direct_algos.emplace_back(&ds8_direct_stride1_small_group);
direct_algos.emplace_back(&ds8_direct_stride2_large_group);
@@ -107,6 +106,7 @@ public:
direct_algos.emplace_back(&du8_direct_stride2_small_group);

direct_algos.emplace_back(&ds8_direct_nchw44);
direct_algos.emplace_back(&ds8_direct_nchw_nchw44);
#endif
direct_algos.emplace_back(&qu8_direct_stride2_large_group);
direct_algos.emplace_back(&qu8_direct_stride2_small_group);


+ 13
- 8
dnn/test/arm_common/conv_bias_multi_thread.cpp View File

@@ -582,14 +582,19 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_QUINT8_STRIDE2_SMALL_GROUP) {
/****************************dot qint8 direct*************************/
#if __ARM_FEATURE_DOTPROD
TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_DOT_NCHW_NCHW44) {
checker_conv_bias_qint8x8x8(
get_nchw44_conv_bias_args({2, 3, 5, 7}, 2, false, false, false,
true),
handle(), "ARMDOTS8_NCHW_NCHW44");
checker_conv_bias_qint8x8x8(
get_nchw44_conv_bias_args({2, 3, 5, 7}, 1, false, false, false,
true),
handle(), "ARMDOTS8_NCHW_NCHW44");
auto args = get_nchw44_conv_bias_args({2, 3, 5, 7}, 2, false, false, false,
true);
for (auto&& arg : args) {
arg.param.format = param::ConvBias::Format::NCHW44_DOT;
}
checker_conv_bias_qint8x8x8(args, handle(), "ARMDOTS8_NCHW_NCHW44");

args = get_nchw44_conv_bias_args({2, 3, 5, 7}, 1, false, false, false,
true);
for (auto&& arg : args) {
arg.param.format = param::ConvBias::Format::NCHW44_DOT;
}
checker_conv_bias_qint8x8x8(args, handle(), "ARMDOTS8_NCHW_NCHW44");
}
TEST_F(ARM_COMMON_MULTI_THREADS,
CONV_BIAS_INT8_STRIDE1_WITHDOTPROD_LARGE_GROUP) {


+ 6
- 0
sdk/load-and-run/src/mgblar.cpp View File

@@ -987,6 +987,12 @@ Args Args::from_argv(int argc, char **argv) {
cb(nchw32);
cb(nhwcd4);
#undef cb
if (!strcmp(argv[i], "--enable-nchw44-dot")) {
mgb_log_warn("enable-nchw44-dot optimization");
graph_opt.graph_opt.enable_nchw44_dot();
continue;
}

if (!strcmp(argv[i], "--enable-fuse-conv-bias-nonlinearity")) {
mgb_log_warn("enable fuse-conv-bias-nonlinearity optimization");
graph_opt.graph_opt.enable_fuse_conv_bias_nonlinearity();


+ 0
- 1
src/core/include/megbrain/utils/persistent_cache.h View File

@@ -94,7 +94,6 @@ namespace mgb {
m_param{param}, m_param_size{param_size}
{
}

//! build a blob representation to be used as cache key
PersistentCache::Blob build_blob() const;
};


+ 0
- 1
src/opr/impl/dnn/convolution.cpp View File

@@ -611,7 +611,6 @@ AlgoChooserProfileCache::Result AlgoChooser<Opr>::get_profile_result(
AlgoChooserProfileCache::Key cache_key{origin_layouts.data(),
origin_layouts.size(), &origin_param,
sizeof(origin_param)};

{
auto&& rst = cache.get(cache_key);
if (rst.valid())


+ 3
- 1
src/plugin/impl/opr_footprint.cpp View File

@@ -107,7 +107,8 @@ uint64_t eval_conv_computation(const TensorShape& src_shape,
src_shape[1] / group * 2;
return hybird_nchwx ? computation : computation * 8;
}
if (param.format == Param::Format::NCHW44) {
if (param.format == Param::Format::NCHW44 ||
param.format == Param::Format::NCHW44_DOT) {
//! if channel wise weight layout is {group/4, FH, FW, 1, 1, 4}
if (filter_shape[1] == 1 && filter_shape[2] == 1) {
group *= 4;
@@ -145,6 +146,7 @@ uint64_t eval_conv_computation(const TensorShape& src_shape,
if (param.format == Param::Format::NCHW4 ||
param.format == Param::Format::NCHW88 ||
param.format == Param::Format::NCHW44 ||
param.format == Param::Format::NCHW44_DOT ||
param.format == Param::Format::NCHW32) {
return eval_conv_computation_nchwx();
}


Loading…
Cancel
Save