Browse Source

fix(dnn/arm): fix stride 1 support for int8 nchw_nchw44

GitOrigin-RevId: 9d718eb7a4
tags/v0.5.0
Megvii Engine Team Xu Xinran 5 years ago
parent
commit
7b0dbe6af8
11 changed files with 1682 additions and 1146 deletions
  1. +3
    -2
      dnn/src/arm_common/conv_bias/fp32/f32_direct_nchw44_algo.cpp
  2. +2
    -2
      dnn/src/arm_common/conv_bias/int8/algos.h
  3. +373
    -0
      dnn/src/arm_common/conv_bias/int8/direct_nchw_nchw44_algo.cpp
  4. +1287
    -0
      dnn/src/arm_common/conv_bias/int8/direct_nchw_nchw44_kern.h
  5. +0
    -305
      dnn/src/arm_common/conv_bias/int8/direct_stride2_nchw_nchw44_algo.cpp
  6. +0
    -789
      dnn/src/arm_common/conv_bias/int8/direct_stride2_nchw_nchw44_kern.cpp
  7. +0
    -44
      dnn/src/arm_common/conv_bias/int8/direct_stride2_nchw_nchw44_kern.h
  8. +2
    -2
      dnn/src/arm_common/conv_bias/opr_impl.cpp
  9. +1
    -1
      dnn/src/arm_common/conv_bias/opr_impl.h
  10. +8
    -0
      dnn/test/arm_common/conv_bias.cpp
  11. +6
    -1
      dnn/test/arm_common/conv_bias_multi_thread.cpp

+ 3
- 2
dnn/src/arm_common/conv_bias/fp32/f32_direct_nchw44_algo.cpp View File

@@ -37,7 +37,7 @@ static inline size_t get_perthread_cache_bytes(const int ic, const int ih2,
static void get_rectified_size(
const megdnn::fallback::ConvBiasImpl::NCBKernSizeParam& param, int& ih2,
int& iw2, int& oh2, int& ow2) {
constexpr int cacheline = 64 / sizeof(float);
constexpr int nr_elements_in_cacheline = 64 / sizeof(float);
int ic = param.filter_meta.icpg;
int iw = param.isz[1];
int oh = param.osz[0];
@@ -52,7 +52,8 @@ static void get_rectified_size(
int block_oh = l2_block_helper(param.nr_threads, oh,
ic * iw * sizeof(float) * stride_h);
ih2 = block_oh * stride_h + filter_h - stride_h;
iw2 = round_up(iw + 2 * static_cast<int>(fm.padding[1]), cacheline);
iw2 = round_up(iw + 2 * static_cast<int>(fm.padding[1]),
nr_elements_in_cacheline);
}

static WorkspaceBundle get_bundle(const ConvBiasImpl::NCBKernSizeParam& param) {


+ 2
- 2
dnn/src/arm_common/conv_bias/int8/algos.h View File

@@ -90,9 +90,9 @@ public:
const NCBKernSizeParam& param) const override;
};

class ConvBiasImpl::AlgoS8DirectStride2NCHWNCHW44 final : public AlgoBase {
class ConvBiasImpl::AlgoS8DirectNCHWNCHW44 final : public AlgoBase {
public:
AlgoS8DirectStride2NCHWNCHW44() {}
AlgoS8DirectNCHWNCHW44() {}
bool is_reproducible() const override { return true; }
const char* name() const override { return "S8_CONV_NCHW_NCHW44"; }
bool usable(fallback::ConvBiasImpl*, const NCBKernSizeParam& param,


+ 373
- 0
dnn/src/arm_common/conv_bias/int8/direct_nchw_nchw44_algo.cpp View File

@@ -0,0 +1,373 @@
/**
* \file dnn/src/arm_common/conv_bias/int8/direct_nchw_nchw44_algo.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/

#include "megdnn/oprs.h"
#include "src/arm_common/conv_bias/int8/algos.h"
#include "src/arm_common/conv_bias/int8/direct_nchw_nchw44_kern.h"
#include "src/arm_common/conv_bias/int8/strategy.h"
#include "src/arm_common/elemwise_op.h"
#include "src/common/opr_delegate.h"

#include "midout.h"

using namespace megdnn;
using namespace arm_common;
using conv_fun = std::function<void(
WorkspaceBundle bundle, const ConvBiasImpl::NCBKernParam& kern_param,
const ConvBiasImpl::NCBKernIndex& ncb_index,
const CpuNDRange& workspace_ids, const CpuNDRange& ncb_range)>;
MIDOUT_DECL(megdnn_arm_common_conv_bias_int8_nchw_nchw44)

static void get_rectified_size(
const megdnn::fallback::ConvBiasImpl::NCBKernSizeParam& param, int& ih2,
int& iw2, int& oh2, int& ow2) {
auto&& fm = param.filter_meta;
int ih = param.isz[0];
int iw = param.isz[1];
int oh = param.osz[0];
int ow = param.osz[1];
int ph = fm.padding[0];
int pw = fm.padding[1];
int stride_h = fm.stride[0];

oh2 = oh;
ow2 = ow;
ih2 = stride_h == 2 ? round_up(ih + 2 * ph, 2) : ih + 2 * ph;
iw2 = iw + 2 * pw;
}
static inline size_t get_temp_bytes(const int iw, const int pw) {
//! border_size is used to avoid read illegal memory
constexpr int cacheline_size = 64;
constexpr int border_size = 1 * cacheline_size;

return round_up(iw + pw * 2, cacheline_size) + border_size;
}
static WorkspaceBundle get_bundle(const ConvBiasImpl::NCBKernSizeParam& param) {
auto&& fm = param.filter_meta;
int group = fm.group;
int batch = param.n;
int ic = fm.icpg;
int oc = fm.ocpg;
int fh = fm.spatial[0];
int fw = fm.spatial[1];
int stride_h = fm.stride[0];
int iw = param.isz[1];
int pw = fm.padding[1];
int ih2, iw2, oh2, ow2;
const size_t src_expand = stride_h == 2 ? 4 : 16;
get_rectified_size(param, ih2, iw2, oh2, ow2);
megdnn_assert(group == 1, "only support group == 1 now");
size_t src_size =
batch * group * ic * ih2 * iw2 * sizeof(int8_t) * src_expand;
size_t weight_size = group * oc * ic * fh * fw * sizeof(int8_t);
size_t tmp_size = 0;
if (stride_h == 1) {
weight_size = group * oc * ic * fh * round_up(fw, 4) * sizeof(int8_t);
tmp_size = get_temp_bytes(iw, pw);
}
return {nullptr, {src_size, weight_size, tmp_size * param.nr_threads}};
};

static void copy_padding_kern(WorkspaceBundle bundle,
const ConvBiasImpl::NCBKernParam& kern_param,
const ConvBiasImpl::NCBKernIndex& ncb_index,
const CpuNDRange& workspace_ids) {
int ih = kern_param.isz[0];
int iw = kern_param.isz[1];
int ic = kern_param.filter_meta.icpg;
int ph = kern_param.filter_meta.padding[0];
int pw = kern_param.filter_meta.padding[1];
int group = kern_param.filter_meta.group;
int stride_h = kern_param.filter_meta.stride[0];

int ih2, iw2, oh2, ow2;
get_rectified_size(kern_param, ih2, iw2, oh2, ow2);
int padding_group_size = ih2 * iw2 * ic;
bundle.set(kern_param.workspace_ptr);
//! Used for get the workspace offset
const int src_expand = stride_h == 2 ? 4 : 16;

//! TODO: block dim is better to get from arg
int workspace_ic_block = 1;
int workspace_batch_id = workspace_ids[0];
int workspace_group_id = workspace_ids[1];
int workspace_ic_id = workspace_ids[2];
int workspace_ic = workspace_ic_id * workspace_ic_block;
int batch_id = ncb_index.ndrange_id[0];
int group_id = ncb_index.ndrange_id[1];

const int8_t* sptr = static_cast<const int8_t*>(
kern_param.src<int8_t>(batch_id, group_id, workspace_ic_id, 1, 1));
//! copy to sptr_base to eliminate padding effect
int8_t* sptr_base = static_cast<int8_t*>(bundle.get(0)) +
(workspace_batch_id * group * padding_group_size +
workspace_group_id * padding_group_size +
workspace_ic * ih2 * iw2) *
src_expand;
if (stride_h == 1) {
const size_t tmp_size = get_temp_bytes(iw, pw);
int8_t* tmp_ptr = reinterpret_cast<int8_t*>(bundle.get(2)) +
ncb_index.thread_id * tmp_size;
pack_nchw_src_for_nchw44_conv<1>(sptr, sptr_base, 1, ph, ph, pw, pw, ih,
iw, iw2, pw, tmp_ptr);
} else {
pack_nchw_src_for_nchw44_conv<2>(sptr, sptr_base, 1, ph, ph, pw, pw, ih,
iw, iw2, pw, nullptr);
}
}
static void pack_weight(WorkspaceBundle bundle,
const ConvBiasImpl::NCBKernParam& kern_param,
const ConvBiasImpl::NCBKernIndex& ncb_index) {
bundle.set(kern_param.workspace_ptr);
const int group_id = ncb_index.ndrange_id[0];
int fh = kern_param.filter_meta.spatial[0];
int fw = kern_param.filter_meta.spatial[1];
int oc = kern_param.filter_meta.ocpg;
int ic = kern_param.filter_meta.icpg;
int stride_h = kern_param.filter_meta.stride[0];
int fw2 = stride_h == 2 ? fw : round_up(fw, 4);
int oc_block = oc;
int oc_idx = 0;
const int8_t* fptr =
kern_param.filter<dt_int8>(group_id) + oc_idx * fh * fw * ic;
auto packed_weight = reinterpret_cast<int8_t*>(bundle.get(1)) +
group_id * oc * ic * fh * fw2 + oc_idx * ic * fh * fw2;

if (stride_h == 1) {
pack_nchw44_weight_for_nchw_conv<1>(fptr, packed_weight, ic, fh, fw,
oc_block);
} else {
pack_nchw44_weight_for_nchw_conv<2>(fptr, packed_weight, ic, fh, fw,
oc_block);
}
}
template <size_t filter, BiasMode bias_mode, typename Op, int stride>
static void do_conv_kern(WorkspaceBundle bundle,
const ConvBiasImpl::NCBKernParam& kern_param,
const ConvBiasImpl::NCBKernIndex& ncb_index,
const CpuNDRange& workspace_ids,
const CpuNDRange& ncb_range) {
int oh = kern_param.osz[0];
int ow = kern_param.osz[1];
int fh = kern_param.filter_meta.spatial[0];
int fw = kern_param.filter_meta.spatial[1];
int fw2 = stride == 2 ? fw : round_up(fw, 4);
int ic = kern_param.filter_meta.icpg;
int oc = kern_param.filter_meta.ocpg;
int group = kern_param.filter_meta.group;
int ih2, iw2, oh2, ow2;
get_rectified_size(kern_param, ih2, iw2, oh2, ow2);
bool need_post_process =
kern_param.dst_type.enumv() == DTypeEnum::QuantizedS8;
//! if dst_type is qint32, the op is not used, just fill with (1.0f,4.0f)
Op op = Op(1.0f, 4.0f);
if (need_post_process) {
float scale_bias =
kern_param.bias_type.param<dtype::QuantizedS32>().scale;
float scale_dst = kern_param.dst_type.param<dtype::QuantizedS8>().scale;
op = Op(scale_bias, scale_dst);
}
int padding_group_size = ih2 * iw2 * ic;
bundle.set(kern_param.workspace_ptr);

constexpr int pack_c = 4;
constexpr int src_expand_size = stride == 2 ? 4 : 16;
const int workspace_batch_id = workspace_ids[0];
const int workspace_group_id = workspace_ids[1];
const int batch_id = ncb_index.ndrange_id[0];
const int group_id = ncb_index.ndrange_id[1];
const int oc_id = ncb_index.ndrange_id[2];
const int oc_block_num = ncb_range[2];
int nr_pack_per_step = div_ceil(div_ceil(oc, pack_c), oc_block_num);
int oc_block = nr_pack_per_step * pack_c;
const int oc_idx = oc_id * oc_block;
if (oc_id == (oc_block_num - 1)) {
oc_block = oc - oc_id * nr_pack_per_step * pack_c;
}
megdnn_assert(oc_block % pack_c == 0,
"oc must be devisible by 4, but oc = %d", oc_block);
const int8_t* sptr =
static_cast<int8_t*>(bundle.get(0)) +
workspace_batch_id * group * padding_group_size * src_expand_size +
workspace_group_id * padding_group_size * src_expand_size;

int8_t* dst = reinterpret_cast<int8_t*>(
reinterpret_cast<ptrdiff_t>(
kern_param.dst<void>(batch_id, group_id)) +
oc_idx * oh * ow);

const int32_t* bptr =
kern_param.bias<dt_int32>(batch_id, group_id) + oc_idx;
int8_t* packed_weight = reinterpret_cast<int8_t*>(bundle.get(1)) +
group_id * oc * ic * fh * fw2 +
oc_idx * ic * fh * fw2;
conv_direct_int8_nchw_nchw44<bias_mode, Op, filter, stride>(
sptr, packed_weight, bptr, nullptr, dst, oc_block, ic, ih2, iw2, oh,
ow, op);
}

bool ConvBiasImpl::AlgoS8DirectNCHWNCHW44::usable(
fallback::ConvBiasImpl*, const NCBKernSizeParam& param,
AlgoSelectionStrategy algo_selection_strategy) const {
MEGDNN_MARK_USED_VAR(algo_selection_strategy);
auto&& fm = param.filter_meta;
auto FH = fm.spatial[0];
auto OC = fm.ocpg;
bool avaible = //! src and filter are qint8, dst is qint8
fm.icpg < 4 && // must be nchw input
((param.src_type.enumv() == DTypeEnum::QuantizedS8 &&
param.filter_type.enumv() == DTypeEnum::QuantizedS8 &&
(param.dst_type.enumv() == DTypeEnum::QuantizedS8))) &&
(fm.format == param::Convolution::Format::NCHW44) &&
(OC % 4 == 0 && OC >= 4) && !fm.should_flip && fm.group == 1 &&
fm.spatial_ndim == 2 && fm.dilation[0] == 1 &&
fm.dilation[1] == 1 && fm.stride[0] == fm.stride[1] &&
(fm.stride[0] == 1 || fm.stride[0] == 2) && FH == fm.spatial[1] &&
(FH == 2 || FH == 3 || FH == 5 || FH == 7) && fm.group == 1 &&
param.bias_mode != BiasMode::BIAS;
return avaible;
}

bool ConvBiasImpl::AlgoS8DirectNCHWNCHW44::is_preferred(
megdnn::fallback::ConvBiasImpl* conv_bias_impl_ptr,
const NCBKernSizeParam& param) const {
// TODO: benchmark and fix
MEGDNN_MARK_USED_VAR(conv_bias_impl_ptr);
MEGDNN_MARK_USED_VAR(param);
return false;
}

size_t ConvBiasImpl::AlgoS8DirectNCHWNCHW44::get_workspace(
fallback::ConvBiasImpl*, const NCBKernSizeParam& param) const {
return get_bundle(param).total_size_in_bytes();
}

SmallVector<ConvBiasImpl::NCBKern>
ConvBiasImpl::AlgoS8DirectNCHWNCHW44::dispatch_kerns(
fallback::ConvBiasImpl*, const NCBKernSizeParam& param) const {
auto fm = param.filter_meta;
size_t N = param.n;
size_t OC = fm.ocpg;
size_t group = fm.group;
WorkspaceBundle wbundle = get_bundle(param);
conv_fun do_conv_fun = nullptr;
// NOTE: remain_w is not used to gen hash of midout for compatible with changing
// shape runtime
#define DO_CONV_KERN_FUN(stride, filter, bias_mode, op) \
MIDOUT_BEGIN(megdnn_arm_common_conv_bias_int8_nchw_nchw44, \
midout_iv(#stride #filter #bias_mode #op##_hash)) { \
do_conv_fun = do_conv_kern<filter, bias_mode, op, stride>; \
} \
MIDOUT_END();

#define GET_OP_PARAM(stride, filter, bias_mode) \
switch (param.nonlineMode) { \
case param::ConvBias::NonlineMode::IDENTITY: \
DO_CONV_KERN_FUN(stride, filter, bias_mode, \
TypeCvtOp<dt_qint32 MEGDNN_COMMA dt_qint8>) \
break; \
case param::ConvBias::NonlineMode::RELU: \
DO_CONV_KERN_FUN(stride, filter, bias_mode, \
ReluOp<dt_qint32 MEGDNN_COMMA dt_qint8>) \
break; \
case param::ConvBias::NonlineMode::H_SWISH: \
DO_CONV_KERN_FUN(stride, filter, bias_mode, \
HSwishOp<dt_qint32 MEGDNN_COMMA dt_qint8>) \
break; \
default: \
megdnn_assert(0); \
break; \
}
#define GET_BIAS_MODE_PARAM(stride, filter) \
switch (param.bias_mode) { \
case BiasMode::NO_BIAS: \
GET_OP_PARAM(stride, filter, BiasMode::NO_BIAS) \
break; \
case BiasMode::BROADCAST_CHANNEL_BIAS: \
GET_OP_PARAM(stride, filter, BiasMode::BROADCAST_CHANNEL_BIAS) \
break; \
default: \
megdnn_assert(0); \
break; \
}

#define DISPATCH_CONV_KERN(stride) \
switch (param.filter_meta.spatial[0]) { \
case 2: \
GET_BIAS_MODE_PARAM(stride, 2) \
break; \
case 3: \
GET_BIAS_MODE_PARAM(stride, 3) \
break; \
case 5: \
GET_BIAS_MODE_PARAM(stride, 5) \
break; \
case 7: \
GET_BIAS_MODE_PARAM(stride, 7) \
break; \
default: \
megdnn_assert(0); \
break; \
}

switch (param.filter_meta.stride[0]) {
case 1:
DISPATCH_CONV_KERN(1);
break;
case 2:
DISPATCH_CONV_KERN(2);
break;
default:
megdnn_throw(ssprintf("Unsupport stride size %u for the first conv",
param.filter_meta.stride[0])
.c_str());
break;
}

#undef DO_CONV_KERN_FUN
#undef GET_REMAIN_W_PARAM
#undef GET_OP_PARAM
#undef GET_BIAS_MODE_PARAM
#undef DISPATCH_CONV_KERN

megdnn_assert(do_conv_fun);

SmallVector<ConvBiasImpl::NCBKern> ret_kerns;
WorkspaceBundle bundle = wbundle;

constexpr size_t pack_oc = 8;
size_t oc_step = pack_oc;
auto copy_padding = [bundle](const NCBKernParam& kern_param,
const NCBKernIndex& ncb_index) {
copy_padding_kern(bundle, kern_param, ncb_index, ncb_index.ndrange_id);
};
ret_kerns.push_back({copy_padding, {N, group, fm.icpg}});

auto do_pack_weight = [bundle](const NCBKernParam& kern_param,
const NCBKernIndex& ncb_index) {
pack_weight(bundle, kern_param, ncb_index);
};
ret_kerns.push_back({do_pack_weight, {static_cast<size_t>(group)}});

CpuNDRange ncb_range = {N, group, div_ceil(OC, oc_step)};
auto do_conv = [bundle, do_conv_fun, ncb_range](
const NCBKernParam& kern_param,
const NCBKernIndex& ncb_index) {
do_conv_fun(bundle, kern_param, ncb_index, ncb_index.ndrange_id,
ncb_range);
};
ret_kerns.push_back({do_conv, ncb_range});

return ret_kerns;
}

// vim: syntax=cpp.doxygen

+ 1287
- 0
dnn/src/arm_common/conv_bias/int8/direct_nchw_nchw44_kern.h
File diff suppressed because it is too large
View File


+ 0
- 305
dnn/src/arm_common/conv_bias/int8/direct_stride2_nchw_nchw44_algo.cpp View File

@@ -1,305 +0,0 @@
/**
* \file dnn/src/arm_common/conv_bias/int8/direct_stride2_nchw_nchw44_algo.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/

#include "megdnn/oprs.h"
#include "src/arm_common/conv_bias/int8/algos.h"
#include "src/arm_common/conv_bias/int8/direct_stride2_nchw_nchw44_kern.h"
#include "src/arm_common/conv_bias/int8/strategy.h"
#include "src/arm_common/elemwise_op.h"
#include "src/common/opr_delegate.h"

#include "midout.h"

using namespace megdnn;
using namespace arm_common;
using conv_fun = std::function<void(
WorkspaceBundle bundle, const ConvBiasImpl::NCBKernParam& kern_param,
const ConvBiasImpl::NCBKernIndex& ncb_index,
const CpuNDRange& workspace_ids, const CpuNDRange& ncb_range)>;
MIDOUT_DECL(megdnn_arm_common_conv_bias_int8_nchw_nchw44_stride2)

static void get_rectified_size(
const megdnn::fallback::ConvBiasImpl::NCBKernSizeParam& param,
size_t& IH2, size_t& IW2, size_t& OH2, size_t& OW2) {
auto&& fm = param.filter_meta;
size_t IH = param.isz[0];
size_t IW = param.isz[1];
size_t OH = param.osz[0];
size_t OW = param.osz[1];

OH2 = OH;
OW2 = OW;
IH2 = round_up(IH + 2 * fm.padding[0], static_cast<size_t>(2));
IW2 = IW + 2 * fm.padding[1];
}
static WorkspaceBundle get_bundle(const ConvBiasImpl::NCBKernSizeParam& param) {
constexpr size_t src_expand = 4;
auto&& fm = param.filter_meta;
size_t group = fm.group;
size_t batch = param.n;
size_t IC = fm.icpg;
size_t OC = fm.ocpg;
size_t FH = fm.spatial[0];
size_t FW = fm.spatial[1];
size_t IH2, IW2, OH2, OW2;
get_rectified_size(param, IH2, IW2, OH2, OW2);
megdnn_assert(group == 1, "only support group == 1 now");
size_t src_size =
batch * group * IC * IH2 * IW2 * sizeof(int8_t) * src_expand;
size_t weight_size = group * OC * IC * FH * FW * sizeof(int8_t);
return {nullptr, {src_size, weight_size}};
};

static void copy_padding_kern(WorkspaceBundle bundle,
const ConvBiasImpl::NCBKernParam& kern_param,
const ConvBiasImpl::NCBKernIndex& ncb_index,
const CpuNDRange& workspace_ids) {
size_t IH = kern_param.isz[0];
size_t IW = kern_param.isz[1];
size_t IC = kern_param.filter_meta.icpg;
size_t PH = kern_param.filter_meta.padding[0];
size_t PW = kern_param.filter_meta.padding[1];
size_t GROUP = kern_param.filter_meta.group;

size_t IH2, IW2, OH2, OW2;
get_rectified_size(kern_param, IH2, IW2, OH2, OW2);
size_t padding_group_size = IH2 * IW2 * IC;
bundle.set(kern_param.workspace_ptr);
//! Used for get the workspace offset
constexpr int expend_element = 4;
// TODO: block dim is better to get from arg
size_t workspace_ic_block = 1;
size_t workspace_batch_id = workspace_ids[0];
size_t workspace_group_id = workspace_ids[1];
size_t workspace_ic_id = workspace_ids[2];
size_t workspace_ic = workspace_ic_id * workspace_ic_block;
size_t batch_id = ncb_index.ndrange_id[0];
size_t group_id = ncb_index.ndrange_id[1];

const int8_t* sptr = static_cast<const int8_t*>(
kern_param.src<int8_t>(batch_id, group_id, workspace_ic_id, 1, 1));
//! copy to sptr_base to eliminate padding effect
int8_t* sptr_base = static_cast<int8_t*>(bundle.get(0)) +
(workspace_batch_id * GROUP * padding_group_size +
workspace_group_id * padding_group_size +
workspace_ic * IH2 * IW2) *
expend_element;
conv_bias::pack_nchw_src_for_nchw44_conv(sptr, sptr_base, 1, PH, PH, PW, PW,
IH, IW);
}

template <size_t filter, BiasMode bias_mode, typename Op>
static void do_conv_kern(WorkspaceBundle bundle,
const ConvBiasImpl::NCBKernParam& kern_param,
const ConvBiasImpl::NCBKernIndex& ncb_index,
const CpuNDRange& workspace_ids,
const CpuNDRange& ncb_range) {
size_t OH = kern_param.osz[0];
size_t OW = kern_param.osz[1];
size_t FH = kern_param.filter_meta.spatial[0];
size_t FW = kern_param.filter_meta.spatial[1];
size_t IC = kern_param.filter_meta.icpg;
size_t OC = kern_param.filter_meta.ocpg;
size_t GROUP = kern_param.filter_meta.group;
size_t IH2, IW2, OH2, OW2;
get_rectified_size(kern_param, IH2, IW2, OH2, OW2);
bool need_post_process =
kern_param.dst_type.enumv() == DTypeEnum::QuantizedS8;
//! if dst_type is qint32, the op is not used, just fill with (1.0f,4.0f)
Op op = Op(1.0f, 4.0f);
if (need_post_process) {
float scale_bias =
kern_param.bias_type.param<dtype::QuantizedS32>().scale;
float scale_dst = kern_param.dst_type.param<dtype::QuantizedS8>().scale;
op = Op(scale_bias, scale_dst);
}
size_t padding_group_size = IH2 * IW2 * IC;
bundle.set(kern_param.workspace_ptr);

constexpr size_t pack_c = 4;
constexpr size_t src_expand_size = 4;
const size_t workspace_batch_id = workspace_ids[0];
const size_t workspace_group_id = workspace_ids[1];
const size_t batch_id = ncb_index.ndrange_id[0];
const size_t group_id = ncb_index.ndrange_id[1];
const size_t oc_id = ncb_index.ndrange_id[2];
const size_t oc_block_num = ncb_range[2];
size_t nr_pack_per_step = div_ceil(div_ceil(OC, pack_c), oc_block_num);
size_t oc_block = nr_pack_per_step * pack_c;
const size_t oc_idx = oc_id * oc_block;
if (oc_id == (oc_block_num - 1)) {
oc_block = OC - oc_id * nr_pack_per_step * pack_c;
}
megdnn_assert(oc_block % pack_c == 0,
"oc must be devisible by 4, but oc = %zu", oc_block);
const int8_t* sptr =
static_cast<int8_t*>(bundle.get(0)) +
workspace_batch_id * GROUP * padding_group_size * src_expand_size +
workspace_group_id * padding_group_size * src_expand_size;

const int8_t* fptr =
kern_param.filter<dt_int8>(group_id) + oc_idx * FH * FW * IC;
void* dst = reinterpret_cast<void*>(
reinterpret_cast<ptrdiff_t>(
kern_param.dst<void>(batch_id, group_id)) +
oc_idx * OH * OW);
const int32_t* bptr =
kern_param.bias<dt_int32>(batch_id, group_id) + oc_idx;
auto packed_weight = reinterpret_cast<int8_t*>(bundle.get(1)) +
group_id * OC * IC * FH * FW + oc_idx * IC * FH * FW;

conv_bias::pack_nchw44_weight_for_nchw_conv(fptr, packed_weight, IC, FH, FW,
oc_block);
#define KERN1_NCHW44_CONV(filter) \
conv_bias::conv_direct_stride2_##filter##x##filter##_int8_nchw_nchw44< \
bias_mode, Op>(sptr, packed_weight, bptr, nullptr, \
static_cast<int8_t*>(dst), oc_block, IC, IH2, IW2, \
OH, OW, op)
DISPATCH_FILTER(filter, KERN1_NCHW44_CONV);
#undef KERN1_NCHW44_CONV
}

/* ===================== stride2 algo ===================== */
bool ConvBiasImpl::AlgoS8DirectStride2NCHWNCHW44::usable(
fallback::ConvBiasImpl*, const NCBKernSizeParam& param,
AlgoSelectionStrategy algo_selection_strategy) const {
MEGDNN_MARK_USED_VAR(algo_selection_strategy);
auto&& fm = param.filter_meta;
auto FH = fm.spatial[0];
auto OC = fm.ocpg;
bool avaible = //! src and filter are qint8, dst is qint8
fm.icpg < 4 && // must be nchw input
((param.src_type.enumv() == DTypeEnum::QuantizedS8 &&
param.filter_type.enumv() == DTypeEnum::QuantizedS8 &&
(param.dst_type.enumv() == DTypeEnum::QuantizedS8))) &&
(fm.format == param::Convolution::Format::NCHW44) &&
(OC % 4 == 0 && OC >= 4) && !fm.should_flip && fm.group == 1 &&
fm.spatial_ndim == 2 && fm.dilation[0] == 1 &&
fm.dilation[1] == 1 && fm.stride[0] == 2 && fm.stride[1] == 2 &&
FH == fm.spatial[1] && (FH == 3 || FH == 5 || FH == 7) &&
fm.group == 1 && param.bias_mode != BiasMode::BIAS;
return avaible;
}

bool ConvBiasImpl::AlgoS8DirectStride2NCHWNCHW44::is_preferred(
megdnn::fallback::ConvBiasImpl* conv_bias_impl_ptr,
const NCBKernSizeParam& param) const {
// TODO: benchmark and fix
MEGDNN_MARK_USED_VAR(conv_bias_impl_ptr);
MEGDNN_MARK_USED_VAR(param);
return false;
}

size_t ConvBiasImpl::AlgoS8DirectStride2NCHWNCHW44::get_workspace(
fallback::ConvBiasImpl*, const NCBKernSizeParam& param) const {
return get_bundle(param).total_size_in_bytes();
}

SmallVector<ConvBiasImpl::NCBKern>
ConvBiasImpl::AlgoS8DirectStride2NCHWNCHW44::dispatch_kerns(
fallback::ConvBiasImpl*, const NCBKernSizeParam& param) const {
auto fm = param.filter_meta;
size_t N = param.n;
size_t OC = fm.ocpg;
size_t group = fm.group;
WorkspaceBundle wbundle = get_bundle(param);
conv_fun do_conv_fun = nullptr;
// NOTE: remain_w is not used to gen hash of midout for compatible with changing
// shape runtime
#define DO_CONV_KERN_FUN(filter, bias_mode, op) \
MIDOUT_BEGIN(megdnn_arm_common_conv_bias_int8_nchw_nchw44_stride2, \
midout_iv(#filter #bias_mode #op##_hash)) { \
do_conv_fun = do_conv_kern<filter, bias_mode, op>; \
} \
MIDOUT_END();

#define GET_OP_PARAM(filter, bias_mode) \
switch (param.nonlineMode) { \
case param::ConvBias::NonlineMode::IDENTITY: \
DO_CONV_KERN_FUN(filter, bias_mode, \
TypeCvtOp<dt_qint32 MEGDNN_COMMA dt_qint8>) \
break; \
case param::ConvBias::NonlineMode::RELU: \
DO_CONV_KERN_FUN(filter, bias_mode, \
ReluOp<dt_qint32 MEGDNN_COMMA dt_qint8>) \
break; \
case param::ConvBias::NonlineMode::H_SWISH: \
DO_CONV_KERN_FUN(filter, bias_mode, \
HSwishOp<dt_qint32 MEGDNN_COMMA dt_qint8>) \
break; \
default: \
megdnn_assert(0); \
break; \
}
#define GET_BIAS_MODE_PARAM(filter) \
switch (param.bias_mode) { \
case BiasMode::NO_BIAS: \
GET_OP_PARAM(filter, BiasMode::NO_BIAS) \
break; \
case BiasMode::BROADCAST_CHANNEL_BIAS: \
GET_OP_PARAM(filter, BiasMode::BROADCAST_CHANNEL_BIAS) \
break; \
default: \
megdnn_assert(0); \
break; \
}

#define DISPATCH_CONV_KERN() \
switch (param.filter_meta.spatial[0]) { \
case 3: \
GET_BIAS_MODE_PARAM(3) \
break; \
case 5: \
GET_BIAS_MODE_PARAM(5) \
break; \
case 7: \
GET_BIAS_MODE_PARAM(7) \
break; \
default: \
megdnn_assert(0); \
break; \
}

DISPATCH_CONV_KERN();

#undef DO_CONV_KERN_FUN
#undef GET_REMAIN_W_PARAM
#undef GET_OP_PARAM
#undef GET_BIAS_MODE_PARAM
#undef DISPATCH_CONV_KERN

megdnn_assert(do_conv_fun);

SmallVector<ConvBiasImpl::NCBKern> ret_kerns;
WorkspaceBundle bundle = wbundle;

constexpr size_t pack_oc = 8;
size_t oc_step = pack_oc;
auto copy_padding = [bundle](const NCBKernParam& kern_param,
const NCBKernIndex& ncb_index) {
copy_padding_kern(bundle, kern_param, ncb_index, ncb_index.ndrange_id);
};
ret_kerns.push_back({copy_padding, {N, group, fm.icpg}});

CpuNDRange ncb_range = {N, group, div_ceil(OC, oc_step)};
auto do_conv = [bundle, do_conv_fun, ncb_range](
const NCBKernParam& kern_param,
const NCBKernIndex& ncb_index) {
do_conv_fun(bundle, kern_param, ncb_index, ncb_index.ndrange_id,
ncb_range);
};
ret_kerns.push_back({do_conv, ncb_range});

return ret_kerns;
}

// vim: syntax=cpp.doxygen

+ 0
- 789
dnn/src/arm_common/conv_bias/int8/direct_stride2_nchw_nchw44_kern.cpp View File

@@ -1,789 +0,0 @@
/**
* \file dnn/src/arm_common/conv_bias/int8/direct_stride2_nchw44_kern_nchw.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/

#include "src/arm_common/conv_bias/int8/direct_stride2_nchw_nchw44_kern.h"
#include "src/arm_common/conv_bias/intrinsic_helper.h"
#include "src/arm_common/elemwise_op.h"
#include "src/arm_common/simd_macro/marm_neon.h"
#include "src/common/unroll_macro.h"
#include "src/common/utils.h"
#include "src/fallback/conv_bias/common.h"

using namespace megdnn;
using namespace arm_common;
namespace {

template <int src_idx, int weight_idx, int c_dim, typename Func, typename T,
typename T2, typename T3, typename T4>
struct ShiftCalHelper {
static void impl(T& c, T2& src, T3& weight, T4& temp);
static void impl(T& c, T2& src, T3& weight);
};
template <int src_idx, int weight_idx, typename Func, typename T, typename T2,
typename T3, typename T4>
struct ShiftCalHelper<src_idx, weight_idx, 2, Func, T, T2, T3, T4> {
static void impl(T& c, T2& src, T3& weight, T4& temp) {
c[0][0] = Func::impl(src[0 + src_idx], weight[0][weight_idx], c[0][0],
temp[0]);
c[1][0] = Func::impl(src[0 + src_idx], weight[1][weight_idx], c[1][0],
temp[1]);
c[0][1] = Func::impl(src[1 + src_idx], weight[0][weight_idx], c[0][1],
temp[2]);
c[1][1] = Func::impl(src[1 + src_idx], weight[1][weight_idx], c[1][1],
temp[3]);
c[0][2] = Func::impl(src[2 + src_idx], weight[0][weight_idx], c[0][2],
temp[0]);
c[1][2] = Func::impl(src[2 + src_idx], weight[1][weight_idx], c[1][2],
temp[1]);
c[0][3] = Func::impl(src[3 + src_idx], weight[0][weight_idx], c[0][3],
temp[2]);
c[1][3] = Func::impl(src[3 + src_idx], weight[1][weight_idx], c[1][3],
temp[3]);
}
static void impl(T& c, T2& src, T3& weight) {
c[0][0] = Func::impl(src[0 + src_idx], weight[0][weight_idx], c[0][0]);
c[1][0] = Func::impl(src[0 + src_idx], weight[1][weight_idx], c[1][0]);
c[0][1] = Func::impl(src[1 + src_idx], weight[0][weight_idx], c[0][1]);
c[1][1] = Func::impl(src[1 + src_idx], weight[1][weight_idx], c[1][1]);
c[0][2] = Func::impl(src[2 + src_idx], weight[0][weight_idx], c[0][2]);
c[1][2] = Func::impl(src[2 + src_idx], weight[1][weight_idx], c[1][2]);
c[0][3] = Func::impl(src[3 + src_idx], weight[0][weight_idx], c[0][3]);
c[1][3] = Func::impl(src[3 + src_idx], weight[1][weight_idx], c[1][3]);
}
};
template <int src_idx, int weight_idx, typename Func, typename T, typename T2,
typename T3, typename T4>
struct ShiftCalHelper<src_idx, weight_idx, 1, Func, T, T2, T3, T4> {
static void impl(T& c, T2& src, T3& weight, T4& temp) {
c[0][0] = Func::impl(src[0 + src_idx], weight[0][weight_idx], c[0][0],
temp[0]);
c[0][1] = Func::impl(src[1 + src_idx], weight[0][weight_idx], c[0][1],
temp[2]);
c[0][2] = Func::impl(src[2 + src_idx], weight[0][weight_idx], c[0][2],
temp[0]);
c[0][3] = Func::impl(src[3 + src_idx], weight[0][weight_idx], c[0][3],
temp[2]);
}
static void impl(T& c, T2& src, T3& weight) {
c[0][0] = Func::impl(src[0 + src_idx], weight[0][weight_idx], c[0][0]);
c[0][1] = Func::impl(src[1 + src_idx], weight[0][weight_idx], c[0][1]);
c[0][2] = Func::impl(src[2 + src_idx], weight[0][weight_idx], c[0][2]);
c[0][3] = Func::impl(src[3 + src_idx], weight[0][weight_idx], c[0][3]);
}
};

template <int src_idx, int weight_idx, int c_dim, typename FUNC, typename T,
typename T2, typename T3, typename T4>
inline void cal_helper(T& c, T2& src, T3& weight, T4& temp) {
ShiftCalHelper<src_idx, weight_idx, c_dim, FUNC, T, T2, T3, T4>::impl(
c, src, weight, temp);
}
template <int src_idx, int weight_idx, int c_dim, typename FUNC, typename T,
typename T2, typename T3>
inline void cal_helper(T& c, T2& src, T3& weight) {
ShiftCalHelper<src_idx, weight_idx, c_dim, FUNC, T, T2, T3, int>::impl(
c, src, weight);
};

template <int oc>
struct OCHelper {
public:
static const int val = 0;
};
template <>
struct OCHelper<4> {
public:
static const int val = 1;
};
template <>
struct OCHelper<8> {
public:
static const int val = 2;
};

template <BiasMode bias_mode, typename Op, int remain_w, int filter_size,
int oc_block>
struct KerNeonXXs2NchwNchw44 {
static void impl(const int8_t* src_ptr, const int8_t* weight_ptr,
const int32_t* bias_ptr, int8_t* dst_ptr, int ic, int ih,
int iw, int ld_dst_oc, const Op& op);
};
/**
* filter shape = (oc/4, ic, 7, 7, 4), first 4 oc is f0 = filter[0, 0, :, :, :]
* calculate sequence \
* f0[0:1, 0:1, 4] dot4, \
* f0[0:1, 2:3, 4] dot4, \
* f0[0:1, 4:5, 4] dot4, \
* f0[0:1, 6, 4] dot2, \
* ...
* f0[6, 0:1, 4] dot2, \
* f0[6, 2:3, 4] dot2, \
* f0[6, 4:5, 4] dot2, \
* f0[6, 6, 4] dot1, \
* look like:
* |---|---|---|-|
* |x x|x x|x x|x|
* |x x|x x|x x|x|
* |---|---|---|-|
* |x x|x x|x x|x|
* |x x|x x|x x|x|
* |---|---|---|-|
* |x x|x x|x x|x|
* |x x|x x|x x|x|
* |---|---|---|-|
* |x x|x x|x x|x|
* |---|---|---|-|
**/
template <BiasMode bias_mode, typename Op, int remain_w, int oc_block>
struct KerNeonXXs2NchwNchw44<bias_mode, Op, remain_w, 7, oc_block> {
static void impl(const int8_t* src_ptr, const int8_t* weight_ptr,
const int32_t* bias_ptr, int8_t* dst_ptr, int ic, int ih,
int iw, int ld_dst_oc, const Op& op) {
static const uint8_t src_idx_buffer[16] = {0, 8, 0, 8, 0, 8, 0, 8,
0, 8, 0, 8, 0, 8, 0, 8};
constexpr int filter_size = 7;
constexpr int ic_step = 1;
constexpr int oc_step = 4;
constexpr int pack_iw_len = 4;
constexpr int fh_step = 2;
constexpr int fh_end = filter_size / fh_step * fh_step;
constexpr int c_dim = OCHelper<oc_block>::val;

const int ic_stride = ih * iw * pack_iw_len;
const int ld_dot4_weight_oc = oc_step * filter_size * filter_size * ic;

int32x4_t c[c_dim][4];

init_ocx_ow4<c_dim, bias_mode>(c, bias_ptr, oc_step);

for (int ic_idx = 0; ic_idx < ic; ic_idx += ic_step) {
for (int fh_idx = 0; fh_idx < fh_end; fh_idx += fh_step) {
const int8_t* nchw_src_ptr =
src_ptr + ic_idx * ic_stride +
fh_idx * iw * ic_step * pack_iw_len;
int8x16_t src[6];
int8x16_t dot4_weight[c_dim][3];
int16x8_t temp_c[4];
load_helper<3, 0, 16, c_dim, Vld1q_s8>(dot4_weight, weight_ptr,
ld_dot4_weight_oc);
load_helper<6, 0, 16, 0, Vld1q_s8>(src, nchw_src_ptr, 0);
cal_helper<0, 0, c_dim, Vdotq_s32_h>(c, src, dot4_weight,
temp_c);
cal_helper<1, 1, c_dim, Vdotq_s32_h>(c, src, dot4_weight,
temp_c);
cal_helper<2, 2, c_dim, Vdotq_s32_h>(c, src, dot4_weight,
temp_c);

int8x8_t src_dot2[4];
int8x8_t dot2_weight[c_dim][1];
load_helper<1, 3 * 16, 8, c_dim, Vld1_s8>(
dot2_weight, weight_ptr, ld_dot4_weight_oc);
load_helper<4, 3 * 16, 16, 0, Vld1_s8>(src_dot2, nchw_src_ptr,
0);
cal_helper<0, 0, c_dim, Vdot2_s32_h>(c, src_dot2, dot2_weight,
temp_c);
weight_ptr += filter_size * pack_iw_len * fh_step;
}
const int8_t* nchw_src_ptr = src_ptr + ic_idx * ic_stride +
6 * iw * ic_step * pack_iw_len;

int8x8_t dot2_weight[c_dim][3];
int16x8_t temp_c[4];
int8x8_t src_dot2[6];
uint8x16_t tbl = vld1q_u8(src_idx_buffer);
load_helper<3, 0, 8, c_dim, Vld1_s8>(dot2_weight, weight_ptr,
ld_dot4_weight_oc);
load_helper_x<6, 0, 16, 0, Vldq_tbl_low_s8>(src_dot2, nchw_src_ptr,
0, tbl);
cal_helper<0, 0, c_dim, Vdot2_s32_h>(c, src_dot2, dot2_weight,
temp_c);
cal_helper<1, 1, c_dim, Vdot2_s32_h>(c, src_dot2, dot2_weight,
temp_c);
cal_helper<2, 2, c_dim, Vdot2_s32_h>(c, src_dot2, dot2_weight,
temp_c);

int16x8_t dot1_weight[c_dim][1];
int16x8_t src_dot1[4];
load_helper<1, 3 * 8, 8, c_dim, Vldq_dup_4s8_8s16>(
dot1_weight, weight_ptr, ld_dot4_weight_oc);
load_helper<4, 3 * 16, 16, 0, Vld1_dup_s8_s16>(src_dot1,
nchw_src_ptr, 0);
cal_helper<0, 0, c_dim, Vmlal_s16>(c, src_dot1, dot1_weight);
weight_ptr += filter_size * pack_iw_len;
}
store_ocx_ow4_remain_static<c_dim, remain_w>(c, op, dst_ptr, ld_dst_oc);
}
};
template <BiasMode bias_mode, typename Op, int remain_w, int oc_block>
struct KerNeonXXs2NchwNchw44<bias_mode, Op, remain_w, 5, oc_block> {
static void impl(const int8_t* src_ptr, const int8_t* weight_ptr,
const int32_t* bias_ptr, int8_t* dst_ptr, int ic, int ih,
int iw, int ld_dst_oc, const Op& op) {
constexpr int filter_size = 5;
static const uint8_t src_idx_buffer[16] = {0, 8, 0, 8, 0, 8, 0, 8,
0, 8, 0, 8, 0, 8, 0, 8};
constexpr int ih_step = 2;
constexpr int ic_step = 1;
constexpr int oc_step = 4;
constexpr int pack_iw_len = 4;
constexpr int fh_end = filter_size / ih_step * ih_step;

const int ic_stride = ih * iw * pack_iw_len;
const int ld_dot4_weight_oc = oc_step * filter_size * filter_size * ic;
constexpr int c_dim = OCHelper<oc_block>::val;
int32x4_t c[c_dim][4];

init_ocx_ow4<c_dim, bias_mode>(c, bias_ptr, oc_step);

for (int ic_idx = 0; ic_idx < ic; ic_idx += ic_step) {
for (int fh_idx = 0; fh_idx < fh_end; fh_idx += ih_step) {
const int8_t* nchw_src_ptr =
src_ptr + ic_idx * ic_stride +
fh_idx * iw * ic_step * pack_iw_len;
int8x16_t src[5];
int8x16_t dot4_weight[c_dim][2];
int16x8_t temp_c[4];
load_helper<2, 0, 16, c_dim, Vld1q_s8>(dot4_weight, weight_ptr,
ld_dot4_weight_oc);
load_helper<5, 0, 16, 0, Vld1q_s8>(src, nchw_src_ptr, 0);
cal_helper<0, 0, c_dim, Vdotq_s32_h>(c, src, dot4_weight,
temp_c);
cal_helper<1, 1, c_dim, Vdotq_s32_h>(c, src, dot4_weight,
temp_c);

int8x8_t src_dot2[4];
int8x8_t dot2_weight[c_dim][1];
load_helper<1, 2 * 16, 8, c_dim, Vld1_s8>(
dot2_weight, weight_ptr, ld_dot4_weight_oc);
load_helper<4, 2 * 16, 16, 0, Vld1_s8>(src_dot2, nchw_src_ptr,
0);
cal_helper<0, 0, c_dim, Vdot2_s32_h>(c, src_dot2, dot2_weight,
temp_c);
weight_ptr += filter_size * pack_iw_len * ih_step;
}
const int8_t* nchw_src_ptr = src_ptr + ic_idx * ic_stride +
fh_end * iw * ic_step * pack_iw_len;

int8x8_t dot2_weight[c_dim][2];
int16x8_t temp_c[4];
int8x8_t src_dot2[5];
uint8x16_t tbl = vld1q_u8(src_idx_buffer);
load_helper<2, 0, 8, c_dim, Vld1_s8>(dot2_weight, weight_ptr,
ld_dot4_weight_oc);
load_helper_x<5, 0, 16, 0, Vldq_tbl_low_s8>(src_dot2, nchw_src_ptr,
0, tbl);

cal_helper<0, 0, c_dim, Vdot2_s32_h>(c, src_dot2, dot2_weight,
temp_c);
cal_helper<1, 1, c_dim, Vdot2_s32_h>(c, src_dot2, dot2_weight,
temp_c);

int16x8_t dot1_weight[c_dim][1];
int16x8_t src_dot1[4];
load_helper<1, 2 * 8, 8, c_dim, Vldq_dup_4s8_8s16>(
dot1_weight, weight_ptr, ld_dot4_weight_oc);
load_helper<4, 2 * 16, 16, 0, Vld1_dup_s8_s16>(src_dot1,
nchw_src_ptr, 0);

cal_helper<0, 0, c_dim, Vmlal_s16>(c, src_dot1, dot1_weight);
weight_ptr += filter_size * pack_iw_len;
}
store_ocx_ow4_remain_static<c_dim, remain_w>(c, op, dst_ptr, ld_dst_oc);
}
};
/**
* filter shape = (oc/4, ic, 3, 3, 4), first 4 oc is f0 = filter[0, 0, :, :, :]
* calculate sequence \
* f0[0:1, 0:1, 4] dot4, \
* f0[0:1, 2, 4] dot2, \
* f0[2, 0:1, 4] dot2, \
* f0[2, 2, 4] dot1 \
* look like:
* |---|-|
* |x x|x|
* |x x|x|
* |-----|
* |x x|x|
* |-----|
**/
template <BiasMode bias_mode, typename Op, int remain_w, int oc_block>
struct KerNeonXXs2NchwNchw44<bias_mode, Op, remain_w, 3, oc_block> {
static void impl(const int8_t* src_ptr, const int8_t* weight_ptr,
const int32_t* bias_ptr, int8_t* dst_ptr, int ic, int ih,
int iw, int ld_dst_oc, const Op& op) {
constexpr int filter_size = 3;
static const uint8_t src_idx_buffer[16] = {0, 8, 0, 8, 0, 8, 0, 8,
0, 8, 0, 8, 0, 8, 0, 8};
constexpr int oc_step = 4;
constexpr int ic_step = 1;
constexpr int loop_ic_step = 1;
constexpr int pack_iw_len = 4;

const int ic_stride = ih * iw * pack_iw_len;
const int ld_weight_oc = oc_step * filter_size * filter_size * ic;
constexpr int c_dim = OCHelper<oc_block>::val;

int32x4_t c[c_dim][4];
init_ocx_ow4<c_dim, bias_mode>(c, bias_ptr, oc_step);

for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) {
// first 2 line
{
const int8_t* nchw_src_ptr = src_ptr + ic_idx * ic_stride;
int8x16_t src[4];
int8x16_t dot4_weight[c_dim][1];
int16x8_t temp_c[4];
load_helper<1, 0, 16, c_dim, Vld1q_s8>(dot4_weight, weight_ptr,
ld_weight_oc);
load_helper<4, 0, 16, 0, Vld1q_s8>(src, nchw_src_ptr, 0);
cal_helper<0, 0, c_dim, Vdotq_s32_h>(c, src, dot4_weight,
temp_c);

int8x8_t src_dot2[4];
int8x8_t dot2_weight[c_dim][1];
load_helper<1, 1 * 16, 8, c_dim, Vld1_s8>(
dot2_weight, weight_ptr, ld_weight_oc);
load_helper<4, 1 * 16, 16, 0, Vld1_s8>(src_dot2, nchw_src_ptr,
0);
cal_helper<0, 0, c_dim, Vdot2_s32_h>(c, src_dot2, dot2_weight,
temp_c);
}
// last line
{
const int8_t* nchw_src_ptr = src_ptr + ic_idx * ic_stride +
2 * iw * ic_step * pack_iw_len;
int16x8_t temp_c[4];
int8x8_t src_dot2[4];
int8x8_t dot2_weight[c_dim][1];
uint8x16_t tbl = vld1q_u8(src_idx_buffer);
load_helper<1, 24, 8, c_dim, Vld1_s8>(dot2_weight, weight_ptr,
ld_weight_oc);
load_helper_x<4, 0, 16, 0, Vldq_tbl_low_s8>(
src_dot2, nchw_src_ptr, 0, tbl);
cal_helper<0, 0, c_dim, Vdot2_s32_h>(c, src_dot2, dot2_weight,
temp_c);
int16x8_t dot1_weight[c_dim][1];
int16x8_t src_dot1[4];
load_helper<1, 32, 8, c_dim, Vldq_dup_4s8_8s16>(
dot1_weight, weight_ptr, ld_weight_oc);
load_helper<4, 1 * 16, 16, 0, Vld1_dup_s8_s16>(src_dot1,
nchw_src_ptr, 0);
cal_helper<0, 0, c_dim, Vmlal_s16>(c, src_dot1, dot1_weight);
weight_ptr += filter_size * filter_size * pack_iw_len;
}
}
store_ocx_ow4_remain_static<c_dim, remain_w>(c, op, dst_ptr, ld_dst_oc);
}
};

} // namespace
enum PACK_MODE { NO_PAD = 0, FIRST_PAD = 1, LAST_PAD = 2 };
template <PACK_MODE mode>
inline void pack_src_one_line(const int8_t* inptr, int8_t* outptr, int left_pad,
int right_pad, const int iw) {
const int8_t* src_row_0 = inptr;
const int8_t* src_row_1 = inptr + iw;
constexpr int combine_row = 2;
constexpr int iw_step = 16;
constexpr int src_expand = 4;
constexpr int out_gap = iw_step * src_expand;
const int iw_end = iw / iw_step * iw_step;

memset(outptr, 0, combine_row * left_pad * src_expand * sizeof(int8_t));
outptr += combine_row * left_pad * src_expand;

for (int iw_idx = 0; iw_idx < iw_end; iw_idx += iw_step) {
int8x16_t row0 = vld1q_s8(src_row_0 + iw_idx);
int8x16_t row1 = vdupq_n_s8(0);
if (mode == PACK_MODE::NO_PAD) {
row1 = vld1q_s8(src_row_1 + iw_idx);
} else if (mode == PACK_MODE::FIRST_PAD) {
row1 = row0;
row0 = vdupq_n_s8(0);
}
int8x16x2_t pack_rows = vzipq_s8(row0, row1);
#define STORE_8S8(step) \
vst1_s8(outptr + step * 8, \
vreinterpret_s8_s16(vdup_laneq_s16( \
vreinterpretq_s16_s8(pack_rows.val[0]), step)));

UNROLL_CALL_RAW(8, STORE_8S8);
#undef STORE_8S8
#define STORE_8S8(step) \
vst1_s8(outptr + out_gap + step * 8, \
vreinterpret_s8_s16(vdup_laneq_s16( \
vreinterpretq_s16_s8(pack_rows.val[1]), step)));

UNROLL_CALL_RAW(8, STORE_8S8);
#undef STORE_8S8
outptr += out_gap * combine_row;
}
for (int iw_idx = iw_end; iw_idx < iw; iw_idx++) {
int8x8_t row0 = vld1_dup_s8(src_row_0 + iw_idx);
int8x8_t row1 = vdup_n_s8(0);
if (mode == PACK_MODE::NO_PAD) {
row1 = vld1_dup_s8(src_row_1 + iw_idx);
} else if (mode == PACK_MODE::FIRST_PAD) {
row1 = row0;
row0 = vdup_n_s8(0);
}
int8x8x2_t pack_rows = vzip_s8(row0, row1);
vst1_s8(outptr, pack_rows.val[0]);
outptr += src_expand * combine_row;
}
memset(outptr, 0, combine_row * right_pad * src_expand * sizeof(int8_t));
outptr += combine_row * right_pad * src_expand;
}
/**
* pack (ic, h, w) to (ic, h / 2, 2 * w)
* pack interleave two adjacent row in src and repeat 4 times, store to one row
* */
void conv_bias::pack_nchw_src_for_nchw44_conv(
const int8_t* inptr, int8_t* outptr, const int ic, const int top_pad,
const int bottom_pad, const int left_pad, const int right_pad,
const int ih, const int iw) {
constexpr int src_expand = 4;
constexpr int oh_step = 2;
const int oh = ih + top_pad + bottom_pad;
const int oh_end = div_floor(ih + top_pad, oh_step) * oh_step;
const int ow = (iw + left_pad + right_pad) * src_expand;

for (int ic_idx = 0; ic_idx < ic; ++ic_idx) {
int oh_idx = 0;
for (; oh_idx < top_pad; oh_idx += oh_step) {
if (top_pad - oh_idx >= oh_step) {
memset(outptr, 0, oh_step * ow * sizeof(int8_t));
} else {
pack_src_one_line<PACK_MODE::FIRST_PAD>(inptr, outptr, left_pad,
right_pad, iw);
inptr += iw;
}
outptr += oh_step * ow;
}

for (; oh_idx < oh_end; oh_idx += oh_step) {
pack_src_one_line<PACK_MODE::NO_PAD>(inptr, outptr, left_pad,
right_pad, iw);
inptr += oh_step * iw;
outptr += oh_step * ow;
}

for (; oh_idx < oh; oh_idx += oh_step) {
const int last_pad = oh_idx - ih - top_pad;
if (last_pad >= 0) {
memset(outptr, 0, oh_step * ow * sizeof(int8_t));
} else {
pack_src_one_line<PACK_MODE::LAST_PAD>(inptr, outptr, left_pad,
right_pad, iw);
inptr += iw;
}
outptr += oh_step * ow;
}
}
}

/**
* pack {oc / 4, fh, fw, ic, 4(oc)} to {oc / 4, ic, fh * fw, 4(oc)}
* pack interleave two adjacent row in filter to one row
* */
void conv_bias::pack_nchw44_weight_for_nchw_conv(const int8_t* inptr,
int8_t* outptr, const int ic,
const int fh, const int fw,
const int oc) {
constexpr int oc_step = 4;
constexpr int ic_step = 2;
constexpr int fh_step = 2;
constexpr int fw_step = 2;
const int ic_end = ic / ic_step * ic_step;
const int ic_remain = ic - ic_end;
const int fh_end = fh / fh_step * fh_step;
const int fh_remain = fh - fh_end;
const int fw_end = fw / fw_step * fw_step;
const int fw_remain = fw - fw_end;
const int filter_stride = ic * oc_step;
static const uint8_t ic2_idx_h_buffer[16] = {0, 8, 1, 9, 2, 10, 3, 11,
4, 12, 5, 13, 6, 14, 7, 15};
uint8x16_t ic2_idx_h = vld1q_u8(ic2_idx_h_buffer);
for (int oc_idx = 0; oc_idx < oc; oc_idx += oc_step) {
for (int ic_idx = 0; ic_idx < ic_end; ic_idx += ic_step) {
const int ic_offset = ic_idx * oc_step;
int8_t* output_ic0 = outptr + ic_idx * fh * fw * oc_step;
int8_t* output_ic1 = output_ic0 + fh * fw * oc_step;
for (int fh_idx = 0; fh_idx < fh_end; fh_idx += fh_step) {
const int fh_offset = fh_idx * fw * filter_stride;
for (int fw_idx = 0; fw_idx < fw; ++fw_idx) {
const int8_t* filter_ptr = inptr + fh_offset +
fw_idx * filter_stride +
ic_offset;
int8x8_t row_0 = vld1_s8(filter_ptr);
int8x8_t row_1 = vld1_s8(filter_ptr + fw * filter_stride);
int8x16_t combine_row = vcombine_s8(row_0, row_1);
combine_row = vqtbl1q_s8(combine_row, ic2_idx_h);
vst1_s8(output_ic0, vget_low_s8(combine_row));
vst1_s8(output_ic1, vget_high_s8(combine_row));
output_ic0 += 8;
output_ic1 += 8;
}
}
if (fh_remain > 0) {
const int fh_offset = fh_end * fw * filter_stride;
for (int fw_idx = 0; fw_idx < fw_end; fw_idx += fw_step) {
const int8_t* filter_ptr = inptr + fh_offset +
fw_idx * filter_stride +
ic_offset;
int8x8_t row_0 = vld1_s8(filter_ptr);
int8x8_t row_1 = vld1_s8(filter_ptr + filter_stride);
int8x16_t combine_row = vcombine_s8(row_0, row_1);
combine_row = vqtbl1q_s8(combine_row, ic2_idx_h);
vst1_s8(output_ic0, vget_low_s8(combine_row));
vst1_s8(output_ic1, vget_high_s8(combine_row));
output_ic0 += 8;
output_ic1 += 8;
}
if (fw_remain > 0) {
const int8_t* filter_ptr = inptr + fh_offset +
fw_end * filter_stride +
ic_offset;
int8x8_t row_0 = vld1_s8(filter_ptr);
vst1_lane_s32((int32_t*)output_ic0,
vreinterpret_s32_s8(row_0), 0);
vst1_lane_s32((int32_t*)output_ic1,
vreinterpret_s32_s8(row_0), 1);
output_ic0 += 4;
output_ic1 += 4;
}
}
}
if (ic_remain > 0) {
const int ic_offset = ic_end * oc_step;
int8_t* output_ic0 = outptr + ic_end * fh * fw * oc_step;
for (int fh_idx = 0; fh_idx < fh_end; fh_idx += fh_step) {
const int fh_offset = fh_idx * fw * filter_stride;
for (int fw_idx = 0; fw_idx < fw; ++fw_idx) {
const int8_t* filter_ptr = inptr + fh_offset +
fw_idx * filter_stride +
ic_offset;
int8x8_t row_0 = vreinterpret_s8_s32(
vld1_dup_s32((const int32_t*)(filter_ptr)));
int8x8_t row_1 = vreinterpret_s8_s32(vld1_dup_s32(
(const int32_t*)(filter_ptr + fw * filter_stride)));
int8x16_t combine_row = vcombine_s8(row_0, row_1);
combine_row = vqtbl1q_s8(combine_row, ic2_idx_h);
vst1_s8(output_ic0, vget_low_s8(combine_row));
output_ic0 += 8;
}
}
if (fh_remain > 0) {
const int fh_offset = fh_end * fw * filter_stride;
for (int fw_idx = 0; fw_idx < fw_end; fw_idx += fw_step) {
const int8_t* filter_ptr = inptr + fh_offset +
fw_idx * filter_stride +
ic_offset;
int8x8_t row_0 = vreinterpret_s8_s32(
vld1_dup_s32((const int32_t*)(filter_ptr)));
int8x8_t row_1 = vreinterpret_s8_s32(vld1_dup_s32(
(const int32_t*)(filter_ptr + filter_stride)));
int8x16_t combine_row = vcombine_s8(row_0, row_1);
combine_row = vqtbl1q_s8(combine_row, ic2_idx_h);
vst1_s8(output_ic0, vget_low_s8(combine_row));
output_ic0 += 8;
}
if (fw_remain > 0) {
const int8_t* filter_ptr = inptr + fh_offset +
fw_end * filter_stride +
ic_offset;
*(int32_t*)(output_ic0) = *(const int32_t*)(filter_ptr);
output_ic0 += 4;
}
}
}
inptr += oc_step * fh * fw * ic;
outptr += oc_step * fh * fw * ic;
}
}

template <BiasMode bias_mode, typename Op, size_t filter_size>
static void conv_direct_stride2_int8_nchw_nchw44(
const int8_t* src, const int8_t* filter, const int32_t* bias,
int32_t* temp, int8_t* dst, const size_t oc, const size_t ic,
const size_t ih, const size_t iw, const size_t oh, const size_t ow,
const Op& op) {
MEGDNN_MARK_USED_VAR(temp);
constexpr size_t fh = filter_size;
constexpr size_t fw = filter_size;
constexpr size_t ic_step = 1;
constexpr size_t big_oc_step = 8;
constexpr size_t oc_step = 4;
constexpr size_t ih_step = 2;
constexpr size_t oh_step = 1;
constexpr size_t ow_step = 4;
constexpr size_t stride_h = 2;
constexpr size_t stride_w = 2;
constexpr int pack_iw_len = 4;

const size_t img_stride = oh * ow;
const size_t ow_end = ow / ow_step * ow_step;
const size_t ow_remain = ow - ow_end;
const size_t oc_end = oc / big_oc_step * big_oc_step;
const size_t oc_remain = oc - oc_end;
const int ld_dst_oc = oc_step * img_stride;

using remain_fun =
std::function<void(const int8_t* src_ptr, const int8_t* weight_ptr,
const int32_t* bias_ptr, int8_t* dst_ptr, int ic,
int ih, int iw, int ld_dst_oc, const Op& op)>;
remain_fun kern_big_oc_remain = nullptr;
remain_fun kern_small_oc_remain = nullptr;
switch (ow_remain) {
#define cb(step) \
case step: \
kern_big_oc_remain = \
KerNeonXXs2NchwNchw44<bias_mode, Op, step, filter_size, \
big_oc_step>::impl; \
kern_small_oc_remain = \
KerNeonXXs2NchwNchw44<bias_mode, Op, step, filter_size, \
oc_step>::impl; \
break;

UNROLL_CALL_RAW(4, cb);
default:
megdnn_assert(0, "no remain %zu for kern", ow_remain);
}

for (size_t oc_idx = 0; oc_idx < oc_end; oc_idx += big_oc_step) {
const size_t weight_offset = oc_idx * ic * fh * fw;
for (size_t oh_idx = 0; oh_idx < oh; oh_idx += oh_step) {
for (size_t ow_idx = 0; ow_idx < ow_end; ow_idx += ow_step) {
const size_t src_offset =
(oh_idx * stride_h * iw + ow_idx * stride_w * ih_step) *
ic_step * pack_iw_len;
const size_t dst_offset =
oc_idx * img_stride + (oh_idx * ow + ow_idx) * oc_step;
KerNeonXXs2NchwNchw44<bias_mode, Op, 0, filter_size,
big_oc_step>::impl(src + src_offset,
filter + weight_offset,
bias + oc_idx,
dst + dst_offset, ic,
ih, iw, ld_dst_oc, op);
}
if (ow_remain > 0) {
const size_t src_offset =
(oh_idx * stride_h * iw + ow_end * stride_w * ih_step) *
ic_step * pack_iw_len;
const size_t dst_offset =
oc_idx * img_stride + (oh_idx * ow + ow_end) * oc_step;
kern_big_oc_remain(src + src_offset, filter + weight_offset,
bias + oc_idx, dst + dst_offset, ic, ih, iw,
ld_dst_oc, op);
}
}
}
if (oc_remain > 0) {
size_t oc_idx = oc_end;
const size_t weight_offset = oc_idx * ic * fh * fw;
for (size_t oh_idx = 0; oh_idx < oh; oh_idx += oh_step) {
for (size_t ow_idx = 0; ow_idx < ow_end; ow_idx += ow_step) {
const size_t src_offset =
(oh_idx * stride_h * iw + ow_idx * stride_w * ih_step) *
ic_step * pack_iw_len;
const size_t dst_offset =
oc_idx * img_stride + (oh_idx * ow + ow_idx) * oc_step;
KerNeonXXs2NchwNchw44<bias_mode, Op, 0, filter_size,
oc_step>::impl(src + src_offset,
filter + weight_offset,
bias + oc_idx,
dst + dst_offset, ic, ih,
iw, ld_dst_oc, op);
}
if (ow_remain > 0) {
const size_t src_offset =
(oh_idx * stride_h * iw + ow_end * stride_w * ih_step) *
ic_step * pack_iw_len;
const size_t dst_offset =
oc_idx * img_stride + (oh_idx * ow + ow_end) * oc_step;
kern_small_oc_remain(src + src_offset, filter + weight_offset,
bias + oc_idx, dst + dst_offset, ic, ih,
iw, ld_dst_oc, op);
}
}
}
}
#define CONSTRUCT_FUNC(filter_size) \
template <BiasMode bias_mode, typename Op> \
void conv_bias:: \
conv_direct_stride2_##filter_size##x##filter_size##_int8_nchw_nchw44( \
const int8_t* src, const int8_t* filter, \
const int32_t* bias, int32_t* temp, int8_t* dst, \
const size_t oc, const size_t ic, const size_t ih, \
const size_t iw, const size_t oh, const size_t ow, \
const Op& op) { \
conv_direct_stride2_int8_nchw_nchw44<bias_mode, Op, filter_size>( \
src, filter, bias, temp, dst, oc, ic, ih, iw, oh, ow, op); \
}

CONSTRUCT_FUNC(3);
CONSTRUCT_FUNC(5);
CONSTRUCT_FUNC(7);
#undef CONSTRUCT_FUNC

template <BiasMode bias_mode, typename Op>
void conv_bias::conv_direct_stride2_2x2_int8_nchw_nchw44(
const int8_t* src, const int8_t* filter, const int32_t* bias,
int32_t* temp, int8_t* dst, const size_t oc, const size_t ic,
const size_t ih, const size_t iw, const size_t oh, const size_t ow,
const Op& op) {
MEGDNN_MARK_USED_VAR(src);
MEGDNN_MARK_USED_VAR(filter);
MEGDNN_MARK_USED_VAR(bias);
MEGDNN_MARK_USED_VAR(temp);
MEGDNN_MARK_USED_VAR(dst);
MEGDNN_MARK_USED_VAR(oc);
MEGDNN_MARK_USED_VAR(ic);
MEGDNN_MARK_USED_VAR(ih);
MEGDNN_MARK_USED_VAR(iw);
MEGDNN_MARK_USED_VAR(oh);
MEGDNN_MARK_USED_VAR(ow);
MEGDNN_MARK_USED_VAR(op);
megdnn_assert(0, "not imple nchw_nchw44 2x2s2 conv");
}

#define INSTANTIATION(stride, i, bias, Op) \
template void conv_bias:: \
conv_direct_##stride##_##i##x##i##_int8_nchw_nchw44<bias, Op>( \
const int8_t*, const int8_t*, const int32_t*, int32_t*, \
int8_t*, const size_t, const size_t, const size_t, \
const size_t, const size_t, const size_t, const Op&);

#define FOR_OP(stride, i, bias) \
INSTANTIATION(stride, i, bias, TypeCvtOp<dt_qint32 MEGDNN_COMMA dt_qint8>) \
INSTANTIATION(stride, i, bias, ReluOp<dt_qint32 MEGDNN_COMMA dt_qint8>) \
INSTANTIATION(stride, i, bias, HSwishOp<dt_qint32 MEGDNN_COMMA dt_qint8>)

#define FOR_BIAS(stride, i) \
FOR_OP(stride, i, BiasMode::NO_BIAS) \
FOR_OP(stride, i, BiasMode::BROADCAST_CHANNEL_BIAS)

#define FOR_FILTER(stride) \
FOR_BIAS(stride, 2) \
FOR_BIAS(stride, 3) \
FOR_BIAS(stride, 5) \
FOR_BIAS(stride, 7)

FOR_FILTER(stride2)

#undef FOR_STRIDE
#undef FOR_FILTER
#undef FOR_IC
#undef FOR_BIAS
#undef FOR_NONLINEAR
#undef FOR_REMAIN
#undef INSTANTIATION

+ 0
- 44
dnn/src/arm_common/conv_bias/int8/direct_stride2_nchw_nchw44_kern.h View File

@@ -1,44 +0,0 @@
/**
* \file dnn/src/arm_common/conv_bias/int8/direct_stride2_nchw_nchw44_kern.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/

#include "src/arm_common/conv_bias/opr_impl.h"
#include "src/fallback/conv_bias/common.h"

namespace megdnn {
namespace arm_common {
namespace conv_bias {
#define KERN(stride, i, layout) \
template <BiasMode bias_mode, typename Op> \
void conv_direct_##stride##_##i##x##i##_int8_nchw_##layout( \
const int8_t* src, const int8_t* filter, const int32_t* bias, \
int32_t* temp, int8_t* dst, const size_t OC, const size_t IC, \
const size_t IH, const size_t IW, const size_t OH, \
const size_t OW, const Op& op);

KERN(stride2, 2, nchw44)
KERN(stride2, 3, nchw44)
KERN(stride2, 5, nchw44)
KERN(stride2, 7, nchw44)
#undef KERN

void pack_nchw44_weight_for_nchw_conv(const int8_t* inptr, int8_t* outptr,
const int ic, const int fh, const int fw,
const int oc);

void pack_nchw_src_for_nchw44_conv(const int8_t* inptr, int8_t* outptr,
const int ic, const int top_pad,
const int bottom_pad, const int left_pad,
const int right_pad, const int ih,
const int iw);
} // namespace conv_bias
} // namespace arm_common
} // namespace megdnn

+ 2
- 2
dnn/src/arm_common/conv_bias/opr_impl.cpp View File

@@ -47,7 +47,7 @@ class ConvBiasImpl::AlgoPack : NonCopyableObj {
AlgoS8DirectStride2 s8_direct_stride2_large_group{true};
AlgoS8DirectStride2 s8_direct_stride2_small_group{false};
AlgoS8DirectStride2NCHW44 s8_direct_stride2_nchw44;
AlgoS8DirectStride2NCHWNCHW44 s8_direct_stride2_nchw_nchw44;
AlgoS8DirectNCHWNCHW44 s8_direct_nchw_nchw44;
AlgoS8DirectStride1 s8_direct_stride1_large_group{true};
AlgoS8DirectStride1 s8_direct_stride1_small_group{false};
AlgoS8DirectStride1NCHW44 s8_direct_stride1_nchw44;
@@ -115,7 +115,7 @@ public:
direct_algos.emplace_back(&s8_direct_stride2_large_group);
direct_algos.emplace_back(&s8_direct_stride2_small_group);
direct_algos.emplace_back(&s8_direct_stride2_nchw44);
direct_algos.emplace_back(&s8_direct_stride2_nchw_nchw44);
direct_algos.emplace_back(&s8_direct_nchw_nchw44);
direct_algos.emplace_back(&s8_direct_stride1_large_group);
direct_algos.emplace_back(&s8_direct_stride1_small_group);
direct_algos.emplace_back(&s8_direct_stride1_nchw44);


+ 1
- 1
dnn/src/arm_common/conv_bias/opr_impl.h View File

@@ -40,7 +40,7 @@ private:
class AlgoS8DirectStride1NCHW44;
class AlgoS8DirectStride2;
class AlgoS8DirectStride2NCHW44;
class AlgoS8DirectStride2NCHWNCHW44;
class AlgoS8DirectNCHWNCHW44;
class AlgoQU8DirectStride1;
class AlgoQU8DirectStride2;
class AlgoFP32WinogradF23_4x4;


+ 8
- 0
dnn/test/arm_common/conv_bias.cpp View File

@@ -244,18 +244,26 @@ TEST_F(ARM_COMMON, BENCHMARK_CONVBIAS_NCHW44) {
#if MEGDNN_AARCH64
benchmark_convbias(handle(), "IM2COLMATMUL:AARCH64_INT8X8X32_K4X4X16:384",
"IM2COLMATMUL:AARCH64_F32K8X12X1:192", true);
benchmark_convbias(handle(), "IM2COLMATMUL:AARCH64_INT8X8X32_K4X4X16:384",
"IM2COLMATMUL:AARCH64_F32K8X12X1:192", false);
#else
benchmark_convbias(handle(), "IM2COLMATMUL:ARMV7_INT8X8X32_K4X8X8:384",
"IM2COLMATMUL:ARMV7_F32:192", true);
benchmark_convbias(handle(), "IM2COLMATMUL:ARMV7_INT8X8X32_K4X8X8:384",
"IM2COLMATMUL:ARMV7_F32:192", false);
#endif
}
TEST_F(ARM_COMMON_MULTI_THREADS, BENCHMARK_CONVBIAS_NCHW44) {
#if MEGDNN_AARCH64
benchmark_convbias(handle(), "IM2COLMATMUL:AARCH64_INT8X8X32_K4X4X16:384",
"IM2COLMATMUL:AARCH64_F32K8X12X1:192", true);
benchmark_convbias(handle(), "IM2COLMATMUL:AARCH64_INT8X8X32_K4X4X16:384",
"IM2COLMATMUL:AARCH64_F32K8X12X1:192", false);
#else
benchmark_convbias(handle(), "IM2COLMATMUL:AARCH64_INT8X8X32_K4X4X16:384",
"IM2COLMATMUL:ARMV7_F32:192", true);
benchmark_convbias(handle(), "IM2COLMATMUL:AARCH64_INT8X8X32_K4X4X16:384",
"IM2COLMATMUL:ARMV7_F32:192", false);
#endif
}



+ 6
- 1
dnn/test/arm_common/conv_bias_multi_thread.cpp View File

@@ -541,7 +541,12 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_QS8_CHANNEL_WISE_DIRECT2_NCHW44) {

TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_INT8_NCHW_NCHW44) {
checker_conv_bias_qint8x8x8(
get_nchw44_conv_bias_args({3, 5, 7}, 2, false, false, false, true),
get_nchw44_conv_bias_args({2, 3, 5, 7}, 2, false, false, false,
true),
handle(), "S8_CONV_NCHW_NCHW44");
checker_conv_bias_qint8x8x8(
get_nchw44_conv_bias_args({2, 3, 5, 7}, 1, false, false, false,
true),
handle(), "S8_CONV_NCHW_NCHW44");
}



Loading…
Cancel
Save