Browse Source

fix(dnn/arm): fix nchw44 fp32 direct algo oh block and unused stride2 algo

GitOrigin-RevId: 8012678fae
tags/v0.5.0
Megvii Engine Team Xu Xinran 5 years ago
parent
commit
02cbb13bbc
2 changed files with 16 additions and 9 deletions
  1. +4
    -3
      dnn/src/arm_common/conv_bias/fp32/f32_direct_nchw44_algo.cpp
  2. +12
    -6
      dnn/test/arm_common/conv_bias.cpp

+ 4
- 3
dnn/src/arm_common/conv_bias/fp32/f32_direct_nchw44_algo.cpp View File

@@ -107,7 +107,7 @@ static void do_conv_kern(WorkspaceBundle bundle,
constexpr int oc_idx = 0;
int oc_block = oc;
int oh_block = block_helper(kern_param.nr_threads, oh2,
ic * iw * sizeof(float) * 2);
ic * iw * sizeof(float) * stride_h);
const int oh_idx = ncb_index.ndrange_id[2];
const int oh_block_real = std::min(oh - oh_idx * oh_block, oh_block);
const int ih_real = oh_block_real * stride_h + fh - stride_h;
@@ -297,8 +297,9 @@ ConvBiasImpl::AlgoF32DirectNCHW44::dispatch_kerns(
int oh = param.osz[0];
int ic = param.filter_meta.icpg;
int iw = param.isz[1];
int oh_block =
block_helper(param.nr_threads, oh, ic * iw * sizeof(float) * 2);
int stride_h = param.filter_meta.stride[0];
int oh_block = block_helper(param.nr_threads, oh,
ic * iw * sizeof(float) * stride_h);
CpuNDRange ncb_range = {static_cast<size_t>(batch),
static_cast<size_t>(group),
static_cast<size_t>(div_ceil(oh, oh_block))};


+ 12
- 6
dnn/test/arm_common/conv_bias.cpp View File

@@ -118,24 +118,30 @@ static void benchmark_convbias(Handle* handle, bool is_fp32 = false) {
conv_bias::ConvBiasAlgoChecker<ConvBias>(
"IM2COLMATMUL:AARCH64_F32K8X12X1:192"));

Benchmarker<ConvBias> benchmarker_int_nchw44(handle);
Benchmarker<ConvBias> benchmarker_nchw44(handle);
if (is_fp32) {
benchmarker_int_nchw44.set_times(RUNS)
benchmarker_nchw44.set_times(RUNS)
.set_dtype(0, dtype::Float32())
.set_dtype(1, dtype::Float32())
.set_dtype(2, dtype::Float32())
.set_dtype(4, dtype::Float32())
.set_display(false);
} else {
benchmarker_int_nchw44.set_times(RUNS)
benchmarker_nchw44.set_times(RUNS)
.set_dtype(0, dtype::QuantizedS8(2.5))
.set_dtype(1, dtype::QuantizedS8(2.5))
.set_dtype(2, dtype::QuantizedS32(6.25))
.set_dtype(4, dtype::QuantizedS8(60.25))
.set_display(false);
}
benchmarker_int_nchw44.set_before_exec_callback(
conv_bias::ConvBiasAlgoChecker<ConvBias>(".+"));
auto nchw44_algo_regx = ".*(DIRECT|NCHW_NCHW44).*";
#if __ARM_FEATURE_DOTPROD
if (!is_fp32) {
nchw44_algo_regx = ".*DOT.*";
}
#endif
benchmarker_nchw44.set_before_exec_callback(
conv_bias::ConvBiasAlgoChecker<ConvBias>(nchw44_algo_regx));

auto run = [&](size_t N, size_t IC, size_t OC, size_t H, size_t W,
size_t FS, size_t stride, bool input_nchw = false) {
@@ -171,7 +177,7 @@ static void benchmark_convbias(Handle* handle, bool is_fp32 = false) {

bias = {1, OC / 4, 1, 1, 4};
dst = {N, OC / 4, OH, OW, 4};
auto int_nchw44_used = benchmarker_int_nchw44.set_param(param).exec(
auto int_nchw44_used = benchmarker_nchw44.set_param(param).exec(
{src, filter, bias, {}, dst}) /
RUNS;
float computations = IC * (FS * FS) * dst.total_nr_elems() * 2 * 1e-6;


Loading…
Cancel
Save