diff --git a/dnn/src/arm_common/conv_bias/f16/channel_wise_3x3_s1p1_nchw88_kern.cpp b/dnn/src/arm_common/conv_bias/f16/channel_wise_3x3_s1p1_nchw88_kern.cpp index a84762f6..fbd5fa7a 100644 --- a/dnn/src/arm_common/conv_bias/f16/channel_wise_3x3_s1p1_nchw88_kern.cpp +++ b/dnn/src/arm_common/conv_bias/f16/channel_wise_3x3_s1p1_nchw88_kern.cpp @@ -19,6 +19,8 @@ #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC +#pragma GCC diagnostic ignored "-Wunused-parameter" + using namespace megdnn; using namespace arm_common; using namespace fp16; diff --git a/dnn/src/arm_common/conv_bias/f16/channel_wise_nchw88_kern.cpp b/dnn/src/arm_common/conv_bias/f16/channel_wise_nchw88_kern.cpp index af4e4285..9457c2fc 100644 --- a/dnn/src/arm_common/conv_bias/f16/channel_wise_nchw88_kern.cpp +++ b/dnn/src/arm_common/conv_bias/f16/channel_wise_nchw88_kern.cpp @@ -284,7 +284,7 @@ void channel_wise_nchw88::do_conv_kern_stride1_3x3( const __fp16* src, const __fp16* filter, const __fp16* bias, __fp16* dst, const size_t IH, const size_t IW, const size_t OH, const size_t OW, const size_t PH, const size_t PW) { - if (IH == OH && IW == OW && PH == 1 && PW == 1) { + if (IH == OH && IW == OW && IH >= 3 && IW >= 3 && PH == 1 && PW == 1) { do_conv_kern_3x3_stride1_padding1(src, dst, filter, bias, OH, OW); return; diff --git a/dnn/src/arm_common/conv_bias/fp32/channel_wise_3x3_s1p1_nchw44_kern.cpp b/dnn/src/arm_common/conv_bias/fp32/channel_wise_3x3_s1p1_nchw44_kern.cpp new file mode 100644 index 00000000..dc5a8798 --- /dev/null +++ b/dnn/src/arm_common/conv_bias/fp32/channel_wise_3x3_s1p1_nchw44_kern.cpp @@ -0,0 +1,316 @@ +/** + * \file dnn/src/arm_common/conv_bias/fp32/channel_wise_3x3_s1p1_nchw44_kern.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ + +#include "src/arm_common/conv_bias/fp32/channel_wise_3x3_s1p1_nchw44_kern.h" +#include "src/arm_common/elemwise_op.h" +#include "src/arm_common/simd_macro/marm_neon.h" +#include "src/arm_common/utils.h" +#include "src/common/utils.h" +#include "src/fallback/conv_bias/common.h" + +#pragma GCC diagnostic ignored "-Wunused-parameter" + +using namespace megdnn; +using namespace arm_common; + +namespace { + +#if defined(__ARM_FEATURE_FMA) +#define Vfmaq_f32(d, n, m) vfmaq_f32(d, n, m) +#else +#define Vfmaq_f32(d, n, m) vmlaq_f32(d, n, m) +#endif + +template +static inline void shift_src(float32x4_t rsrc[3][4]) { + float32x4_t t[4]; + + t[0] = rsrc[0][(shift + 0) % 4]; + t[1] = rsrc[0][(shift + 1) % 4]; + t[2] = rsrc[0][(shift + 2) % 4]; + t[3] = rsrc[0][(shift + 3) % 4]; + rsrc[0][0] = t[0]; + rsrc[0][1] = t[1]; + rsrc[0][2] = t[2]; + rsrc[0][3] = t[3]; + + t[0] = rsrc[1][(shift + 0) % 4]; + t[1] = rsrc[1][(shift + 1) % 4]; + t[2] = rsrc[1][(shift + 2) % 4]; + t[3] = rsrc[1][(shift + 3) % 4]; + rsrc[1][0] = t[0]; + rsrc[1][1] = t[1]; + rsrc[1][2] = t[2]; + rsrc[1][3] = t[3]; + + t[0] = rsrc[2][(shift + 0) % 4]; + t[1] = rsrc[2][(shift + 1) % 4]; + t[2] = rsrc[2][(shift + 2) % 4]; + t[3] = rsrc[2][(shift + 3) % 4]; + rsrc[2][0] = t[0]; + rsrc[2][1] = t[1]; + rsrc[2][2] = t[2]; + rsrc[2][3] = t[3]; +} + +template +static inline float32x4_t load_bias(const float* bias, + const float32x4_t& init) { + if (bias_mode == BiasMode::BIAS) { + return vld1q_f32(bias); + } else { + return init; + } +} + +template +struct compute_element { + template + static inline void call(const float*& src0, const float*& src1, + const float*& src2, float*& dst, const float*& bias, + const float32x4_t& init, float32x4_t rsrc[3][4], + float32x4_t rfilter[3][3], const Op& op) { +#define RSRC(i, j) rsrc[i][((j) + bw) % 4] + float32x4_t rdst = load_bias(bias, init); + if (has_top) { + RSRC(0, 3) = vld1q_f32(src0 + 8); + } + { RSRC(1, 3) = vld1q_f32(src1 + 8); } + if (has_bottom) { + RSRC(2, 3) = vld1q_f32(src2 + 8); + } + + if (has_top) { + rdst = Vfmaq_f32(rdst, RSRC(0, 0), rfilter[0][0]); + rdst = Vfmaq_f32(rdst, RSRC(0, 1), rfilter[0][1]); + rdst = Vfmaq_f32(rdst, RSRC(0, 2), rfilter[0][2]); + } + { + rdst = Vfmaq_f32(rdst, RSRC(1, 0), rfilter[1][0]); + rdst = Vfmaq_f32(rdst, RSRC(1, 1), rfilter[1][1]); + rdst = Vfmaq_f32(rdst, RSRC(1, 2), rfilter[1][2]); + } + if (has_bottom) { + rdst = Vfmaq_f32(rdst, RSRC(2, 0), rfilter[2][0]); + rdst = Vfmaq_f32(rdst, RSRC(2, 1), rfilter[2][1]); + rdst = Vfmaq_f32(rdst, RSRC(2, 2), rfilter[2][2]); + } + + vst1q_f32(dst, op(rdst)); + + if (has_top) { + src0 += 4; + } + { src1 += 4; } + if (has_bottom) { + src2 += 4; + } + dst += 4; + bias += 4; + compute_element::call( + src0, src1, src2, dst, bias, init, rsrc, rfilter, op); +#undef RSRC + } +}; + +template +struct compute_element { + template + static inline void call(Types... args) {} +}; + +template +struct compute_element_right { + template + static inline void call(float*& dst, const float*& bias, + const float32x4_t& init, float32x4_t rsrc[3][4], + float32x4_t rfilter[3][3], const Op& op) { + float32x4_t rdst = load_bias(bias, init); + + if (has_top) { + rdst = Vfmaq_f32(rdst, rsrc[0][0], rfilter[0][0]); + rdst = Vfmaq_f32(rdst, rsrc[0][1], rfilter[0][1]); + rdst = Vfmaq_f32(rdst, rsrc[0][2], rfilter[0][2]); + } + { + rdst = Vfmaq_f32(rdst, rsrc[1][0], rfilter[1][0]); + rdst = Vfmaq_f32(rdst, rsrc[1][1], rfilter[1][1]); + rdst = Vfmaq_f32(rdst, rsrc[1][2], rfilter[1][2]); + } + if (has_bottom) { + rdst = Vfmaq_f32(rdst, rsrc[2][0], rfilter[2][0]); + rdst = Vfmaq_f32(rdst, rsrc[2][1], rfilter[2][1]); + rdst = Vfmaq_f32(rdst, rsrc[2][2], rfilter[2][2]); + } + + vst1q_f32(dst, op(rdst)); + + dst += 4; + bias += 4; + } +}; + +template +struct compute_element_right_pad { + template + static inline void call(float*& dst, const float*& bias, + const float32x4_t& init, float32x4_t rsrc[3][4], + float32x4_t rfilter[3][3], const Op& op) { + float32x4_t rdst = load_bias(bias, init); + + if (has_top) { + rdst = Vfmaq_f32(rdst, rsrc[0][1], rfilter[0][0]); + rdst = Vfmaq_f32(rdst, rsrc[0][2], rfilter[0][1]); + } + { + rdst = Vfmaq_f32(rdst, rsrc[1][1], rfilter[1][0]); + rdst = Vfmaq_f32(rdst, rsrc[1][2], rfilter[1][1]); + } + if (has_bottom) { + rdst = Vfmaq_f32(rdst, rsrc[2][1], rfilter[2][0]); + rdst = Vfmaq_f32(rdst, rsrc[2][2], rfilter[2][1]); + } + + vst1q_f32(dst, op(rdst)); + dst += 4; + bias += 4; + } +}; + +template +struct compute_row { + template + static inline void call(const float*& src0, const float*& src1, + const float*& src2, float*& dst, const float*& bias, + const float32x4_t& init, float32x4_t rsrc[3][4], + float32x4_t rfilter[3][3], int W, const Op& op) { + if (has_top) { + rsrc[0][0] = vdupq_n_f32(0); + rsrc[0][1] = vld1q_f32(src0 + 0); + rsrc[0][2] = vld1q_f32(src0 + 4); + } + { + rsrc[1][0] = vdupq_n_f32(0); + rsrc[1][1] = vld1q_f32(src1 + 0); + rsrc[1][2] = vld1q_f32(src1 + 4); + } + if (has_bottom) { + rsrc[2][0] = vdupq_n_f32(0); + rsrc[2][1] = vld1q_f32(src2 + 0); + rsrc[2][2] = vld1q_f32(src2 + 4); + } + + int w = 0; + const float* src0_ptr = src0; + const float* src1_ptr = src1; + const float* src2_ptr = src2; + float* dst_ptr = dst; + const float* bias_ptr = bias; + + for (; w + 3 < W - 2; w += 4) { + compute_element<4, 0, has_top, has_bottom, bias_mode>::call( + src0_ptr, src1_ptr, src2_ptr, dst_ptr, bias_ptr, init, rsrc, + rfilter, op); + } + if (w + 1 < W - 2) { + compute_element<2, 0, has_top, has_bottom, bias_mode>::call( + src0_ptr, src1_ptr, src2_ptr, dst_ptr, bias_ptr, init, rsrc, + rfilter, op); + shift_src<2>(rsrc); + w += 2; + } + if (w < W - 2) { + compute_element<1, 0, has_top, has_bottom, bias_mode>::call( + src0_ptr, src1_ptr, src2_ptr, dst_ptr, bias_ptr, init, rsrc, + rfilter, op); + shift_src<1>(rsrc); + w += 1; + } + // compute rightmost 2 elements seperately + compute_element_right::call( + dst_ptr, bias_ptr, init, rsrc, rfilter, op); + compute_element_right_pad::call( + dst_ptr, bias_ptr, init, rsrc, rfilter, op); + + src0 += W * 4; + src1 += W * 4; + src2 += W * 4; + dst += W * 4; + bias += W * 4; + } +}; + +} // namespace + +template +void channel_wise_nchw44_float::do_conv_kern_3x3_stride1_padding1( + const float* src, float* dst, const float* filter, const float* bias, + int H, int W) { + Op op; + + float32x4_t init = vdupq_n_f32(0); + if (bias_mode == BiasMode::BROADCAST_CHANNEL_BIAS) { + init = vld1q_f32(bias); + } + + const float* src0 = src - W * 4; + const float* src1 = src; + const float* src2 = src + W * 4; + + float32x4_t rfilter[3][3]; + rfilter[0][0] = vld1q_f32(filter + 0); + rfilter[0][1] = vld1q_f32(filter + 4); + rfilter[0][2] = vld1q_f32(filter + 8); + rfilter[1][0] = vld1q_f32(filter + 12); + rfilter[1][1] = vld1q_f32(filter + 16); + rfilter[1][2] = vld1q_f32(filter + 20); + rfilter[2][0] = vld1q_f32(filter + 24); + rfilter[2][1] = vld1q_f32(filter + 28); + rfilter[2][2] = vld1q_f32(filter + 32); + + float32x4_t rsrc[3][4]; + + compute_row::call(src0, src1, src2, dst, bias, init, + rsrc, rfilter, W, op); + + for (int h = 1; h < H - 1; h += 1) { + compute_row::call(src0, src1, src2, dst, bias, + init, rsrc, rfilter, W, op); + } + + compute_row::call(src0, src1, src2, dst, bias, init, + rsrc, rfilter, W, op); +} + +#define INSTANTIATION(bias, Op) \ + template void \ + channel_wise_nchw44_float::do_conv_kern_3x3_stride1_padding1( \ + const float*, float*, const float*, const float*, int, int); + +#define FOR_OP(bias) \ + INSTANTIATION(bias, SigmoidOp) \ + INSTANTIATION(bias, ReluOp) \ + INSTANTIATION(bias, HSwishOp) \ + INSTANTIATION(bias, NoneOp) + +#define FOR_BIAS \ + FOR_OP(BiasMode::NO_BIAS) \ + FOR_OP(BiasMode::BROADCAST_CHANNEL_BIAS) \ + FOR_OP(BiasMode::BIAS) + +FOR_BIAS + +#undef FOR_BIAS +#undef FOR_OP +#undef INSTANTIATION + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/conv_bias/fp32/channel_wise_3x3_s1p1_nchw44_kern.h b/dnn/src/arm_common/conv_bias/fp32/channel_wise_3x3_s1p1_nchw44_kern.h new file mode 100644 index 00000000..51669ec2 --- /dev/null +++ b/dnn/src/arm_common/conv_bias/fp32/channel_wise_3x3_s1p1_nchw44_kern.h @@ -0,0 +1,31 @@ +/** + * \file dnn/src/arm_common/conv_bias/fp32/channel_wise_3x3_s1p1_nchw44_kern.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ + +#pragma once + +#include "src/arm_common/conv_bias/opr_impl.h" +#include "src/fallback/conv_bias/common.h" + +namespace megdnn { +namespace arm_common { +namespace channel_wise_nchw44_float { + +template +void do_conv_kern_3x3_stride1_padding1(const float* src, float* dst, + const float* filter, const float* bias, + int H, int W); + +} // namespace channel_wise_nchw44_float +} // namespace arm_common +} // namespace megdnn + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/conv_bias/fp32/channel_wise_5x5_s1p2_nchw44_kern.cpp b/dnn/src/arm_common/conv_bias/fp32/channel_wise_5x5_s1p2_nchw44_kern.cpp new file mode 100644 index 00000000..cb886600 --- /dev/null +++ b/dnn/src/arm_common/conv_bias/fp32/channel_wise_5x5_s1p2_nchw44_kern.cpp @@ -0,0 +1,288 @@ +/** + * \file dnn/src/arm_common/conv_bias/fp32/channel_wise_5x5_s1p2_nchw44_kern.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ + +#include "src/arm_common/conv_bias/fp32/channel_wise_5x5_s1p2_nchw44_kern.h" +#include "src/arm_common/elemwise_op.h" +#include "src/arm_common/simd_macro/marm_neon.h" +#include "src/arm_common/utils.h" +#include "src/common/utils.h" +#include "src/fallback/conv_bias/common.h" + +#pragma GCC diagnostic ignored "-Wunused-parameter" + +using namespace megdnn; +using namespace arm_common; + +namespace { + +#if defined(__ARM_FEATURE_FMA) +#define Vfmaq_f32(d, n, m) vfmaq_f32(d, n, m) +#else +#define Vfmaq_f32(d, n, m) vmlaq_f32(d, n, m) +#endif + +template +static inline void shift_src(float32x4_t rsrc[6]) { + float32x4_t t[6]; + + t[0] = rsrc[(shift + 0) % 6]; + t[1] = rsrc[(shift + 1) % 6]; + t[2] = rsrc[(shift + 2) % 6]; + t[3] = rsrc[(shift + 3) % 6]; + t[4] = rsrc[(shift + 4) % 6]; + t[5] = rsrc[(shift + 5) % 6]; + rsrc[0] = t[0]; + rsrc[1] = t[1]; + rsrc[2] = t[2]; + rsrc[3] = t[3]; + rsrc[4] = t[4]; + rsrc[5] = t[5]; +} + +static inline void load_filter(const float* filter, float32x4_t rfilter[5]) { + rfilter[0] = vld1q_f32(filter + 0); + rfilter[1] = vld1q_f32(filter + 4); + rfilter[2] = vld1q_f32(filter + 8); + rfilter[3] = vld1q_f32(filter + 12); + rfilter[4] = vld1q_f32(filter + 16); +} + +template +static inline float32x4_t load_bias(const float* bias, + const float32x4_t& init) { + if (bias_mode == BiasMode::BIAS) { + return vld1q_f32(bias); + } else { + return init; + } +} + +template +struct compute_element { + template + static inline void call(const float*& src, float*& dst, const float*& bias, + const float32x4_t& init, float32x4_t rsrc[6], + float32x4_t rfilter[5], const Op& op) { +#define RSRC(i) rsrc[((i) + bw) % 6] + float32x4_t rdst; + if (need_load_bias) { + rdst = load_bias(bias, init); + } else { + rdst = vld1q_f32(dst); + } + RSRC(5) = vld1q_f32(src + 12); + + rdst = Vfmaq_f32(rdst, RSRC(0), rfilter[0]); + rdst = Vfmaq_f32(rdst, RSRC(1), rfilter[1]); + rdst = Vfmaq_f32(rdst, RSRC(2), rfilter[2]); + rdst = Vfmaq_f32(rdst, RSRC(3), rfilter[3]); + rdst = Vfmaq_f32(rdst, RSRC(4), rfilter[4]); + + if (need_do_op) { + rdst = op(rdst); + } + vst1q_f32(dst, rdst); + + src += 4; + dst += 4; + bias += 4; + compute_element::call(src, dst, bias, init, rsrc, rfilter, + op); +#undef RSRC + } +}; + +template +struct compute_element { + template + static inline void call(Types... args) {} +}; + +template +struct compute_element_right { + template + static inline void call(float*& dst, const float*& bias, + const float32x4_t& init, float32x4_t rsrc[6], + float32x4_t rfilter[5], const Op& op) { + float32x4_t rdst; + if (need_load_bias) { + rdst = load_bias(bias, init); + } else { + rdst = vld1q_f32(dst); + } + + rdst = Vfmaq_f32(rdst, rsrc[0 + padding], rfilter[0]); + rdst = Vfmaq_f32(rdst, rsrc[1 + padding], rfilter[1]); + rdst = Vfmaq_f32(rdst, rsrc[2 + padding], rfilter[2]); + if (padding < 2) { + rdst = Vfmaq_f32(rdst, rsrc[3 + padding], rfilter[3]); + } + if (padding < 1) { + rdst = Vfmaq_f32(rdst, rsrc[4 + padding], rfilter[4]); + } + + if (need_do_op) { + rdst = op(rdst); + } + vst1q_f32(dst, rdst); + + dst += 4; + bias += 4; + } +}; + +template +struct compute_row_src_1x5 { + template + static inline void call(const float* src, float* dst, const float* bias, + const float32x4_t& init, float32x4_t rsrc[6], + float32x4_t rfilter[5], int W, const Op& op) { + rsrc[0] = vdupq_n_f32(0); + rsrc[1] = vdupq_n_f32(0); + rsrc[2] = vld1q_f32(src + 0); + rsrc[3] = vld1q_f32(src + 4); + rsrc[4] = vld1q_f32(src + 8); + + int w = 0; + + for (; w + 5 < W - 3; w += 6) { + compute_element<6, 0, bias_mode, need_load_bias, need_do_op>::call( + src, dst, bias, init, rsrc, rfilter, op); + } + if (w + 3 < W - 3) { + compute_element<4, 0, bias_mode, need_load_bias, need_do_op>::call( + src, dst, bias, init, rsrc, rfilter, op); + shift_src<4>(rsrc); + w += 4; + } + if (w + 1 < W - 3) { + compute_element<2, 0, bias_mode, need_load_bias, need_do_op>::call( + src, dst, bias, init, rsrc, rfilter, op); + shift_src<2>(rsrc); + w += 2; + } + if (w < W - 3) { + compute_element<1, 0, bias_mode, need_load_bias, need_do_op>::call( + src, dst, bias, init, rsrc, rfilter, op); + shift_src<1>(rsrc); + w += 1; + } + // compute rightmost 3 elements seperately + compute_element_right<0, bias_mode, need_load_bias, need_do_op>::call( + dst, bias, init, rsrc, rfilter, op); + compute_element_right<1, bias_mode, need_load_bias, need_do_op>::call( + dst, bias, init, rsrc, rfilter, op); + compute_element_right<2, bias_mode, need_load_bias, need_do_op>::call( + dst, bias, init, rsrc, rfilter, op); + } +}; + +template +struct compute_row { + template + static inline void call(const float*& src, float*& dst, const float* filter, + const float*& bias, const float32x4_t& init, + float32x4_t rsrc[6], float32x4_t rfilter[5], int W, + const Op& op) { + if (top_padding < 1) { + load_filter(filter + 0, rfilter); + compute_row_src_1x5::call( + src - W * 8, dst, bias, init, rsrc, rfilter, W, op); + } + + if (top_padding < 2) { + load_filter(filter + 20, rfilter); + compute_row_src_1x5::call( + src - W * 4, dst, bias, init, rsrc, rfilter, W, op); + } + + { + load_filter(filter + 40, rfilter); + compute_row_src_1x5::call(src, dst, bias, init, + rsrc, rfilter, W, + op); + } + + if (bottom_padding < 2) { + load_filter(filter + 60, rfilter); + compute_row_src_1x5::call( + src + W * 4, dst, bias, init, rsrc, rfilter, W, op); + } + + if (bottom_padding < 1) { + load_filter(filter + 80, rfilter); + compute_row_src_1x5::call( + src + W * 8, dst, bias, init, rsrc, rfilter, W, op); + } + src += W * 4; + dst += W * 4; + bias += W * 4; + } +}; + +} // namespace + +template +void channel_wise_nchw44_float::do_conv_kern_5x5_stride1_padding2( + const float* src, float* dst, const float* filter, const float* bias, + int H, int W) { + Op op; + + float32x4_t init = vdupq_n_f32(0); + if (bias_mode == BiasMode::BROADCAST_CHANNEL_BIAS) { + init = vld1q_f32(bias); + } + + float32x4_t rsrc[6]; + float32x4_t rfilter[5]; + + compute_row<2, 0, bias_mode>::call(src, dst, filter, bias, init, rsrc, + rfilter, W, op); + compute_row<1, 0, bias_mode>::call(src, dst, filter, bias, init, rsrc, + rfilter, W, op); + for (int h = 2; h < H - 2; h += 1) { + compute_row<0, 0, bias_mode>::call(src, dst, filter, bias, init, rsrc, + rfilter, W, op); + } + compute_row<0, 1, bias_mode>::call(src, dst, filter, bias, init, rsrc, + rfilter, W, op); + compute_row<0, 2, bias_mode>::call(src, dst, filter, bias, init, rsrc, + rfilter, W, op); +} + +#define INSTANTIATION(bias, Op) \ + template void \ + channel_wise_nchw44_float::do_conv_kern_5x5_stride1_padding2( \ + const float*, float*, const float*, const float*, int, int); + +#define FOR_OP(bias) \ + INSTANTIATION(bias, SigmoidOp) \ + INSTANTIATION(bias, ReluOp) \ + INSTANTIATION(bias, HSwishOp) \ + INSTANTIATION(bias, NoneOp) + +#define FOR_BIAS \ + FOR_OP(BiasMode::NO_BIAS) \ + FOR_OP(BiasMode::BROADCAST_CHANNEL_BIAS) \ + FOR_OP(BiasMode::BIAS) + +FOR_BIAS + +#undef FOR_BIAS +#undef FOR_OP +#undef INSTANTIATION + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/conv_bias/fp32/channel_wise_5x5_s1p2_nchw44_kern.h b/dnn/src/arm_common/conv_bias/fp32/channel_wise_5x5_s1p2_nchw44_kern.h new file mode 100644 index 00000000..28b04380 --- /dev/null +++ b/dnn/src/arm_common/conv_bias/fp32/channel_wise_5x5_s1p2_nchw44_kern.h @@ -0,0 +1,31 @@ +/** + * \file dnn/src/arm_common/conv_bias/fp32/channel_wise_5x5_s1p2_nchw44_kern.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ + +#pragma once + +#include "src/arm_common/conv_bias/opr_impl.h" +#include "src/fallback/conv_bias/common.h" + +namespace megdnn { +namespace arm_common { +namespace channel_wise_nchw44_float { + +template +void do_conv_kern_5x5_stride1_padding2(const float* src, float* dst, + const float* filter, const float* bias, + int H, int W); + +} // namespace channel_wise_nchw44_float +} // namespace arm_common +} // namespace megdnn + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/conv_bias/fp32/channel_wise_nchw44_kern.cpp b/dnn/src/arm_common/conv_bias/fp32/channel_wise_nchw44_kern.cpp index 3a8a30e5..bb849a40 100644 --- a/dnn/src/arm_common/conv_bias/fp32/channel_wise_nchw44_kern.cpp +++ b/dnn/src/arm_common/conv_bias/fp32/channel_wise_nchw44_kern.cpp @@ -11,6 +11,8 @@ */ #include "src/arm_common/conv_bias/fp32/channel_wise_nchw44_kern.h" +#include "src/arm_common/conv_bias/fp32/channel_wise_3x3_s1p1_nchw44_kern.h" +#include "src/arm_common/conv_bias/fp32/channel_wise_5x5_s1p2_nchw44_kern.h" #include "src/arm_common/elemwise_op.h" #include "src/arm_common/simd_macro/marm_neon.h" #include "src/arm_common/utils.h" @@ -413,6 +415,13 @@ void channel_wise_nchw44_float::do_conv_kern_stride1_3x3( const float* src, const float* filter, const float* bias, float* dst, const size_t IH, const size_t IW, const size_t OH, const size_t OW, const size_t PH, const size_t PW) { + if (IH == OH && IW == OW && IH >= 3 && IW >= 3 && PH == 1 && PW == 1) { + channel_wise_nchw44_float::do_conv_kern_3x3_stride1_padding1( + src, dst, filter, bias, OH, OW); + return; + } + float32x4_t kernel[9]; load_vec<9>(kernel, filter); Op op; @@ -424,10 +433,7 @@ void channel_wise_nchw44_float::do_conv_kern_stride1_3x3( size_t ow_start = PW; size_t oh_end = IH + PH - 2; size_t ow_end = IW + PW - 2; - if (PH == 1 && PW == 1) { - PaddingComputeK3P1::compute(src, bias, dst, 1, IH, IW, - OH, OW, kernel, init); - } else if (PH || PW) { + if (PH || PW) { PaddingCompute::compute(src, bias, dst, 3, 1, IH, IW, OH, OW, PH, PW, kernel, init); } @@ -557,6 +563,13 @@ void channel_wise_nchw44_float::do_conv_kern_stride1_5x5( const float* src, const float* filter, const float* bias, float* dst, const size_t IH, const size_t IW, const size_t OH, const size_t OW, const size_t PH, const size_t PW) { + if (IH == OH && IW == OW && IH >= 5 && IW >= 5 && PH == 2 && PW == 2) { + channel_wise_nchw44_float::do_conv_kern_5x5_stride1_padding2( + src, dst, filter, bias, OH, OW); + return; + } + Op op; float32x4_t init = vdupq_n_f32(0.f); if (bias_mode == BiasMode::BROADCAST_CHANNEL_BIAS) {