Browse Source

perf(dnn/arm_common): add nchw44 winograd f73

GitOrigin-RevId: 8ed98ab85b
tags/v1.0.0-rc1
Megvii Engine Team 5 years ago
parent
commit
f6018422fd
13 changed files with 756 additions and 11 deletions
  1. +45
    -0
      dnn/src/arm_common/conv_bias/fp32/algos.cpp
  2. +16
    -0
      dnn/src/arm_common/conv_bias/fp32/algos.h
  3. +3
    -0
      dnn/src/arm_common/conv_bias/fp32/strategy.h
  4. +1
    -2
      dnn/src/arm_common/conv_bias/fp32/strategy_f63_mk4_nchw44.cpp
  5. +587
    -0
      dnn/src/arm_common/conv_bias/fp32/strategy_f73_mk4_nchw44.cpp
  6. +7
    -0
      dnn/src/arm_common/conv_bias/opr_impl.cpp
  7. +1
    -0
      dnn/src/arm_common/conv_bias/opr_impl.h
  8. +4
    -0
      dnn/src/arm_common/winograd_filter_preprocess/opr_impl.cpp
  9. +37
    -0
      dnn/src/common/unroll_macro.h
  10. +6
    -1
      dnn/src/naive/winograd_filter_preprocess/opr_impl.cpp
  11. +22
    -7
      dnn/test/arm_common/conv_bias.cpp
  12. +26
    -0
      dnn/test/arm_common/conv_bias_multi_thread.cpp
  13. +1
    -1
      dnn/test/arm_common/matrix_mul.cpp

+ 45
- 0
dnn/src/arm_common/conv_bias/fp32/algos.cpp View File

@@ -331,6 +331,51 @@ MEGDNN_WINOGRAD_ALGO_FUN_DEFINE_ALL(AlgoFP32WinogradF63_4x4_NCHW44,
megdnn_arm_common_winograd_fp32,
param::MatrixMul::Format::MK4);

/* =================== AlgoFP32WinogradF73_4x4_NCHW44 ===================== */

bool ConvBiasImpl::AlgoFP32WinogradF73_4x4_NCHW44::usable(
const NCBKernSizeParam& param,
AlgoSelectionStrategy /*algo_selection_strategy*/) const {
MIDOUT_BEGIN(megdnn_arm_common_winograd_fp32,
midout_iv("AlgoFP32WinogradF73_4x4_NCHW44"_hash)) {
if (param.filter_meta.icpg % 4 != 0 || param.filter_meta.ocpg % 4 != 0)
return false;
using Strategy = winograd::winograd_F73_mk4_f_nchw44;
Strategy strategy(param.src_type, param.filter_type, param.dst_type);
auto&& matmul_param =
megdnn::winograd::ConvBias<Strategy,
param::MatrixMul::Format::MK4>(
strategy, m_tile_size, param)
.get_matmul_kern_param(param);
return m_matmul_algo->usable(matmul_param) &&
m_matmul_algo->packmode() ==
fallback::MatrixMulImpl::AlgoBase::PackMode::NO_PACK &&
(param.filter_meta.format == param::ConvBias::Format::NCHW44 ||
(param.filter_meta.format ==
param::ConvBias::Format::NCHW44_WINOGRAD &&
param.output_block_size == 7 &&
param.winograd_matmul_format ==
param::MatrixMul::Format::MK4)) &&
!param.filter_meta.should_flip &&
(param.filter_meta.spatial[0] == param.filter_meta.spatial[1] &&
param.filter_meta.spatial[0] == 3) &&
(param.filter_meta.stride[0] == param.filter_meta.stride[1] &&
param.filter_meta.stride[0] == 1) &&
(param.filter_meta.dilation[0] ==
param.filter_meta.dilation[1] &&
param.filter_meta.dilation[0] == 1) &&
param.compute_mode == param::ConvBias::ComputeMode::DEFAULT &&
param.src_type.enumv() == DTypeEnum::Float32;
}
MIDOUT_END();
return false;
}

MEGDNN_WINOGRAD_ALGO_FUN_DEFINE_ALL(AlgoFP32WinogradF73_4x4_NCHW44,
winograd::winograd_F73_mk4_f_nchw44,
megdnn_arm_common_winograd_fp32,
param::MatrixMul::Format::MK4);

/* ===================== direct algo ===================== */
MIDOUT_DECL(megdnn_arm_common_conv_bias_f32_kimpl);



+ 16
- 0
dnn/src/arm_common/conv_bias/fp32/algos.h View File

@@ -124,6 +124,22 @@ public:
}
MEGDNN_WINOGRAD_ALGO_FUN_DECLARE();
};

class ConvBiasImpl::AlgoFP32WinogradF73_4x4_NCHW44 final : public AlgoBase {
public:
AlgoFP32WinogradF73_4x4_NCHW44(
fallback::MatrixMulImpl::AlgoBase* matmul_algo, uint32_t tile_size)
: m_matmul_algo{matmul_algo}, m_tile_size{tile_size} {}
const char* name() const override {
if (m_name.empty()) {
m_name = ConvBiasImpl::algo_name<ConvBias::WinogradParam>(
m_matmul_algo->name(), {4, 7, m_tile_size},
param::ConvBias::Format::NCHW44);
}
return m_name.c_str();
}
MEGDNN_WINOGRAD_ALGO_FUN_DECLARE();
};
// ================================================================= //

class ConvBiasImpl::AlgoF32Direct final : public AlgoBase {


+ 3
- 0
dnn/src/arm_common/conv_bias/fp32/strategy.h View File

@@ -38,6 +38,9 @@ MEGDNN_REG_WINOGRAD_STRATEGY(float, float, float, float, 2, 3, 4, 4,

MEGDNN_REG_WINOGRAD_STRATEGY(float, float, float, float, 6, 3, 4, 4,
winograd_F63_mk4_f_nchw44)

MEGDNN_REG_WINOGRAD_STRATEGY(float, float, float, float, 7, 3, 4, 4,
winograd_F73_mk4_f_nchw44)
} // namespace winograd
} // namespace arm_common
} // namespace megdnn


+ 1
- 2
dnn/src/arm_common/conv_bias/fp32/strategy_f63_mk4_nchw44.cpp View File

@@ -1,6 +1,5 @@
/**
* \file dnn/src/arm_common/conv_bias/fp32/strategy_f
* 63_mk4_nchw44.cpp
* \file dnn/src/arm_common/conv_bias/fp32/strategy_f63_mk4_nchw44.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.


+ 587
- 0
dnn/src/arm_common/conv_bias/fp32/strategy_f73_mk4_nchw44.cpp View File

@@ -0,0 +1,587 @@
/**
* \file dnn/src/arm_common/conv_bias/fp32/strategy_f73_mk4_nchw44.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/

#include "src/arm_common/conv_bias/fp32/filter_transform.h"
#include "src/arm_common/conv_bias/fp32/helper.h"
#include "src/arm_common/conv_bias/fp32/strategy.h"
#include "src/arm_common/elemwise_helper/op_unary.h"
#include "src/arm_common/simd_macro/marm_neon.h"
#include "src/arm_common/utils.h"
#include "src/common/unroll_macro.h"
#include "src/common/utils.h"
#include "src/common/winograd/winograd_helper.h"
#include "src/fallback/conv_bias/winograd/winograd.h"

#include "midout.h"
MIDOUT_DECL(megdnn_arm_common_winograd_fp32_F73_mk4)

using namespace megdnn;
using namespace arm_common;

namespace {

constexpr size_t alpha = 7 + 3 - 1;
constexpr size_t pack_size = 4;
constexpr float input_parameters[28] = {
1.5f, 0.75f, 3.0f, 7.875f, 0.5f, 2.5f, 0.125f,
0.875f, 4.0f, 8.0f, 5.25f, 7.375f, 5.375f, 3.5f,
7.75f, 0.25f, 2.125f, 10.625f, 0.625f, 4.375f, 5.0f,
10.0f, 5.75f, 2.75f, 4.25f, 1.75f, 2.0f, 0.0f};

struct InputTransformF73_NCHW44 {
template <bool inner>
static void prepare(const float* input, float* patch, float* patchT,
int ih_start, int iw_start, size_t IH, size_t IW,
size_t ic, size_t IC) {
MEGDNN_MARK_USED_VAR(patch);
size_t IW4 = IW * pack_size;
size_t iw4_start = iw_start * pack_size;
size_t icb = ic / pack_size;
if (!(inner && ic + pack_size < IC)) {
memset(patchT, 0, sizeof(float) * pack_size * alpha * alpha);
}
if (inner) {
const float* input_ptr =
input + icb * IH * IW4 + ih_start * IW4 + iw4_start;
for (size_t ih = 0; ih < alpha; ih++) {
#define cb(i) auto v##i = vld1q_f32(input_ptr + pack_size * i);
UNROLL_CALL_NOWRAPPER(9, cb);
#undef cb

#define cb(i) vst1q_f32(patchT + ih * pack_size * alpha + i * pack_size, v##i);
UNROLL_CALL_NOWRAPPER(9, cb);
#undef cb
input_ptr += IW4;
}
} else {
int ih0_act = std::max<int>(ih_start, 0),
ih1_act = std::min<int>(ih_start + alpha, IH),
iw0_act = std::max<int>(iw_start, 0),
iw1_act = std::min<int>(iw_start + alpha, IW);
const float* input_ptr = input + icb * IH * IW4;
// partial copy
for (int ih = ih0_act; ih < ih1_act; ++ih) {
for (int iw = iw0_act; iw < iw1_act; ++iw) {
size_t iho = ih - ih_start, iwo = iw - iw_start;
auto src = vld1q_f32(input_ptr + ih * IW4 + iw * pack_size);
vst1q_f32(
patchT + iho * pack_size * alpha + iwo * pack_size,
src);
}
}
}
}

static void transform(const float* patchT, float* input_transform_buf,
size_t unit_idx, size_t nr_units_in_tile, size_t ic,
size_t IC) {
// BT * d * B

size_t ICB = IC / pack_size;
size_t icb = ic / pack_size;

float32x4_t d0, d1, d2, d3, d4, d5, d6, d7, d8;
float32x4_t v0 = vld1q_f32(input_parameters + 0);
float32x4_t v1 = vld1q_f32(input_parameters + 4);
float32x4_t v2 = vld1q_f32(input_parameters + 8);
float32x4_t v3 = vld1q_f32(input_parameters + 12);
float32x4_t v4 = vld1q_f32(input_parameters + 16);
float32x4_t v5 = vld1q_f32(input_parameters + 20);
float32x4_t v6 = vld1q_f32(input_parameters + 24);

//! B
//! 1.5 0 0 0 0 0 0 0 0
//! -1 -1.5 1.5 -0.75 0.75 -3 3 -1 1.5
//! -7.875 -0.5 -2.5 0.125 -0.875 -4 -8 0 -1
//! 5.25 7.375 -5.375 4 -3.5 7.75 0.25 5.25 -7.875
//! 7.875 2.125 10.625 -0.625 4.375 5 10 0 5.25
//! -5.25 -5.75 -2.75 -4.25 1.75 -5.75 -4.25 -5.25 7.875
//! -1.5 -0.5 -2.5 0.5 -3.5 -1 -2 0 -5.25
//! 1 1 1 1 1 1 1 1 -1.5
//! 0 0 0 0 0 0 0 0 1

// 1.5f, 0.75f, 3.0f, 7.875f, v0
// 0.5f, 2.5f, 0.125f, 0.875f, v1
// 4.0f, 8.0f, 5.25f, 7.375f, v2
// 5.375f, 3.5f, 7.75f, 0.25f, v3
// 2.125f, 10.625f, 0.625f, 4.375f, v4
// 5.0f, 10.0f, 5.75f, 2.75f, v5
// 4.25f, 1.75f, 2.0f, 0.0f, v6

#define cb(i) \
d0 = vld1q_f32(patchT + i * alpha * pack_size + 0 * pack_size); \
d1 = vld1q_f32(patchT + i * alpha * pack_size + 1 * pack_size); \
d2 = vld1q_f32(patchT + i * alpha * pack_size + 2 * pack_size); \
d3 = vld1q_f32(patchT + i * alpha * pack_size + 3 * pack_size); \
d4 = vld1q_f32(patchT + i * alpha * pack_size + 4 * pack_size); \
d5 = vld1q_f32(patchT + i * alpha * pack_size + 5 * pack_size); \
d6 = vld1q_f32(patchT + i * alpha * pack_size + 6 * pack_size); \
d7 = vld1q_f32(patchT + i * alpha * pack_size + 7 * pack_size); \
auto t##i##8 = vld1q_f32(patchT + i * alpha * pack_size + 8 * pack_size); \
auto t##i##0 = d7; \
auto t##i##1 = d7; \
auto t##i##2 = d7; \
auto t##i##3 = d7; \
auto t##i##4 = d7; \
auto t##i##5 = d7; \
auto t##i##6 = d7; \
auto t##i##7 = d7; \
t##i##8 = vfmsq_laneq_f32(t##i##8, d7, v0, 0); \
t##i##0 = t##i##0 - d1; \
t##i##1 = vfmsq_laneq_f32(t##i##1, d1, v0, 0); \
t##i##2 = vfmaq_laneq_f32(t##i##2, d1, v0, 0); \
t##i##3 = vfmsq_laneq_f32(t##i##3, d1, v0, 1); \
t##i##4 = vfmaq_laneq_f32(t##i##4, d1, v0, 1); \
t##i##5 = vfmsq_laneq_f32(t##i##5, d1, v0, 2); \
t##i##6 = vfmaq_laneq_f32(t##i##6, d1, v0, 2); \
t##i##7 = t##i##7 - d1; \
t##i##8 = vfmaq_laneq_f32(t##i##8, d1, v0, 0); \
t##i##0 = vfmsq_laneq_f32(t##i##0, d2, v0, 3); \
t##i##1 = vfmsq_laneq_f32(t##i##1, d2, v1, 0); \
t##i##2 = vfmsq_laneq_f32(t##i##2, d2, v1, 1); \
t##i##3 = vfmaq_laneq_f32(t##i##3, d2, v1, 2); \
t##i##4 = vfmsq_laneq_f32(t##i##4, d2, v1, 3); \
t##i##5 = vfmsq_laneq_f32(t##i##5, d2, v2, 0); \
t##i##6 = vfmsq_laneq_f32(t##i##6, d2, v2, 1); \
t##i##8 = t##i##8 - d2; \
t##i##0 = vfmaq_laneq_f32(t##i##0, d3, v2, 2); \
t##i##1 = vfmaq_laneq_f32(t##i##1, d3, v2, 3); \
t##i##2 = vfmsq_laneq_f32(t##i##2, d3, v3, 0); \
t##i##3 = vfmaq_laneq_f32(t##i##3, d3, v2, 0); \
t##i##4 = vfmsq_laneq_f32(t##i##4, d3, v3, 1); \
t##i##5 = vfmaq_laneq_f32(t##i##5, d3, v3, 2); \
t##i##6 = vfmaq_laneq_f32(t##i##6, d3, v3, 3); \
t##i##7 = vfmaq_laneq_f32(t##i##7, d3, v2, 2); \
t##i##8 = vfmsq_laneq_f32(t##i##8, d3, v0, 3); \
t##i##0 = vfmaq_laneq_f32(t##i##0, d4, v0, 3); \
t##i##1 = vfmaq_laneq_f32(t##i##1, d4, v4, 0); \
t##i##2 = vfmaq_laneq_f32(t##i##2, d4, v4, 1); \
t##i##3 = vfmsq_laneq_f32(t##i##3, d4, v4, 2); \
t##i##4 = vfmaq_laneq_f32(t##i##4, d4, v4, 3); \
t##i##5 = vfmaq_laneq_f32(t##i##5, d4, v5, 0); \
t##i##6 = vfmaq_laneq_f32(t##i##6, d4, v5, 1); \
t##i##8 = vfmaq_laneq_f32(t##i##8, d4, v2, 2); \
t##i##0 = vfmsq_laneq_f32(t##i##0, d5, v2, 2); \
t##i##1 = vfmsq_laneq_f32(t##i##1, d5, v5, 2); \
t##i##2 = vfmsq_laneq_f32(t##i##2, d5, v5, 3); \
t##i##3 = vfmsq_laneq_f32(t##i##3, d5, v6, 0); \
t##i##4 = vfmaq_laneq_f32(t##i##4, d5, v6, 1); \
t##i##5 = vfmsq_laneq_f32(t##i##5, d5, v5, 2); \
t##i##6 = vfmsq_laneq_f32(t##i##6, d5, v6, 0); \
t##i##7 = vfmsq_laneq_f32(t##i##7, d5, v2, 2); \
t##i##8 = vfmaq_laneq_f32(t##i##8, d5, v0, 3); \
t##i##0 = vfmsq_laneq_f32(t##i##0, d6, v0, 0); \
t##i##1 = vfmsq_laneq_f32(t##i##1, d6, v1, 0); \
t##i##2 = vfmsq_laneq_f32(t##i##2, d6, v1, 1); \
t##i##3 = vfmaq_laneq_f32(t##i##3, d6, v1, 0); \
t##i##4 = vfmsq_laneq_f32(t##i##4, d6, v3, 1); \
t##i##5 = t##i##5 - d6; \
t##i##6 = vfmsq_laneq_f32(t##i##6, d6, v6, 2); \
t##i##8 = vfmsq_laneq_f32(t##i##8, d6, v2, 2); \
t##i##0 = vfmaq_laneq_f32(t##i##0, d0, v0, 0);

UNROLL_CALL_RAW(9, cb);
#undef cb

#define cb(i) \
d8 = t8##i; \
d0 = t7##i; \
d1 = t7##i; \
d2 = t7##i; \
d3 = t7##i; \
d4 = t7##i; \
d5 = t7##i; \
d6 = t7##i; \
d7 = t7##i; \
d8 = vfmsq_laneq_f32(d8, t7##i, v0, 0); \
d0 = d0 - t1##i; \
d1 = vfmsq_laneq_f32(d1, t1##i, v0, 0); \
d2 = vfmaq_laneq_f32(d2, t1##i, v0, 0); \
d3 = vfmsq_laneq_f32(d3, t1##i, v0, 1); \
d4 = vfmaq_laneq_f32(d4, t1##i, v0, 1); \
d5 = vfmsq_laneq_f32(d5, t1##i, v0, 2); \
d6 = vfmaq_laneq_f32(d6, t1##i, v0, 2); \
d7 = d7 - t1##i; \
d8 = vfmaq_laneq_f32(d8, t1##i, v0, 0); \
d0 = vfmsq_laneq_f32(d0, t2##i, v0, 3); \
d1 = vfmsq_laneq_f32(d1, t2##i, v1, 0); \
d2 = vfmsq_laneq_f32(d2, t2##i, v1, 1); \
d3 = vfmaq_laneq_f32(d3, t2##i, v1, 2); \
d4 = vfmsq_laneq_f32(d4, t2##i, v1, 3); \
d5 = vfmsq_laneq_f32(d5, t2##i, v2, 0); \
d6 = vfmsq_laneq_f32(d6, t2##i, v2, 1); \
d8 = d8 - t2##i; \
d0 = vfmaq_laneq_f32(d0, t3##i, v2, 2); \
d1 = vfmaq_laneq_f32(d1, t3##i, v2, 3); \
d2 = vfmsq_laneq_f32(d2, t3##i, v3, 0); \
d3 = vfmaq_laneq_f32(d3, t3##i, v2, 0); \
d4 = vfmsq_laneq_f32(d4, t3##i, v3, 1); \
d5 = vfmaq_laneq_f32(d5, t3##i, v3, 2); \
d6 = vfmaq_laneq_f32(d6, t3##i, v3, 3); \
d7 = vfmaq_laneq_f32(d7, t3##i, v2, 2); \
d8 = vfmsq_laneq_f32(d8, t3##i, v0, 3); \
d0 = vfmaq_laneq_f32(d0, t4##i, v0, 3); \
d1 = vfmaq_laneq_f32(d1, t4##i, v4, 0); \
d2 = vfmaq_laneq_f32(d2, t4##i, v4, 1); \
d3 = vfmsq_laneq_f32(d3, t4##i, v4, 2); \
d4 = vfmaq_laneq_f32(d4, t4##i, v4, 3); \
d5 = vfmaq_laneq_f32(d5, t4##i, v5, 0); \
d6 = vfmaq_laneq_f32(d6, t4##i, v5, 1); \
d8 = vfmaq_laneq_f32(d8, t4##i, v2, 2); \
d0 = vfmsq_laneq_f32(d0, t5##i, v2, 2); \
d1 = vfmsq_laneq_f32(d1, t5##i, v5, 2); \
d2 = vfmsq_laneq_f32(d2, t5##i, v5, 3); \
d3 = vfmsq_laneq_f32(d3, t5##i, v6, 0); \
d4 = vfmaq_laneq_f32(d4, t5##i, v6, 1); \
d5 = vfmsq_laneq_f32(d5, t5##i, v5, 2); \
d6 = vfmsq_laneq_f32(d6, t5##i, v6, 0); \
d7 = vfmsq_laneq_f32(d7, t5##i, v2, 2); \
d8 = vfmaq_laneq_f32(d8, t5##i, v0, 3); \
d0 = vfmsq_laneq_f32(d0, t6##i, v0, 0); \
d1 = vfmsq_laneq_f32(d1, t6##i, v1, 0); \
d2 = vfmsq_laneq_f32(d2, t6##i, v1, 1); \
d3 = vfmaq_laneq_f32(d3, t6##i, v1, 0); \
d4 = vfmsq_laneq_f32(d4, t6##i, v3, 1); \
d5 = d5 - t6##i; \
d6 = vfmsq_laneq_f32(d6, t6##i, v6, 2); \
d8 = vfmsq_laneq_f32(d8, t6##i, v2, 2); \
d0 = vfmaq_laneq_f32(d0, t0##i, v0, 0); \
vst1q_f32(input_transform_buf + \
(0 * alpha + i) * ICB * nr_units_in_tile * pack_size + \
icb * nr_units_in_tile * pack_size + \
unit_idx * pack_size, \
d0); \
vst1q_f32(input_transform_buf + \
(1 * alpha + i) * ICB * nr_units_in_tile * pack_size + \
icb * nr_units_in_tile * pack_size + \
unit_idx * pack_size, \
d1); \
vst1q_f32(input_transform_buf + \
(2 * alpha + i) * ICB * nr_units_in_tile * pack_size + \
icb * nr_units_in_tile * pack_size + \
unit_idx * pack_size, \
d2); \
vst1q_f32(input_transform_buf + \
(3 * alpha + i) * ICB * nr_units_in_tile * pack_size + \
icb * nr_units_in_tile * pack_size + \
unit_idx * pack_size, \
d3); \
vst1q_f32(input_transform_buf + \
(4 * alpha + i) * ICB * nr_units_in_tile * pack_size + \
icb * nr_units_in_tile * pack_size + \
unit_idx * pack_size, \
d4); \
vst1q_f32(input_transform_buf + \
(5 * alpha + i) * ICB * nr_units_in_tile * pack_size + \
icb * nr_units_in_tile * pack_size + \
unit_idx * pack_size, \
d5); \
vst1q_f32(input_transform_buf + \
(6 * alpha + i) * ICB * nr_units_in_tile * pack_size + \
icb * nr_units_in_tile * pack_size + \
unit_idx * pack_size, \
d6); \
vst1q_f32(input_transform_buf + \
(7 * alpha + i) * ICB * nr_units_in_tile * pack_size + \
icb * nr_units_in_tile * pack_size + \
unit_idx * pack_size, \
d7); \
vst1q_f32(input_transform_buf + \
(8 * alpha + i) * ICB * nr_units_in_tile * pack_size + \
icb * nr_units_in_tile * pack_size + \
unit_idx * pack_size, \
d8);

UNROLL_CALL_RAW(9, cb);
#undef cb
}
};

template <BiasMode bmode, typename Op>
struct OutputTransformF73_NCHW44 {
static void transform(const float* output_transform_buf, const float* bias,
float* output, float* transform_mid_buf,
size_t oh_start, size_t ow_start, size_t OH,
size_t OW, size_t oc_start, size_t oc_end,
size_t oc_index, size_t unit_idx,
size_t nr_units_in_tile, const DType& src_dtype,
const DType& dst_dtype) {
MEGDNN_MARK_USED_VAR(transform_mid_buf);
Op op(src_dtype, dst_dtype);
//! AT * m * A

size_t oc = oc_start + oc_index;
size_t OCB = (oc_end - oc_start) / pack_size;
size_t ocb = oc_index / pack_size;

#define cb(m, n) \
auto v##m##n = Vector<float, 4>::load( \
output_transform_buf + \
(m * alpha + n) * OCB * nr_units_in_tile * pack_size + \
ocb * nr_units_in_tile * pack_size + unit_idx * pack_size);

UNROLL_CALL_NOWRAPPER_D2(9, 9, cb);
#undef cb

/**
* A
*
* 1 0 0 0 0 0 0
* 1 1 1 1 1 1 1
* 1 -1 1 -1 1 -1 1
* 1 2 4 8 16 32 64
* 1 -2 4 -8 16 -32 64
* 1 0.5 0.25 0.125 0.0625 0.03125 0.015625
* 1 -0.5 0.25 -0.125 0.0625 -0.03125 0.015625
* 1 1.5 2.25 3.375 5.0625 7.59375 11.390625
* 0 0 0 0 0 0 1
*/

Vector<float, 4> v1addv2, v1subv2, v3addv4, v3subv4, v5addv6, v5subv6;
#define cb(m) \
v1addv2 = v1##m + v2##m; \
v1subv2 = v1##m - v2##m; \
v3addv4 = v3##m + v4##m; \
v3subv4 = v3##m - v4##m; \
v5addv6 = v5##m + v6##m; \
v5subv6 = v5##m - v6##m; \
auto t0##m = v0##m + v1addv2 + v3addv4 + v5addv6 + v7##m; \
auto t1##m = v1subv2 + v3subv4 * 2.f + v5subv6 * 0.5f + v7##m * 1.5f; \
auto t2##m = v1addv2 + v3addv4 * 4.f + v5addv6 * 0.25f + v7##m * 2.25f; \
auto t3##m = v1subv2 + v3subv4 * 8.f + v5subv6 * 0.125f + v7##m * 3.375f; \
auto t4##m = \
v1addv2 + v3addv4 * 16.f + v5addv6 * 0.0625f + v7##m * 5.0625f; \
auto t5##m = \
v1subv2 + v3subv4 * 32.f + v5subv6 * 0.03125f + v7##m * 7.59375f; \
auto t6##m = v1addv2 + v3addv4 * 64.f + v5addv6 * 0.015625f + \
v7##m * 11.390625f + v8##m;

UNROLL_CALL_NOWRAPPER(9, cb);
#undef cb

#define cb(m) \
v1addv2 = t##m##1 + t##m##2; \
v1subv2 = t##m##1 - t##m##2; \
v3addv4 = t##m##3 + t##m##4; \
v3subv4 = t##m##3 - t##m##4; \
v5addv6 = t##m##5 + t##m##6; \
v5subv6 = t##m##5 - t##m##6; \
v##m##0 = t##m##0 + v1addv2 + v3addv4 + v5addv6 + t##m##7; \
v##m##1 = v1subv2 + v3subv4 * 2.f + v5subv6 * 0.5f + t##m##7 * 1.5f; \
v##m##2 = v1addv2 + v3addv4 * 4.f + v5addv6 * 0.25f + t##m##7 * 2.25f; \
v##m##3 = v1subv2 + v3subv4 * 8.f + v5subv6 * 0.125f + t##m##7 * 3.375; \
v##m##4 = \
v1addv2 + v3addv4 * 16.f + v5addv6 * 0.0625f + t##m##7 * 5.0625f; \
v##m##5 = v1subv2 + v3subv4 * 32.f + v5subv6 * 0.03125f + \
t##m##7 * 7.59375f; \
v##m##6 = v1addv2 + v3addv4 * 64.f + v5addv6 * 0.015625f + \
t##m##7 * 11.390625f + t##m##8;

UNROLL_CALL_NOWRAPPER(7, cb);
#undef cb

Vector<float, 4> vbias;
if (bmode == BiasMode::BROADCAST_CHANNEL_BIAS) {
vbias = Vector<float, 4>::load(bias + oc);

#define cb(m, n) v##m##n += vbias;
UNROLL_CALL_RAW_D2(7, 7, cb);
#undef cb
}
if (bmode != BiasMode::BIAS) {
#define cb(m, n) v##m##n = op(CONCAT(v##m, n).value);
UNROLL_CALL_RAW_D2(7, 7, cb);
#undef cb
}
#define out_save(oho, owo) \
do { \
size_t oh = oh_start + oho; \
size_t ow = ow_start + owo; \
if (oh < OH && ow < OW) { \
if (bmode == BiasMode::BIAS) { \
v##oho##owo += Vector<float, 4>::load(bias + oc * OH * OW + \
oh * OW * pack_size + \
ow * pack_size); \
v##oho##owo = op(v##oho##owo.value); \
} \
v##oho##owo.save(output + oc * OH * OW + oh * OW * pack_size + \
ow * pack_size); \
} \
} while (0);
UNROLL_CALL_RAW_D2(7, 7, out_save);
}
#undef out_save
};
} // namespace

namespace megdnn {
namespace arm_common {
namespace winograd {

MEGDNN_REG_WINOGRAD_STRATEGY_IMPL(winograd_F73_mk4_f_nchw44)

void winograd_F73_mk4_f_nchw44::filter(const float* filter,
float* filter_transform_buf,
float* transform_mid_buf, size_t OC,
size_t IC, size_t oc_start,
size_t oc_end) {
constexpr size_t pack_size = 4;
// Gg * GT
// G
// 0.6666667 0.0000000 0.0000000
// 0.4444444 0.4444444 0.4444444
// 0.0888889 -0.0888889 0.0888889
// 0.0222222 0.0444444 0.0888889
//-0.0031746 0.0063492 -0.0126984
//-0.7111111 -0.3555556 -0.1777778
//-0.3555556 0.1777778 -0.0888889
//-0.1523810 -0.2285714 -0.3428572
// 0.0000000 0.0000000 1.0000000
MEGDNN_MARK_USED_VAR(transform_mid_buf);
megdnn_assert((oc_end - oc_start) % pack_size == 0 &&
oc_start % pack_size == 0 &&
oc_end % pack_size == 0 && IC % pack_size == 0 &&
OC % pack_size == 0,
"NCHW44 Winograd filter transform requires both OC and IC "
"are times of 4");

size_t ICB = IC / pack_size;

for (size_t ocb = oc_start / pack_size; ocb < oc_end / pack_size; ocb++) {
for (size_t icb = 0; icb < ICB; icb++) {
for (size_t ic_inner = 0; ic_inner < pack_size; ic_inner++) {
const float* fptr = filter +
(ocb * ICB + icb) * KERNEL_SIZE *
KERNEL_SIZE * pack_size *
pack_size +
ic_inner * pack_size;

#define cb(m, n) \
Vector<float, 4> g##m##n = Vector<float, 4>::load( \
fptr + (m * KERNEL_SIZE + n) * pack_size * pack_size);
UNROLL_CALL_NOWRAPPER_D2(3, 3, cb)
#undef cb

#define FILTER_TRANSFORM(n, wd, g) \
auto wd##n##0 = g##0##n * 0.6666667f; \
auto wd##n##1 = (g##0##n + g##1##n + g##2##n) * 0.4444444f; \
auto wd##n##2 = (g##0##n - g##1##n + g##2##n) * 0.0888889f; \
auto wd##n##3 = g##0##n * 0.0222222f + g##1##n * 0.0444444f + \
g##2##n * 0.0888889f; \
auto wd##n##4 = g##0##n * -0.0031746f + g##1##n * 0.0063492f + \
g##2##n * -0.0126984f; \
auto wd##n##5 = g##0##n * -0.7111111f + g##1##n * -0.3555556f + \
g##2##n * -0.1777778f; \
auto wd##n##6 = g##0##n * -0.3555556f + g##1##n * 0.1777778f + \
g##2##n * -0.0888889f; \
auto wd##n##7 = g##0##n * -0.1523810f + g##1##n * -0.2285714f + \
g##2##n * -0.3428572f; \
auto wd##n##8 = g##2##n;
UNROLL_CALL_RAW(3, FILTER_TRANSFORM, wd, g);
UNROLL_CALL_RAW(9, FILTER_TRANSFORM, ret, wd);
#undef FILTER_TRANSFORM
#define cb_save(m, n) \
ret##m##n.save(filter_transform_buf + (m * alpha + n) * OC * IC + \
ocb * IC * pack_size + icb * pack_size * pack_size + \
ic_inner * pack_size);
UNROLL_CALL_NOWRAPPER_D2(9, 9, cb_save)
#undef cb_save
}
}
}
}

void winograd_F73_mk4_f_nchw44::input(const float* input,
float* input_transform_buf,
float* transform_mid_buf, size_t IH,
size_t IW, size_t IC, size_t PH,
size_t PW, size_t unit_start_idx,
size_t nr_units_in_tile) {
constexpr size_t pack_size = 4;
megdnn_assert(IC % pack_size == 0);
constexpr int alpha = 3 + 7 - 1;

// OW = IW + 2 * PW - KERNEL_SIZE + 1
auto units_w =
div_ceil<size_t>(IW + 2 * PW - KERNEL_SIZE + 1, OUTPUT_BLOCK_SIZE);
float* patch = transform_mid_buf;
float* patchT = transform_mid_buf + pack_size * alpha * alpha;

for (size_t ic = 0; ic < IC; ic += pack_size) {
rep(unit_idx, nr_units_in_tile) {
size_t index = unit_start_idx + unit_idx;
size_t nh = index / units_w;
size_t nw = index % units_w;
int ih_start = nh * OUTPUT_BLOCK_SIZE - PH;
int iw_start = nw * OUTPUT_BLOCK_SIZE - PW;
if (ih_start >= 0 && ih_start + alpha <= static_cast<int>(IH) &&
iw_start >= 0 && iw_start + alpha <= static_cast<int>(IW)) {
InputTransformF73_NCHW44::prepare<true>(input, patch, patchT,
ih_start, iw_start, IH,
IW, ic, IC);
InputTransformF73_NCHW44::transform(patchT, input_transform_buf,
unit_idx, nr_units_in_tile,
ic, IC);

} else {
InputTransformF73_NCHW44::prepare<false>(input, patch, patchT,
ih_start, iw_start, IH,
IW, ic, IC);
InputTransformF73_NCHW44::transform(patchT, input_transform_buf,
unit_idx, nr_units_in_tile,
ic, IC);
}
}
}
}

void winograd_F73_mk4_f_nchw44::output(const float* output_transform_buf,
const float* bias, float* output,
float* transform_mid_buf, BiasMode bmode,
NonlineMode nonline_mode, size_t OH,
size_t OW, size_t oc_start,
size_t oc_end, size_t unit_start_idx,
size_t nr_units_in_tile) {
#define cb(_bmode, _nonline_op, ...) \
for (size_t oc = oc_start; oc < oc_end; oc += pack_size) { \
size_t oc_index = oc - oc_start; \
rep(unit_idx, nr_units_in_tile) { \
size_t index = unit_start_idx + unit_idx; \
auto nh = index / units_w; \
auto nw = index % units_w; \
size_t oh_start = nh * OUTPUT_BLOCK_SIZE; \
size_t ow_start = nw * OUTPUT_BLOCK_SIZE; \
OutputTransformF73_NCHW44<_bmode MEGDNN_COMMA _nonline_op>:: \
transform(output_transform_buf, bias, output, \
transform_mid_buf, oh_start, ow_start, OH, OW, \
oc_start, oc_end, oc_index, unit_idx, \
nr_units_in_tile, src_dtype, dst_dtype); \
} \
}

auto units_w = div_ceil<size_t>(OW, OUTPUT_BLOCK_SIZE);
constexpr size_t pack_size = 4;

size_t OC = oc_end - oc_start;
megdnn_assert(OC % pack_size == 0 && oc_start % pack_size == 0 &&
oc_end % pack_size == 0,
"NCHW44 Winograd filter transform requires OC is times of 4");

DISPATCH_CONV_WINOGRAD_BIAS(megdnn_arm_common_winograd_fp32_F73_mk4, cb,
float, float, bmode, nonline_mode);
#undef cb
}

} // namespace winograd
} // namespace arm_common
} // namespace megdnn

// vim: syntax=cpp.doxygen

+ 7
- 0
dnn/src/arm_common/conv_bias/opr_impl.cpp View File

@@ -151,6 +151,13 @@ public:
static_cast<fallback::MatrixMulImpl::AlgoBase*>(algo),
tile_size));
winograd_algos.emplace_back(refhold.back().get());
//! uncomment this when low precision mode is done
#if 0
refhold.emplace_back(new AlgoFP32WinogradF73_4x4_NCHW44(
static_cast<fallback::MatrixMulImpl::AlgoBase*>(algo),
tile_size));
winograd_algos.emplace_back(refhold.back().get());
#endif
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
refhold.emplace_back(new AlgoFP16WinogradF23(
static_cast<fallback::MatrixMulImpl::AlgoBase*>(algo),


+ 1
- 0
dnn/src/arm_common/conv_bias/opr_impl.h View File

@@ -50,6 +50,7 @@ private:

class AlgoFP32WinogradF23_4x4_NCHW44;
class AlgoFP32WinogradF63_4x4_NCHW44;
class AlgoFP32WinogradF73_4x4_NCHW44;

class AlgoS8ChanWiseStride1NCHW44;
class AlgoS8ChanWiseStride2NCHW44;


+ 4
- 0
dnn/src/arm_common/winograd_filter_preprocess/opr_impl.cpp View File

@@ -94,6 +94,10 @@ void WinogradFilterPreprocessImpl::exec(_megdnn_tensor_in src,
DISPATCH(winograd_F63_mk4_f_nchw44,
param::Winograd::Format::MK4, 0, 6);
}
} else if (m == 7) {
megdnn_assert(pack_c_size == 4, "WINOGRAD F(7,3) Only Supports NCHW44");
DISPATCH(winograd_F73_mk4_f_nchw44,
param::Winograd::Format::MK4, 0, 7);
}
} else if (FW == 4) {
if (m == 5) {


+ 37
- 0
dnn/src/common/unroll_macro.h View File

@@ -122,6 +122,23 @@
cb(5, 0, ##a) cb(5, 1, ##a) cb(5, 2, ##a) cb(5, 3, ##a) \
cb(5, 4, ##a) cb(5, 5, ##a) \

#define UNROLL_RAW_7x7(cb, v0, a...) \
cb(0, 0, ##a) cb(0, 1, ##a) cb(0, 2, ##a) cb(0, 3, ##a) \
cb(0, 4, ##a) cb(0, 5, ##a) cb(0, 6, ##a) \
cb(1, 0, ##a) cb(1, 1, ##a) cb(1, 2, ##a) cb(1, 3, ##a) \
cb(1, 4, ##a) cb(1, 5, ##a) cb(1, 6, ##a) \
cb(2, 0, ##a) cb(2, 1, ##a) cb(2, 2, ##a) cb(2, 3, ##a) \
cb(2, 4, ##a) cb(2, 5, ##a) cb(2, 6, ##a) \
cb(3, 0, ##a) cb(3, 1, ##a) cb(3, 2, ##a) cb(3, 3, ##a) \
cb(3, 4, ##a) cb(3, 5, ##a) cb(3, 6, ##a) \
cb(4, 0, ##a) cb(4, 1, ##a) cb(4, 2, ##a) cb(4, 3, ##a) \
cb(4, 4, ##a) cb(4, 5, ##a) cb(4, 6, ##a) \
cb(5, 0, ##a) cb(5, 1, ##a) cb(5, 2, ##a) cb(5, 3, ##a) \
cb(5, 4, ##a) cb(5, 5, ##a) cb(5, 6, ##a) \
cb(6, 0, ##a) cb(6, 1, ##a) cb(6, 2, ##a) cb(6, 3, ##a) \
cb(6, 4, ##a) cb(6, 5, ##a) cb(6, 6, ##a) \


#define UNROLL_RAW_8x8(cb, v0, a...) \
cb(0, 0, ##a) cb(0, 1, ##a) cb(0, 2, ##a) cb(0, 3, ##a) \
cb(0, 4, ##a) cb(0, 5, ##a) cb(0, 6, ##a) cb(0, 7, ##a) \
@@ -140,6 +157,26 @@
cb(7, 0, ##a) cb(7, 1, ##a) cb(7, 2, ##a) cb(7, 3, ##a) \
cb(7, 4, ##a) cb(7, 5, ##a) cb(7, 6, ##a) cb(7, 7, ##a)

#define UNROLL_RAW_9x9(cb, v0, a...) \
cb(0, 0, ##a) cb(0, 1, ##a) cb(0, 2, ##a) cb(0, 3, ##a) \
cb(0, 4, ##a) cb(0, 5, ##a) cb(0, 6, ##a) cb(0, 7, ##a) cb(0, 8, ##a) \
cb(1, 0, ##a) cb(1, 1, ##a) cb(1, 2, ##a) cb(1, 3, ##a) \
cb(1, 4, ##a) cb(1, 5, ##a) cb(1, 6, ##a) cb(1, 7, ##a) cb(1, 8, ##a) \
cb(2, 0, ##a) cb(2, 1, ##a) cb(2, 2, ##a) cb(2, 3, ##a) \
cb(2, 4, ##a) cb(2, 5, ##a) cb(2, 6, ##a) cb(2, 7, ##a) cb(2, 8, ##a) \
cb(3, 0, ##a) cb(3, 1, ##a) cb(3, 2, ##a) cb(3, 3, ##a) \
cb(3, 4, ##a) cb(3, 5, ##a) cb(3, 6, ##a) cb(3, 7, ##a) cb(3, 8, ##a) \
cb(4, 0, ##a) cb(4, 1, ##a) cb(4, 2, ##a) cb(4, 3, ##a) \
cb(4, 4, ##a) cb(4, 5, ##a) cb(4, 6, ##a) cb(4, 7, ##a) cb(4, 8, ##a) \
cb(5, 0, ##a) cb(5, 1, ##a) cb(5, 2, ##a) cb(5, 3, ##a) \
cb(5, 4, ##a) cb(5, 5, ##a) cb(5, 6, ##a) cb(5, 7, ##a) cb(5, 8, ##a) \
cb(6, 0, ##a) cb(6, 1, ##a) cb(6, 2, ##a) cb(6, 3, ##a) \
cb(6, 4, ##a) cb(6, 5, ##a) cb(6, 6, ##a) cb(6, 7, ##a) cb(6, 8, ##a) \
cb(7, 0, ##a) cb(7, 1, ##a) cb(7, 2, ##a) cb(7, 3, ##a) \
cb(7, 4, ##a) cb(7, 5, ##a) cb(7, 6, ##a) cb(7, 7, ##a) cb(7, 8, ##a) \
cb(8, 0, ##a) cb(8, 1, ##a) cb(8, 2, ##a) cb(8, 3, ##a) \
cb(8, 4, ##a) cb(8, 5, ##a) cb(8, 6, ##a) cb(8, 7, ##a) cb(8, 8, ##a)

#define UNROLL_CALL0_D2(step, step2, cb, v...) \
UNROLL_RAW_##step##x##step2(cb, 0, ##v)
#define UNROLL_CALL1_D2(step, step2, cb, v...) \


+ 6
- 1
dnn/src/naive/winograd_filter_preprocess/opr_impl.cpp View File

@@ -25,7 +25,7 @@ void WinogradFilterPreprocessImpl::exec(_megdnn_tensor_in src,
_megdnn_tensor_out dst,
_megdnn_workspace workspace) {
check_exec(src.layout, dst.layout, workspace.size);
//! nchw88 group conv
size_t flt_start = 0;
size_t pack_c_size = 1;
@@ -212,6 +212,10 @@ void WinogradFilterPreprocessImpl::exec(_megdnn_tensor_in src,
std::vector<float> interp_points = {0, 1, -1, 2,
-2, 0.5, -0.5};
DISPATCH_DTYPE(7);
} else if (m == 7) {
std::vector<float> interp_points = {0, 1, -1, 2,
-2, 0.5, -0.5, 1.5};
DISPATCH_DTYPE(8);
}
}
#undef cb
@@ -221,6 +225,7 @@ void WinogradFilterPreprocessImpl::exec(_megdnn_tensor_in src,
#undef DISPATCH_DTYPE
}
}

megdnn_assert(execed,
"Unsupport winograd filter preprocess. m: %zu src: %s", m,
src.layout.to_string().c_str());


+ 22
- 7
dnn/test/arm_common/conv_bias.cpp View File

@@ -777,7 +777,8 @@ TEST_F(ARM_COMMON, BENCHMARK_CONVBIAS_WINOGRAD_F16_F23_8x8) {
}
#endif

void benchmark_winograd_nchw_vs_nchw44(const char* algo_name, Handle* handle) {
void benchmark_winograd_nchw_vs_nchw44(const char* algo_name0,
const char* algo_name1, Handle* handle) {
using namespace conv_bias;
using NLMode = param::ConvBias::NonlineMode;
std::vector<conv_bias::TestArg> args_nchw44;
@@ -846,9 +847,9 @@ void benchmark_winograd_nchw_vs_nchw44(const char* algo_name, Handle* handle) {
benchmark_winograd_nchw44.set_display(false);
benchmark_winograd_nchw44.set_times(RUN);

std::string winograd_nchw_algo_name = ssprintf("WINOGRAD:%s", algo_name);
std::string winograd_nchw_algo_name = ssprintf("WINOGRAD:%s", algo_name0);
std::string winograd_nchw44_algo_name =
ssprintf("WINOGRAD_NCHW44:%s", algo_name);
ssprintf("WINOGRAD_NCHW44:%s", algo_name1);

for (size_t i = 0; i < args_nchw.size(); ++i) {
auto arg_nchw = args_nchw[i];
@@ -892,17 +893,31 @@ void benchmark_winograd_nchw_vs_nchw44(const char* algo_name, Handle* handle) {

TEST_F(ARM_COMMON, BENCHMARK_CONVBIAS_WINOGRAD_F23_MK4_NCHW_VS_NCHW44) {
#if MEGDNN_AARCH64
benchmark_winograd_nchw_vs_nchw44("AARCH64_F32_MK4_4x16:4:2", handle());
benchmark_winograd_nchw_vs_nchw44("AARCH64_F32_MK4_4x16:4:2",
"AARCH64_F32_MK4_4x16:4:2", handle());
#else
benchmark_winograd_nchw_vs_nchw44("ARMV7_F32_MK4_4x8:4:2", handle());
benchmark_winograd_nchw_vs_nchw44("ARMV7_F32_MK4_4x8:4:2",
"ARMV7_F32_MK4_4x8:4:2", handle());
#endif
}

TEST_F(ARM_COMMON, BENCHMARK_CONVBIAS_WINOGRAD_F63_MK4_NCHW_VS_NCHW44) {
#if MEGDNN_AARCH64
benchmark_winograd_nchw_vs_nchw44("AARCH64_F32_MK4_4x16:4:6", handle());
benchmark_winograd_nchw_vs_nchw44("AARCH64_F32_MK4_4x16:4:6",
"AARCH64_F32_MK4_4x16:4:6", handle());
#else
benchmark_winograd_nchw_vs_nchw44("ARMV7_F32_MK4_4x8:4:6", handle());
benchmark_winograd_nchw_vs_nchw44("ARMV7_F32_MK4_4x8:4:6",
"ARMV7_F32_MK4_4x8:4:6", handle());
#endif
}

TEST_F(ARM_COMMON, BENCHMARK_CONVBIAS_WINOGRAD_F73_MK4_NCHW_VS_NCHW44) {
#if MEGDNN_AARCH64
benchmark_winograd_nchw_vs_nchw44("AARCH64_F32_MK4_4x16:4:6",
"ARM_COMMON_F32_GEMV_MK4:4:7", handle());
#else
benchmark_winograd_nchw_vs_nchw44("ARMV7_F32_MK4_4x8:4:6",
"ARMV7_F32_MK4_4x8:4:7", handle());
#endif
}



+ 26
- 0
dnn/test/arm_common/conv_bias_multi_thread.cpp View File

@@ -750,6 +750,26 @@ TEST_F(ARM_COMMON_MULTI_THREADS,
param::ConvBias::Format::NCHW44);
}

//! uncomment it when low precision mode is ok
#if 0
TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_F73_4_NCHW44) {
using namespace conv_bias;
std::vector<TestArg> args = get_nchw44_conv_bias_args({3}, 1);
Checker<ConvBiasForward> checker(handle());
check_winograd("4:7:16", checker, args, param::MatrixMul::Format::MK4,
param::ConvBias::Format::NCHW44);
}

TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_F73_4_NCHW44_WEIGHT_PREPROCESS) {
using namespace conv_bias;
std::vector<TestArg> args = get_nchw44_conv_bias_args({3}, 1);
Checker<ConvBiasForward, OprWeightPreprocessProxy<ConvBiasForward>> checker(
handle());
check_winograd("4:7:16", checker, args, param::MatrixMul::Format::MK4,
param::ConvBias::Format::NCHW44);
}
#endif

TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_F54) {
using namespace conv_bias;
std::vector<TestArg> args = get_winograd_args(4);
@@ -923,6 +943,12 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_PREPROCESS_NCHW44) {
}
}
};

//! uncomment this when low precision mode is ok
// run(handle(), nchw44_args, {2, 6, 7}, dtype::Float32(), dtype::Float32(),
// dtype::Float32(), dtype::Float32(), 1e-2f);

//! remove this when low precision mode is ok
run(handle(), nchw44_args, {2, 6}, dtype::Float32(), dtype::Float32(),
dtype::Float32(), dtype::Float32(), 1e-3f);
}


+ 1
- 1
dnn/test/arm_common/matrix_mul.cpp View File

@@ -399,7 +399,7 @@ TEST_F(ARM_COMMON, BENCHMARK_SGEMV_MK4) {
.set_param(param);

auto run = [&](size_t M, size_t K) {
printf("SGEMV_MK4: (%zu, %zu)\n", M, K);
printf("SGEMV_MK4: (%zu, %zu, 1)\n", M, K);
TensorShape A, B;
A = TensorShape{M / 4, K / 4, 4, 4};
B = TensorShape{K / 4, 1, 4};


Loading…
Cancel
Save