diff --git a/dnn/src/aarch64/elemwise/sigmoid.h b/dnn/src/aarch64/elemwise/sigmoid.h new file mode 100644 index 00000000..4ab2c659 --- /dev/null +++ b/dnn/src/aarch64/elemwise/sigmoid.h @@ -0,0 +1,291 @@ +#pragma once +#include "src/arm_common/elemwise_helper/elemwise_op.h" +namespace megdnn { +namespace elemwise { +#if MEGDNN_AARCH64 +template <> +struct OpCallerUnary, VEC> { + static void run(const float* src, float* dst, DType, DType, size_t nr_elems) { + size_t x6_iter = nr_elems / (4 * 6); + size_t offset = x6_iter * 4 * 6; + + float32x4_t lower_range; + float32x4_t upper_range; + float32x4_t alpha_9; + float32x4_t alpha_7; + float32x4_t alpha_5; + float32x4_t alpha_3; + float32x4_t alpha_1; + float32x4_t beta_10; + float32x4_t beta_8; + float32x4_t beta_6; + float32x4_t beta_4; + float32x4_t beta_2; + float32x4_t beta_0; + float32x4_t one_half; + + const float* const_ptr = &(arm_common::sigmoid_constants.lower_range); + if (x6_iter > 0) { + /** + * q0 - q5 : squared + * q6 - q11 : p + * q12- q17 : val(temp), q + * q18- q31 : const + */ + asm volatile( + "ld1r {%[lower_range].4s}, [%[const_ptr]], #4\n" + "ld1r {%[upper_range].4s}, [%[const_ptr]], #4\n" + "ld1r {%[alpha_9].4s}, [%[const_ptr]], #4\n" + "ld1r {%[alpha_7].4s}, [%[const_ptr]], #4\n" + "ld1r {%[alpha_5].4s}, [%[const_ptr]], #4\n" + "ld1r {%[alpha_3].4s}, [%[const_ptr]], #4\n" + "ld1r {%[alpha_1].4s}, [%[const_ptr]], #4\n" + "ld1r {%[beta_10].4s}, [%[const_ptr]], #4\n" + "ld1r {%[beta_8].4s}, [%[const_ptr]], #4\n" + "ld1r {%[beta_6].4s}, [%[const_ptr]], #4\n" + "ld1r {%[beta_4].4s}, [%[const_ptr]], #4\n" + "ld1r {%[beta_2].4s}, [%[const_ptr]], #4\n" + "ld1r {%[beta_0].4s}, [%[const_ptr]], #4\n" + "ld1r {%[one_half].4s}, [%[const_ptr]], #4\n" + + "1:\n" + "ldr q12, [%[a_ptr]] \n" + "ldr q13, [%[a_ptr], #16]\n" + "ldr q14, [%[a_ptr], #32]\n" + "ldr q15, [%[a_ptr], #48]\n" + "ldr q16, [%[a_ptr], #64]\n" + "ldr q17, [%[a_ptr], #80]\n" + // auto val = vmaxq_f32(vdupq_n_f32(sigmoid_constants.lower_range), + // src); + "fmax v12.4s, v12.4s, %[lower_range].4s\n" + "fmax v13.4s, v13.4s, %[lower_range].4s\n" + "fmax v14.4s, v14.4s, %[lower_range].4s\n" + "fmax v15.4s, v15.4s, %[lower_range].4s\n" + "fmax v16.4s, v16.4s, %[lower_range].4s\n" + "fmax v17.4s, v17.4s, %[lower_range].4s\n" + "add %[a_ptr], %[a_ptr], #96\n" + + // val = vminq_f32(vdupq_n_f32(sigmoid_constants.upper_range), val); + "fmin v12.4s, v12.4s, %[upper_range].4s\n" + "fmin v13.4s, v13.4s, %[upper_range].4s\n" + "fmin v14.4s, v14.4s, %[upper_range].4s\n" + "fmin v15.4s, v15.4s, %[upper_range].4s\n" + "fmin v16.4s, v16.4s, %[upper_range].4s\n" + "fmin v17.4s, v17.4s, %[upper_range].4s\n" + + //! auto squared = vmulq_f32(val, val); + "fmul v0.4s, v12.4s, v12.4s\n" + "fmul v1.4s, v13.4s, v13.4s\n" + "fmul v2.4s, v14.4s, v14.4s\n" + "fmul v3.4s, v15.4s, v15.4s\n" + "fmul v4.4s, v16.4s, v16.4s\n" + "fmul v5.4s, v17.4s, v17.4s\n" + // auto p = fma_ps_f32( + // vdupq_n_f32(sigmoid_constants.alpha_7), squared, + // vdupq_n_f32(sigmoid_constants.alpha_9)); + "fmul v6.4s, v0.4s, %[alpha_9].4s\n" + "fmul v7.4s, v1.4s, %[alpha_9].4s\n" + "fmul v8.4s, v2.4s, %[alpha_9].4s\n" + "fmul v9.4s, v3.4s, %[alpha_9].4s\n" + "fmul v10.4s, v4.4s, %[alpha_9].4s\n" + "fmul v11.4s, v5.4s, %[alpha_9].4s\n" + "fadd v6.4s, v6.4s, %[alpha_7].4s\n" + "fadd v7.4s, v7.4s, %[alpha_7].4s\n" + "fadd v8.4s, v8.4s, %[alpha_7].4s\n" + "fadd v9.4s, v9.4s, %[alpha_7].4s\n" + "fadd v10.4s, v10.4s, %[alpha_7].4s\n" + "fadd v11.4s, v11.4s, %[alpha_7].4s\n" + + // p = fma_ps_f32(vdupq_n_f32(sigmoid_constants.alpha_5), p, squared); + "fmul v6.4s, v6.4s, v0.4s\n" + "fmul v7.4s, v7.4s, v1.4s\n" + "fmul v8.4s, v8.4s, v2.4s\n" + "fmul v9.4s, v9.4s, v3.4s\n" + "fmul v10.4s, v10.4s, v4.4s\n" + "fmul v11.4s, v11.4s, v5.4s\n" + "fadd v6.4s, v6.4s, %[alpha_5].4s\n" + "fadd v7.4s, v7.4s, %[alpha_5].4s\n" + "fadd v8.4s, v8.4s, %[alpha_5].4s\n" + "fadd v9.4s, v9.4s, %[alpha_5].4s\n" + "fadd v10.4s, v10.4s, %[alpha_5].4s\n" + "fadd v11.4s, v11.4s, %[alpha_5].4s\n" + + // p = fma_ps_f32(vdupq_n_f32(sigmoid_constants.alpha_3), p, squared); + "fmul v6.4s, v6.4s, v0.4s\n" + "fmul v7.4s, v7.4s, v1.4s\n" + "fmul v8.4s, v8.4s, v2.4s\n" + "fmul v9.4s, v9.4s, v3.4s\n" + "fmul v10.4s, v10.4s, v4.4s\n" + "fmul v11.4s, v11.4s, v5.4s\n" + "fadd v6.4s, v6.4s, %[alpha_3].4s\n" + "fadd v7.4s, v7.4s, %[alpha_3].4s\n" + "fadd v8.4s, v8.4s, %[alpha_3].4s\n" + "fadd v9.4s, v9.4s, %[alpha_3].4s\n" + "fadd v10.4s, v10.4s, %[alpha_3].4s\n" + "fadd v11.4s, v11.4s, %[alpha_3].4s\n" + + // p = fma_ps_f32(vdupq_n_f32(sigmoid_constants.alpha_1), p, squared); + "fmul v6.4s, v6.4s, v0.4s\n" + "fmul v7.4s, v7.4s, v1.4s\n" + "fmul v8.4s, v8.4s, v2.4s\n" + "fmul v9.4s, v9.4s, v3.4s\n" + "fmul v10.4s, v10.4s, v4.4s\n" + "fmul v11.4s, v11.4s, v5.4s\n" + "fadd v6.4s, v6.4s, %[alpha_1].4s\n" + "fadd v7.4s, v7.4s, %[alpha_1].4s\n" + "fadd v8.4s, v8.4s, %[alpha_1].4s\n" + "fadd v9.4s, v9.4s, %[alpha_1].4s\n" + "fadd v10.4s, v10.4s, %[alpha_1].4s\n" + "fadd v11.4s, v11.4s, %[alpha_1].4s\n" + + // p = vmulq_f32(p, val); + "fmul v6.4s, v6.4s, v12.4s\n" + "fmul v7.4s, v7.4s, v13.4s\n" + "fmul v8.4s, v8.4s, v14.4s\n" + "fmul v9.4s, v9.4s, v15.4s\n" + "fmul v10.4s, v10.4s, v16.4s\n" + "fmul v11.4s, v11.4s, v17.4s\n" + + // auto q = fma_ps_f32( + // vdupq_n_f32(sigmoid_constants.beta_8), squared, + // vdupq_n_f32(sigmoid_constants.beta_10)); + "fmul v12.4s, v0.4s, %[beta_10].4s\n" + "fmul v13.4s, v1.4s, %[beta_10].4s\n" + "fmul v14.4s, v2.4s, %[beta_10].4s\n" + "fmul v15.4s, v3.4s, %[beta_10].4s\n" + "fmul v16.4s, v4.4s, %[beta_10].4s\n" + "fmul v17.4s, v5.4s, %[beta_10].4s\n" + "fadd v12.4s, v12.4s, %[beta_8].4s\n" + "fadd v13.4s, v13.4s, %[beta_8].4s\n" + "fadd v14.4s, v14.4s, %[beta_8].4s\n" + "fadd v15.4s, v15.4s, %[beta_8].4s\n" + "fadd v16.4s, v16.4s, %[beta_8].4s\n" + "fadd v17.4s, v17.4s, %[beta_8].4s\n" + + // q = fma_ps_f32(vdupq_n_f32(sigmoid_constants.beta_6), q, + // squared); + "fmul v12.4s, v12.4s, v0.4s\n" + "fmul v13.4s, v13.4s, v1.4s\n" + "fmul v14.4s, v14.4s, v2.4s\n" + "fmul v15.4s, v15.4s, v3.4s\n" + "fmul v16.4s, v16.4s, v4.4s\n" + "fmul v17.4s, v17.4s, v5.4s\n" + "fadd v12.4s, v12.4s, %[beta_6].4s\n" + "fadd v13.4s, v13.4s, %[beta_6].4s\n" + "fadd v14.4s, v14.4s, %[beta_6].4s\n" + "fadd v15.4s, v15.4s, %[beta_6].4s\n" + "fadd v16.4s, v16.4s, %[beta_6].4s\n" + "fadd v17.4s, v17.4s, %[beta_6].4s\n" + + // q = fma_ps_f32(vdupq_n_f32(sigmoid_constants.beta_4), q, + // squared); + "fmul v12.4s, v12.4s, v0.4s\n" + "fmul v13.4s, v13.4s, v1.4s\n" + "fmul v14.4s, v14.4s, v2.4s\n" + "fmul v15.4s, v15.4s, v3.4s\n" + "fmul v16.4s, v16.4s, v4.4s\n" + "fmul v17.4s, v17.4s, v5.4s\n" + "fadd v12.4s, v12.4s, %[beta_4].4s\n" + "fadd v13.4s, v13.4s, %[beta_4].4s\n" + "fadd v14.4s, v14.4s, %[beta_4].4s\n" + "fadd v15.4s, v15.4s, %[beta_4].4s\n" + "fadd v16.4s, v16.4s, %[beta_4].4s\n" + "fadd v17.4s, v17.4s, %[beta_4].4s\n" + + // q = fma_ps_f32(vdupq_n_f32(sigmoid_constants.beta_2), q, + // squared); + "fmul v12.4s, v12.4s, v0.4s\n" + "fmul v13.4s, v13.4s, v1.4s\n" + "fmul v14.4s, v14.4s, v2.4s\n" + "fmul v15.4s, v15.4s, v3.4s\n" + "fmul v16.4s, v16.4s, v4.4s\n" + "fmul v17.4s, v17.4s, v5.4s\n" + "fadd v12.4s, v12.4s, %[beta_2].4s\n" + "fadd v13.4s, v13.4s, %[beta_2].4s\n" + "fadd v14.4s, v14.4s, %[beta_2].4s\n" + "fadd v15.4s, v15.4s, %[beta_2].4s\n" + "fadd v16.4s, v16.4s, %[beta_2].4s\n" + "fadd v17.4s, v17.4s, %[beta_2].4s\n" + + // q = fma_ps_f32(vdupq_n_f32(sigmoid_constants.beta_0), q, squared); + "fmul v12.4s, v12.4s, v0.4s\n" + "fmul v13.4s, v13.4s, v1.4s\n" + "fmul v14.4s, v14.4s, v2.4s\n" + "fmul v15.4s, v15.4s, v3.4s\n" + "fmul v16.4s, v16.4s, v4.4s\n" + "fmul v17.4s, v17.4s, v5.4s\n" + "fadd v12.4s, v12.4s, %[beta_0].4s\n" + "fadd v13.4s, v13.4s, %[beta_0].4s\n" + "fadd v14.4s, v14.4s, %[beta_0].4s\n" + "fadd v15.4s, v15.4s, %[beta_0].4s\n" + "fadd v16.4s, v16.4s, %[beta_0].4s\n" + "fadd v17.4s, v17.4s, %[beta_0].4s\n" + + // vaddq_f32(div_ps_f32(p, q), + // vdupq_n_f32(sigmoid_constants.one_half)); + "fdiv v12.4s, v6.4s, v12.4s\n" + "fdiv v13.4s, v7.4s, v13.4s\n" + "fdiv v14.4s, v8.4s, v14.4s\n" + "fdiv v15.4s, v9.4s, v15.4s\n" + "fdiv v16.4s, v10.4s, v16.4s\n" + "fdiv v17.4s, v11.4s, v17.4s\n" + "subs %w[x6_iter], %w[x6_iter], #1\n" + "fadd v12.4s, v12.4s, %[one_half].4s\n" + "fadd v13.4s, v13.4s, %[one_half].4s\n" + "fadd v14.4s, v14.4s, %[one_half].4s\n" + "fadd v15.4s, v15.4s, %[one_half].4s\n" + "fadd v16.4s, v16.4s, %[one_half].4s\n" + "fadd v17.4s, v17.4s, %[one_half].4s\n" + + // save it + "str q12, [%[d_ptr]] \n" + "str q13, [%[d_ptr], #16]\n" + "str q14, [%[d_ptr], #32]\n" + "str q15, [%[d_ptr], #48]\n" + "str q16, [%[d_ptr], #64]\n" + "str q17, [%[d_ptr], #80]\n" + "add %[d_ptr], %[d_ptr], #96\n" + + "bne 1b\n" + + "2:\n" + : [a_ptr] "+r"(src), [d_ptr] "+r"(dst), [const_ptr] "+r"(const_ptr), + [x6_iter] "+r"(x6_iter), [lower_range] "=w"(lower_range), + [alpha_9] "=w"(alpha_9), [upper_range] "=w"(upper_range), + [alpha_7] "=w"(alpha_7), [alpha_5] "=w"(alpha_5), + [alpha_3] "=w"(alpha_3), [alpha_1] "=w"(alpha_1), + [beta_10] "=w"(beta_10), [beta_8] "=w"(beta_8), + [beta_6] "=w"(beta_6), [beta_4] "=w"(beta_4), + [beta_2] "=w"(beta_2), [beta_0] "=w"(beta_0), + [one_half] "=w"(one_half) + : + : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", + "v11", "v12", "v13", "v14", "v15", "v16", "v17", "x1", "x2", "x8", + "x9", "cc", "memory"); + } + + using Op = arm_common::SigmoidOp; + Op op; + ParamElemVisitorV2 vis2; + size_t i = offset; + for (; i + Op::SIMD_WIDTH * 2 <= nr_elems; i += Op::SIMD_WIDTH * 2) { + op(vis2(src, src + Op::SIMD_WIDTH), dst); + src += Op::SIMD_WIDTH * 2; + dst += Op::SIMD_WIDTH * 2; + } + for (; i + Op::SIMD_WIDTH * 1 <= nr_elems; i += Op::SIMD_WIDTH * 1) { + op(vld1q_f32(src), dst); + src += Op::SIMD_WIDTH; + dst += Op::SIMD_WIDTH; + } + for (; i < nr_elems; i++) { + op(*src, dst); + src++; + dst++; + } + } +}; +#endif + +} // namespace elemwise +} // namespace megdnn \ No newline at end of file diff --git a/dnn/src/arm_common/elemwise/neon_mathfun.cpp b/dnn/src/arm_common/elemwise/neon_mathfun.cpp index 04c6bd7c..c091f7df 100644 --- a/dnn/src/arm_common/elemwise/neon_mathfun.cpp +++ b/dnn/src/arm_common/elemwise/neon_mathfun.cpp @@ -320,69 +320,6 @@ v4sf tan_ps_f32(v4sf x) { #undef c_cephes_log_q1 #undef c_cephes_log_q2 -static const struct { - float lower_range; - float upper_range; - float alpha_9; - float alpha_7; - float alpha_5; - float alpha_3; - float alpha_1; - float beta_10; - float beta_8; - float beta_6; - float beta_4; - float beta_2; - float beta_0; - float one_half; -} sigmoid_constants = { - -18.0f, - 18.0f, - 4.37031012579801e-11f, - 1.15627324459942e-07f, - 6.08574864600143e-05f, - 8.51377133304701e-03f, - 2.48287947061529e-01f, - 6.10247389755681e-13f, - 5.76102136993427e-09f, - 6.29106785017040e-06f, - 1.70198817374094e-03f, - 1.16817656904453e-01f, - 9.93151921023180e-01f, - 0.5f, -}; - -v4sf sigmoid_ps_f32(v4sf src) { - auto val = vmaxq_f32(vdupq_n_f32(sigmoid_constants.lower_range), src); - val = vminq_f32(vdupq_n_f32(sigmoid_constants.upper_range), val); - auto squared = vmulq_f32(val, val); - auto p = fma_ps_f32( - vdupq_n_f32(sigmoid_constants.alpha_7), squared, - vdupq_n_f32(sigmoid_constants.alpha_9)); - p = fma_ps_f32(vdupq_n_f32(sigmoid_constants.alpha_5), p, squared); - p = fma_ps_f32(vdupq_n_f32(sigmoid_constants.alpha_3), p, squared); - p = fma_ps_f32(vdupq_n_f32(sigmoid_constants.alpha_1), p, squared); - p = vmulq_f32(p, val); - auto q = fma_ps_f32( - vdupq_n_f32(sigmoid_constants.beta_8), squared, - vdupq_n_f32(sigmoid_constants.beta_10)); - q = fma_ps_f32(vdupq_n_f32(sigmoid_constants.beta_6), q, squared); - q = fma_ps_f32(vdupq_n_f32(sigmoid_constants.beta_4), q, squared); - q = fma_ps_f32(vdupq_n_f32(sigmoid_constants.beta_2), q, squared); - q = fma_ps_f32(vdupq_n_f32(sigmoid_constants.beta_0), q, squared); - return vaddq_f32(div_ps_f32(p, q), vdupq_n_f32(sigmoid_constants.one_half)); -} - -#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC -float16x8_t sigmoid_ps_f16(float16x8_t x) { - float32x4_t low = vcvt_f32_f16(vget_low_f16(x)); - float32x4_t high = vcvt_f32_f16(vget_high_f16(x)); - low = sigmoid_ps_f32(low); - high = sigmoid_ps_f32(high); - return vcombine_f16(vcvt_f16_f32(low), vcvt_f16_f32(high)); -} -#endif - } // namespace arm_common } // namespace megdnn diff --git a/dnn/src/arm_common/elemwise/neon_mathfun.h b/dnn/src/arm_common/elemwise/neon_mathfun.h index 22ba5d50..231c1e40 100644 --- a/dnn/src/arm_common/elemwise/neon_mathfun.h +++ b/dnn/src/arm_common/elemwise/neon_mathfun.h @@ -59,7 +59,68 @@ static inline v4sf div_ps_f32(v4sf& x, v4sf& y) { #define fma_ps_f32(c, b, a) vmlaq_f32((c), (a), (b)) #endif -v4sf sigmoid_ps_f32(v4sf x); +static const struct { + float lower_range; + float upper_range; + float alpha_9; + float alpha_7; + float alpha_5; + float alpha_3; + float alpha_1; + float beta_10; + float beta_8; + float beta_6; + float beta_4; + float beta_2; + float beta_0; + float one_half; +} sigmoid_constants = { + -18.0f, + 18.0f, + 4.37031012579801e-11f, + 1.15627324459942e-07f, + 6.08574864600143e-05f, + 8.51377133304701e-03f, + 2.48287947061529e-01f, + 6.10247389755681e-13f, + 5.76102136993427e-09f, + 6.29106785017040e-06f, + 1.70198817374094e-03f, + 1.16817656904453e-01f, + 9.93151921023180e-01f, + 0.5f, +}; +//! for compiler inline, do not move this func to cpp +static inline v4sf sigmoid_ps_f32(v4sf src) { + auto val = vmaxq_f32(vdupq_n_f32(sigmoid_constants.lower_range), src); + val = vminq_f32(vdupq_n_f32(sigmoid_constants.upper_range), val); + auto squared = vmulq_f32(val, val); + auto p = fma_ps_f32( + vdupq_n_f32(sigmoid_constants.alpha_7), squared, + vdupq_n_f32(sigmoid_constants.alpha_9)); + p = fma_ps_f32(vdupq_n_f32(sigmoid_constants.alpha_5), p, squared); + p = fma_ps_f32(vdupq_n_f32(sigmoid_constants.alpha_3), p, squared); + p = fma_ps_f32(vdupq_n_f32(sigmoid_constants.alpha_1), p, squared); + p = vmulq_f32(p, val); + auto q = fma_ps_f32( + vdupq_n_f32(sigmoid_constants.beta_8), squared, + vdupq_n_f32(sigmoid_constants.beta_10)); + q = fma_ps_f32(vdupq_n_f32(sigmoid_constants.beta_6), q, squared); + q = fma_ps_f32(vdupq_n_f32(sigmoid_constants.beta_4), q, squared); + q = fma_ps_f32(vdupq_n_f32(sigmoid_constants.beta_2), q, squared); + q = fma_ps_f32(vdupq_n_f32(sigmoid_constants.beta_0), q, squared); + return vaddq_f32(div_ps_f32(p, q), vdupq_n_f32(sigmoid_constants.one_half)); +} + +#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC +static inline float16x8_t sigmoid_ps_f16(float16x8_t x) { + float32x4_t low = vcvt_f32_f16(vget_low_f16(x)); + float32x4_t high = vcvt_f32_f16(vget_high_f16(x)); + low = sigmoid_ps_f32(low); + high = sigmoid_ps_f32(high); + return vcombine_f16(vcvt_f16_f32(low), vcvt_f16_f32(high)); +} +#endif #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC /** @@ -78,8 +139,6 @@ static inline float16x8_t div_ps_f16(float16x8_t& x, float16x8_t& y) { #endif } -float16x8_t sigmoid_ps_f16(float16x8_t x); - #endif } // namespace arm_common diff --git a/dnn/src/arm_common/elemwise/unary/algo.cpp b/dnn/src/arm_common/elemwise/unary/algo.cpp index 01337b58..006e68ba 100644 --- a/dnn/src/arm_common/elemwise/unary/algo.cpp +++ b/dnn/src/arm_common/elemwise/unary/algo.cpp @@ -1,6 +1,8 @@ #include "src/arm_common/elemwise/unary/algo.h" #include "src/arm_common/elemwise_helper/elemwise_op.h" - +#if MEGDNN_AARCH64 +#include "src/aarch64/elemwise/sigmoid.h" +#endif #include "src/common/utils.h" #include "src/naive/handle.h"