Browse Source

feat(dnn/arm): nchw_nchw44 conv support 1x1s1

GitOrigin-RevId: 8c8f7d7c76
release-1.5
Megvii Engine Team huangxinda 4 years ago
parent
commit
3597a6dbd7
10 changed files with 146 additions and 7 deletions
  1. +47
    -0
      dnn/src/arm_common/conv_bias/int8/direct_kernels/dot_direct_nchw_nchw44_s1.cpp
  2. +12
    -0
      dnn/src/arm_common/conv_bias/int8/direct_kernels/dot_direct_nchw_nchw44_s2.cpp
  3. +42
    -0
      dnn/src/arm_common/conv_bias/int8/direct_kernels/int8_direct_nchw_nchw44_s1.cpp
  4. +10
    -0
      dnn/src/arm_common/conv_bias/int8/direct_kernels/int8_direct_nchw_nchw44_s2.cpp
  5. +3
    -0
      dnn/src/arm_common/conv_bias/int8/direct_nchw_nchw44_algo.cpp
  6. +3
    -0
      dnn/src/arm_common/conv_bias/int8/dot_direct_nchw_nchw44_algo.cpp
  7. +10
    -6
      dnn/src/common/nchw_nchwxx_valid.h
  8. +16
    -0
      dnn/test/arm_common/conv_bias_multi_thread.cpp
  9. +1
    -0
      dnn/test/arm_common/conv_bias_multi_thread_benchmark.cpp
  10. +2
    -1
      src/plugin/impl/opr_footprint.cpp

+ 47
- 0
dnn/src/arm_common/conv_bias/int8/direct_kernels/dot_direct_nchw_nchw44_s1.cpp View File

@@ -47,6 +47,52 @@ struct ShiftCalHelper<src_idx, weight_idx, 1, Func, 8, 1, T, T2, T3, T4> {
}
};
////////////////////stride 1///////////////////

template <BiasMode bias_mode, typename Op, int remain_w, int oc_block,
int ow_block>
struct KerNeonDotXXs2Nchw44Int8<bias_mode, Op, remain_w, 1, oc_block, ow_block,
1> {
MEGDNN_ATTRIBUTE_TARGET("dotprod")
static void impl(const int8_t* src_ptr, const int8_t* weight_ptr,
const int32_t* bias_ptr, int8_t* dst_ptr, int ic, int ih,
int iw, int ld_dst_oc, const Op& op) {
constexpr int stride = 1;
constexpr int filter_hight = 1;
constexpr int filter_width = 4;
constexpr int weight_reg = 2;
constexpr int src_reg = 2;

constexpr int oc_step = 4;
constexpr int ic_step = 1;
constexpr int pack_iw_len = 4;
constexpr int simd_len = 16;

const int ld_bias = oc_step;
const int ic_stride = ih * iw * pack_iw_len;
const int ld_weight_oc = oc_step * filter_hight * filter_width * ic;
constexpr int c_dim = OCHelper<oc_block>::val;

int32x4_t c[c_dim][8];
init_ocx_ow8<c_dim, bias_mode, remain_w>(c, bias_ptr, ld_bias);
for (int ic_idx = 0; ic_idx < ic; ic_idx += ic_step) {
int8x16_t src[src_reg];
int8x16_t weight[c_dim][weight_reg];
// row 0
load_helper<src_reg, 0, simd_len, 0, Vld1q_s8>(
src, src_ptr + 0 * iw * pack_iw_len, 0);
load_helper<weight_reg, 0, simd_len, c_dim, Vld1q_s8>(
weight, weight_ptr, ld_weight_oc);
cal_helper<0, 0, c_dim, Vdotq_laneq_s32, ow_block, stride>(c, src,
weight);

src_ptr += ic_stride;
weight_ptr += filter_hight * filter_width * oc_step;
}
store_ocx_ow8_remain_static_dt<c_dim, remain_w, Op, dt_qint8*>(
c, op, dst_ptr, ld_dst_oc);
}
};

template <BiasMode bias_mode, typename Op, int remain_w, int oc_block,
int ow_block>
struct KerNeonDotXXs2Nchw44Int8<bias_mode, Op, remain_w, 2, oc_block, ow_block,
@@ -441,6 +487,7 @@ void conv_direct_int8_nchw_nchw44_dot(const int8_t* src, const int8_t* filter,
GET_OP_PARAM(stride, filter, BiasMode::BROADCAST_CHANNEL_BIAS)

#define DISPATCH_CONV_KERN(stride) \
GET_BIAS_MODE_PARAM(stride, 1) \
GET_BIAS_MODE_PARAM(stride, 2) \
GET_BIAS_MODE_PARAM(stride, 3) \
GET_BIAS_MODE_PARAM(stride, 5) \


+ 12
- 0
dnn/src/arm_common/conv_bias/int8/direct_kernels/dot_direct_nchw_nchw44_s2.cpp View File

@@ -60,6 +60,17 @@ struct ShiftCalHelper<src_idx, weight_idx, 1, Func, 8, 2, T, T2, T3, T4> {

template <BiasMode bias_mode, typename Op, int remain_w, int oc_block,
int ow_block>
struct KerNeonDotXXs2Nchw44Int8<bias_mode, Op, remain_w, 1, oc_block, ow_block,
2> {
MEGDNN_ATTRIBUTE_TARGET("dotprod")
static void impl(const int8_t*, const int8_t*, const int32_t*, int8_t*, int,
int, int, int, const Op&) {
megdnn_assert(0, "not impl");
}
};

template <BiasMode bias_mode, typename Op, int remain_w, int oc_block,
int ow_block>
struct KerNeonDotXXs2Nchw44Int8<bias_mode, Op, remain_w, 2, oc_block, ow_block,
2> {
MEGDNN_ATTRIBUTE_TARGET("dotprod")
@@ -429,6 +440,7 @@ void conv_direct_int8_nchw_nchw44_dot(const int8_t* src, const int8_t* filter,
GET_OP_PARAM(stride, filter, BiasMode::BROADCAST_CHANNEL_BIAS)

#define DISPATCH_CONV_KERN(stride) \
GET_BIAS_MODE_PARAM(stride, 1) \
GET_BIAS_MODE_PARAM(stride, 2) \
GET_BIAS_MODE_PARAM(stride, 3) \
GET_BIAS_MODE_PARAM(stride, 5) \


+ 42
- 0
dnn/src/arm_common/conv_bias/int8/direct_kernels/int8_direct_nchw_nchw44_s1.cpp View File

@@ -113,6 +113,47 @@ struct ShiftCalHelper<src_idx, weight_idx, 1, 1, T, T2, T3, T4> {
};

template <BiasMode bias_mode, typename Op, int remain_w, int oc_block>
struct KerNeonXXs2NchwNchw44<bias_mode, Op, remain_w, 1, oc_block, 1> {
static void impl(const int8_t* src_ptr, const int8_t* weight_ptr,
const int32_t* bias_ptr, int8_t* dst_ptr, int ic, int ih,
int iw, int ld_dst_oc, const Op& op) {
constexpr int stride = 1;
constexpr int filter_height = 1;
constexpr int filter_width = 4;
constexpr int oc_step = 4;
constexpr int loop_ic_step = 1;
constexpr int simd_len = 16;
constexpr int pack_iw_len = 16;
constexpr int src_reg = 8;
constexpr int weight_reg = 1;

const int ic_stride = ih * iw * pack_iw_len;
const int ld_weight_oc = oc_step * filter_height * filter_width * ic;
constexpr int c_dim = OCHelper<oc_block>::val;
int32x4_t c[c_dim][8];
init_ocx_ow8<c_dim, bias_mode, remain_w>(c, bias_ptr, oc_step);

for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) {
const int8_t* nchw_src_ptr = src_ptr + ic_idx * ic_stride;
int8x16_t src[src_reg];
int8x16_t dot4_weight[c_dim][weight_reg];
int16x8_t temp_c[4];
load_helper<weight_reg, 0, simd_len, c_dim, Vld1q_s8>(
dot4_weight, weight_ptr, ld_weight_oc);
load_helper<src_reg, 0, simd_len, 0, Vld1q_s8>(
src, nchw_src_ptr + 0 * iw * pack_iw_len, 0);
cal_helper<0, 0, c_dim, stride>(c, src, dot4_weight, temp_c);

weight_ptr += oc_step * filter_height * filter_width;
}

store_ocx_ow8_remain_static_dt<c_dim, remain_w, Op, dt_qint8*>(
c, op, dst_ptr, ld_dst_oc);
}
};


template <BiasMode bias_mode, typename Op, int remain_w, int oc_block>
struct KerNeonXXs2NchwNchw44<bias_mode, Op, remain_w, 2, oc_block, 1> {
static void impl(const int8_t* src_ptr, const int8_t* weight_ptr,
const int32_t* bias_ptr, int8_t* dst_ptr, int ic, int ih,
@@ -547,6 +588,7 @@ struct ConvDiectStrideInt8NchwNchw44<bias_mode, Op, filter_size, 1> {
INSTANCE_OP_PARAM(stride, filter, BiasMode::BROADCAST_CHANNEL_BIAS)

#define INSTANCE_CONV_KERN(stride) \
INSTANCE_BIAS_MODE_PARAM(stride, 1) \
INSTANCE_BIAS_MODE_PARAM(stride, 2) \
INSTANCE_BIAS_MODE_PARAM(stride, 3) \
INSTANCE_BIAS_MODE_PARAM(stride, 5) \


+ 10
- 0
dnn/src/arm_common/conv_bias/int8/direct_kernels/int8_direct_nchw_nchw44_s2.cpp View File

@@ -1033,6 +1033,15 @@ struct KerNeonXXs2NchwNchw44<bias_mode, Op, remain_w, 2, oc_block, stride> {
}
};

template <BiasMode bias_mode, typename Op, int remain_w, int oc_block,
int stride>
struct KerNeonXXs2NchwNchw44<bias_mode, Op, remain_w, 1, oc_block, stride> {
static void impl(const int8_t*, const int8_t*, const int32_t*, int8_t*, int,
int, int, int, const Op&) {
megdnn_assert(0, "not impl nchw_nchw44 1x1 s2");
}
};

enum PACK_MODE { NO_PAD = 0, FIRST_PAD = 1, LAST_PAD = 2 };
template <PACK_MODE mode>
MEGDNN_ALWAYS_INLINE void pack_src_one_line(const int8_t* inptr, int8_t* outptr,
@@ -1398,6 +1407,7 @@ struct ConvDiectStrideInt8NchwNchw44<bias_mode, Op, filter_size, 2> {
INSTANCE_OP_PARAM(stride, filter, BiasMode::BROADCAST_CHANNEL_BIAS)

#define INSTANCE_CONV_KERN(stride) \
INSTANCE_BIAS_MODE_PARAM(stride, 1) \
INSTANCE_BIAS_MODE_PARAM(stride, 2) \
INSTANCE_BIAS_MODE_PARAM(stride, 3) \
INSTANCE_BIAS_MODE_PARAM(stride, 5) \


+ 3
- 0
dnn/src/arm_common/conv_bias/int8/direct_nchw_nchw44_algo.cpp View File

@@ -291,6 +291,9 @@ ConvBiasImpl::AlgoS8DirectNCHWNCHW44::dispatch_kerns(

#define DISPATCH_CONV_KERN(stride) \
switch (param.filter_meta.spatial[0]) { \
case 1: \
GET_BIAS_MODE_PARAM(stride, 1) \
break; \
case 2: \
GET_BIAS_MODE_PARAM(stride, 2) \
break; \


+ 3
- 0
dnn/src/arm_common/conv_bias/int8/dot_direct_nchw_nchw44_algo.cpp View File

@@ -245,6 +245,9 @@ ConvBiasImpl::AlgoDotS8DirectNCHWNCHW44::dispatch_kerns(

#define DISPATCH_CONV_KERN(stride) \
switch (param.filter_meta.spatial[0]) { \
case 1: \
GET_BIAS_MODE_PARAM(stride, 1) \
break; \
case 2: \
GET_BIAS_MODE_PARAM(stride, 2) \
break; \


+ 10
- 6
dnn/src/common/nchw_nchwxx_valid.h View File

@@ -74,9 +74,11 @@ inline bool nchw_nchwxx_valid<NCHW44_INT8>(
nonline_mode == param::ConvBias::NonlineMode::H_SWISH;
bool ok_src_dst =
fm.icpg < 4 && (fm.ocpg % 4 == 0 && fm.ocpg >= 4) && fm.group == 1;
bool ok_filter = fm.spatial_ndim == 2 && fm.spatial[0] == fm.spatial[1] &&
(fm.spatial[0] == 2 || fm.spatial[0] == 3 ||
fm.spatial[0] == 5 || fm.spatial[0] == 7);
bool ok_filter =
fm.spatial_ndim == 2 && fm.spatial[0] == fm.spatial[1] &&
(fm.spatial[0] == 2 || fm.spatial[0] == 3 || fm.spatial[0] == 5 ||
fm.spatial[0] == 7 ||
(fm.spatial[0] == 1 && fm.stride[0] == 1 && fm.stride[1] == 1));
bool ok_slide = fm.dilation[0] == 1 && fm.dilation[1] == 1 &&
fm.stride[0] == fm.stride[1] &&
(fm.stride[0] == 1 || fm.stride[1] == 2);
@@ -126,9 +128,11 @@ inline bool nchw_nchwxx_valid<NCHW44_INT8_DOT>(
nonline_mode == param::ConvBias::NonlineMode::H_SWISH;
bool ok_src_dst =
fm.icpg < 4 && (fm.ocpg % 4 == 0 && fm.ocpg >= 4) && fm.group == 1;
bool ok_filter = fm.spatial_ndim == 2 && fm.spatial[0] == fm.spatial[1] &&
(fm.spatial[0] == 2 || fm.spatial[0] == 3 ||
fm.spatial[0] == 5 || fm.spatial[0] == 7);
bool ok_filter =
fm.spatial_ndim == 2 && fm.spatial[0] == fm.spatial[1] &&
(fm.spatial[0] == 2 || fm.spatial[0] == 3 || fm.spatial[0] == 5 ||
fm.spatial[0] == 7 ||
(fm.spatial[0] == 1 && fm.stride[0] == 1 && fm.stride[1] == 1));
bool ok_slide = fm.dilation[0] == 1 && fm.dilation[1] == 1 &&
fm.stride[0] == fm.stride[1] &&
(fm.stride[0] == 1 || fm.stride[1] == 2);


+ 16
- 0
dnn/test/arm_common/conv_bias_multi_thread.cpp View File

@@ -487,6 +487,13 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_INT8_NCHW_NCHW44_S2) {
handle(), "S8_CONV_NCHW_NCHW44");
}

TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_INT8_NCHW_NCHW44_S1_F1) {
checker_conv_bias_qint8x8x8(
get_nchw44_conv_bias_args({1}, QUAN_NLMODE, BR_AND_NO_BIASMODE, 1,
false, true),
handle(), "S8_CONV_NCHW_NCHW44");
}

/*****************************quint8 direct****************************/
TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_QUINT8_STRIDE1) {
checker_conv_bias_quint8x8x8(get_int8_quint8_conv_bias_args(
@@ -517,6 +524,15 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_DOT_NCHW_NCHW44) {
checker_conv_bias_qint8x8x8(args, handle(), "ARMDOTS8_NCHW_NCHW44");
}

TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_DOT_NCHW_NCHW44_S1_F1) {
auto args = get_nchw44_conv_bias_args({1}, QUAN_NLMODE, BR_AND_NO_BIASMODE,
1, false, true);
for (auto&& arg : args) {
arg.param.format = param::ConvBias::Format::NCHW44_DOT;
}
checker_conv_bias_qint8x8x8(args, handle(), "ARMDOTS8_NCHW_NCHW44");
}

TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_INT8_STRIDE1_WITHDOTPROD) {
checker_conv_bias_qint8x8x8(get_int8_quint8_conv_bias_args(
{2, 3, 5, 7}, 1, false, false, false),


+ 1
- 0
dnn/test/arm_common/conv_bias_multi_thread_benchmark.cpp View File

@@ -635,6 +635,7 @@ TEST_F(ARM_COMMON_BENCHMARK_MULTI_THREADS, BENCHMARK_CONVBIAS_INT8_NCHW44) {
benchmark_impl(param, shape_arg, ".+", RUNS, {4, {4, 5, 6, 7}},
{1, {7}}, data_type);
};
bench_case(1, 2, 64, 160, 160, 1, 1, 0, 1, true);
bench_case(1, 3, 64, 224, 224, 7, 1, 3, 2, true);
bench_case(1, 64, 64, 56, 56, 3, 1, 1, 1);
bench_case(1, 128, 128, 28, 28, 3, 1, 1, 1);


+ 2
- 1
src/plugin/impl/opr_footprint.cpp View File

@@ -131,7 +131,8 @@ uint64_t eval_conv_computation(const TensorShape& src_shape,
if (param.format == Param::Format::NCHW44 ||
param.format == Param::Format::NCHW44_DOT) {
//! if channel wise weight layout is {group/4, FH, FW, 1, 1, 4}
if (filter_shape[1] == 1 && filter_shape[2] == 1) {
if (filter_shape[1] == 1 && filter_shape[2] == 1 &&
filter_shape.ndim == 6) {
group *= 4;
}
size_t computation = dst_shape.total_nr_elems() * fh * fw *


Loading…
Cancel
Save