@@ -47,6 +47,52 @@ struct ShiftCalHelper<src_idx, weight_idx, 1, Func, 8, 1, T, T2, T3, T4> { | |||||
} | } | ||||
}; | }; | ||||
////////////////////stride 1/////////////////// | ////////////////////stride 1/////////////////// | ||||
template <BiasMode bias_mode, typename Op, int remain_w, int oc_block, | |||||
int ow_block> | |||||
struct KerNeonDotXXs2Nchw44Int8<bias_mode, Op, remain_w, 1, oc_block, ow_block, | |||||
1> { | |||||
MEGDNN_ATTRIBUTE_TARGET("dotprod") | |||||
static void impl(const int8_t* src_ptr, const int8_t* weight_ptr, | |||||
const int32_t* bias_ptr, int8_t* dst_ptr, int ic, int ih, | |||||
int iw, int ld_dst_oc, const Op& op) { | |||||
constexpr int stride = 1; | |||||
constexpr int filter_hight = 1; | |||||
constexpr int filter_width = 4; | |||||
constexpr int weight_reg = 2; | |||||
constexpr int src_reg = 2; | |||||
constexpr int oc_step = 4; | |||||
constexpr int ic_step = 1; | |||||
constexpr int pack_iw_len = 4; | |||||
constexpr int simd_len = 16; | |||||
const int ld_bias = oc_step; | |||||
const int ic_stride = ih * iw * pack_iw_len; | |||||
const int ld_weight_oc = oc_step * filter_hight * filter_width * ic; | |||||
constexpr int c_dim = OCHelper<oc_block>::val; | |||||
int32x4_t c[c_dim][8]; | |||||
init_ocx_ow8<c_dim, bias_mode, remain_w>(c, bias_ptr, ld_bias); | |||||
for (int ic_idx = 0; ic_idx < ic; ic_idx += ic_step) { | |||||
int8x16_t src[src_reg]; | |||||
int8x16_t weight[c_dim][weight_reg]; | |||||
// row 0 | |||||
load_helper<src_reg, 0, simd_len, 0, Vld1q_s8>( | |||||
src, src_ptr + 0 * iw * pack_iw_len, 0); | |||||
load_helper<weight_reg, 0, simd_len, c_dim, Vld1q_s8>( | |||||
weight, weight_ptr, ld_weight_oc); | |||||
cal_helper<0, 0, c_dim, Vdotq_laneq_s32, ow_block, stride>(c, src, | |||||
weight); | |||||
src_ptr += ic_stride; | |||||
weight_ptr += filter_hight * filter_width * oc_step; | |||||
} | |||||
store_ocx_ow8_remain_static_dt<c_dim, remain_w, Op, dt_qint8*>( | |||||
c, op, dst_ptr, ld_dst_oc); | |||||
} | |||||
}; | |||||
template <BiasMode bias_mode, typename Op, int remain_w, int oc_block, | template <BiasMode bias_mode, typename Op, int remain_w, int oc_block, | ||||
int ow_block> | int ow_block> | ||||
struct KerNeonDotXXs2Nchw44Int8<bias_mode, Op, remain_w, 2, oc_block, ow_block, | struct KerNeonDotXXs2Nchw44Int8<bias_mode, Op, remain_w, 2, oc_block, ow_block, | ||||
@@ -441,6 +487,7 @@ void conv_direct_int8_nchw_nchw44_dot(const int8_t* src, const int8_t* filter, | |||||
GET_OP_PARAM(stride, filter, BiasMode::BROADCAST_CHANNEL_BIAS) | GET_OP_PARAM(stride, filter, BiasMode::BROADCAST_CHANNEL_BIAS) | ||||
#define DISPATCH_CONV_KERN(stride) \ | #define DISPATCH_CONV_KERN(stride) \ | ||||
GET_BIAS_MODE_PARAM(stride, 1) \ | |||||
GET_BIAS_MODE_PARAM(stride, 2) \ | GET_BIAS_MODE_PARAM(stride, 2) \ | ||||
GET_BIAS_MODE_PARAM(stride, 3) \ | GET_BIAS_MODE_PARAM(stride, 3) \ | ||||
GET_BIAS_MODE_PARAM(stride, 5) \ | GET_BIAS_MODE_PARAM(stride, 5) \ | ||||
@@ -60,6 +60,17 @@ struct ShiftCalHelper<src_idx, weight_idx, 1, Func, 8, 2, T, T2, T3, T4> { | |||||
template <BiasMode bias_mode, typename Op, int remain_w, int oc_block, | template <BiasMode bias_mode, typename Op, int remain_w, int oc_block, | ||||
int ow_block> | int ow_block> | ||||
struct KerNeonDotXXs2Nchw44Int8<bias_mode, Op, remain_w, 1, oc_block, ow_block, | |||||
2> { | |||||
MEGDNN_ATTRIBUTE_TARGET("dotprod") | |||||
static void impl(const int8_t*, const int8_t*, const int32_t*, int8_t*, int, | |||||
int, int, int, const Op&) { | |||||
megdnn_assert(0, "not impl"); | |||||
} | |||||
}; | |||||
template <BiasMode bias_mode, typename Op, int remain_w, int oc_block, | |||||
int ow_block> | |||||
struct KerNeonDotXXs2Nchw44Int8<bias_mode, Op, remain_w, 2, oc_block, ow_block, | struct KerNeonDotXXs2Nchw44Int8<bias_mode, Op, remain_w, 2, oc_block, ow_block, | ||||
2> { | 2> { | ||||
MEGDNN_ATTRIBUTE_TARGET("dotprod") | MEGDNN_ATTRIBUTE_TARGET("dotprod") | ||||
@@ -429,6 +440,7 @@ void conv_direct_int8_nchw_nchw44_dot(const int8_t* src, const int8_t* filter, | |||||
GET_OP_PARAM(stride, filter, BiasMode::BROADCAST_CHANNEL_BIAS) | GET_OP_PARAM(stride, filter, BiasMode::BROADCAST_CHANNEL_BIAS) | ||||
#define DISPATCH_CONV_KERN(stride) \ | #define DISPATCH_CONV_KERN(stride) \ | ||||
GET_BIAS_MODE_PARAM(stride, 1) \ | |||||
GET_BIAS_MODE_PARAM(stride, 2) \ | GET_BIAS_MODE_PARAM(stride, 2) \ | ||||
GET_BIAS_MODE_PARAM(stride, 3) \ | GET_BIAS_MODE_PARAM(stride, 3) \ | ||||
GET_BIAS_MODE_PARAM(stride, 5) \ | GET_BIAS_MODE_PARAM(stride, 5) \ | ||||
@@ -113,6 +113,47 @@ struct ShiftCalHelper<src_idx, weight_idx, 1, 1, T, T2, T3, T4> { | |||||
}; | }; | ||||
template <BiasMode bias_mode, typename Op, int remain_w, int oc_block> | template <BiasMode bias_mode, typename Op, int remain_w, int oc_block> | ||||
struct KerNeonXXs2NchwNchw44<bias_mode, Op, remain_w, 1, oc_block, 1> { | |||||
static void impl(const int8_t* src_ptr, const int8_t* weight_ptr, | |||||
const int32_t* bias_ptr, int8_t* dst_ptr, int ic, int ih, | |||||
int iw, int ld_dst_oc, const Op& op) { | |||||
constexpr int stride = 1; | |||||
constexpr int filter_height = 1; | |||||
constexpr int filter_width = 4; | |||||
constexpr int oc_step = 4; | |||||
constexpr int loop_ic_step = 1; | |||||
constexpr int simd_len = 16; | |||||
constexpr int pack_iw_len = 16; | |||||
constexpr int src_reg = 8; | |||||
constexpr int weight_reg = 1; | |||||
const int ic_stride = ih * iw * pack_iw_len; | |||||
const int ld_weight_oc = oc_step * filter_height * filter_width * ic; | |||||
constexpr int c_dim = OCHelper<oc_block>::val; | |||||
int32x4_t c[c_dim][8]; | |||||
init_ocx_ow8<c_dim, bias_mode, remain_w>(c, bias_ptr, oc_step); | |||||
for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) { | |||||
const int8_t* nchw_src_ptr = src_ptr + ic_idx * ic_stride; | |||||
int8x16_t src[src_reg]; | |||||
int8x16_t dot4_weight[c_dim][weight_reg]; | |||||
int16x8_t temp_c[4]; | |||||
load_helper<weight_reg, 0, simd_len, c_dim, Vld1q_s8>( | |||||
dot4_weight, weight_ptr, ld_weight_oc); | |||||
load_helper<src_reg, 0, simd_len, 0, Vld1q_s8>( | |||||
src, nchw_src_ptr + 0 * iw * pack_iw_len, 0); | |||||
cal_helper<0, 0, c_dim, stride>(c, src, dot4_weight, temp_c); | |||||
weight_ptr += oc_step * filter_height * filter_width; | |||||
} | |||||
store_ocx_ow8_remain_static_dt<c_dim, remain_w, Op, dt_qint8*>( | |||||
c, op, dst_ptr, ld_dst_oc); | |||||
} | |||||
}; | |||||
template <BiasMode bias_mode, typename Op, int remain_w, int oc_block> | |||||
struct KerNeonXXs2NchwNchw44<bias_mode, Op, remain_w, 2, oc_block, 1> { | struct KerNeonXXs2NchwNchw44<bias_mode, Op, remain_w, 2, oc_block, 1> { | ||||
static void impl(const int8_t* src_ptr, const int8_t* weight_ptr, | static void impl(const int8_t* src_ptr, const int8_t* weight_ptr, | ||||
const int32_t* bias_ptr, int8_t* dst_ptr, int ic, int ih, | const int32_t* bias_ptr, int8_t* dst_ptr, int ic, int ih, | ||||
@@ -547,6 +588,7 @@ struct ConvDiectStrideInt8NchwNchw44<bias_mode, Op, filter_size, 1> { | |||||
INSTANCE_OP_PARAM(stride, filter, BiasMode::BROADCAST_CHANNEL_BIAS) | INSTANCE_OP_PARAM(stride, filter, BiasMode::BROADCAST_CHANNEL_BIAS) | ||||
#define INSTANCE_CONV_KERN(stride) \ | #define INSTANCE_CONV_KERN(stride) \ | ||||
INSTANCE_BIAS_MODE_PARAM(stride, 1) \ | |||||
INSTANCE_BIAS_MODE_PARAM(stride, 2) \ | INSTANCE_BIAS_MODE_PARAM(stride, 2) \ | ||||
INSTANCE_BIAS_MODE_PARAM(stride, 3) \ | INSTANCE_BIAS_MODE_PARAM(stride, 3) \ | ||||
INSTANCE_BIAS_MODE_PARAM(stride, 5) \ | INSTANCE_BIAS_MODE_PARAM(stride, 5) \ | ||||
@@ -1033,6 +1033,15 @@ struct KerNeonXXs2NchwNchw44<bias_mode, Op, remain_w, 2, oc_block, stride> { | |||||
} | } | ||||
}; | }; | ||||
template <BiasMode bias_mode, typename Op, int remain_w, int oc_block, | |||||
int stride> | |||||
struct KerNeonXXs2NchwNchw44<bias_mode, Op, remain_w, 1, oc_block, stride> { | |||||
static void impl(const int8_t*, const int8_t*, const int32_t*, int8_t*, int, | |||||
int, int, int, const Op&) { | |||||
megdnn_assert(0, "not impl nchw_nchw44 1x1 s2"); | |||||
} | |||||
}; | |||||
enum PACK_MODE { NO_PAD = 0, FIRST_PAD = 1, LAST_PAD = 2 }; | enum PACK_MODE { NO_PAD = 0, FIRST_PAD = 1, LAST_PAD = 2 }; | ||||
template <PACK_MODE mode> | template <PACK_MODE mode> | ||||
MEGDNN_ALWAYS_INLINE void pack_src_one_line(const int8_t* inptr, int8_t* outptr, | MEGDNN_ALWAYS_INLINE void pack_src_one_line(const int8_t* inptr, int8_t* outptr, | ||||
@@ -1398,6 +1407,7 @@ struct ConvDiectStrideInt8NchwNchw44<bias_mode, Op, filter_size, 2> { | |||||
INSTANCE_OP_PARAM(stride, filter, BiasMode::BROADCAST_CHANNEL_BIAS) | INSTANCE_OP_PARAM(stride, filter, BiasMode::BROADCAST_CHANNEL_BIAS) | ||||
#define INSTANCE_CONV_KERN(stride) \ | #define INSTANCE_CONV_KERN(stride) \ | ||||
INSTANCE_BIAS_MODE_PARAM(stride, 1) \ | |||||
INSTANCE_BIAS_MODE_PARAM(stride, 2) \ | INSTANCE_BIAS_MODE_PARAM(stride, 2) \ | ||||
INSTANCE_BIAS_MODE_PARAM(stride, 3) \ | INSTANCE_BIAS_MODE_PARAM(stride, 3) \ | ||||
INSTANCE_BIAS_MODE_PARAM(stride, 5) \ | INSTANCE_BIAS_MODE_PARAM(stride, 5) \ | ||||
@@ -291,6 +291,9 @@ ConvBiasImpl::AlgoS8DirectNCHWNCHW44::dispatch_kerns( | |||||
#define DISPATCH_CONV_KERN(stride) \ | #define DISPATCH_CONV_KERN(stride) \ | ||||
switch (param.filter_meta.spatial[0]) { \ | switch (param.filter_meta.spatial[0]) { \ | ||||
case 1: \ | |||||
GET_BIAS_MODE_PARAM(stride, 1) \ | |||||
break; \ | |||||
case 2: \ | case 2: \ | ||||
GET_BIAS_MODE_PARAM(stride, 2) \ | GET_BIAS_MODE_PARAM(stride, 2) \ | ||||
break; \ | break; \ | ||||
@@ -245,6 +245,9 @@ ConvBiasImpl::AlgoDotS8DirectNCHWNCHW44::dispatch_kerns( | |||||
#define DISPATCH_CONV_KERN(stride) \ | #define DISPATCH_CONV_KERN(stride) \ | ||||
switch (param.filter_meta.spatial[0]) { \ | switch (param.filter_meta.spatial[0]) { \ | ||||
case 1: \ | |||||
GET_BIAS_MODE_PARAM(stride, 1) \ | |||||
break; \ | |||||
case 2: \ | case 2: \ | ||||
GET_BIAS_MODE_PARAM(stride, 2) \ | GET_BIAS_MODE_PARAM(stride, 2) \ | ||||
break; \ | break; \ | ||||
@@ -74,9 +74,11 @@ inline bool nchw_nchwxx_valid<NCHW44_INT8>( | |||||
nonline_mode == param::ConvBias::NonlineMode::H_SWISH; | nonline_mode == param::ConvBias::NonlineMode::H_SWISH; | ||||
bool ok_src_dst = | bool ok_src_dst = | ||||
fm.icpg < 4 && (fm.ocpg % 4 == 0 && fm.ocpg >= 4) && fm.group == 1; | fm.icpg < 4 && (fm.ocpg % 4 == 0 && fm.ocpg >= 4) && fm.group == 1; | ||||
bool ok_filter = fm.spatial_ndim == 2 && fm.spatial[0] == fm.spatial[1] && | |||||
(fm.spatial[0] == 2 || fm.spatial[0] == 3 || | |||||
fm.spatial[0] == 5 || fm.spatial[0] == 7); | |||||
bool ok_filter = | |||||
fm.spatial_ndim == 2 && fm.spatial[0] == fm.spatial[1] && | |||||
(fm.spatial[0] == 2 || fm.spatial[0] == 3 || fm.spatial[0] == 5 || | |||||
fm.spatial[0] == 7 || | |||||
(fm.spatial[0] == 1 && fm.stride[0] == 1 && fm.stride[1] == 1)); | |||||
bool ok_slide = fm.dilation[0] == 1 && fm.dilation[1] == 1 && | bool ok_slide = fm.dilation[0] == 1 && fm.dilation[1] == 1 && | ||||
fm.stride[0] == fm.stride[1] && | fm.stride[0] == fm.stride[1] && | ||||
(fm.stride[0] == 1 || fm.stride[1] == 2); | (fm.stride[0] == 1 || fm.stride[1] == 2); | ||||
@@ -126,9 +128,11 @@ inline bool nchw_nchwxx_valid<NCHW44_INT8_DOT>( | |||||
nonline_mode == param::ConvBias::NonlineMode::H_SWISH; | nonline_mode == param::ConvBias::NonlineMode::H_SWISH; | ||||
bool ok_src_dst = | bool ok_src_dst = | ||||
fm.icpg < 4 && (fm.ocpg % 4 == 0 && fm.ocpg >= 4) && fm.group == 1; | fm.icpg < 4 && (fm.ocpg % 4 == 0 && fm.ocpg >= 4) && fm.group == 1; | ||||
bool ok_filter = fm.spatial_ndim == 2 && fm.spatial[0] == fm.spatial[1] && | |||||
(fm.spatial[0] == 2 || fm.spatial[0] == 3 || | |||||
fm.spatial[0] == 5 || fm.spatial[0] == 7); | |||||
bool ok_filter = | |||||
fm.spatial_ndim == 2 && fm.spatial[0] == fm.spatial[1] && | |||||
(fm.spatial[0] == 2 || fm.spatial[0] == 3 || fm.spatial[0] == 5 || | |||||
fm.spatial[0] == 7 || | |||||
(fm.spatial[0] == 1 && fm.stride[0] == 1 && fm.stride[1] == 1)); | |||||
bool ok_slide = fm.dilation[0] == 1 && fm.dilation[1] == 1 && | bool ok_slide = fm.dilation[0] == 1 && fm.dilation[1] == 1 && | ||||
fm.stride[0] == fm.stride[1] && | fm.stride[0] == fm.stride[1] && | ||||
(fm.stride[0] == 1 || fm.stride[1] == 2); | (fm.stride[0] == 1 || fm.stride[1] == 2); | ||||
@@ -487,6 +487,13 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_INT8_NCHW_NCHW44_S2) { | |||||
handle(), "S8_CONV_NCHW_NCHW44"); | handle(), "S8_CONV_NCHW_NCHW44"); | ||||
} | } | ||||
TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_INT8_NCHW_NCHW44_S1_F1) { | |||||
checker_conv_bias_qint8x8x8( | |||||
get_nchw44_conv_bias_args({1}, QUAN_NLMODE, BR_AND_NO_BIASMODE, 1, | |||||
false, true), | |||||
handle(), "S8_CONV_NCHW_NCHW44"); | |||||
} | |||||
/*****************************quint8 direct****************************/ | /*****************************quint8 direct****************************/ | ||||
TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_QUINT8_STRIDE1) { | TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_QUINT8_STRIDE1) { | ||||
checker_conv_bias_quint8x8x8(get_int8_quint8_conv_bias_args( | checker_conv_bias_quint8x8x8(get_int8_quint8_conv_bias_args( | ||||
@@ -517,6 +524,15 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_DOT_NCHW_NCHW44) { | |||||
checker_conv_bias_qint8x8x8(args, handle(), "ARMDOTS8_NCHW_NCHW44"); | checker_conv_bias_qint8x8x8(args, handle(), "ARMDOTS8_NCHW_NCHW44"); | ||||
} | } | ||||
TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_DOT_NCHW_NCHW44_S1_F1) { | |||||
auto args = get_nchw44_conv_bias_args({1}, QUAN_NLMODE, BR_AND_NO_BIASMODE, | |||||
1, false, true); | |||||
for (auto&& arg : args) { | |||||
arg.param.format = param::ConvBias::Format::NCHW44_DOT; | |||||
} | |||||
checker_conv_bias_qint8x8x8(args, handle(), "ARMDOTS8_NCHW_NCHW44"); | |||||
} | |||||
TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_INT8_STRIDE1_WITHDOTPROD) { | TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_INT8_STRIDE1_WITHDOTPROD) { | ||||
checker_conv_bias_qint8x8x8(get_int8_quint8_conv_bias_args( | checker_conv_bias_qint8x8x8(get_int8_quint8_conv_bias_args( | ||||
{2, 3, 5, 7}, 1, false, false, false), | {2, 3, 5, 7}, 1, false, false, false), | ||||
@@ -635,6 +635,7 @@ TEST_F(ARM_COMMON_BENCHMARK_MULTI_THREADS, BENCHMARK_CONVBIAS_INT8_NCHW44) { | |||||
benchmark_impl(param, shape_arg, ".+", RUNS, {4, {4, 5, 6, 7}}, | benchmark_impl(param, shape_arg, ".+", RUNS, {4, {4, 5, 6, 7}}, | ||||
{1, {7}}, data_type); | {1, {7}}, data_type); | ||||
}; | }; | ||||
bench_case(1, 2, 64, 160, 160, 1, 1, 0, 1, true); | |||||
bench_case(1, 3, 64, 224, 224, 7, 1, 3, 2, true); | bench_case(1, 3, 64, 224, 224, 7, 1, 3, 2, true); | ||||
bench_case(1, 64, 64, 56, 56, 3, 1, 1, 1); | bench_case(1, 64, 64, 56, 56, 3, 1, 1, 1); | ||||
bench_case(1, 128, 128, 28, 28, 3, 1, 1, 1); | bench_case(1, 128, 128, 28, 28, 3, 1, 1, 1); | ||||
@@ -131,7 +131,8 @@ uint64_t eval_conv_computation(const TensorShape& src_shape, | |||||
if (param.format == Param::Format::NCHW44 || | if (param.format == Param::Format::NCHW44 || | ||||
param.format == Param::Format::NCHW44_DOT) { | param.format == Param::Format::NCHW44_DOT) { | ||||
//! if channel wise weight layout is {group/4, FH, FW, 1, 1, 4} | //! if channel wise weight layout is {group/4, FH, FW, 1, 1, 4} | ||||
if (filter_shape[1] == 1 && filter_shape[2] == 1) { | |||||
if (filter_shape[1] == 1 && filter_shape[2] == 1 && | |||||
filter_shape.ndim == 6) { | |||||
group *= 4; | group *= 4; | ||||
} | } | ||||
size_t computation = dst_shape.total_nr_elems() * fh * fw * | size_t computation = dst_shape.total_nr_elems() * fh * fw * | ||||