diff --git a/dnn/src/arm_common/elemwise/neon_mathfun.cpp b/dnn/src/arm_common/elemwise/neon_mathfun.cpp index beffb726..1d356d54 100644 --- a/dnn/src/arm_common/elemwise/neon_mathfun.cpp +++ b/dnn/src/arm_common/elemwise/neon_mathfun.cpp @@ -371,6 +371,69 @@ v4sf tan_ps_f32(v4sf x) { #undef c_cephes_log_q1 #undef c_cephes_log_q2 +static const struct { + float lower_range; + float upper_range; + float alpha_9; + float alpha_7; + float alpha_5; + float alpha_3; + float alpha_1; + float beta_10; + float beta_8; + float beta_6; + float beta_4; + float beta_2; + float beta_0; + float one_half; +} sigmoid_constants = { + -18.0f, + 18.0f, + 4.37031012579801e-11f, + 1.15627324459942e-07f, + 6.08574864600143e-05f, + 8.51377133304701e-03f, + 2.48287947061529e-01f, + 6.10247389755681e-13f, + 5.76102136993427e-09f, + 6.29106785017040e-06f, + 1.70198817374094e-03f, + 1.16817656904453e-01f, + 9.93151921023180e-01f, + 0.5f, +}; + +v4sf sigmoid_ps_f32(v4sf src) { + auto val = vmaxq_f32(vdupq_n_f32(sigmoid_constants.lower_range), src); + val = vminq_f32(vdupq_n_f32(sigmoid_constants.upper_range), val); + auto squared = vmulq_f32(val, val); + auto p = vmlaq_f32( + vdupq_n_f32(sigmoid_constants.alpha_7), squared, + vdupq_n_f32(sigmoid_constants.alpha_9)); + p = vmlaq_f32(vdupq_n_f32(sigmoid_constants.alpha_5), p, squared); + p = vmlaq_f32(vdupq_n_f32(sigmoid_constants.alpha_3), p, squared); + p = vmlaq_f32(vdupq_n_f32(sigmoid_constants.alpha_1), p, squared); + p = vmulq_f32(p, val); + auto q = vmlaq_f32( + vdupq_n_f32(sigmoid_constants.beta_8), squared, + vdupq_n_f32(sigmoid_constants.beta_10)); + q = vmlaq_f32(vdupq_n_f32(sigmoid_constants.beta_6), q, squared); + q = vmlaq_f32(vdupq_n_f32(sigmoid_constants.beta_4), q, squared); + q = vmlaq_f32(vdupq_n_f32(sigmoid_constants.beta_2), q, squared); + q = vmlaq_f32(vdupq_n_f32(sigmoid_constants.beta_0), q, squared); + return vaddq_f32(div_ps_f32(p, q), vdupq_n_f32(sigmoid_constants.one_half)); +} + +#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC +float16x8_t sigmoid_ps_f16(float16x8_t x) { + float32x4_t low = vcvt_f32_f16(vget_low_f16(x)); + float32x4_t high = vcvt_f32_f16(vget_high_f16(x)); + low = sigmoid_ps_f32(low); + high = sigmoid_ps_f32(high); + return vcombine_f16(vcvt_f16_f32(low), vcvt_f16_f32(high)); +} +#endif + } // namespace arm_common } // namespace megdnn diff --git a/dnn/src/arm_common/elemwise/neon_mathfun.h b/dnn/src/arm_common/elemwise/neon_mathfun.h index 11937ecc..8172e8d5 100644 --- a/dnn/src/arm_common/elemwise/neon_mathfun.h +++ b/dnn/src/arm_common/elemwise/neon_mathfun.h @@ -54,11 +54,38 @@ v4sf cos_ps_f32(v4sf x); v4sf tan_ps_f32(v4sf x); +static inline v4sf div_ps_f32(v4sf x, v4sf y) { +#if MEGDNN_AARCH64 + return vdivq_f32(x, y); +#else + //! armv7 not support vdiv, so compute the reciprocal and iterate again + float32x4_t recp = vrecpeq_f32(y); + recp = vmulq_f32(vrecpsq_f32(y, recp), recp); + return vmulq_f32(x, recp); +#endif +} + +v4sf sigmoid_ps_f32(v4sf x); + #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC /** * \brief compute for 8 half at once, the inner just invoke exp_ps_f32 twice */ float16x8_t exp_ps_f16(float16x8_t x); + +static inline float16x8_t div_ps_f16(float16x8_t x, float16x8_t y) { +#if MEGDNN_AARCH64 + return vdivq_f16(x, y); +#else + //! armv7 not support vdiv, so compute the reciprocal and iterate again + float16x8_t recp = vrecpeq_f16(y); + recp = vmulq_f16(vrecpsq_f16(y, recp), recp); + return vmulq_f16(x, recp); +#endif +} + +float16x8_t sigmoid_ps_f16(float16x8_t x); + #endif } // namespace arm_common diff --git a/dnn/src/arm_common/elemwise_helper/kimpl/fuse_add_sigmoid.h b/dnn/src/arm_common/elemwise_helper/kimpl/fuse_add_sigmoid.h index 56c11219..2f8755ae 100644 --- a/dnn/src/arm_common/elemwise_helper/kimpl/fuse_add_sigmoid.h +++ b/dnn/src/arm_common/elemwise_helper/kimpl/fuse_add_sigmoid.h @@ -47,24 +47,14 @@ struct FuseAddSigmoidOp; vst1q_##_func_suffix(dst + SIMD_WIDTH, vitem.val[1]); \ } \ _neon_type operator()(const _neon_type& src0, const _neon_type& src1) const { \ - auto zero_val = vdupq_n_##_func_suffix(0.f); \ - auto one_val = vdupq_n_##_func_suffix(1.f); \ auto val1 = src0.val[0]; \ auto val2 = src0.val[1]; \ auto val3 = src1.val[0]; \ auto val4 = src1.val[1]; \ val1 = vaddq_##_func_suffix(val1, val3); \ val2 = vaddq_##_func_suffix(val2, val4); \ - val1 = vsubq_##_func_suffix(zero_val, val1); \ - val2 = vsubq_##_func_suffix(zero_val, val2); \ - val1 = exp_ps_##_func_suffix(val1); \ - val2 = exp_ps_##_func_suffix(val2); \ - auto recipe1 = vaddq_##_func_suffix(one_val, val1); \ - auto recipe2 = vaddq_##_func_suffix(one_val, val2); \ - val1 = vrecpeq_##_func_suffix(recipe1); \ - val2 = vrecpeq_##_func_suffix(recipe2); \ - val1 = vmulq_##_func_suffix(vrecpsq_##_func_suffix(recipe1, val1), val1); \ - val2 = vmulq_##_func_suffix(vrecpsq_##_func_suffix(recipe2, val2), val2); \ + val1 = sigmoid_ps_##_func_suffix(val1); \ + val2 = sigmoid_ps_##_func_suffix(val2); \ return {{val1, val2}}; \ } \ }; diff --git a/dnn/src/arm_common/elemwise_helper/kimpl/sigmoid.h b/dnn/src/arm_common/elemwise_helper/kimpl/sigmoid.h index 3d10636c..cc635b84 100644 --- a/dnn/src/arm_common/elemwise_helper/kimpl/sigmoid.h +++ b/dnn/src/arm_common/elemwise_helper/kimpl/sigmoid.h @@ -33,34 +33,27 @@ struct SigmoidOpBase : UnaryOpBase { template struct SigmoidOp; -#define OP(_ctype, _neon_type, _neon_type2, _func_suffix, _simd_width) \ - template <> \ - struct SigmoidOp<_ctype> : SigmoidOpBase<_ctype> { \ - using SigmoidOpBase::SigmoidOpBase; \ - using SigmoidOpBase::operator(); \ - constexpr static size_t SIMD_WIDTH = _simd_width; \ - void operator()(const _neon_type2& src, _ctype* dst) const { \ - auto vitem = operator()(src); \ - vst1q_##_func_suffix(dst, vitem.val[0]); \ - vst1q_##_func_suffix(dst + SIMD_WIDTH, vitem.val[1]); \ - } \ - void operator()(const _neon_type& src, _ctype* dst) const { \ - auto vitem = operator()(src); \ - vst1q_##_func_suffix(dst, vitem); \ - } \ - _neon_type2 operator()(const _neon_type2& src) const { \ - return {{operator()(src.val[0]), operator()(src.val[1])}}; \ - } \ - _neon_type operator()(const _neon_type& src) const { \ - auto zero_val = vdupq_n_##_func_suffix(0.f); \ - auto one_val = vdupq_n_##_func_suffix(1.f); \ - auto val1 = vsubq_##_func_suffix(zero_val, src); \ - val1 = exp_ps_##_func_suffix(val1); \ - auto recipe1 = vaddq_##_func_suffix(one_val, val1); \ - val1 = vrecpeq_##_func_suffix(recipe1); \ - val1 = vmulq_##_func_suffix(vrecpsq_##_func_suffix(recipe1, val1), val1); \ - return val1; \ - } \ +#define OP(_ctype, _neon_type, _neon_type2, _func_suffix, _simd_width) \ + template <> \ + struct SigmoidOp<_ctype> : SigmoidOpBase<_ctype> { \ + using SigmoidOpBase::SigmoidOpBase; \ + using SigmoidOpBase::operator(); \ + constexpr static size_t SIMD_WIDTH = _simd_width; \ + void operator()(const _neon_type2& src, _ctype* dst) const { \ + auto vitem = operator()(src); \ + vst1q_##_func_suffix(dst, vitem.val[0]); \ + vst1q_##_func_suffix(dst + SIMD_WIDTH, vitem.val[1]); \ + } \ + void operator()(const _neon_type& src, _ctype* dst) const { \ + auto vitem = operator()(src); \ + vst1q_##_func_suffix(dst, vitem); \ + } \ + _neon_type2 operator()(const _neon_type2& src) const { \ + return {{operator()(src.val[0]), operator()(src.val[1])}}; \ + } \ + _neon_type operator()(const _neon_type& src) const { \ + return sigmoid_ps_##_func_suffix(src); \ + } \ }; OP(dt_float32, float32x4_t, float32x4x2_t, f32, 4) #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC diff --git a/imperative/python/test/unit/utils/test_network.py b/imperative/python/test/unit/utils/test_network.py index f223d69d..ca5abe68 100644 --- a/imperative/python/test/unit/utils/test_network.py +++ b/imperative/python/test/unit/utils/test_network.py @@ -318,7 +318,7 @@ def test_add_remove_output(): out = g.run(a.numpy(), b.numpy()) np.testing.assert_equal(out["new_o1"], ((a + b) * 3).numpy()) - np.testing.assert_equal(out["new_o2"], (F.sigmoid((a + b))).numpy()) + np.testing.assert_almost_equal(out["new_o2"], (F.sigmoid((a + b))).numpy()) def test_query():