GitOrigin-RevId: 9ea411d0e1
release-1.7
@@ -19,6 +19,8 @@ | |||||
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC | #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC | ||||
#pragma GCC diagnostic ignored "-Wunused-parameter" | |||||
using namespace megdnn; | using namespace megdnn; | ||||
using namespace arm_common; | using namespace arm_common; | ||||
using namespace fp16; | using namespace fp16; | ||||
@@ -284,7 +284,7 @@ void channel_wise_nchw88::do_conv_kern_stride1_3x3( | |||||
const __fp16* src, const __fp16* filter, const __fp16* bias, | const __fp16* src, const __fp16* filter, const __fp16* bias, | ||||
__fp16* dst, const size_t IH, const size_t IW, const size_t OH, | __fp16* dst, const size_t IH, const size_t IW, const size_t OH, | ||||
const size_t OW, const size_t PH, const size_t PW) { | const size_t OW, const size_t PH, const size_t PW) { | ||||
if (IH == OH && IW == OW && PH == 1 && PW == 1) { | |||||
if (IH == OH && IW == OW && IH >= 3 && IW >= 3 && PH == 1 && PW == 1) { | |||||
do_conv_kern_3x3_stride1_padding1<bias_mode, Op>(src, dst, filter, bias, | do_conv_kern_3x3_stride1_padding1<bias_mode, Op>(src, dst, filter, bias, | ||||
OH, OW); | OH, OW); | ||||
return; | return; | ||||
@@ -0,0 +1,316 @@ | |||||
/** | |||||
* \file dnn/src/arm_common/conv_bias/fp32/channel_wise_3x3_s1p1_nchw44_kern.cpp | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
* implied. | |||||
*/ | |||||
#include "src/arm_common/conv_bias/fp32/channel_wise_3x3_s1p1_nchw44_kern.h" | |||||
#include "src/arm_common/elemwise_op.h" | |||||
#include "src/arm_common/simd_macro/marm_neon.h" | |||||
#include "src/arm_common/utils.h" | |||||
#include "src/common/utils.h" | |||||
#include "src/fallback/conv_bias/common.h" | |||||
#pragma GCC diagnostic ignored "-Wunused-parameter" | |||||
using namespace megdnn; | |||||
using namespace arm_common; | |||||
namespace { | |||||
#if defined(__ARM_FEATURE_FMA) | |||||
#define Vfmaq_f32(d, n, m) vfmaq_f32(d, n, m) | |||||
#else | |||||
#define Vfmaq_f32(d, n, m) vmlaq_f32(d, n, m) | |||||
#endif | |||||
template <int shift> | |||||
static inline void shift_src(float32x4_t rsrc[3][4]) { | |||||
float32x4_t t[4]; | |||||
t[0] = rsrc[0][(shift + 0) % 4]; | |||||
t[1] = rsrc[0][(shift + 1) % 4]; | |||||
t[2] = rsrc[0][(shift + 2) % 4]; | |||||
t[3] = rsrc[0][(shift + 3) % 4]; | |||||
rsrc[0][0] = t[0]; | |||||
rsrc[0][1] = t[1]; | |||||
rsrc[0][2] = t[2]; | |||||
rsrc[0][3] = t[3]; | |||||
t[0] = rsrc[1][(shift + 0) % 4]; | |||||
t[1] = rsrc[1][(shift + 1) % 4]; | |||||
t[2] = rsrc[1][(shift + 2) % 4]; | |||||
t[3] = rsrc[1][(shift + 3) % 4]; | |||||
rsrc[1][0] = t[0]; | |||||
rsrc[1][1] = t[1]; | |||||
rsrc[1][2] = t[2]; | |||||
rsrc[1][3] = t[3]; | |||||
t[0] = rsrc[2][(shift + 0) % 4]; | |||||
t[1] = rsrc[2][(shift + 1) % 4]; | |||||
t[2] = rsrc[2][(shift + 2) % 4]; | |||||
t[3] = rsrc[2][(shift + 3) % 4]; | |||||
rsrc[2][0] = t[0]; | |||||
rsrc[2][1] = t[1]; | |||||
rsrc[2][2] = t[2]; | |||||
rsrc[2][3] = t[3]; | |||||
} | |||||
template <BiasMode bias_mode> | |||||
static inline float32x4_t load_bias(const float* bias, | |||||
const float32x4_t& init) { | |||||
if (bias_mode == BiasMode::BIAS) { | |||||
return vld1q_f32(bias); | |||||
} else { | |||||
return init; | |||||
} | |||||
} | |||||
template <int BW, int bw, bool has_top, bool has_bottom, BiasMode bias_mode> | |||||
struct compute_element { | |||||
template <typename Op> | |||||
static inline void call(const float*& src0, const float*& src1, | |||||
const float*& src2, float*& dst, const float*& bias, | |||||
const float32x4_t& init, float32x4_t rsrc[3][4], | |||||
float32x4_t rfilter[3][3], const Op& op) { | |||||
#define RSRC(i, j) rsrc[i][((j) + bw) % 4] | |||||
float32x4_t rdst = load_bias<bias_mode>(bias, init); | |||||
if (has_top) { | |||||
RSRC(0, 3) = vld1q_f32(src0 + 8); | |||||
} | |||||
{ RSRC(1, 3) = vld1q_f32(src1 + 8); } | |||||
if (has_bottom) { | |||||
RSRC(2, 3) = vld1q_f32(src2 + 8); | |||||
} | |||||
if (has_top) { | |||||
rdst = Vfmaq_f32(rdst, RSRC(0, 0), rfilter[0][0]); | |||||
rdst = Vfmaq_f32(rdst, RSRC(0, 1), rfilter[0][1]); | |||||
rdst = Vfmaq_f32(rdst, RSRC(0, 2), rfilter[0][2]); | |||||
} | |||||
{ | |||||
rdst = Vfmaq_f32(rdst, RSRC(1, 0), rfilter[1][0]); | |||||
rdst = Vfmaq_f32(rdst, RSRC(1, 1), rfilter[1][1]); | |||||
rdst = Vfmaq_f32(rdst, RSRC(1, 2), rfilter[1][2]); | |||||
} | |||||
if (has_bottom) { | |||||
rdst = Vfmaq_f32(rdst, RSRC(2, 0), rfilter[2][0]); | |||||
rdst = Vfmaq_f32(rdst, RSRC(2, 1), rfilter[2][1]); | |||||
rdst = Vfmaq_f32(rdst, RSRC(2, 2), rfilter[2][2]); | |||||
} | |||||
vst1q_f32(dst, op(rdst)); | |||||
if (has_top) { | |||||
src0 += 4; | |||||
} | |||||
{ src1 += 4; } | |||||
if (has_bottom) { | |||||
src2 += 4; | |||||
} | |||||
dst += 4; | |||||
bias += 4; | |||||
compute_element<BW, bw + 1, has_top, has_bottom, bias_mode>::call( | |||||
src0, src1, src2, dst, bias, init, rsrc, rfilter, op); | |||||
#undef RSRC | |||||
} | |||||
}; | |||||
template <int BW, bool has_top, bool has_bottom, BiasMode bias_mode> | |||||
struct compute_element<BW, BW, has_top, has_bottom, bias_mode> { | |||||
template <typename... Types> | |||||
static inline void call(Types... args) {} | |||||
}; | |||||
template <bool has_top, bool has_bottom, BiasMode bias_mode> | |||||
struct compute_element_right { | |||||
template <typename Op> | |||||
static inline void call(float*& dst, const float*& bias, | |||||
const float32x4_t& init, float32x4_t rsrc[3][4], | |||||
float32x4_t rfilter[3][3], const Op& op) { | |||||
float32x4_t rdst = load_bias<bias_mode>(bias, init); | |||||
if (has_top) { | |||||
rdst = Vfmaq_f32(rdst, rsrc[0][0], rfilter[0][0]); | |||||
rdst = Vfmaq_f32(rdst, rsrc[0][1], rfilter[0][1]); | |||||
rdst = Vfmaq_f32(rdst, rsrc[0][2], rfilter[0][2]); | |||||
} | |||||
{ | |||||
rdst = Vfmaq_f32(rdst, rsrc[1][0], rfilter[1][0]); | |||||
rdst = Vfmaq_f32(rdst, rsrc[1][1], rfilter[1][1]); | |||||
rdst = Vfmaq_f32(rdst, rsrc[1][2], rfilter[1][2]); | |||||
} | |||||
if (has_bottom) { | |||||
rdst = Vfmaq_f32(rdst, rsrc[2][0], rfilter[2][0]); | |||||
rdst = Vfmaq_f32(rdst, rsrc[2][1], rfilter[2][1]); | |||||
rdst = Vfmaq_f32(rdst, rsrc[2][2], rfilter[2][2]); | |||||
} | |||||
vst1q_f32(dst, op(rdst)); | |||||
dst += 4; | |||||
bias += 4; | |||||
} | |||||
}; | |||||
template <bool has_top, bool has_bottom, BiasMode bias_mode> | |||||
struct compute_element_right_pad { | |||||
template <typename Op> | |||||
static inline void call(float*& dst, const float*& bias, | |||||
const float32x4_t& init, float32x4_t rsrc[3][4], | |||||
float32x4_t rfilter[3][3], const Op& op) { | |||||
float32x4_t rdst = load_bias<bias_mode>(bias, init); | |||||
if (has_top) { | |||||
rdst = Vfmaq_f32(rdst, rsrc[0][1], rfilter[0][0]); | |||||
rdst = Vfmaq_f32(rdst, rsrc[0][2], rfilter[0][1]); | |||||
} | |||||
{ | |||||
rdst = Vfmaq_f32(rdst, rsrc[1][1], rfilter[1][0]); | |||||
rdst = Vfmaq_f32(rdst, rsrc[1][2], rfilter[1][1]); | |||||
} | |||||
if (has_bottom) { | |||||
rdst = Vfmaq_f32(rdst, rsrc[2][1], rfilter[2][0]); | |||||
rdst = Vfmaq_f32(rdst, rsrc[2][2], rfilter[2][1]); | |||||
} | |||||
vst1q_f32(dst, op(rdst)); | |||||
dst += 4; | |||||
bias += 4; | |||||
} | |||||
}; | |||||
template <bool has_top, bool has_bottom, BiasMode bias_mode> | |||||
struct compute_row { | |||||
template <typename Op> | |||||
static inline void call(const float*& src0, const float*& src1, | |||||
const float*& src2, float*& dst, const float*& bias, | |||||
const float32x4_t& init, float32x4_t rsrc[3][4], | |||||
float32x4_t rfilter[3][3], int W, const Op& op) { | |||||
if (has_top) { | |||||
rsrc[0][0] = vdupq_n_f32(0); | |||||
rsrc[0][1] = vld1q_f32(src0 + 0); | |||||
rsrc[0][2] = vld1q_f32(src0 + 4); | |||||
} | |||||
{ | |||||
rsrc[1][0] = vdupq_n_f32(0); | |||||
rsrc[1][1] = vld1q_f32(src1 + 0); | |||||
rsrc[1][2] = vld1q_f32(src1 + 4); | |||||
} | |||||
if (has_bottom) { | |||||
rsrc[2][0] = vdupq_n_f32(0); | |||||
rsrc[2][1] = vld1q_f32(src2 + 0); | |||||
rsrc[2][2] = vld1q_f32(src2 + 4); | |||||
} | |||||
int w = 0; | |||||
const float* src0_ptr = src0; | |||||
const float* src1_ptr = src1; | |||||
const float* src2_ptr = src2; | |||||
float* dst_ptr = dst; | |||||
const float* bias_ptr = bias; | |||||
for (; w + 3 < W - 2; w += 4) { | |||||
compute_element<4, 0, has_top, has_bottom, bias_mode>::call( | |||||
src0_ptr, src1_ptr, src2_ptr, dst_ptr, bias_ptr, init, rsrc, | |||||
rfilter, op); | |||||
} | |||||
if (w + 1 < W - 2) { | |||||
compute_element<2, 0, has_top, has_bottom, bias_mode>::call( | |||||
src0_ptr, src1_ptr, src2_ptr, dst_ptr, bias_ptr, init, rsrc, | |||||
rfilter, op); | |||||
shift_src<2>(rsrc); | |||||
w += 2; | |||||
} | |||||
if (w < W - 2) { | |||||
compute_element<1, 0, has_top, has_bottom, bias_mode>::call( | |||||
src0_ptr, src1_ptr, src2_ptr, dst_ptr, bias_ptr, init, rsrc, | |||||
rfilter, op); | |||||
shift_src<1>(rsrc); | |||||
w += 1; | |||||
} | |||||
// compute rightmost 2 elements seperately | |||||
compute_element_right<has_top, has_bottom, bias_mode>::call( | |||||
dst_ptr, bias_ptr, init, rsrc, rfilter, op); | |||||
compute_element_right_pad<has_top, has_bottom, bias_mode>::call( | |||||
dst_ptr, bias_ptr, init, rsrc, rfilter, op); | |||||
src0 += W * 4; | |||||
src1 += W * 4; | |||||
src2 += W * 4; | |||||
dst += W * 4; | |||||
bias += W * 4; | |||||
} | |||||
}; | |||||
} // namespace | |||||
template <BiasMode bias_mode, typename Op> | |||||
void channel_wise_nchw44_float::do_conv_kern_3x3_stride1_padding1( | |||||
const float* src, float* dst, const float* filter, const float* bias, | |||||
int H, int W) { | |||||
Op op; | |||||
float32x4_t init = vdupq_n_f32(0); | |||||
if (bias_mode == BiasMode::BROADCAST_CHANNEL_BIAS) { | |||||
init = vld1q_f32(bias); | |||||
} | |||||
const float* src0 = src - W * 4; | |||||
const float* src1 = src; | |||||
const float* src2 = src + W * 4; | |||||
float32x4_t rfilter[3][3]; | |||||
rfilter[0][0] = vld1q_f32(filter + 0); | |||||
rfilter[0][1] = vld1q_f32(filter + 4); | |||||
rfilter[0][2] = vld1q_f32(filter + 8); | |||||
rfilter[1][0] = vld1q_f32(filter + 12); | |||||
rfilter[1][1] = vld1q_f32(filter + 16); | |||||
rfilter[1][2] = vld1q_f32(filter + 20); | |||||
rfilter[2][0] = vld1q_f32(filter + 24); | |||||
rfilter[2][1] = vld1q_f32(filter + 28); | |||||
rfilter[2][2] = vld1q_f32(filter + 32); | |||||
float32x4_t rsrc[3][4]; | |||||
compute_row<false, true, bias_mode>::call(src0, src1, src2, dst, bias, init, | |||||
rsrc, rfilter, W, op); | |||||
for (int h = 1; h < H - 1; h += 1) { | |||||
compute_row<true, true, bias_mode>::call(src0, src1, src2, dst, bias, | |||||
init, rsrc, rfilter, W, op); | |||||
} | |||||
compute_row<true, false, bias_mode>::call(src0, src1, src2, dst, bias, init, | |||||
rsrc, rfilter, W, op); | |||||
} | |||||
#define INSTANTIATION(bias, Op) \ | |||||
template void \ | |||||
channel_wise_nchw44_float::do_conv_kern_3x3_stride1_padding1<bias, Op>( \ | |||||
const float*, float*, const float*, const float*, int, int); | |||||
#define FOR_OP(bias) \ | |||||
INSTANTIATION(bias, SigmoidOp<dt_float32>) \ | |||||
INSTANTIATION(bias, ReluOp<dt_float32>) \ | |||||
INSTANTIATION(bias, HSwishOp<dt_float32>) \ | |||||
INSTANTIATION(bias, NoneOp<dt_float32>) | |||||
#define FOR_BIAS \ | |||||
FOR_OP(BiasMode::NO_BIAS) \ | |||||
FOR_OP(BiasMode::BROADCAST_CHANNEL_BIAS) \ | |||||
FOR_OP(BiasMode::BIAS) | |||||
FOR_BIAS | |||||
#undef FOR_BIAS | |||||
#undef FOR_OP | |||||
#undef INSTANTIATION | |||||
// vim: syntax=cpp.doxygen |
@@ -0,0 +1,31 @@ | |||||
/** | |||||
* \file dnn/src/arm_common/conv_bias/fp32/channel_wise_3x3_s1p1_nchw44_kern.h | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
* implied. | |||||
*/ | |||||
#pragma once | |||||
#include "src/arm_common/conv_bias/opr_impl.h" | |||||
#include "src/fallback/conv_bias/common.h" | |||||
namespace megdnn { | |||||
namespace arm_common { | |||||
namespace channel_wise_nchw44_float { | |||||
template <BiasMode bias_mode, typename Op> | |||||
void do_conv_kern_3x3_stride1_padding1(const float* src, float* dst, | |||||
const float* filter, const float* bias, | |||||
int H, int W); | |||||
} // namespace channel_wise_nchw44_float | |||||
} // namespace arm_common | |||||
} // namespace megdnn | |||||
// vim: syntax=cpp.doxygen |
@@ -0,0 +1,288 @@ | |||||
/** | |||||
* \file dnn/src/arm_common/conv_bias/fp32/channel_wise_5x5_s1p2_nchw44_kern.cpp | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
* implied. | |||||
*/ | |||||
#include "src/arm_common/conv_bias/fp32/channel_wise_5x5_s1p2_nchw44_kern.h" | |||||
#include "src/arm_common/elemwise_op.h" | |||||
#include "src/arm_common/simd_macro/marm_neon.h" | |||||
#include "src/arm_common/utils.h" | |||||
#include "src/common/utils.h" | |||||
#include "src/fallback/conv_bias/common.h" | |||||
#pragma GCC diagnostic ignored "-Wunused-parameter" | |||||
using namespace megdnn; | |||||
using namespace arm_common; | |||||
namespace { | |||||
#if defined(__ARM_FEATURE_FMA) | |||||
#define Vfmaq_f32(d, n, m) vfmaq_f32(d, n, m) | |||||
#else | |||||
#define Vfmaq_f32(d, n, m) vmlaq_f32(d, n, m) | |||||
#endif | |||||
template <int shift> | |||||
static inline void shift_src(float32x4_t rsrc[6]) { | |||||
float32x4_t t[6]; | |||||
t[0] = rsrc[(shift + 0) % 6]; | |||||
t[1] = rsrc[(shift + 1) % 6]; | |||||
t[2] = rsrc[(shift + 2) % 6]; | |||||
t[3] = rsrc[(shift + 3) % 6]; | |||||
t[4] = rsrc[(shift + 4) % 6]; | |||||
t[5] = rsrc[(shift + 5) % 6]; | |||||
rsrc[0] = t[0]; | |||||
rsrc[1] = t[1]; | |||||
rsrc[2] = t[2]; | |||||
rsrc[3] = t[3]; | |||||
rsrc[4] = t[4]; | |||||
rsrc[5] = t[5]; | |||||
} | |||||
static inline void load_filter(const float* filter, float32x4_t rfilter[5]) { | |||||
rfilter[0] = vld1q_f32(filter + 0); | |||||
rfilter[1] = vld1q_f32(filter + 4); | |||||
rfilter[2] = vld1q_f32(filter + 8); | |||||
rfilter[3] = vld1q_f32(filter + 12); | |||||
rfilter[4] = vld1q_f32(filter + 16); | |||||
} | |||||
template <BiasMode bias_mode> | |||||
static inline float32x4_t load_bias(const float* bias, | |||||
const float32x4_t& init) { | |||||
if (bias_mode == BiasMode::BIAS) { | |||||
return vld1q_f32(bias); | |||||
} else { | |||||
return init; | |||||
} | |||||
} | |||||
template <int BW, int bw, BiasMode bias_mode, bool need_load_bias, | |||||
bool need_do_op> | |||||
struct compute_element { | |||||
template <typename Op> | |||||
static inline void call(const float*& src, float*& dst, const float*& bias, | |||||
const float32x4_t& init, float32x4_t rsrc[6], | |||||
float32x4_t rfilter[5], const Op& op) { | |||||
#define RSRC(i) rsrc[((i) + bw) % 6] | |||||
float32x4_t rdst; | |||||
if (need_load_bias) { | |||||
rdst = load_bias<bias_mode>(bias, init); | |||||
} else { | |||||
rdst = vld1q_f32(dst); | |||||
} | |||||
RSRC(5) = vld1q_f32(src + 12); | |||||
rdst = Vfmaq_f32(rdst, RSRC(0), rfilter[0]); | |||||
rdst = Vfmaq_f32(rdst, RSRC(1), rfilter[1]); | |||||
rdst = Vfmaq_f32(rdst, RSRC(2), rfilter[2]); | |||||
rdst = Vfmaq_f32(rdst, RSRC(3), rfilter[3]); | |||||
rdst = Vfmaq_f32(rdst, RSRC(4), rfilter[4]); | |||||
if (need_do_op) { | |||||
rdst = op(rdst); | |||||
} | |||||
vst1q_f32(dst, rdst); | |||||
src += 4; | |||||
dst += 4; | |||||
bias += 4; | |||||
compute_element<BW, bw + 1, bias_mode, need_load_bias, | |||||
need_do_op>::call(src, dst, bias, init, rsrc, rfilter, | |||||
op); | |||||
#undef RSRC | |||||
} | |||||
}; | |||||
template <int BW, BiasMode bias_mode, bool need_load_bias, bool need_do_op> | |||||
struct compute_element<BW, BW, bias_mode, need_load_bias, need_do_op> { | |||||
template <typename... Types> | |||||
static inline void call(Types... args) {} | |||||
}; | |||||
template <size_t padding, BiasMode bias_mode, bool need_load_bias, | |||||
bool need_do_op> | |||||
struct compute_element_right { | |||||
template <typename Op> | |||||
static inline void call(float*& dst, const float*& bias, | |||||
const float32x4_t& init, float32x4_t rsrc[6], | |||||
float32x4_t rfilter[5], const Op& op) { | |||||
float32x4_t rdst; | |||||
if (need_load_bias) { | |||||
rdst = load_bias<bias_mode>(bias, init); | |||||
} else { | |||||
rdst = vld1q_f32(dst); | |||||
} | |||||
rdst = Vfmaq_f32(rdst, rsrc[0 + padding], rfilter[0]); | |||||
rdst = Vfmaq_f32(rdst, rsrc[1 + padding], rfilter[1]); | |||||
rdst = Vfmaq_f32(rdst, rsrc[2 + padding], rfilter[2]); | |||||
if (padding < 2) { | |||||
rdst = Vfmaq_f32(rdst, rsrc[3 + padding], rfilter[3]); | |||||
} | |||||
if (padding < 1) { | |||||
rdst = Vfmaq_f32(rdst, rsrc[4 + padding], rfilter[4]); | |||||
} | |||||
if (need_do_op) { | |||||
rdst = op(rdst); | |||||
} | |||||
vst1q_f32(dst, rdst); | |||||
dst += 4; | |||||
bias += 4; | |||||
} | |||||
}; | |||||
template <BiasMode bias_mode, bool need_load_bias, bool need_do_op> | |||||
struct compute_row_src_1x5 { | |||||
template <typename Op> | |||||
static inline void call(const float* src, float* dst, const float* bias, | |||||
const float32x4_t& init, float32x4_t rsrc[6], | |||||
float32x4_t rfilter[5], int W, const Op& op) { | |||||
rsrc[0] = vdupq_n_f32(0); | |||||
rsrc[1] = vdupq_n_f32(0); | |||||
rsrc[2] = vld1q_f32(src + 0); | |||||
rsrc[3] = vld1q_f32(src + 4); | |||||
rsrc[4] = vld1q_f32(src + 8); | |||||
int w = 0; | |||||
for (; w + 5 < W - 3; w += 6) { | |||||
compute_element<6, 0, bias_mode, need_load_bias, need_do_op>::call( | |||||
src, dst, bias, init, rsrc, rfilter, op); | |||||
} | |||||
if (w + 3 < W - 3) { | |||||
compute_element<4, 0, bias_mode, need_load_bias, need_do_op>::call( | |||||
src, dst, bias, init, rsrc, rfilter, op); | |||||
shift_src<4>(rsrc); | |||||
w += 4; | |||||
} | |||||
if (w + 1 < W - 3) { | |||||
compute_element<2, 0, bias_mode, need_load_bias, need_do_op>::call( | |||||
src, dst, bias, init, rsrc, rfilter, op); | |||||
shift_src<2>(rsrc); | |||||
w += 2; | |||||
} | |||||
if (w < W - 3) { | |||||
compute_element<1, 0, bias_mode, need_load_bias, need_do_op>::call( | |||||
src, dst, bias, init, rsrc, rfilter, op); | |||||
shift_src<1>(rsrc); | |||||
w += 1; | |||||
} | |||||
// compute rightmost 3 elements seperately | |||||
compute_element_right<0, bias_mode, need_load_bias, need_do_op>::call( | |||||
dst, bias, init, rsrc, rfilter, op); | |||||
compute_element_right<1, bias_mode, need_load_bias, need_do_op>::call( | |||||
dst, bias, init, rsrc, rfilter, op); | |||||
compute_element_right<2, bias_mode, need_load_bias, need_do_op>::call( | |||||
dst, bias, init, rsrc, rfilter, op); | |||||
} | |||||
}; | |||||
template <size_t top_padding, size_t bottom_padding, BiasMode bias_mode> | |||||
struct compute_row { | |||||
template <typename Op> | |||||
static inline void call(const float*& src, float*& dst, const float* filter, | |||||
const float*& bias, const float32x4_t& init, | |||||
float32x4_t rsrc[6], float32x4_t rfilter[5], int W, | |||||
const Op& op) { | |||||
if (top_padding < 1) { | |||||
load_filter(filter + 0, rfilter); | |||||
compute_row_src_1x5<bias_mode, top_padding == 0, false>::call( | |||||
src - W * 8, dst, bias, init, rsrc, rfilter, W, op); | |||||
} | |||||
if (top_padding < 2) { | |||||
load_filter(filter + 20, rfilter); | |||||
compute_row_src_1x5<bias_mode, top_padding == 1, false>::call( | |||||
src - W * 4, dst, bias, init, rsrc, rfilter, W, op); | |||||
} | |||||
{ | |||||
load_filter(filter + 40, rfilter); | |||||
compute_row_src_1x5<bias_mode, top_padding == 2, | |||||
bottom_padding == 2>::call(src, dst, bias, init, | |||||
rsrc, rfilter, W, | |||||
op); | |||||
} | |||||
if (bottom_padding < 2) { | |||||
load_filter(filter + 60, rfilter); | |||||
compute_row_src_1x5<bias_mode, false, bottom_padding == 1>::call( | |||||
src + W * 4, dst, bias, init, rsrc, rfilter, W, op); | |||||
} | |||||
if (bottom_padding < 1) { | |||||
load_filter(filter + 80, rfilter); | |||||
compute_row_src_1x5<bias_mode, false, bottom_padding == 0>::call( | |||||
src + W * 8, dst, bias, init, rsrc, rfilter, W, op); | |||||
} | |||||
src += W * 4; | |||||
dst += W * 4; | |||||
bias += W * 4; | |||||
} | |||||
}; | |||||
} // namespace | |||||
template <BiasMode bias_mode, typename Op> | |||||
void channel_wise_nchw44_float::do_conv_kern_5x5_stride1_padding2( | |||||
const float* src, float* dst, const float* filter, const float* bias, | |||||
int H, int W) { | |||||
Op op; | |||||
float32x4_t init = vdupq_n_f32(0); | |||||
if (bias_mode == BiasMode::BROADCAST_CHANNEL_BIAS) { | |||||
init = vld1q_f32(bias); | |||||
} | |||||
float32x4_t rsrc[6]; | |||||
float32x4_t rfilter[5]; | |||||
compute_row<2, 0, bias_mode>::call(src, dst, filter, bias, init, rsrc, | |||||
rfilter, W, op); | |||||
compute_row<1, 0, bias_mode>::call(src, dst, filter, bias, init, rsrc, | |||||
rfilter, W, op); | |||||
for (int h = 2; h < H - 2; h += 1) { | |||||
compute_row<0, 0, bias_mode>::call(src, dst, filter, bias, init, rsrc, | |||||
rfilter, W, op); | |||||
} | |||||
compute_row<0, 1, bias_mode>::call(src, dst, filter, bias, init, rsrc, | |||||
rfilter, W, op); | |||||
compute_row<0, 2, bias_mode>::call(src, dst, filter, bias, init, rsrc, | |||||
rfilter, W, op); | |||||
} | |||||
#define INSTANTIATION(bias, Op) \ | |||||
template void \ | |||||
channel_wise_nchw44_float::do_conv_kern_5x5_stride1_padding2<bias, Op>( \ | |||||
const float*, float*, const float*, const float*, int, int); | |||||
#define FOR_OP(bias) \ | |||||
INSTANTIATION(bias, SigmoidOp<dt_float32>) \ | |||||
INSTANTIATION(bias, ReluOp<dt_float32>) \ | |||||
INSTANTIATION(bias, HSwishOp<dt_float32>) \ | |||||
INSTANTIATION(bias, NoneOp<dt_float32>) | |||||
#define FOR_BIAS \ | |||||
FOR_OP(BiasMode::NO_BIAS) \ | |||||
FOR_OP(BiasMode::BROADCAST_CHANNEL_BIAS) \ | |||||
FOR_OP(BiasMode::BIAS) | |||||
FOR_BIAS | |||||
#undef FOR_BIAS | |||||
#undef FOR_OP | |||||
#undef INSTANTIATION | |||||
// vim: syntax=cpp.doxygen |
@@ -0,0 +1,31 @@ | |||||
/** | |||||
* \file dnn/src/arm_common/conv_bias/fp32/channel_wise_5x5_s1p2_nchw44_kern.h | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
* implied. | |||||
*/ | |||||
#pragma once | |||||
#include "src/arm_common/conv_bias/opr_impl.h" | |||||
#include "src/fallback/conv_bias/common.h" | |||||
namespace megdnn { | |||||
namespace arm_common { | |||||
namespace channel_wise_nchw44_float { | |||||
template <BiasMode bias_mode, typename Op> | |||||
void do_conv_kern_5x5_stride1_padding2(const float* src, float* dst, | |||||
const float* filter, const float* bias, | |||||
int H, int W); | |||||
} // namespace channel_wise_nchw44_float | |||||
} // namespace arm_common | |||||
} // namespace megdnn | |||||
// vim: syntax=cpp.doxygen |
@@ -11,6 +11,8 @@ | |||||
*/ | */ | ||||
#include "src/arm_common/conv_bias/fp32/channel_wise_nchw44_kern.h" | #include "src/arm_common/conv_bias/fp32/channel_wise_nchw44_kern.h" | ||||
#include "src/arm_common/conv_bias/fp32/channel_wise_3x3_s1p1_nchw44_kern.h" | |||||
#include "src/arm_common/conv_bias/fp32/channel_wise_5x5_s1p2_nchw44_kern.h" | |||||
#include "src/arm_common/elemwise_op.h" | #include "src/arm_common/elemwise_op.h" | ||||
#include "src/arm_common/simd_macro/marm_neon.h" | #include "src/arm_common/simd_macro/marm_neon.h" | ||||
#include "src/arm_common/utils.h" | #include "src/arm_common/utils.h" | ||||
@@ -413,6 +415,13 @@ void channel_wise_nchw44_float::do_conv_kern_stride1_3x3( | |||||
const float* src, const float* filter, const float* bias, float* dst, | const float* src, const float* filter, const float* bias, float* dst, | ||||
const size_t IH, const size_t IW, const size_t OH, const size_t OW, | const size_t IH, const size_t IW, const size_t OH, const size_t OW, | ||||
const size_t PH, const size_t PW) { | const size_t PH, const size_t PW) { | ||||
if (IH == OH && IW == OW && IH >= 3 && IW >= 3 && PH == 1 && PW == 1) { | |||||
channel_wise_nchw44_float::do_conv_kern_3x3_stride1_padding1<bias_mode, | |||||
Op>( | |||||
src, dst, filter, bias, OH, OW); | |||||
return; | |||||
} | |||||
float32x4_t kernel[9]; | float32x4_t kernel[9]; | ||||
load_vec<9>(kernel, filter); | load_vec<9>(kernel, filter); | ||||
Op op; | Op op; | ||||
@@ -424,10 +433,7 @@ void channel_wise_nchw44_float::do_conv_kern_stride1_3x3( | |||||
size_t ow_start = PW; | size_t ow_start = PW; | ||||
size_t oh_end = IH + PH - 2; | size_t oh_end = IH + PH - 2; | ||||
size_t ow_end = IW + PW - 2; | size_t ow_end = IW + PW - 2; | ||||
if (PH == 1 && PW == 1) { | |||||
PaddingComputeK3P1<bias_mode, Op>::compute(src, bias, dst, 1, IH, IW, | |||||
OH, OW, kernel, init); | |||||
} else if (PH || PW) { | |||||
if (PH || PW) { | |||||
PaddingCompute<bias_mode, Op>::compute(src, bias, dst, 3, 1, IH, IW, OH, | PaddingCompute<bias_mode, Op>::compute(src, bias, dst, 3, 1, IH, IW, OH, | ||||
OW, PH, PW, kernel, init); | OW, PH, PW, kernel, init); | ||||
} | } | ||||
@@ -557,6 +563,13 @@ void channel_wise_nchw44_float::do_conv_kern_stride1_5x5( | |||||
const float* src, const float* filter, const float* bias, float* dst, | const float* src, const float* filter, const float* bias, float* dst, | ||||
const size_t IH, const size_t IW, const size_t OH, const size_t OW, | const size_t IH, const size_t IW, const size_t OH, const size_t OW, | ||||
const size_t PH, const size_t PW) { | const size_t PH, const size_t PW) { | ||||
if (IH == OH && IW == OW && IH >= 5 && IW >= 5 && PH == 2 && PW == 2) { | |||||
channel_wise_nchw44_float::do_conv_kern_5x5_stride1_padding2<bias_mode, | |||||
Op>( | |||||
src, dst, filter, bias, OH, OW); | |||||
return; | |||||
} | |||||
Op op; | Op op; | ||||
float32x4_t init = vdupq_n_f32(0.f); | float32x4_t init = vdupq_n_f32(0.f); | ||||
if (bias_mode == BiasMode::BROADCAST_CHANNEL_BIAS) { | if (bias_mode == BiasMode::BROADCAST_CHANNEL_BIAS) { | ||||