diff --git a/dnn/src/arm_common/conv_bias/f16/algos.h b/dnn/src/arm_common/conv_bias/f16/algos.h index 5210ada2..787c6dfe 100644 --- a/dnn/src/arm_common/conv_bias/f16/algos.h +++ b/dnn/src/arm_common/conv_bias/f16/algos.h @@ -6,7 +6,8 @@ * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. */ #pragma once @@ -107,7 +108,7 @@ public: virtual SmallVector dispatch_kerns( const NCBKernSizeParam& param) const override; - ConvAlgoTypePack get_algo_type() const override{ + ConvAlgoTypePack get_algo_type() const override { return {AlgoDataType::FLOAT16, AlgoCategory::DIRECT}; } MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_DIRECT_FP16) @@ -132,6 +133,26 @@ public: MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_DIRECT_STRD1_FP16) }; +class ConvBiasImpl::AlgoF16ChannelWiseNCHW88 final : public AlgoBase { + SmallVector get_kimpls(const NCBKernSizeParam& param) const; + +public: + AlgoAttribute attribute() const override { + return AlgoAttribute::REPRODUCIBLE; + } + const char* name() const override { return "F16_CHANNEL_WISE_NCHW88"; } + bool usable(const NCBKernSizeParam& param, + AlgoSelectionStrategy algo_selection_strategy) const override; + + size_t get_workspace(const NCBKernSizeParam& param) const override; + virtual SmallVector dispatch_kerns( + const NCBKernSizeParam& param) const override; + ConvAlgoTypePack get_algo_type() const override { + return {AlgoDataType::FLOAT16, AlgoCategory::DIRECT}; + } + MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_CHWNWISE_NCHW88_F16) +}; + } // namespace arm_common } // namespace megdnn #endif diff --git a/dnn/src/arm_common/conv_bias/f16/channel_wise_3x3_s1p1_nchw88_kern.cpp b/dnn/src/arm_common/conv_bias/f16/channel_wise_3x3_s1p1_nchw88_kern.cpp new file mode 100644 index 00000000..a84762f6 --- /dev/null +++ b/dnn/src/arm_common/conv_bias/f16/channel_wise_3x3_s1p1_nchw88_kern.cpp @@ -0,0 +1,320 @@ +/** + * \file dnn/src/arm_common/conv_bias/f16/channel_wise_3x3_s1p1_nchw88_kern.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ + +#include "src/arm_common/conv_bias/f16/channel_wise_3x3_s1p1_nchw88_kern.h" +#include "src/arm_common/elemwise_op.h" +#include "src/arm_common/simd_macro/marm_neon.h" +#include "src/arm_common/utils.h" +#include "src/common/utils.h" +#include "src/fallback/conv_bias/common.h" + +#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC + +using namespace megdnn; +using namespace arm_common; +using namespace fp16; + +namespace { + +template +static inline void shift_src(float16x8_t rsrc[3][4]) { + float16x8_t t[4]; + + t[0] = rsrc[0][(shift + 0) % 4]; + t[1] = rsrc[0][(shift + 1) % 4]; + t[2] = rsrc[0][(shift + 2) % 4]; + t[3] = rsrc[0][(shift + 3) % 4]; + rsrc[0][0] = t[0]; + rsrc[0][1] = t[1]; + rsrc[0][2] = t[2]; + rsrc[0][3] = t[3]; + + t[0] = rsrc[1][(shift + 0) % 4]; + t[1] = rsrc[1][(shift + 1) % 4]; + t[2] = rsrc[1][(shift + 2) % 4]; + t[3] = rsrc[1][(shift + 3) % 4]; + rsrc[1][0] = t[0]; + rsrc[1][1] = t[1]; + rsrc[1][2] = t[2]; + rsrc[1][3] = t[3]; + + t[0] = rsrc[2][(shift + 0) % 4]; + t[1] = rsrc[2][(shift + 1) % 4]; + t[2] = rsrc[2][(shift + 2) % 4]; + t[3] = rsrc[2][(shift + 3) % 4]; + rsrc[2][0] = t[0]; + rsrc[2][1] = t[1]; + rsrc[2][2] = t[2]; + rsrc[2][3] = t[3]; +} + +template +static inline float16x8_t load_bias(const float16_t* bias, + const float16x8_t& init) { + if (bias_mode == BiasMode::BIAS) { + return vld1q_f16(bias); + } else { + return init; + } +} + +template +struct compute_element { + template + static inline void call(const float16_t*& src0, const float16_t*& src1, + const float16_t*& src2, float16_t*& dst, + const float16_t*& bias, const float16x8_t& init, + float16x8_t rsrc[3][4], float16x8_t rfilter[3][3], + const Op& op) { +#define RSRC(i, j) rsrc[i][((j) + bw) % 4] + float16x8_t rdst = load_bias(bias, init); + if (has_top) { + RSRC(0, 3) = vld1q_f16(src0 + 16); + } + { RSRC(1, 3) = vld1q_f16(src1 + 16); } + if (has_bottom) { + RSRC(2, 3) = vld1q_f16(src2 + 16); + } + + if (has_top) { + rdst = vfmaq_f16(rdst, RSRC(0, 0), rfilter[0][0]); + rdst = vfmaq_f16(rdst, RSRC(0, 1), rfilter[0][1]); + rdst = vfmaq_f16(rdst, RSRC(0, 2), rfilter[0][2]); + } + { + rdst = vfmaq_f16(rdst, RSRC(1, 0), rfilter[1][0]); + rdst = vfmaq_f16(rdst, RSRC(1, 1), rfilter[1][1]); + rdst = vfmaq_f16(rdst, RSRC(1, 2), rfilter[1][2]); + } + if (has_bottom) { + rdst = vfmaq_f16(rdst, RSRC(2, 0), rfilter[2][0]); + rdst = vfmaq_f16(rdst, RSRC(2, 1), rfilter[2][1]); + rdst = vfmaq_f16(rdst, RSRC(2, 2), rfilter[2][2]); + } + + vst1q_f16(dst, op(rdst)); + + if (has_top) { + src0 += 8; + } + { src1 += 8; } + if (has_bottom) { + src2 += 8; + } + dst += 8; + bias += 8; + compute_element::call( + src0, src1, src2, dst, bias, init, rsrc, rfilter, op); +#undef RSRC + } +}; + +template +struct compute_element { + template + static inline void call(const float16_t*& src0, const float16_t*& src1, + const float16_t*& src2, float16_t*& dst, + const float16_t*& bias, const float16x8_t& init, + float16x8_t rsrc[3][4], float16x8_t rfilter[3][3], + const Op& op) {} +}; + +template +struct compute_element_right { + template + static inline void call(float16_t*& dst, const float16_t*& bias, + const float16x8_t& init, float16x8_t rsrc[3][4], + float16x8_t rfilter[3][3], const Op& op) { + float16x8_t rdst = load_bias(bias, init); + + if (has_top) { + rdst = vfmaq_f16(rdst, rsrc[0][0], rfilter[0][0]); + rdst = vfmaq_f16(rdst, rsrc[0][1], rfilter[0][1]); + rdst = vfmaq_f16(rdst, rsrc[0][2], rfilter[0][2]); + } + { + rdst = vfmaq_f16(rdst, rsrc[1][0], rfilter[1][0]); + rdst = vfmaq_f16(rdst, rsrc[1][1], rfilter[1][1]); + rdst = vfmaq_f16(rdst, rsrc[1][2], rfilter[1][2]); + } + if (has_bottom) { + rdst = vfmaq_f16(rdst, rsrc[2][0], rfilter[2][0]); + rdst = vfmaq_f16(rdst, rsrc[2][1], rfilter[2][1]); + rdst = vfmaq_f16(rdst, rsrc[2][2], rfilter[2][2]); + } + + vst1q_f16(dst, op(rdst)); + + dst += 8; + bias += 8; + } +}; + +template +struct compute_element_right_pad { + template + static inline void call(float16_t*& dst, const float16_t*& bias, + const float16x8_t& init, float16x8_t rsrc[3][4], + float16x8_t rfilter[3][3], const Op& op) { + float16x8_t rdst = load_bias(bias, init); + + if (has_top) { + rdst = vfmaq_f16(rdst, rsrc[0][1], rfilter[0][0]); + rdst = vfmaq_f16(rdst, rsrc[0][2], rfilter[0][1]); + } + { + rdst = vfmaq_f16(rdst, rsrc[1][1], rfilter[1][0]); + rdst = vfmaq_f16(rdst, rsrc[1][2], rfilter[1][1]); + } + if (has_bottom) { + rdst = vfmaq_f16(rdst, rsrc[2][1], rfilter[2][0]); + rdst = vfmaq_f16(rdst, rsrc[2][2], rfilter[2][1]); + } + + vst1q_f16(dst, op(rdst)); + dst += 8; + bias += 8; + } +}; + +template +struct compute_row { + template + static inline void call(const float16_t*& src0, const float16_t*& src1, + const float16_t*& src2, float16_t*& dst, + const float16_t*& bias, const float16x8_t& init, + float16x8_t rsrc[3][4], float16x8_t rfilter[3][3], + int W, const Op& op) { + if (has_top) { + rsrc[0][0] = vdupq_n_f16(0); + rsrc[0][1] = vld1q_f16(src0 + 0); + rsrc[0][2] = vld1q_f16(src0 + 8); + } + { + rsrc[1][0] = vdupq_n_f16(0); + rsrc[1][1] = vld1q_f16(src1 + 0); + rsrc[1][2] = vld1q_f16(src1 + 8); + } + if (has_bottom) { + rsrc[2][0] = vdupq_n_f16(0); + rsrc[2][1] = vld1q_f16(src2 + 0); + rsrc[2][2] = vld1q_f16(src2 + 8); + } + + int w = 0; + const float16_t* src0_ptr = src0; + const float16_t* src1_ptr = src1; + const float16_t* src2_ptr = src2; + float16_t* dst_ptr = dst; + const float16_t* bias_ptr = bias; + + for (; w + 3 < W - 2; w += 4) { + compute_element<4, 0, has_top, has_bottom, bias_mode>::call( + src0_ptr, src1_ptr, src2_ptr, dst_ptr, bias_ptr, init, rsrc, + rfilter, op); + } + if (w + 1 < W - 2) { + compute_element<2, 0, has_top, has_bottom, bias_mode>::call( + src0_ptr, src1_ptr, src2_ptr, dst_ptr, bias_ptr, init, rsrc, + rfilter, op); + shift_src<2>(rsrc); + w += 2; + } + if (w < W - 2) { + compute_element<1, 0, has_top, has_bottom, bias_mode>::call( + src0_ptr, src1_ptr, src2_ptr, dst_ptr, bias_ptr, init, rsrc, + rfilter, op); + shift_src<1>(rsrc); + w += 1; + } + compute_element_right::call( + dst_ptr, bias_ptr, init, rsrc, rfilter, op); + compute_element_right_pad::call( + dst_ptr, bias_ptr, init, rsrc, rfilter, op); + + src0 += W * 8; + src1 += W * 8; + src2 += W * 8; + dst += W * 8; + bias += W * 8; + } +}; + +} // namespace + +template +void channel_wise_nchw88::do_conv_kern_3x3_stride1_padding1( + const float16_t* src, float16_t* dst, const float16_t* filter, + const float16_t* bias, int H, int W) { + Op op; + + float16x8_t init; + if (bias_mode == BiasMode::NO_BIAS) { + init = vdupq_n_f16(__fp16(0.f)); + } else if (bias_mode == BiasMode::BROADCAST_CHANNEL_BIAS) { + init = vld1q_f16(bias); + } + + const float16_t* src0 = src - W * 8; + const float16_t* src1 = src; + const float16_t* src2 = src + W * 8; + + float16x8_t rfilter[3][3]; + rfilter[0][0] = vld1q_f16(filter + 0); + rfilter[0][1] = vld1q_f16(filter + 8); + rfilter[0][2] = vld1q_f16(filter + 16); + rfilter[1][0] = vld1q_f16(filter + 24); + rfilter[1][1] = vld1q_f16(filter + 32); + rfilter[1][2] = vld1q_f16(filter + 40); + rfilter[2][0] = vld1q_f16(filter + 48); + rfilter[2][1] = vld1q_f16(filter + 56); + rfilter[2][2] = vld1q_f16(filter + 64); + + float16x8_t rsrc[3][4]; + + compute_row::call(src0, src1, src2, dst, bias, init, + rsrc, rfilter, W, op); + + for (int h = 1; h < H - 1; h += 1) { + compute_row::call(src0, src1, src2, dst, bias, + init, rsrc, rfilter, W, op); + } + + compute_row::call(src0, src1, src2, dst, bias, init, + rsrc, rfilter, W, op); +} + +#define INSTANTIATION(bias, Op) \ + template void \ + channel_wise_nchw88::do_conv_kern_3x3_stride1_padding1( \ + const float16_t*, float16_t*, const float16_t*, const float16_t*, \ + int, int); + +#define FOR_OP(bias) \ + INSTANTIATION(bias, SigmoidOp<__fp16>) \ + INSTANTIATION(bias, ReluOp<__fp16>) \ + INSTANTIATION(bias, HSwishOp<__fp16>) \ + INSTANTIATION(bias, NoneOp<__fp16>) + +#define FOR_BIAS \ + FOR_OP(BiasMode::NO_BIAS) \ + FOR_OP(BiasMode::BROADCAST_CHANNEL_BIAS) \ + FOR_OP(BiasMode::BIAS) + +FOR_BIAS + +#undef FOR_BIAS +#undef FOR_OP +#undef INSTANTIATION + +#endif +// vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/conv_bias/f16/channel_wise_3x3_s1p1_nchw88_kern.h b/dnn/src/arm_common/conv_bias/f16/channel_wise_3x3_s1p1_nchw88_kern.h new file mode 100644 index 00000000..0a4aa5ee --- /dev/null +++ b/dnn/src/arm_common/conv_bias/f16/channel_wise_3x3_s1p1_nchw88_kern.h @@ -0,0 +1,37 @@ +/** + * \file dnn/src/arm_common/conv_bias/fp16/channel_wise_3x3_s1p1_nchw88_kern.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ + +#pragma once + +#include "src/arm_common/conv_bias/opr_impl.h" +#include "src/fallback/conv_bias/common.h" + +#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC + +namespace megdnn { +namespace arm_common { +namespace fp16 { +namespace channel_wise_nchw88 { + +template +void do_conv_kern_3x3_stride1_padding1(const __fp16* src, __fp16* dst, + const __fp16* filter, const __fp16* bias, + int H, int W); + +} // namespace channel_wise_nchw88 +} // namespace fp16 +} // namespace arm_common +} // namespace megdnn + +#endif + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/conv_bias/f16/channel_wise_nchw88_algo.cpp b/dnn/src/arm_common/conv_bias/f16/channel_wise_nchw88_algo.cpp new file mode 100644 index 00000000..face6042 --- /dev/null +++ b/dnn/src/arm_common/conv_bias/f16/channel_wise_nchw88_algo.cpp @@ -0,0 +1,183 @@ +/** + * \file dnn/src/arm_common/conv_bias/f16/channel_wise_nchw88_algo.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ + +#include "src/arm_common/conv_bias/f16/algos.h" +#include "src/arm_common/conv_bias/f16/channel_wise_nchw88_kern.h" +#include "src/arm_common/elemwise_op.h" + +#include "midout.h" + +#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC + +using namespace megdnn; +using namespace arm_common; +using namespace fp16; + +using conv_fun = std::function; + +MIDOUT_DECL(conv_bias_fp16_channel_wise_nchw88) + +bool ConvBiasImpl::AlgoF16ChannelWiseNCHW88::usable( + const NCBKernSizeParam& param, AlgoSelectionStrategy) const { + auto&& fm = param.filter_meta; + auto FH = fm.spatial[0]; + size_t OC = fm.ocpg; + size_t IC = fm.icpg; + size_t GROUP = fm.group; + bool ok_type = (param.src_type.enumv() == DTypeEnum::Float16 && + param.filter_type.enumv() == DTypeEnum::Float16 && + param.bias_type.enumv() == DTypeEnum::Float16 && + param.dst_type.enumv() == DTypeEnum::Float16); + bool ok_format = OC == 1 && IC == 1 && GROUP % 8 == 0 && + fm.format == param::Convolution::Format::NCHW88; + bool ok_filter = fm.spatial_ndim == 2 && FH == fm.spatial[1] && + (FH == 2 || FH == 3 || FH == 5); + bool ok_slide = fm.dilation[0] == 1 && fm.dilation[1] == 1 && + fm.stride[0] == fm.stride[1] && + (fm.stride[0] == 1 || fm.stride[0] == 2); + bool ok_conv = !fm.should_flip; + bool ok_comp = param.compute_mode == Param::ComputeMode::DEFAULT; + return ok_type && ok_format && ok_filter && ok_slide && ok_conv && ok_comp; +} + +size_t ConvBiasImpl::AlgoF16ChannelWiseNCHW88::get_workspace( + const NCBKernSizeParam&) const { + return 0; +} + +SmallVector +ConvBiasImpl::AlgoF16ChannelWiseNCHW88::dispatch_kerns( + const NCBKernSizeParam& param) const { + const constexpr size_t pack_group_size = 8_z; + auto fm = param.filter_meta; + const int batch = param.n; + const int group = fm.group; + const int stride = fm.stride[0]; + + conv_fun do_conv_fun = nullptr; + // NOTE: remain_w is not used to gen hash of midout for compatible with +// shape runtime +#define DO_CONV_KERN_FUN(_stride, filter, bias_mode, op) \ + MIDOUT_BEGIN(conv_bias_fp16_channel_wise_nchw88, \ + midout_iv(#_stride #filter #bias_mode #op##_hash)) { \ + do_conv_fun = channel_wise_nchw88:: \ + do_conv_kern_##_stride##_##filter##x##filter; \ + } \ + MIDOUT_END(); + +#define GET_OP_PARAM(_stride, filter, bias_mode) \ + switch (param.nonlineMode) { \ + case param::ConvBias::NonlineMode::IDENTITY: \ + DO_CONV_KERN_FUN(_stride, filter, bias_mode, NoneOp<__fp16>) \ + break; \ + case param::ConvBias::NonlineMode::RELU: \ + DO_CONV_KERN_FUN(_stride, filter, bias_mode, ReluOp<__fp16>) \ + break; \ + case param::ConvBias::NonlineMode::SIGMOID: \ + DO_CONV_KERN_FUN(_stride, filter, bias_mode, SigmoidOp<__fp16>) \ + break; \ + case param::ConvBias::NonlineMode::H_SWISH: \ + DO_CONV_KERN_FUN(_stride, filter, bias_mode, HSwishOp<__fp16>) \ + break; \ + default: \ + megdnn_assert(0, "not supported nonline mode"); \ + break; \ + } + +#define GET_BIAS_MODE_PARAM(_stride, filter) \ + switch (param.bias_mode) { \ + case BiasMode::NO_BIAS: \ + GET_OP_PARAM(_stride, filter, BiasMode::NO_BIAS) \ + break; \ + case BiasMode::BROADCAST_CHANNEL_BIAS: \ + GET_OP_PARAM(_stride, filter, BiasMode::BROADCAST_CHANNEL_BIAS) \ + break; \ + case BiasMode::BIAS: \ + GET_OP_PARAM(_stride, filter, BiasMode::BIAS) \ + break; \ + default: \ + megdnn_assert(0, "not supported bias mode"); \ + break; \ + } + +#define DISPATCH_CONV_KERN(_stride) \ + switch (param.filter_meta.spatial[0]) { \ + case 2: \ + GET_BIAS_MODE_PARAM(_stride, 2) \ + break; \ + case 3: \ + GET_BIAS_MODE_PARAM(_stride, 3) \ + break; \ + case 5: \ + GET_BIAS_MODE_PARAM(_stride, 5) \ + break; \ + default: \ + megdnn_assert(0, "not supported stride"); \ + break; \ + } + +#define DISPATCH_STRIDE() \ + if (1 == stride) { \ + DISPATCH_CONV_KERN(stride1); \ + } else { \ + DISPATCH_CONV_KERN(stride2); \ + } + + DISPATCH_STRIDE(); + +#undef DO_CONV_KERN_FUN +#undef GET_REMAIN_W_PARAM +#undef GET_OP_PARAM +#undef GET_BIAS_MODE_PARAM +#undef DISPATCH_CONV_KERN +#undef DISPATCH_STRIDE + + megdnn_assert(do_conv_fun, "conv filter not supported"); + + SmallVector ret_kerns; + + CpuNDRange ncb_range = {static_cast(batch), + static_cast(group / pack_group_size)}; + auto do_conv = [do_conv_fun](const NCBKernParam& kern_param, + const NCBKernIndex& ncb_index) { + size_t PH = kern_param.filter_meta.padding[0]; + size_t PW = kern_param.filter_meta.padding[1]; + size_t OH = kern_param.osz[0]; + size_t OW = kern_param.osz[1]; + size_t IH = kern_param.isz[0]; + size_t IW = kern_param.isz[1]; + + size_t batch_id = ncb_index.ndrange_id[0]; + size_t group_id = ncb_index.ndrange_id[1]; + const __fp16* sptr = + reinterpret_cast(kern_param.src( + batch_id, group_id, 0, pack_group_size)); + const __fp16* fptr = reinterpret_cast( + kern_param.filter(group_id, pack_group_size)); + __fp16* dst = reinterpret_cast<__fp16*>(kern_param.dst( + batch_id, group_id, 0, pack_group_size)); + const __fp16* bptr = + reinterpret_cast(kern_param.bias( + batch_id, group_id, 0, pack_group_size)); + + do_conv_fun(sptr, fptr, bptr, dst, IH, IW, OH, OW, PH, PW); + }; + ret_kerns.push_back({do_conv, ncb_range}); + return ret_kerns; +} + +#endif + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/conv_bias/f16/channel_wise_nchw88_kern.cpp b/dnn/src/arm_common/conv_bias/f16/channel_wise_nchw88_kern.cpp new file mode 100644 index 00000000..af4e4285 --- /dev/null +++ b/dnn/src/arm_common/conv_bias/f16/channel_wise_nchw88_kern.cpp @@ -0,0 +1,911 @@ +/** + * \file dnn/src/arm_common/conv_bias/f16/channel_wise_nchw88_kern.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ + +#include "src/arm_common/conv_bias/f16/channel_wise_nchw88_kern.h" +#include "src/arm_common/conv_bias/f16/channel_wise_3x3_s1p1_nchw88_kern.h" +#include "src/arm_common/elemwise_op.h" +#include "src/arm_common/simd_macro/marm_neon.h" +#include "src/arm_common/utils.h" +#include "src/common/unroll_macro.h" +#include "src/common/utils.h" +#include "src/fallback/conv_bias/common.h" + +#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC + +using namespace megdnn; +using namespace arm_common; +using namespace fp16; + +namespace { + +template +void load_vec(float16x8_t* dst, const __fp16* src); + +#define cb(i) dst[i] = vld1q_f16(src + i * 8); +#define LOAD_MACRO(n) \ + template <> \ + inline void load_vec(float16x8_t * dst, const __fp16* src) { \ + UNROLL_CALL_NOWRAPPER(n, cb); \ + } +LOAD_MACRO(2); +LOAD_MACRO(3); +LOAD_MACRO(4); +LOAD_MACRO(5); +LOAD_MACRO(6); +LOAD_MACRO(7); +LOAD_MACRO(8); +LOAD_MACRO(9); +LOAD_MACRO(25); +#undef cb +#undef LOAD_MACRO + +template +void compute_vec(float16x8_t& dst, float16x8_t* src, float16x8_t* filter); + +#define cb(i) dst = vfmaq_f16(dst, src[i], filter[i]); +#define COMPUTE_MACRO(n) \ + template <> \ + inline void compute_vec(float16x8_t & dst, float16x8_t * src, \ + float16x8_t * filter) { \ + UNROLL_CALL_NOWRAPPER(n, cb); \ + } +COMPUTE_MACRO(2); +COMPUTE_MACRO(3); +COMPUTE_MACRO(5); +#undef cb +#undef COMPUTE_MACRO + +template +struct load_bias_vec; + +#define cb_bias(i) dst[i] = vld1q_f16((bptr) + i * 8); +#define cb_init(i) dst[i] = init; + +#define INIT_BIAS_MACRO(n) \ + template \ + struct load_bias_vec { \ + static void impl(float16x8_t* dst, const float16x8_t& init, \ + const __fp16* bptr) { \ + if (bias_mode == BiasMode::BIAS) { \ + UNROLL_CALL_NOWRAPPER(n, cb_bias); \ + } else { \ + UNROLL_CALL_NOWRAPPER(n, cb_init); \ + } \ + } \ + }; + +INIT_BIAS_MACRO(1); +INIT_BIAS_MACRO(2); +INIT_BIAS_MACRO(4); +#undef cb_bias +#undef cb_init +#undef INIT_BIAS_MACRO +} // namespace + +#define COMPUTE_PADDING_KERNEL(oh) \ + do { \ + int iw = ow * stride - PW; \ + float16x8_t result; \ + load_bias_vec::impl(&result, init, \ + bias + (oh)*OW * 8 + ow * 8); \ + for (int kh = 0; kh < fh; kh++) { \ + if (kh + ih < 0 || kh + ih >= static_cast(IH)) \ + continue; \ + for (int kw = 0; kw < fh; kw++) { \ + if (kw + iw < 0 || kw + iw >= static_cast(IW)) \ + continue; \ + const __fp16* sptr = src + (kh + ih) * IW * 8 + (kw + iw) * 8; \ + result = vfmaq_f16(result, kernel[kh * fh + kw], \ + vld1q_f16(sptr)); \ + } \ + } \ + __fp16* output = dst + (oh)*OW * 8 + ow * 8; \ + op(result, output); \ + } while (0) + +#define COMPUTE_PADDING_TOP() \ + do { \ + size_t oh_start = (PH + stride - 1) / stride; \ + for (size_t oh = 0; oh < oh_start; oh++) { \ + int ih = oh * stride - PH; \ + for (size_t ow = 0; ow < OW; ow++) { \ + COMPUTE_PADDING_KERNEL(oh); \ + } \ + } \ + } while (0) + +#define COMPUTE_PADDING_LEFT(n) \ + do { \ + for (int i = 0; i < n; ++i) { \ + size_t ow_start = (PW + stride - 1) / stride; \ + int ih = (oh + i) * stride - PH; \ + for (size_t ow = 0; ow < ow_start; ow++) { \ + COMPUTE_PADDING_KERNEL(oh + i); \ + } \ + } \ + } while (0) + +#define COMPUTE_PADDING_RIGHT(n) \ + do { \ + for (int i = 0; i < n; ++i) { \ + size_t ow_end = (IW + PW - fh) / stride + 1; \ + int ih = (oh + i) * stride - PH; \ + for (size_t ow = ow_end; ow < OW; ow++) { \ + COMPUTE_PADDING_KERNEL(oh + i); \ + } \ + } \ + } while (0) + +#define COMPUTE_PADDING_BOTTOM() \ + do { \ + size_t oh_end = (IH + PH - fh) / stride + 1; \ + for (size_t oh = oh_end; oh < OH; oh++) { \ + int ih = oh * stride - PH; \ + for (size_t ow = 0; ow < OW; ow++) { \ + COMPUTE_PADDING_KERNEL(oh); \ + } \ + } \ + } while (0) + +template +void channel_wise_nchw88::do_conv_kern_stride1_2x2( + const __fp16* src, const __fp16* filter, const __fp16* bias, + __fp16* dst, const size_t IH, const size_t IW, const size_t OH, + const size_t OW, const size_t PH, const size_t PW) { + float16x8_t kernel[4]; + load_vec<4>(kernel, filter); + Op op; + float16x8_t init; + if (bias_mode == BiasMode::NO_BIAS) { + init = vdupq_n_f16(__fp16(0.f)); + } else if (bias_mode == BiasMode::BROADCAST_CHANNEL_BIAS) { + init = vld1q_f16(bias); + } + constexpr int fh = 2; + constexpr int stride = 1; + size_t oh_start = PH; + size_t ow_start = PW; + size_t oh_end = IH + PH - 1; + size_t ow_end = IW + PW - 1; +#define COMPUTE_2X2(dst, src, kernel) \ + compute_vec<2>(dst[0], &src[0], kernel); \ + compute_vec<2>(dst[1], &src[1], kernel); \ + compute_vec<2>(dst[2], &src[2], kernel); \ + compute_vec<2>(dst[3], &src[3], kernel) + + size_t oh = oh_start; + COMPUTE_PADDING_TOP(); + for (; oh + 1 < oh_end; oh += 2) { + COMPUTE_PADDING_LEFT(2); + + size_t ih = oh - oh_start; + size_t ow = ow_start; + for (; ow + 3 < ow_end; ow += 4) { + size_t iw = ow - ow_start; + const __fp16* input = src + ih * IW * 8 + iw * 8; + __fp16* output = dst + oh * OW * 8 + ow * 8; + float16x8_t dst_v[2][4]; + load_bias_vec::impl(dst_v[0], init, + bias + oh * OW * 8 + ow * 8); + load_bias_vec::impl( + dst_v[1], init, bias + (oh + 1) * OW * 8 + ow * 8); + float16x8_t src_v[3][5]; + load_vec<5>(src_v[0], input); + COMPUTE_2X2(dst_v[0], src_v[0], &kernel[0]); + load_vec<5>(src_v[1], input + IW * 8); + COMPUTE_2X2(dst_v[0], src_v[1], &kernel[2]); + COMPUTE_2X2(dst_v[1], src_v[1], &kernel[0]); + load_vec<5>(src_v[2], input + 2 * IW * 8); + COMPUTE_2X2(dst_v[1], src_v[2], &kernel[2]); + + op({{dst_v[0][0], dst_v[0][1]}}, output); + op({{dst_v[0][2], dst_v[0][3]}}, output + 16); + op({{dst_v[1][0], dst_v[1][1]}}, output + OW * 8); + op({{dst_v[1][2], dst_v[1][3]}}, output + OW * 8 + 16); + } + for (; ow < ow_end; ow++) { + size_t iw = ow - ow_start; + const __fp16* input = src + ih * IW * 8 + iw * 8; + __fp16* output = dst + oh * OW * 8 + ow * 8; + float16x8_t dst_v[2]; + load_bias_vec::impl(&dst_v[0], init, + bias + oh * OW * 8 + ow * 8); + load_bias_vec::impl( + &dst_v[1], init, bias + (oh + 1) * OW * 8 + ow * 8); + float16x8_t src_v[3][2]; + load_vec<2>(src_v[0], input); + compute_vec<2>(dst_v[0], &src_v[0][0], &kernel[0]); + load_vec<2>(src_v[1], input + IW * 8); + compute_vec<2>(dst_v[0], &src_v[1][0], &kernel[2]); + compute_vec<2>(dst_v[1], &src_v[1][0], &kernel[0]); + load_vec<2>(src_v[2], input + 2 * IW * 8); + compute_vec<2>(dst_v[1], &src_v[2][0], &kernel[2]); + + op(dst_v[0], output); + op(dst_v[1], output + OW * 8); + } + + COMPUTE_PADDING_RIGHT(2); + } + for (; oh < oh_end; oh++) { + COMPUTE_PADDING_LEFT(1); + + size_t ih = oh - oh_start; + size_t ow = ow_start; + for (; ow + 3 < ow_end; ow += 4) { + size_t iw = ow - ow_start; + const __fp16* input = src + ih * IW * 8 + iw * 8; + __fp16* output = dst + oh * OW * 8 + ow * 8; + float16x8_t dst_v[1][4]; + load_bias_vec::impl(dst_v[0], init, + bias + oh * OW * 8 + ow * 8); + float16x8_t src_v[2][5]; + load_vec<5>(src_v[0], input); + COMPUTE_2X2(dst_v[0], src_v[0], &kernel[0]); + load_vec<5>(src_v[1], input + IW * 8); + COMPUTE_2X2(dst_v[0], src_v[1], &kernel[2]); + + op({{dst_v[0][0], dst_v[0][1]}}, output); + op({{dst_v[0][2], dst_v[0][3]}}, output + 16); + } + for (; ow < ow_end; ow++) { + size_t iw = ow - ow_start; + const __fp16* input = src + ih * IW * 8 + iw * 8; + __fp16* output = dst + oh * OW * 8 + ow * 8; + float16x8_t dst_v; + load_bias_vec::impl(&dst_v, init, + bias + oh * OW * 8 + ow * 8); + float16x8_t src_v[2][2]; + load_vec<2>(src_v[0], input); + compute_vec<2>(dst_v, &src_v[0][0], &kernel[0]); + load_vec<2>(src_v[1], input + IW * 8); + compute_vec<2>(dst_v, &src_v[1][0], &kernel[2]); + + op(dst_v, output); + } + COMPUTE_PADDING_RIGHT(1); + } + COMPUTE_PADDING_BOTTOM(); + +#undef COMPUTE_2X2 +} + +template +void channel_wise_nchw88::do_conv_kern_stride1_3x3( + const __fp16* src, const __fp16* filter, const __fp16* bias, + __fp16* dst, const size_t IH, const size_t IW, const size_t OH, + const size_t OW, const size_t PH, const size_t PW) { + if (IH == OH && IW == OW && PH == 1 && PW == 1) { + do_conv_kern_3x3_stride1_padding1(src, dst, filter, bias, + OH, OW); + return; + } + + float16x8_t kernel[9]; + load_vec<9>(kernel, filter); + Op op; + float16x8_t init; + if (bias_mode == BiasMode::NO_BIAS) { + init = vdupq_n_f16(__fp16(0.f)); + } else if (bias_mode == BiasMode::BROADCAST_CHANNEL_BIAS) { + init = vld1q_f16(bias); + } + constexpr int fh = 3; + constexpr int stride = 1; + size_t oh_start = PH; + size_t ow_start = PW; + size_t oh_end = IH + PH - 2; + size_t ow_end = IW + PW - 2; + + size_t oh = oh_start; + COMPUTE_PADDING_TOP(); + for (; oh < oh_end; oh += 1) { + COMPUTE_PADDING_LEFT(1); + + size_t ih = oh - PH; + size_t ow = ow_start; + for (; ow + 1 < ow_end; ow += 2) { + size_t iw = ow - PW; + const __fp16* input = src + ih * IW * 8 + iw * 8; + __fp16* output = dst + oh * OW * 8 + ow * 8; + float16x8_t dst_v[1][2]; + load_bias_vec::impl(dst_v[0], init, + bias + oh * OW * 8 + ow * 8); + float16x8_t src_v[3][4]; + load_vec<4>(src_v[0], input); + load_vec<4>(src_v[1], input + IW * 8); + load_vec<4>(src_v[2], input + 2 * IW * 8); + compute_vec<3>(dst_v[0][0], &src_v[0][0], &kernel[0]); + compute_vec<3>(dst_v[0][1], &src_v[0][1], &kernel[0]); + compute_vec<3>(dst_v[0][0], &src_v[1][0], &kernel[3]); + compute_vec<3>(dst_v[0][1], &src_v[1][1], &kernel[3]); + compute_vec<3>(dst_v[0][0], &src_v[2][0], &kernel[6]); + compute_vec<3>(dst_v[0][1], &src_v[2][1], &kernel[6]); + + op({{dst_v[0][0], dst_v[0][1]}}, output); + } + for (; ow < ow_end; ow++) { + size_t iw = ow - PW; + const __fp16* input = src + ih * IW * 8 + iw * 8; + __fp16* output = dst + oh * OW * 8 + ow * 8; + float16x8_t dst_v[1]; + load_bias_vec::impl(&dst_v[0], init, + bias + oh * OW * 8 + ow * 8); + float16x8_t src_v[3][3]; + load_vec<3>(src_v[0], input); + load_vec<3>(src_v[1], input + IW * 8); + load_vec<3>(src_v[2], input + 2 * IW * 8); + compute_vec<3>(dst_v[0], &src_v[0][0], &kernel[0]); + compute_vec<3>(dst_v[0], &src_v[1][0], &kernel[3]); + compute_vec<3>(dst_v[0], &src_v[2][0], &kernel[6]); + + op(dst_v[0], output); + } + + COMPUTE_PADDING_RIGHT(1); + } + COMPUTE_PADDING_BOTTOM(); +} + +template +void channel_wise_nchw88::do_conv_kern_stride1_5x5( + const __fp16* src, const __fp16* filter, const __fp16* bias, + __fp16* dst, const size_t IH, const size_t IW, const size_t OH, + const size_t OW, const size_t PH, const size_t PW) { + float16x8_t kernel[25]; + load_vec<25>(kernel, filter); + Op op; + float16x8_t init; + if (bias_mode == BiasMode::NO_BIAS) { + init = vdupq_n_f16(__fp16(0.f)); + } else if (bias_mode == BiasMode::BROADCAST_CHANNEL_BIAS) { + init = vld1q_f16(bias); + } + constexpr int fh = 5; + constexpr int stride = 1; + size_t oh_start = PH; + size_t ow_start = PW; + size_t oh_end = IH + PH - 4; + size_t ow_end = IW + PW - 4; + + size_t oh = oh_start; + + COMPUTE_PADDING_TOP(); + for (; oh + 1 < oh_end; oh += 2) { + COMPUTE_PADDING_LEFT(2); + + size_t ih = oh - PH; + size_t ow = ow_start; + for (; ow + 1 < ow_end; ow += 2) { + size_t iw = ow - PW; + const __fp16* input = src + ih * IW * 8 + iw * 8; + __fp16* output = dst + oh * OW * 8 + ow * 8; + float16x8_t dst_v[2][2]; + load_bias_vec::impl(dst_v[0], init, + bias + oh * OW * 8 + ow * 8); + load_bias_vec::impl( + dst_v[1], init, bias + (oh + 1) * OW * 8 + ow * 8); + float16x8_t kernel[2][5]; + float16x8_t src_v[2][6]; +#define COMPUTE_5X5_4(i, dst, src, kernel0, kernel1) \ + load_vec<5>(kernel0, filter + i * 5 * 8); \ + load_vec<6>(src, input + i * IW * 8); \ + compute_vec<5>(dst[0][0], &src[0], kernel0); \ + compute_vec<5>(dst[0][1], &src[1], kernel0); \ + compute_vec<5>(dst[1][0], &src[0], kernel1); \ + compute_vec<5>(dst[1][1], &src[1], kernel1) + // line 0 + load_vec<5>(kernel[0], filter); + load_vec<6>(src_v[0], input); + compute_vec<5>(dst_v[0][0], &src_v[0][0], kernel[0]); + compute_vec<5>(dst_v[0][1], &src_v[0][1], kernel[0]); + // line 1 + COMPUTE_5X5_4(1, dst_v, src_v[1], kernel[1], kernel[0]); + // line 2 + COMPUTE_5X5_4(2, dst_v, src_v[0], kernel[0], kernel[1]); + // line 3 + COMPUTE_5X5_4(3, dst_v, src_v[1], kernel[1], kernel[0]); + // line 4 + COMPUTE_5X5_4(4, dst_v, src_v[0], kernel[0], kernel[1]); + // line 5 + load_vec<6>(src_v[1], input + 5 * IW * 8); + compute_vec<5>(dst_v[1][0], &src_v[1][0], kernel[0]); + compute_vec<5>(dst_v[1][1], &src_v[1][1], kernel[0]); +#undef COMPUTE_5X5_4 + op({{dst_v[0][0], dst_v[0][1]}}, output); + op({{dst_v[1][0], dst_v[1][1]}}, output + OW * 8); + } + for (; ow < ow_end; ow++) { + size_t iw = ow - PW; + const __fp16* input = src + ih * IW * 8 + iw * 8; + __fp16* output = dst + oh * OW * 8 + ow * 8; + float16x8_t dst_v[2][1]; + load_bias_vec::impl(dst_v[0], init, + bias + oh * OW * 8 + ow * 8); + load_bias_vec::impl( + dst_v[1], init, bias + (oh + 1) * OW * 8 + ow * 8); + float16x8_t kernel[2][5]; + float16x8_t src_v[2][5]; +#define COMPUTE_5X5_2(i, dst, src, kernel0, kernel1) \ + load_vec<5>(kernel0, filter + i * 5 * 8); \ + load_vec<6>(src, input + i * IW * 8); \ + compute_vec<5>(dst[0][0], &src[0], kernel0); \ + compute_vec<5>(dst[1][0], &src[0], kernel1); + // line 0 + load_vec<5>(kernel[0], filter); + load_vec<5>(src_v[0], input); + compute_vec<5>(dst_v[0][0], &src_v[0][0], kernel[0]); + // line 1 + COMPUTE_5X5_2(1, dst_v, src_v[1], kernel[1], kernel[0]); + // line 2 + COMPUTE_5X5_2(2, dst_v, src_v[0], kernel[0], kernel[1]); + // line 3 + COMPUTE_5X5_2(3, dst_v, src_v[1], kernel[1], kernel[0]); + // line 4 + COMPUTE_5X5_2(4, dst_v, src_v[0], kernel[0], kernel[1]); + // line 5 + load_vec<5>(src_v[1], input + 5 * IW * 8); + compute_vec<5>(dst_v[1][0], &src_v[1][0], kernel[0]); +#undef COMPUTE_5X5_2 + op(dst_v[0][0], output); + op(dst_v[1][0], output + OW * 8); + } + COMPUTE_PADDING_RIGHT(2); + } + for (; oh < oh_end; oh++) { + COMPUTE_PADDING_LEFT(1); + + size_t ih = oh - PH; + size_t ow = ow_start; + for (; ow + 1 < ow_end; ow += 2) { + size_t iw = ow - PW; + const __fp16* input = src + ih * IW * 8 + iw * 8; + __fp16* output = dst + oh * OW * 8 + ow * 8; + float16x8_t dst_v[1][2]; + load_bias_vec::impl(dst_v[0], init, + bias + oh * OW * 8 + ow * 8); + float16x8_t kernel[2][5]; + float16x8_t src_v[2][6]; +#define COMPUTE_5X5_2(i, dst, src, kernel) \ + load_vec<5>(kernel, filter + i * 5 * 8); \ + load_vec<6>(src, input + i * IW * 8); \ + compute_vec<5>(dst[0][0], &src[0], kernel); \ + compute_vec<5>(dst[0][1], &src[1], kernel) + // line 0 + COMPUTE_5X5_2(0, dst_v, src_v[0], kernel[0]); + // line 1 + COMPUTE_5X5_2(1, dst_v, src_v[1], kernel[1]); + // line 2 + COMPUTE_5X5_2(2, dst_v, src_v[0], kernel[0]); + // line 3 + COMPUTE_5X5_2(3, dst_v, src_v[1], kernel[1]); + // line 4 + COMPUTE_5X5_2(4, dst_v, src_v[0], kernel[0]); +#undef COMPUTE_5X5_2 + op({{dst_v[0][0], dst_v[0][1]}}, output); + } + for (; ow < ow_end; ow++) { + size_t iw = ow - PW; + const __fp16* input = src + ih * IW * 8 + iw * 8; + __fp16* output = dst + oh * OW * 8 + ow * 8; + float16x8_t dst_v; + load_bias_vec::impl(&dst_v, init, + bias + oh * OW * 8 + ow * 8); + float16x8_t kernel[2][5]; + float16x8_t src_v[2][5]; +#define COMPUTE_5X5_1(i, dst, src, kernel) \ + load_vec<5>(kernel, filter + i * 5 * 8); \ + load_vec<6>(src, input + i * IW * 8); \ + compute_vec<5>(dst, &src[0], kernel) + // line 0 + COMPUTE_5X5_1(0, dst_v, src_v[0], kernel[0]); + // line 1 + COMPUTE_5X5_1(1, dst_v, src_v[1], kernel[1]); + // line 2 + COMPUTE_5X5_1(2, dst_v, src_v[0], kernel[0]); + // line 3 + COMPUTE_5X5_1(3, dst_v, src_v[1], kernel[1]); + // line 4 + COMPUTE_5X5_1(4, dst_v, src_v[0], kernel[0]); +#undef COMPUTE_5X5_1 + op(dst_v, output); + } + COMPUTE_PADDING_RIGHT(1); + } + COMPUTE_PADDING_BOTTOM(); +} + +template +void channel_wise_nchw88::do_conv_kern_stride2_2x2( + const __fp16* src, const __fp16* filter, const __fp16* bias, + __fp16* dst, const size_t IH, const size_t IW, const size_t OH, + const size_t OW, const size_t PH, const size_t PW) { + float16x8_t kernel[4]; + load_vec<4>(kernel, filter); + Op op; + float16x8_t init; + if (bias_mode == BiasMode::NO_BIAS) { + init = vdupq_n_f16(__fp16(0.f)); + } else if (bias_mode == BiasMode::BROADCAST_CHANNEL_BIAS) { + init = vld1q_f16(bias); + } + constexpr int fh = 2; + constexpr int stride = 2; + size_t oh_start = (PH + 1) / 2; + size_t ow_start = (PW + 1) / 2; + size_t oh_end = (IH + PH) / 2; + size_t ow_end = (IW + PW) / 2; + +#define COMPUTE_2X2(dst, src, kernel) \ + compute_vec<2>(dst[0], &src[0], kernel); \ + compute_vec<2>(dst[1], &src[2], kernel); \ + compute_vec<2>(dst[2], &src[4], kernel); \ + compute_vec<2>(dst[3], &src[6], kernel) + size_t oh = oh_start; + COMPUTE_PADDING_TOP(); + for (; oh < oh_end; oh++) { + COMPUTE_PADDING_LEFT(1); + size_t ih = oh * 2 - PH; + size_t ow = ow_start; + for (; ow + 3 < ow_end; ow += 4) { + size_t iw = ow * 2 - PW; + const __fp16* input = src + ih * IW * 8 + iw * 8; + __fp16* output = dst + oh * OW * 8 + ow * 8; + float16x8_t dst_v[4]; + load_bias_vec::impl(&dst_v[0], init, + bias + oh * OW * 8 + ow * 8); + float16x8_t src_v[2][8]; + load_vec<8>(src_v[0], input); + COMPUTE_2X2(dst_v, src_v[0], &kernel[0]); + load_vec<8>(src_v[1], input + IW * 8); + COMPUTE_2X2(dst_v, src_v[1], &kernel[2]); +#undef COMPUTE_2X2 + op({{dst_v[0], dst_v[1]}}, output); + op({{dst_v[2], dst_v[3]}}, output + 16); + } + for (; ow < ow_end; ow++) { + size_t iw = ow * 2 - PW; + const __fp16* input = src + ih * IW * 8 + iw * 8; + __fp16* output = dst + oh * OW * 8 + ow * 8; + float16x8_t dst_v; + load_bias_vec::impl(&dst_v, init, + bias + oh * OW * 8 + ow * 8); + float16x8_t src_v[2][2]; + load_vec<2>(src_v[0], input); + compute_vec<2>(dst_v, &src_v[0][0], &kernel[0]); + load_vec<2>(src_v[1], input + IW * 8); + compute_vec<2>(dst_v, &src_v[1][0], &kernel[2]); + + op(dst_v, output); + } + COMPUTE_PADDING_RIGHT(1); + } + COMPUTE_PADDING_BOTTOM(); +#undef COMPUTE_2X2 +} + +template +void channel_wise_nchw88::do_conv_kern_stride2_3x3( + const __fp16* src, const __fp16* filter, const __fp16* bias, + __fp16* dst, const size_t IH, const size_t IW, const size_t OH, + const size_t OW, const size_t PH, const size_t PW) { + float16x8_t kernel[9]; + load_vec<9>(kernel, filter); + Op op; + float16x8_t init; + if (bias_mode == BiasMode::NO_BIAS) { + init = vdupq_n_f16(__fp16(0.f)); + } else if (bias_mode == BiasMode::BROADCAST_CHANNEL_BIAS) { + init = vld1q_f16(bias); + } + constexpr int fh = 3; + constexpr int stride = 2; + size_t oh_start = (PH + 1) / 2; + size_t ow_start = (PW + 1) / 2; + size_t oh_end = (IH + PH - 3) / 2 + 1; + size_t ow_end = (IW + PW - 3) / 2 + 1; + + size_t oh = oh_start; + COMPUTE_PADDING_TOP(); + for (; oh + 1 < oh_end; oh += 2) { + COMPUTE_PADDING_LEFT(2); + size_t ih = oh * 2 - PH; + size_t ow = ow_start; + for (; ow + 1 < ow_end; ow += 2) { + size_t iw = ow * 2 - PW; + const __fp16* input = src + ih * IW * 8 + iw * 8; + __fp16* output = dst + oh * OW * 8 + ow * 8; + float16x8_t dst_v[2][2]; + load_bias_vec::impl(dst_v[0], init, + bias + oh * OW * 8 + ow * 8); + load_bias_vec::impl( + dst_v[1], init, bias + (oh + 1) * OW * 8 + ow * 8); + float16x8_t src_v[2][5]; + load_vec<5>(src_v[0], input); + compute_vec<3>(dst_v[0][0], &src_v[0][0], &kernel[0]); + compute_vec<3>(dst_v[0][1], &src_v[0][2], &kernel[0]); + load_vec<5>(src_v[1], input + IW * 8); + compute_vec<3>(dst_v[0][0], &src_v[1][0], &kernel[3]); + compute_vec<3>(dst_v[0][1], &src_v[1][2], &kernel[3]); + load_vec<5>(src_v[0], input + 2 * IW * 8); + compute_vec<3>(dst_v[0][0], &src_v[0][0], &kernel[6]); + compute_vec<3>(dst_v[0][1], &src_v[0][2], &kernel[6]); + compute_vec<3>(dst_v[1][0], &src_v[0][0], &kernel[0]); + compute_vec<3>(dst_v[1][1], &src_v[0][2], &kernel[0]); + load_vec<5>(src_v[1], input + 3 * IW * 8); + compute_vec<3>(dst_v[1][0], &src_v[1][0], &kernel[3]); + compute_vec<3>(dst_v[1][1], &src_v[1][2], &kernel[3]); + load_vec<5>(src_v[0], input + 4 * IW * 8); + compute_vec<3>(dst_v[1][0], &src_v[0][0], &kernel[6]); + compute_vec<3>(dst_v[1][1], &src_v[0][2], &kernel[6]); + + op({{dst_v[0][0], dst_v[0][1]}}, output); + op({{dst_v[1][0], dst_v[1][1]}}, output + OW * 8); + } + for (; ow < ow_end; ow++) { + size_t iw = ow * 2 - PW; + const __fp16* input = src + ih * IW * 8 + iw * 8; + __fp16* output = dst + oh * OW * 8 + ow * 8; + float16x8_t dst_v[2]; + load_bias_vec::impl(&dst_v[0], init, + bias + oh * OW * 8 + ow * 8); + load_bias_vec::impl( + &dst_v[1], init, bias + (oh + 1) * OW * 8 + ow * 8); + float16x8_t src_v[2][3]; + load_vec<3>(src_v[0], input); + compute_vec<3>(dst_v[0], &src_v[0][0], &kernel[0]); + load_vec<3>(src_v[1], input + IW * 8); + compute_vec<3>(dst_v[0], &src_v[1][0], &kernel[3]); + load_vec<3>(src_v[0], input + 2 * IW * 8); + compute_vec<3>(dst_v[0], &src_v[0][0], &kernel[6]); + compute_vec<3>(dst_v[1], &src_v[0][0], &kernel[0]); + load_vec<3>(src_v[1], input + 3 * IW * 8); + compute_vec<3>(dst_v[1], &src_v[1][0], &kernel[3]); + load_vec<3>(src_v[0], input + 4 * IW * 8); + compute_vec<3>(dst_v[1], &src_v[0][0], &kernel[6]); + + op(dst_v[0], output); + op(dst_v[1], output + OW * 8); + } + COMPUTE_PADDING_RIGHT(2); + } + for (; oh < oh_end; oh++) { + COMPUTE_PADDING_LEFT(1); + size_t ih = oh * 2 - PH; + size_t ow = ow_start; + for (; ow + 1 < ow_end; ow += 2) { + size_t iw = ow * 2 - PW; + const __fp16* input = src + ih * IW * 8 + iw * 8; + __fp16* output = dst + oh * OW * 8 + ow * 8; + float16x8_t dst_v[2]; + load_bias_vec::impl(&dst_v[0], init, + bias + oh * OW * 8 + ow * 8); + float16x8_t src_v[3][5]; + load_vec<5>(src_v[0], input); + compute_vec<3>(dst_v[0], &src_v[0][0], &kernel[0]); + compute_vec<3>(dst_v[1], &src_v[0][2], &kernel[0]); + load_vec<5>(src_v[1], input + IW * 8); + compute_vec<3>(dst_v[0], &src_v[1][0], &kernel[3]); + compute_vec<3>(dst_v[1], &src_v[1][2], &kernel[3]); + load_vec<5>(src_v[2], input + 2 * IW * 8); + compute_vec<3>(dst_v[0], &src_v[2][0], &kernel[6]); + compute_vec<3>(dst_v[1], &src_v[2][2], &kernel[6]); + op({{dst_v[0], dst_v[1]}}, output); + } + for (; ow < ow_end; ow++) { + size_t iw = ow * 2 - PW; + const __fp16* input = src + ih * IW * 8 + iw * 8; + __fp16* output = dst + oh * OW * 8 + ow * 8; + float16x8_t dst_v; + load_bias_vec::impl(&dst_v, init, + bias + oh * OW * 8 + ow * 8); + float16x8_t src_v[3][3]; + load_vec<3>(src_v[0], input); + compute_vec<3>(dst_v, &src_v[0][0], &kernel[0]); + load_vec<3>(src_v[1], input + IW * 8); + compute_vec<3>(dst_v, &src_v[1][0], &kernel[3]); + load_vec<3>(src_v[2], input + 2 * IW * 8); + compute_vec<3>(dst_v, &src_v[2][0], &kernel[6]); + op(dst_v, output); + } + COMPUTE_PADDING_RIGHT(1); + } + COMPUTE_PADDING_BOTTOM(); +} + +template +void channel_wise_nchw88::do_conv_kern_stride2_5x5( + const __fp16* src, const __fp16* filter, const __fp16* bias, + __fp16* dst, const size_t IH, const size_t IW, const size_t OH, + const size_t OW, const size_t PH, const size_t PW) { + float16x8_t kernel[25]; + load_vec<25>(kernel, filter); + Op op; + float16x8_t init; + if (bias_mode == BiasMode::NO_BIAS) { + init = vdupq_n_f16(__fp16(0.f)); + } else if (bias_mode == BiasMode::BROADCAST_CHANNEL_BIAS) { + init = vld1q_f16(bias); + } + constexpr int fh = 5; + constexpr int stride = 2; + size_t oh_start = (PH + stride - 1) / stride; + size_t ow_start = (PW + stride - 1) / stride; + size_t oh_end = (IH + PH - 5) / stride + 1; + size_t ow_end = (IW + PW - 5) / stride + 1; + + size_t oh = oh_start; + COMPUTE_PADDING_TOP(); + for (; oh + 1 < oh_end; oh += 2) { + COMPUTE_PADDING_LEFT(2); + size_t ih = oh * stride - PH; + size_t ow = ow_start; + for (; ow + 1 < ow_end; ow += 2) { + size_t iw = ow * stride - PW; + const __fp16* input = src + ih * IW * 8 + iw * 8; + __fp16* output = dst + oh * OW * 8 + ow * 8; + float16x8_t dst_v[2][2]; + load_bias_vec::impl(dst_v[0], init, + bias + oh * OW * 8 + ow * 8); + load_bias_vec::impl( + dst_v[1], init, bias + (oh + 1) * OW * 8 + ow * 8); + float16x8_t kernel[3][5]; + float16x8_t src_v[2][7]; +#define COMPUTE_5X5_4(i, dst, src, kernel0, kernel1) \ + load_vec<5>(kernel0, filter + i * 5 * 8); \ + load_vec<7>(src, input + i * IW * 8); \ + compute_vec<5>(dst[0][0], &src[0], kernel0); \ + compute_vec<5>(dst[0][1], &src[2], kernel0); \ + compute_vec<5>(dst[1][0], &src[0], kernel1); \ + compute_vec<5>(dst[1][1], &src[2], kernel1) + +#define COMPUTE_5X5_2(i, dst, src, kernel) \ + load_vec<7>(src, input + i * IW * 8); \ + compute_vec<5>(dst[0], &src[0], kernel); \ + compute_vec<5>(dst[1], &src[2], kernel) + // line 0 + load_vec<5>(kernel[0], filter); + COMPUTE_5X5_2(0, dst_v[0], src_v[0], kernel[0]); + // line 1 + load_vec<5>(kernel[1], filter + 5 * 8); + COMPUTE_5X5_2(1, dst_v[0], src_v[1], kernel[1]); + // line 2 + COMPUTE_5X5_4(2, dst_v, src_v[0], kernel[2], kernel[0]); + // line 3 + COMPUTE_5X5_4(3, dst_v, src_v[1], kernel[0], kernel[1]); + // line 4 + COMPUTE_5X5_4(4, dst_v, src_v[0], kernel[1], kernel[2]); + // line 5 + COMPUTE_5X5_2(5, dst_v[1], src_v[1], kernel[0]); + // line 6 + COMPUTE_5X5_2(6, dst_v[1], src_v[0], kernel[1]); +#undef COMPUTE_5X5_4 +#undef COMPUTE_5X5_2 + op({{dst_v[0][0], dst_v[0][1]}}, output); + op({{dst_v[1][0], dst_v[1][1]}}, output + OW * 8); + } + for (; ow < ow_end; ow++) { + size_t iw = ow * stride - PW; + const __fp16* input = src + ih * IW * 8 + iw * 8; + __fp16* output = dst + oh * OW * 8 + ow * 8; + float16x8_t dst_v[2]; + load_bias_vec::impl(&dst_v[0], init, + bias + oh * OW * 8 + ow * 8); + load_bias_vec::impl( + &dst_v[1], init, bias + (oh + 1) * OW * 8 + ow * 8); + float16x8_t kernel[3][5]; + float16x8_t src_v[2][5]; +#define COMPUTE_5X5_2(i, dst, src, kernel0, kernel1) \ + load_vec<5>(kernel0, filter + i * 5 * 8); \ + load_vec<5>(src, input + i * IW * 8); \ + compute_vec<5>(dst[0], &src[0], kernel0); \ + compute_vec<5>(dst[1], &src[0], kernel1); + +#define COMPUTE_5X5_1(i, dst, src, kernel) \ + load_vec<5>(src, input + i * IW * 8); \ + compute_vec<5>(dst, &src[0], kernel); \ + // line 0 + load_vec<5>(kernel[0], filter); + COMPUTE_5X5_1(0, dst_v[0], src_v[0], kernel[0]); + // line 1 + load_vec<5>(kernel[1], filter + 5 * 8); + COMPUTE_5X5_1(1, dst_v[0], src_v[1], kernel[1]); + // line 2 + COMPUTE_5X5_2(2, dst_v, src_v[0], kernel[2], kernel[0]); + // line 3 + COMPUTE_5X5_2(3, dst_v, src_v[1], kernel[0], kernel[1]); + // line 4 + COMPUTE_5X5_2(4, dst_v, src_v[0], kernel[1], kernel[2]); + // line 5 + COMPUTE_5X5_1(5, dst_v[1], src_v[1], kernel[0]); + // line 6 + COMPUTE_5X5_1(6, dst_v[1], src_v[0], kernel[1]); +#undef COMPUTE_5X5_2 +#undef COMPUTE_5X5_1 + op(dst_v[0], output); + op(dst_v[1], output + OW * 8); + } + COMPUTE_PADDING_RIGHT(2); + } + for (; oh < oh_end; oh++) { + COMPUTE_PADDING_LEFT(1); + size_t ih = oh * stride - PH; + size_t ow = ow_start; + for (; ow < ow_end; ow++) { + size_t iw = ow * stride - PW; + const __fp16* input = src + ih * IW * 8 + iw * 8; + __fp16* output = dst + oh * OW * 8 + ow * 8; + float16x8_t dst_v; + load_bias_vec::impl(&dst_v, init, + bias + oh * OW * 8 + ow * 8); + float16x8_t kernel[2][5]; + float16x8_t src_v[2][5]; +#define COMPUTE_5X5_1(i, dst, src, kernel) \ + load_vec<5>(kernel, filter + i * 5 * 8); \ + load_vec<6>(src, input + i * IW * 8); \ + compute_vec<5>(dst, &src[0], kernel) + // line 0 + COMPUTE_5X5_1(0, dst_v, src_v[0], kernel[0]); + // line 1 + COMPUTE_5X5_1(1, dst_v, src_v[1], kernel[1]); + // line 2 + COMPUTE_5X5_1(2, dst_v, src_v[0], kernel[0]); + // line 3 + COMPUTE_5X5_1(3, dst_v, src_v[1], kernel[1]); + // line 4 + COMPUTE_5X5_1(4, dst_v, src_v[0], kernel[0]); +#undef COMPUTE_5X5_1 + op(dst_v, output); + } + COMPUTE_PADDING_RIGHT(1); + } + COMPUTE_PADDING_BOTTOM(); +} + +#define INSTANTIATION(stride, i, bias, Op) \ + template void \ + channel_wise_nchw88::do_conv_kern_##stride##_##i##x##i( \ + const __fp16*, const __fp16*, const __fp16*, __fp16*, \ + const size_t, const size_t, const size_t, const size_t, \ + const size_t, const size_t); + +#define FOR_OP(stride, i, bias) \ + INSTANTIATION(stride, i, bias, SigmoidOp<__fp16>) \ + INSTANTIATION(stride, i, bias, ReluOp<__fp16>) \ + INSTANTIATION(stride, i, bias, HSwishOp<__fp16>) \ + INSTANTIATION(stride, i, bias, NoneOp<__fp16>) + +#define FOR_BIAS(stride, i) \ + FOR_OP(stride, i, BiasMode::NO_BIAS) \ + FOR_OP(stride, i, BiasMode::BROADCAST_CHANNEL_BIAS) \ + FOR_OP(stride, i, BiasMode::BIAS) + +#define FOR_FILTER(stride) \ + FOR_BIAS(stride, 2) \ + FOR_BIAS(stride, 3) \ + FOR_BIAS(stride, 5) + +#define FOR_STRIDE \ + FOR_FILTER(stride1) \ + FOR_FILTER(stride2) + +FOR_STRIDE + +#undef FOR_STRIDE +#undef FOR_FILTER +#undef FOR_BIAS +#undef FOR_OP +#undef INSTANTIATION + +#endif + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/conv_bias/f16/channel_wise_nchw88_kern.h b/dnn/src/arm_common/conv_bias/f16/channel_wise_nchw88_kern.h new file mode 100644 index 00000000..033e8d56 --- /dev/null +++ b/dnn/src/arm_common/conv_bias/f16/channel_wise_nchw88_kern.h @@ -0,0 +1,49 @@ +/** + * \file dnn/src/arm_common/conv_bias/fp16/channel_wise_nchw88_kern.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ + +#pragma once + +#include "src/arm_common/conv_bias/opr_impl.h" +#include "src/fallback/conv_bias/common.h" + +#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC + +namespace megdnn { +namespace arm_common { +namespace fp16 { +namespace channel_wise_nchw88 { + +#define KERN(stride, i) \ + template \ + void do_conv_kern_##stride##_##i##x##i( \ + const __fp16* src, const __fp16* filter, const __fp16* bias, \ + __fp16* dst, const size_t IH, const size_t IW, const size_t OH, \ + const size_t OW, const size_t PH, const size_t PW); + +KERN(stride1, 2) +KERN(stride1, 3) +KERN(stride1, 5) + +KERN(stride2, 2) +KERN(stride2, 3) +KERN(stride2, 5) + +#undef KERN + +} // namespace channel_wise_nchw88 +} // namespace fp16 +} // namespace arm_common +} // namespace megdnn + +#endif + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/conv_bias/opr_impl.cpp b/dnn/src/arm_common/conv_bias/opr_impl.cpp index fc4fe356..e3bb59c9 100644 --- a/dnn/src/arm_common/conv_bias/opr_impl.cpp +++ b/dnn/src/arm_common/conv_bias/opr_impl.cpp @@ -85,6 +85,7 @@ class ConvBiasImpl::AlgoPack : NonCopyableObj { #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC AlgoF16Direct f16_direct; AlgoF16DirectStride1 f16_direct_stride1; + AlgoF16ChannelWiseNCHW88 f16_channel_wise_nchw88; #endif SmallVector> refhold; @@ -119,6 +120,7 @@ public: #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC m_direct_algos.emplace_back(&f16_direct_stride1); m_direct_algos.emplace_back(&f16_direct); + m_direct_algos.emplace_back(&f16_channel_wise_nchw88); #endif m_direct_algos.emplace_back(&i8x8x16_direct); m_direct_algos.emplace_back(&i8x8x16_stride2_filter2); diff --git a/dnn/src/arm_common/conv_bias/opr_impl.h b/dnn/src/arm_common/conv_bias/opr_impl.h index d145fe57..225288e3 100644 --- a/dnn/src/arm_common/conv_bias/opr_impl.h +++ b/dnn/src/arm_common/conv_bias/opr_impl.h @@ -96,6 +96,7 @@ private: #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC class AlgoF16Direct; class AlgoF16DirectStride1; + class AlgoF16ChannelWiseNCHW88; #endif class AlgoPack; diff --git a/dnn/src/fallback/conv_bias/opr_impl.h b/dnn/src/fallback/conv_bias/opr_impl.h index 79180b70..0e9abf60 100644 --- a/dnn/src/fallback/conv_bias/opr_impl.h +++ b/dnn/src/fallback/conv_bias/opr_impl.h @@ -238,6 +238,7 @@ public: ARM_COMMON_WINOGRAD_F23_8X8_FP16, ARM_COMMON_DIRECT_FP16, ARM_COMMON_DIRECT_STRD1_FP16, + ARM_COMMON_CHWNWISE_NCHW88_F16, ARM_COMMON_WINOGRAD_F23_4X4_FP32, ARM_COMMON_WINOGRAD_F63_FP32, ARM_COMMON_WINOGRAD_F63_4X4_FP32, diff --git a/dnn/test/arm_common/conv_bias_multi_thread.cpp b/dnn/test/arm_common/conv_bias_multi_thread.cpp index 9e56345d..1c697a18 100644 --- a/dnn/test/arm_common/conv_bias_multi_thread.cpp +++ b/dnn/test/arm_common/conv_bias_multi_thread.cpp @@ -148,6 +148,81 @@ std::vector get_nchw44_channel_wise_args( return args; } +std::vector get_nchw88_channel_wise_args( + std::vector kernel, size_t stride, bool no_bias, + bool no_nonlinemode, bool no_full_bias) { + using namespace conv_bias; + using Param = param::ConvBias; + using NLMode = param::ConvBias::NonlineMode; + std::vector args; + + auto pack = [&](size_t n, size_t group, size_t w, size_t h, size_t kernel, + size_t stride, NLMode nlmode, bool pad) { + Param param; + param.stride_h = stride; + param.stride_w = stride; + if (pad) { + param.pad_h = kernel / 2; + param.pad_w = kernel / 2; + } else { + param.pad_h = 0; + param.pad_w = 0; + } + param.nonlineMode = nlmode; + param.format = param::ConvBias::Format::NCHW88; + param.sparse = param::ConvBias::Sparse::GROUP; + + args.emplace_back(param, TensorShape{n, group, h, w, 8}, + TensorShape{group, 1, 1, kernel, kernel, 8}, + TensorShape{}); + if (!no_bias) { + args.emplace_back(param, TensorShape{n, group, h, w, 8}, + TensorShape{group, 1, 1, kernel, kernel, 8}, + TensorShape{1, group, 1, 1, 8}); + } + if (!no_full_bias) { + args.emplace_back( + param, TensorShape{n, group, h, w, 8}, + TensorShape{group, 1, 1, kernel, kernel, 8}, + TensorShape{n, group, + (h + 2 * param.pad_w - kernel) / stride + 1, + (w + 2 * param.pad_w - kernel) / stride + 1, + 8}); + } + }; + + std::vector nonlinemode = {NLMode::IDENTITY}; + if (!no_nonlinemode) { + nonlinemode.emplace_back(NLMode::RELU); + nonlinemode.emplace_back(NLMode::H_SWISH); + } + for (size_t n : {1, 2}) { + for (auto nlmode : nonlinemode) { + for (bool pad : {true}) { + for (size_t group : {1, 2, 4, 7, 8, 128}) { + for (size_t size : {4, 6, 7, 9, 15, 40}) { + for (size_t kern : kernel) { + pack(n, group, size, size, kern, stride, nlmode, + pad); + } + } + } + } + for (bool pad : {false}) { + for (size_t group : {1, 2, 7, 128}) { + for (size_t size : {7, 9, 15, 40}) { + for (size_t kern : kernel) { + pack(n, group, size, size, kern, stride, nlmode, + pad); + } + } + } + } + } + } + return args; +} + void checker_conv_bias_qint8x8x8(std::vector args, Handle* handle, const char* algo_name) { Checker checker(handle); @@ -317,6 +392,26 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_DIRECT_FP16_STR1) { checker_conv_bias_f16(get_conv_bias_args({2, 3, 5}, 1, false, false, false), handle(), rng, "F16STRD1", 0.03); } +TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_CHANNEL_WISE_STRIDE1_FP16_NCHW88_1) { + NormalRNG rng(1); + checker_conv_bias_f16( + get_nchw88_channel_wise_args({2, 3}, 1, false, false, false), + handle(), rng, "F16_CHANNEL_WISE_NCHW88", 0.03); +} + +TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_CHANNEL_WISE_STRIDE1_FP16_NCHW88_2) { + NormalRNG rng(1); + checker_conv_bias_f16( + get_nchw88_channel_wise_args({5}, 1, false, false, false), handle(), + rng, "F16_CHANNEL_WISE_NCHW88", 0.03); +} + +TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_CHANNEL_WISE_STRIDE2_FP16_NCHW88) { + NormalRNG rng(1); + checker_conv_bias_f16( + get_nchw88_channel_wise_args({2, 3, 5}, 2, false, false, false), + handle(), rng, "F16_CHANNEL_WISE_NCHW88", 0.03); +} #endif /**********************************algo 8816 direct************************/ diff --git a/dnn/test/arm_common/conv_bias_multi_thread_benchmark.cpp b/dnn/test/arm_common/conv_bias_multi_thread_benchmark.cpp index b50937b2..18884324 100644 --- a/dnn/test/arm_common/conv_bias_multi_thread_benchmark.cpp +++ b/dnn/test/arm_common/conv_bias_multi_thread_benchmark.cpp @@ -400,6 +400,68 @@ TEST_F(ARM_COMMON_BENCHMARK_MULTI_THREADS, BENCHMARK_CONVBIAS_DIRECTF16_STR1) { benchmark_impl(param, shapes_and_computation, algo_name, RUNS, {2, {4, 5}}, {1, {4}}, data_type); } +TEST_F(ARM_COMMON_BENCHMARK_MULTI_THREADS, BENCHMARK_CONVBIAS_F16_NCHW88) { + constexpr size_t RUNS = 50; + + std::string algo_name = "F16_CHANNEL_WISE_NCHW88"; + printf("Benchmarker F16_CHANNEL_WISE_NCHW88 algo\n"); + std::vector data_type = {dtype::Float16(), dtype::Float16(), + dtype::Float16(), dtype::Float16()}; + + auto bench_case = [&](size_t N, size_t IC, size_t H, size_t W, size_t FS, + size_t P, size_t S) { + param::ConvBias param; + param.nonlineMode = param::ConvBias::NonlineMode::RELU; + param.pad_h = P; + param.pad_w = P; + param.stride_h = S; + param.stride_w = S; + param.sparse = param::ConvBias::Sparse::GROUP; + param.format = param::ConvBias::Format::NCHW88; + + size_t group = IC; + size_t OC = IC; + SmallVector shapes{ + {N, IC, H, W, 8}, + {group, 1, 1, FS, FS, 8}, + {1, OC, 1, 1, 8}, + {}, + {N, OC, (H + 2 * P - FS) / S + 1, (W + 2 * P - FS) / S + 1, 8}}; + TensorShape dst{N, OC, (H + 2 * P - FS) / S + 1, + (W + 2 * P - FS) / S + 1, 8}; + float computations = + ((IC / group) * FS * FS * dst.total_nr_elems() * 2 + + dst.total_nr_elems()) * + 1e-6; + std::vector, float>> shape_arg = { + std::make_pair(shapes, computations)}; + + benchmark_impl(param, shape_arg, algo_name, RUNS, {4, {4, 5, 6, 7}}, + {1, {7}}, data_type); + }; + + bench_case(1, 64, 100, 100, 5, 2, 1); + bench_case(1, 64, 56, 56, 5, 2, 1); + bench_case(1, 64, 28, 28, 5, 2, 1); + bench_case(1, 64, 100, 100, 5, 2, 2); + bench_case(1, 64, 56, 56, 5, 2, 2); + bench_case(1, 64, 28, 28, 5, 2, 2); + + bench_case(1, 64, 100, 100, 3, 1, 1); + bench_case(1, 64, 56, 56, 3, 1, 1); + bench_case(1, 64, 28, 28, 3, 1, 1); + bench_case(1, 64, 100, 100, 3, 1, 2); + bench_case(1, 64, 56, 56, 3, 1, 2); + bench_case(1, 64, 28, 28, 3, 1, 2); + + bench_case(1, 64, 100, 100, 2, 0, 1); + bench_case(1, 64, 56, 56, 2, 0, 1); + bench_case(1, 64, 28, 28, 2, 0, 1); + bench_case(1, 64, 100, 100, 2, 0, 2); + bench_case(1, 64, 56, 56, 2, 0, 2); + bench_case(1, 64, 28, 28, 2, 0, 2); +} + #endif TEST_F(ARM_COMMON_BENCHMARK_MULTI_THREADS, BENCHMARK_CONVBIAS_DIRECT_INT8x8x16) {