@@ -47,6 +47,52 @@ struct ShiftCalHelper<src_idx, weight_idx, 1, Func, 8, 1, T, T2, T3, T4> { | |||
} | |||
}; | |||
////////////////////stride 1/////////////////// | |||
template <BiasMode bias_mode, typename Op, int remain_w, int oc_block, | |||
int ow_block> | |||
struct KerNeonDotXXs2Nchw44Int8<bias_mode, Op, remain_w, 1, oc_block, ow_block, | |||
1> { | |||
MEGDNN_ATTRIBUTE_TARGET("dotprod") | |||
static void impl(const int8_t* src_ptr, const int8_t* weight_ptr, | |||
const int32_t* bias_ptr, int8_t* dst_ptr, int ic, int ih, | |||
int iw, int ld_dst_oc, const Op& op) { | |||
constexpr int stride = 1; | |||
constexpr int filter_hight = 1; | |||
constexpr int filter_width = 4; | |||
constexpr int weight_reg = 2; | |||
constexpr int src_reg = 2; | |||
constexpr int oc_step = 4; | |||
constexpr int ic_step = 1; | |||
constexpr int pack_iw_len = 4; | |||
constexpr int simd_len = 16; | |||
const int ld_bias = oc_step; | |||
const int ic_stride = ih * iw * pack_iw_len; | |||
const int ld_weight_oc = oc_step * filter_hight * filter_width * ic; | |||
constexpr int c_dim = OCHelper<oc_block>::val; | |||
int32x4_t c[c_dim][8]; | |||
init_ocx_ow8<c_dim, bias_mode, remain_w>(c, bias_ptr, ld_bias); | |||
for (int ic_idx = 0; ic_idx < ic; ic_idx += ic_step) { | |||
int8x16_t src[src_reg]; | |||
int8x16_t weight[c_dim][weight_reg]; | |||
// row 0 | |||
load_helper<src_reg, 0, simd_len, 0, Vld1q_s8>( | |||
src, src_ptr + 0 * iw * pack_iw_len, 0); | |||
load_helper<weight_reg, 0, simd_len, c_dim, Vld1q_s8>( | |||
weight, weight_ptr, ld_weight_oc); | |||
cal_helper<0, 0, c_dim, Vdotq_laneq_s32, ow_block, stride>(c, src, | |||
weight); | |||
src_ptr += ic_stride; | |||
weight_ptr += filter_hight * filter_width * oc_step; | |||
} | |||
store_ocx_ow8_remain_static_dt<c_dim, remain_w, Op, dt_qint8*>( | |||
c, op, dst_ptr, ld_dst_oc); | |||
} | |||
}; | |||
template <BiasMode bias_mode, typename Op, int remain_w, int oc_block, | |||
int ow_block> | |||
struct KerNeonDotXXs2Nchw44Int8<bias_mode, Op, remain_w, 2, oc_block, ow_block, | |||
@@ -441,6 +487,7 @@ void conv_direct_int8_nchw_nchw44_dot(const int8_t* src, const int8_t* filter, | |||
GET_OP_PARAM(stride, filter, BiasMode::BROADCAST_CHANNEL_BIAS) | |||
#define DISPATCH_CONV_KERN(stride) \ | |||
GET_BIAS_MODE_PARAM(stride, 1) \ | |||
GET_BIAS_MODE_PARAM(stride, 2) \ | |||
GET_BIAS_MODE_PARAM(stride, 3) \ | |||
GET_BIAS_MODE_PARAM(stride, 5) \ | |||
@@ -60,6 +60,17 @@ struct ShiftCalHelper<src_idx, weight_idx, 1, Func, 8, 2, T, T2, T3, T4> { | |||
template <BiasMode bias_mode, typename Op, int remain_w, int oc_block, | |||
int ow_block> | |||
struct KerNeonDotXXs2Nchw44Int8<bias_mode, Op, remain_w, 1, oc_block, ow_block, | |||
2> { | |||
MEGDNN_ATTRIBUTE_TARGET("dotprod") | |||
static void impl(const int8_t*, const int8_t*, const int32_t*, int8_t*, int, | |||
int, int, int, const Op&) { | |||
megdnn_assert(0, "not impl"); | |||
} | |||
}; | |||
template <BiasMode bias_mode, typename Op, int remain_w, int oc_block, | |||
int ow_block> | |||
struct KerNeonDotXXs2Nchw44Int8<bias_mode, Op, remain_w, 2, oc_block, ow_block, | |||
2> { | |||
MEGDNN_ATTRIBUTE_TARGET("dotprod") | |||
@@ -429,6 +440,7 @@ void conv_direct_int8_nchw_nchw44_dot(const int8_t* src, const int8_t* filter, | |||
GET_OP_PARAM(stride, filter, BiasMode::BROADCAST_CHANNEL_BIAS) | |||
#define DISPATCH_CONV_KERN(stride) \ | |||
GET_BIAS_MODE_PARAM(stride, 1) \ | |||
GET_BIAS_MODE_PARAM(stride, 2) \ | |||
GET_BIAS_MODE_PARAM(stride, 3) \ | |||
GET_BIAS_MODE_PARAM(stride, 5) \ | |||
@@ -113,6 +113,47 @@ struct ShiftCalHelper<src_idx, weight_idx, 1, 1, T, T2, T3, T4> { | |||
}; | |||
template <BiasMode bias_mode, typename Op, int remain_w, int oc_block> | |||
struct KerNeonXXs2NchwNchw44<bias_mode, Op, remain_w, 1, oc_block, 1> { | |||
static void impl(const int8_t* src_ptr, const int8_t* weight_ptr, | |||
const int32_t* bias_ptr, int8_t* dst_ptr, int ic, int ih, | |||
int iw, int ld_dst_oc, const Op& op) { | |||
constexpr int stride = 1; | |||
constexpr int filter_height = 1; | |||
constexpr int filter_width = 4; | |||
constexpr int oc_step = 4; | |||
constexpr int loop_ic_step = 1; | |||
constexpr int simd_len = 16; | |||
constexpr int pack_iw_len = 16; | |||
constexpr int src_reg = 8; | |||
constexpr int weight_reg = 1; | |||
const int ic_stride = ih * iw * pack_iw_len; | |||
const int ld_weight_oc = oc_step * filter_height * filter_width * ic; | |||
constexpr int c_dim = OCHelper<oc_block>::val; | |||
int32x4_t c[c_dim][8]; | |||
init_ocx_ow8<c_dim, bias_mode, remain_w>(c, bias_ptr, oc_step); | |||
for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) { | |||
const int8_t* nchw_src_ptr = src_ptr + ic_idx * ic_stride; | |||
int8x16_t src[src_reg]; | |||
int8x16_t dot4_weight[c_dim][weight_reg]; | |||
int16x8_t temp_c[4]; | |||
load_helper<weight_reg, 0, simd_len, c_dim, Vld1q_s8>( | |||
dot4_weight, weight_ptr, ld_weight_oc); | |||
load_helper<src_reg, 0, simd_len, 0, Vld1q_s8>( | |||
src, nchw_src_ptr + 0 * iw * pack_iw_len, 0); | |||
cal_helper<0, 0, c_dim, stride>(c, src, dot4_weight, temp_c); | |||
weight_ptr += oc_step * filter_height * filter_width; | |||
} | |||
store_ocx_ow8_remain_static_dt<c_dim, remain_w, Op, dt_qint8*>( | |||
c, op, dst_ptr, ld_dst_oc); | |||
} | |||
}; | |||
template <BiasMode bias_mode, typename Op, int remain_w, int oc_block> | |||
struct KerNeonXXs2NchwNchw44<bias_mode, Op, remain_w, 2, oc_block, 1> { | |||
static void impl(const int8_t* src_ptr, const int8_t* weight_ptr, | |||
const int32_t* bias_ptr, int8_t* dst_ptr, int ic, int ih, | |||
@@ -547,6 +588,7 @@ struct ConvDiectStrideInt8NchwNchw44<bias_mode, Op, filter_size, 1> { | |||
INSTANCE_OP_PARAM(stride, filter, BiasMode::BROADCAST_CHANNEL_BIAS) | |||
#define INSTANCE_CONV_KERN(stride) \ | |||
INSTANCE_BIAS_MODE_PARAM(stride, 1) \ | |||
INSTANCE_BIAS_MODE_PARAM(stride, 2) \ | |||
INSTANCE_BIAS_MODE_PARAM(stride, 3) \ | |||
INSTANCE_BIAS_MODE_PARAM(stride, 5) \ | |||
@@ -1033,6 +1033,15 @@ struct KerNeonXXs2NchwNchw44<bias_mode, Op, remain_w, 2, oc_block, stride> { | |||
} | |||
}; | |||
template <BiasMode bias_mode, typename Op, int remain_w, int oc_block, | |||
int stride> | |||
struct KerNeonXXs2NchwNchw44<bias_mode, Op, remain_w, 1, oc_block, stride> { | |||
static void impl(const int8_t*, const int8_t*, const int32_t*, int8_t*, int, | |||
int, int, int, const Op&) { | |||
megdnn_assert(0, "not impl nchw_nchw44 1x1 s2"); | |||
} | |||
}; | |||
enum PACK_MODE { NO_PAD = 0, FIRST_PAD = 1, LAST_PAD = 2 }; | |||
template <PACK_MODE mode> | |||
MEGDNN_ALWAYS_INLINE void pack_src_one_line(const int8_t* inptr, int8_t* outptr, | |||
@@ -1398,6 +1407,7 @@ struct ConvDiectStrideInt8NchwNchw44<bias_mode, Op, filter_size, 2> { | |||
INSTANCE_OP_PARAM(stride, filter, BiasMode::BROADCAST_CHANNEL_BIAS) | |||
#define INSTANCE_CONV_KERN(stride) \ | |||
INSTANCE_BIAS_MODE_PARAM(stride, 1) \ | |||
INSTANCE_BIAS_MODE_PARAM(stride, 2) \ | |||
INSTANCE_BIAS_MODE_PARAM(stride, 3) \ | |||
INSTANCE_BIAS_MODE_PARAM(stride, 5) \ | |||
@@ -291,6 +291,9 @@ ConvBiasImpl::AlgoS8DirectNCHWNCHW44::dispatch_kerns( | |||
#define DISPATCH_CONV_KERN(stride) \ | |||
switch (param.filter_meta.spatial[0]) { \ | |||
case 1: \ | |||
GET_BIAS_MODE_PARAM(stride, 1) \ | |||
break; \ | |||
case 2: \ | |||
GET_BIAS_MODE_PARAM(stride, 2) \ | |||
break; \ | |||
@@ -245,6 +245,9 @@ ConvBiasImpl::AlgoDotS8DirectNCHWNCHW44::dispatch_kerns( | |||
#define DISPATCH_CONV_KERN(stride) \ | |||
switch (param.filter_meta.spatial[0]) { \ | |||
case 1: \ | |||
GET_BIAS_MODE_PARAM(stride, 1) \ | |||
break; \ | |||
case 2: \ | |||
GET_BIAS_MODE_PARAM(stride, 2) \ | |||
break; \ | |||
@@ -74,9 +74,11 @@ inline bool nchw_nchwxx_valid<NCHW44_INT8>( | |||
nonline_mode == param::ConvBias::NonlineMode::H_SWISH; | |||
bool ok_src_dst = | |||
fm.icpg < 4 && (fm.ocpg % 4 == 0 && fm.ocpg >= 4) && fm.group == 1; | |||
bool ok_filter = fm.spatial_ndim == 2 && fm.spatial[0] == fm.spatial[1] && | |||
(fm.spatial[0] == 2 || fm.spatial[0] == 3 || | |||
fm.spatial[0] == 5 || fm.spatial[0] == 7); | |||
bool ok_filter = | |||
fm.spatial_ndim == 2 && fm.spatial[0] == fm.spatial[1] && | |||
(fm.spatial[0] == 2 || fm.spatial[0] == 3 || fm.spatial[0] == 5 || | |||
fm.spatial[0] == 7 || | |||
(fm.spatial[0] == 1 && fm.stride[0] == 1 && fm.stride[1] == 1)); | |||
bool ok_slide = fm.dilation[0] == 1 && fm.dilation[1] == 1 && | |||
fm.stride[0] == fm.stride[1] && | |||
(fm.stride[0] == 1 || fm.stride[1] == 2); | |||
@@ -126,9 +128,11 @@ inline bool nchw_nchwxx_valid<NCHW44_INT8_DOT>( | |||
nonline_mode == param::ConvBias::NonlineMode::H_SWISH; | |||
bool ok_src_dst = | |||
fm.icpg < 4 && (fm.ocpg % 4 == 0 && fm.ocpg >= 4) && fm.group == 1; | |||
bool ok_filter = fm.spatial_ndim == 2 && fm.spatial[0] == fm.spatial[1] && | |||
(fm.spatial[0] == 2 || fm.spatial[0] == 3 || | |||
fm.spatial[0] == 5 || fm.spatial[0] == 7); | |||
bool ok_filter = | |||
fm.spatial_ndim == 2 && fm.spatial[0] == fm.spatial[1] && | |||
(fm.spatial[0] == 2 || fm.spatial[0] == 3 || fm.spatial[0] == 5 || | |||
fm.spatial[0] == 7 || | |||
(fm.spatial[0] == 1 && fm.stride[0] == 1 && fm.stride[1] == 1)); | |||
bool ok_slide = fm.dilation[0] == 1 && fm.dilation[1] == 1 && | |||
fm.stride[0] == fm.stride[1] && | |||
(fm.stride[0] == 1 || fm.stride[1] == 2); | |||
@@ -487,6 +487,13 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_INT8_NCHW_NCHW44_S2) { | |||
handle(), "S8_CONV_NCHW_NCHW44"); | |||
} | |||
TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_INT8_NCHW_NCHW44_S1_F1) { | |||
checker_conv_bias_qint8x8x8( | |||
get_nchw44_conv_bias_args({1}, QUAN_NLMODE, BR_AND_NO_BIASMODE, 1, | |||
false, true), | |||
handle(), "S8_CONV_NCHW_NCHW44"); | |||
} | |||
/*****************************quint8 direct****************************/ | |||
TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_QUINT8_STRIDE1) { | |||
checker_conv_bias_quint8x8x8(get_int8_quint8_conv_bias_args( | |||
@@ -517,6 +524,15 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_DOT_NCHW_NCHW44) { | |||
checker_conv_bias_qint8x8x8(args, handle(), "ARMDOTS8_NCHW_NCHW44"); | |||
} | |||
TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_DOT_NCHW_NCHW44_S1_F1) { | |||
auto args = get_nchw44_conv_bias_args({1}, QUAN_NLMODE, BR_AND_NO_BIASMODE, | |||
1, false, true); | |||
for (auto&& arg : args) { | |||
arg.param.format = param::ConvBias::Format::NCHW44_DOT; | |||
} | |||
checker_conv_bias_qint8x8x8(args, handle(), "ARMDOTS8_NCHW_NCHW44"); | |||
} | |||
TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_INT8_STRIDE1_WITHDOTPROD) { | |||
checker_conv_bias_qint8x8x8(get_int8_quint8_conv_bias_args( | |||
{2, 3, 5, 7}, 1, false, false, false), | |||
@@ -635,6 +635,7 @@ TEST_F(ARM_COMMON_BENCHMARK_MULTI_THREADS, BENCHMARK_CONVBIAS_INT8_NCHW44) { | |||
benchmark_impl(param, shape_arg, ".+", RUNS, {4, {4, 5, 6, 7}}, | |||
{1, {7}}, data_type); | |||
}; | |||
bench_case(1, 2, 64, 160, 160, 1, 1, 0, 1, true); | |||
bench_case(1, 3, 64, 224, 224, 7, 1, 3, 2, true); | |||
bench_case(1, 64, 64, 56, 56, 3, 1, 1, 1); | |||
bench_case(1, 128, 128, 28, 28, 3, 1, 1, 1); | |||
@@ -131,7 +131,8 @@ uint64_t eval_conv_computation(const TensorShape& src_shape, | |||
if (param.format == Param::Format::NCHW44 || | |||
param.format == Param::Format::NCHW44_DOT) { | |||
//! if channel wise weight layout is {group/4, FH, FW, 1, 1, 4} | |||
if (filter_shape[1] == 1 && filter_shape[2] == 1) { | |||
if (filter_shape[1] == 1 && filter_shape[2] == 1 && | |||
filter_shape.ndim == 6) { | |||
group *= 4; | |||
} | |||
size_t computation = dst_shape.total_nr_elems() * fh * fw * | |||