diff --git a/dnn/src/x86/conv_bias/postprocess_helper.h b/dnn/src/x86/conv_bias/postprocess_helper.h index 1b6a82e5..72d83899 100644 --- a/dnn/src/x86/conv_bias/postprocess_helper.h +++ b/dnn/src/x86/conv_bias/postprocess_helper.h @@ -32,7 +32,7 @@ namespace x86 { thin_function run = \ OpCallerUnary<_op<_simd_type, ctype, ctype>, _simd_type>::run; \ run(static_cast(conv_dst_ptr), reinterpret_cast(dst_ptr), \ - bias_type, dst_type, N* OC* OH* OW); + bias_type, dst_type, N* OC* OH* OW* pack_oc_size); #define CALL_BINARY_BROADCAST(_op, _simd_type) \ thin_function(dst_ptr), bias_type, bias_type, dst_type, N, OC, \ OH* OW); +#define CALL_BINARY_BROADCAST_NCHWXX(_op, _simd_type) \ + thin_function \ + run = OpCallerBinary< \ + _op<_simd_type, ctype, ctype>, _simd_type, \ + megdnn::x86::BcastType::VEC_BCAST101xX>::run; \ + run(static_cast(conv_dst_ptr), static_cast(bias_ptr), \ + reinterpret_cast(dst_ptr), bias_type, bias_type, dst_type, N, OC, \ + OH* OW, pack_oc_size); + #define CALL_BINARY(_op, _simd_type) \ thin_function \ @@ -53,7 +64,7 @@ namespace x86 { megdnn::x86::BcastType::VEC_VEC>::run; \ run(static_cast(conv_dst_ptr), static_cast(bias_ptr), \ reinterpret_cast(dst_ptr), bias_type, bias_type, dst_type, \ - N* OC* OH* OW); + N* OC* OH* OW* pack_oc_size); #define cb_unary(_simd_type) \ if (elem_mode == megdnn::param::Elemwise::Mode::RELU) { \ @@ -93,19 +104,24 @@ namespace x86 { cb_binary(CALLER, SIMDType::NONE) \ } -#define FOR_BIAS(bias_mode) \ - switch (bias_mode) { \ - case BiasMode::NO_BIAS: \ - FOR_NONLINEAR_NOBIAS(); \ - break; \ - case BiasMode::BROADCAST_CHANNEL_BIAS: \ - FOR_NONLINEAR(CALL_BINARY_BROADCAST); \ - break; \ - case BiasMode::BIAS: \ - FOR_NONLINEAR(CALL_BINARY); \ - break; \ - default: \ - break; \ +#define FOR_BIAS(bias_mode) \ + switch (bias_mode) { \ + case BiasMode::NO_BIAS: \ + FOR_NONLINEAR_NOBIAS(); \ + break; \ + case BiasMode::BROADCAST_CHANNEL_BIAS: \ + if (pack_oc_size == 1) { \ + FOR_NONLINEAR(CALL_BINARY_BROADCAST); \ + } else { \ + megdnn_assert(pack_oc_size == 4, "Only support nchw44 in x86"); \ + FOR_NONLINEAR(CALL_BINARY_BROADCAST_NCHWXX); \ + } \ + break; \ + case BiasMode::BIAS: \ + FOR_NONLINEAR(CALL_BINARY); \ + break; \ + default: \ + break; \ } template < @@ -119,7 +135,9 @@ struct PostProcess { DType dst_type, size_t N, size_t OC, size_t OH, size_t OW, size_t pack_oc_size = 1) { MEGDNN_MARK_USED_VAR(pack_oc_size); - megdnn_assert(pack_oc_size == 1, "PostProcess only support nchw in x86"); + megdnn_assert( + pack_oc_size == 1 || pack_oc_size == 4, + "PostProcess only support nchw/44 in x86"); megdnn::param::Elemwise::Mode elem_mode = megdnn::param::Elemwise::Mode::ADD; if (bias_mode != megdnn::ConvBiasForward::BiasMode::NO_BIAS) { switch (nonlineMode) { @@ -320,16 +338,21 @@ struct PostProcess { CALLER(AddOp, SIMDType::NONE) \ } -#define FOR_BIAS(bias_mode) \ - switch (bias_mode) { \ - case BiasMode::BIAS: \ - FOR_SIMD(CALL_BINARY); \ - break; \ - case BiasMode::BROADCAST_CHANNEL_BIAS: \ - FOR_SIMD(CALL_BINARY_BROADCAST); \ - break; \ - default: \ - break; \ +#define FOR_BIAS(bias_mode) \ + switch (bias_mode) { \ + case BiasMode::BIAS: \ + FOR_SIMD(CALL_BINARY); \ + break; \ + case BiasMode::BROADCAST_CHANNEL_BIAS: \ + if (pack_oc_size == 1) { \ + FOR_SIMD(CALL_BINARY_BROADCAST); \ + } else { \ + megdnn_assert(pack_oc_size == 4, "Only support nchw44 in x86"); \ + FOR_SIMD(CALL_BINARY_BROADCAST_NCHWXX); \ + } \ + break; \ + default: \ + break; \ } template diff --git a/dnn/src/x86/elemwise_op.h b/dnn/src/x86/elemwise_op.h index 6bad74c0..18df3ae0 100644 --- a/dnn/src/x86/elemwise_op.h +++ b/dnn/src/x86/elemwise_op.h @@ -54,6 +54,33 @@ cb(dt_uint8, __m256i, "avx2", uint8_t, __m256i, mm256, si256, epi8, SIMDType::AV cb(dt_float32, float, "avx2", float, __m256, mm256, ps, ps, SIMDType::AVX2); #undef cb + +//! visitor for handle BCAST101xX(4) at AVX2, load 128, broadcast to 256 +template +struct ParamElemVisitorHalfBoardCast; + +#define cb( \ + _ctype, _simd_ptr_type, load_half_fuc, half_type, _simd_type, board_cast_func) \ + template <> \ + struct ParamElemVisitorHalfBoardCast<_ctype, SIMDType::AVX2> { \ + MEGDNN_ATTRIBUTE_TARGET("avx2") \ + _simd_type operator()(const _ctype* src) const { \ + half_type tmp = \ + load_half_fuc(reinterpret_cast<_simd_ptr_type const*>(src)); \ + return board_cast_func(tmp, tmp); \ + } \ + } + +cb(dt_qint32, __m128i, _mm_loadu_si128, __m128i, __m256i, _mm256_set_m128i); +cb(dt_qint8, __m128i, _mm_loadu_si128, __m128i, __m256i, _mm256_set_m128i); +cb(dt_quint8, __m128i, _mm_loadu_si128, __m128i, __m256i, _mm256_set_m128i); +cb(dt_int32, __m128i, _mm_loadu_si128, __m128i, __m256i, _mm256_set_m128i); +cb(dt_int16, __m128i, _mm_loadu_si128, __m128i, __m256i, _mm256_set_m128i); +cb(dt_int8, __m128i, _mm_loadu_si128, __m128i, __m256i, _mm256_set_m128i); +cb(dt_uint8, __m128i, _mm_loadu_si128, __m128i, __m256i, _mm256_set_m128i); +cb(dt_float32, float, _mm_load_ps, __m128, __m256, _mm256_set_m128); + +#undef cb /*! * \brief broadcast type * BCAST_x[0]x[1]...: x[i] == !stride[i] @@ -71,7 +98,8 @@ enum BcastType { BCAST101_VEC_BCAST101, VEC_BCAST101_VEC, VEC_SCALAR_VEC, - VEC_SCALAR_SCALAR + VEC_SCALAR_SCALAR, + VEC_BCAST101xX }; ///////////////////////////////// OpCaller ///////////////////////////// @@ -227,6 +255,106 @@ struct OpCallerBinary { }; #undef OP_CALLER +template +struct OpCallerBinary { + MEGDNN_ATTRIBUTE_TARGET("sse4.2") + static void run( + const typename Op::src_ctype* src0, const typename Op::src_ctype* src1, + typename Op::dst_ctype* dst, DType src0_dtype, DType src1_dtype, + DType dst_dtype, size_t batch, size_t channel, size_t channel_stride, + size_t channel_block_dim) { + megdnn_assert(channel_block_dim == 4, "only imp for nchw44"); + Op op(src0_dtype, src1_dtype, dst_dtype); + ParamElemVisitor vis0; + ParamElemVisitor vis1; + for (size_t b = 0; b < batch; b++) { + const typename Op::src_ctype* src1_ptr = src1; + for (size_t c = 0; c < channel; c++) { + auto src1_block_ptr = src1_ptr + c * channel_block_dim; + auto channel_block_vec = vis1(src1_block_ptr); + size_t img_index = 0; + auto src0_offset = Op::SIMD_WIDTH / channel_block_dim; + for (; img_index + 2 * src0_offset <= channel_stride; + img_index += 2 * src0_offset) { + op({{vis0(src0), vis0(src0 + Op::SIMD_WIDTH)}}, + {{channel_block_vec, channel_block_vec}}, dst); + src0 += Op::SIMD_WIDTH * 2; + dst += Op::SIMD_WIDTH * 2; + } + for (; img_index < channel_stride; img_index++) { + for (size_t c_iter = 0; c_iter < channel_block_dim; c_iter++) { + op(*src0, *(src1_block_ptr + c_iter), dst); + src0++; + dst++; + } + } + } + } + } +}; + +template +struct OpCallerBinary { + MEGDNN_ATTRIBUTE_TARGET("avx2") + static void run( + const typename Op::src_ctype* src0, const typename Op::src_ctype* src1, + typename Op::dst_ctype* dst, DType src0_dtype, DType src1_dtype, + DType dst_dtype, size_t batch, size_t channel, size_t channel_stride, + size_t channel_block_dim) { + megdnn_assert(channel_block_dim == 4, "only imp for nchw44"); + Op op(src0_dtype, src1_dtype, dst_dtype); + ParamElemVisitor vis0; + ParamElemVisitorHalfBoardCast vis1; + for (size_t b = 0; b < batch; b++) { + const typename Op::src_ctype* src1_ptr = src1; + for (size_t c = 0; c < channel; c++) { + auto src1_block_ptr = src1_ptr + c * channel_block_dim; + auto channel_block_vec = vis1(src1_block_ptr); + size_t img_index = 0; + auto src0_offset = Op::SIMD_WIDTH / channel_block_dim; + for (; img_index + 2 * src0_offset <= channel_stride; + img_index += 2 * src0_offset) { + op({{vis0(src0), vis0(src0 + Op::SIMD_WIDTH)}}, + {{channel_block_vec, channel_block_vec}}, dst); + src0 += Op::SIMD_WIDTH * 2; + dst += Op::SIMD_WIDTH * 2; + } + for (; img_index < channel_stride; img_index++) { + for (size_t c_iter = 0; c_iter < channel_block_dim; c_iter++) { + op(*src0, *(src1_block_ptr + c_iter), dst); + src0++; + dst++; + } + } + } + } + } +}; + +template +struct OpCallerBinary { + static void run( + const typename Op::src_ctype* src0, const typename Op::src_ctype* src1, + typename Op::dst_ctype* dst, DType src0_dtype, DType src1_dtype, + DType dst_dtype, size_t batch, size_t channel, size_t channel_stride, + size_t channel_block_dim) { + Op op(src0_dtype, src1_dtype, dst_dtype); + for (size_t b = 0; b < batch; b++) { + auto src1_ptr = src1; + for (size_t cb = 0; cb < channel; cb++) { + auto src1_block_ptr = src1_ptr + cb * channel_block_dim; + for (size_t img_index = 0; img_index < channel_stride; img_index++) { + for (size_t c_iter = 0; c_iter < channel_block_dim; c_iter++) { + op(*src0, *(src1_block_ptr + c_iter), dst); + src0++; + dst++; + } + } + } + } + } +}; + #define OP_CALLER(simd_type, target_simd) \ template \ struct OpCallerBinary { \