Browse Source

perf(dnn): slightly improve arm neon transcendental function performance

GitOrigin-RevId: 210d88f81e
tags/v1.7.0.m1
Megvii Engine Team 3 years ago
parent
commit
ead611e11d
2 changed files with 63 additions and 90 deletions
  1. +55
    -88
      dnn/src/arm_common/elemwise/neon_mathfun.cpp
  2. +8
    -2
      dnn/src/arm_common/elemwise/neon_mathfun.h

+ 55
- 88
dnn/src/arm_common/elemwise/neon_mathfun.cpp View File

@@ -86,11 +86,11 @@ v4sf log_ps_f32(v4sf x) {
e = vaddq_f32(e, one);

/* part2:
if( x < SQRTHF ) {
e -= 1;
x = x + x - 1.0;
} else { x = x - 1.0; }
*/
* if( x < SQRTHF ) {
* e -= 1;
* x = x + x - 1.0;
* } else { x = x - 1.0; }
*/
v4su mask = vcltq_f32(x, vdupq_n_f32(c_cephes_SQRTHF));
v4sf tmp = vreinterpretq_f32_u32(vandq_u32(vreinterpretq_u32_f32(x), mask));
x = vsubq_f32(x, one);
@@ -101,38 +101,26 @@ v4sf log_ps_f32(v4sf x) {
v4sf z = vmulq_f32(x, x);

v4sf y = vdupq_n_f32(c_cephes_log_p0);
y = vmulq_f32(y, x);
y = vaddq_f32(y, vdupq_n_f32(c_cephes_log_p1));
y = vmulq_f32(y, x);
y = vaddq_f32(y, vdupq_n_f32(c_cephes_log_p2));
y = vmulq_f32(y, x);
y = vaddq_f32(y, vdupq_n_f32(c_cephes_log_p3));
y = vmulq_f32(y, x);
y = vaddq_f32(y, vdupq_n_f32(c_cephes_log_p4));
y = vmulq_f32(y, x);
y = vaddq_f32(y, vdupq_n_f32(c_cephes_log_p5));
y = vmulq_f32(y, x);
y = vaddq_f32(y, vdupq_n_f32(c_cephes_log_p6));
y = vmulq_f32(y, x);
y = vaddq_f32(y, vdupq_n_f32(c_cephes_log_p7));
y = vmulq_f32(y, x);
y = vaddq_f32(y, vdupq_n_f32(c_cephes_log_p8));
y = fma_ps_f32(vdupq_n_f32(c_cephes_log_p1), y, x);
y = fma_ps_f32(vdupq_n_f32(c_cephes_log_p2), y, x);
y = fma_ps_f32(vdupq_n_f32(c_cephes_log_p3), y, x);
y = fma_ps_f32(vdupq_n_f32(c_cephes_log_p4), y, x);
y = fma_ps_f32(vdupq_n_f32(c_cephes_log_p5), y, x);
y = fma_ps_f32(vdupq_n_f32(c_cephes_log_p6), y, x);
y = fma_ps_f32(vdupq_n_f32(c_cephes_log_p7), y, x);
y = fma_ps_f32(vdupq_n_f32(c_cephes_log_p8), y, x);
y = vmulq_f32(y, x);

y = vmulq_f32(y, z);

tmp = vmulq_f32(e, vdupq_n_f32(c_cephes_log_q1));
y = vaddq_f32(y, tmp);
y = fma_ps_f32(y, e, vdupq_n_f32(c_cephes_log_q1));

tmp = vmulq_f32(z, vdupq_n_f32(0.5f));
y = vsubq_f32(y, tmp);
y = vmlsq_f32(y, z, vdupq_n_f32(0.5f));

tmp = vmulq_f32(e, vdupq_n_f32(c_cephes_log_q2));
x = vaddq_f32(x, y);
x = vaddq_f32(x, tmp);
x = fma_ps_f32(x, e, vdupq_n_f32(c_cephes_log_q2));
x = vreinterpretq_f32_u32(vorrq_u32(
vreinterpretq_u32_f32(x),
invalid_mask)); // negative arg will be NAN
vreinterpretq_u32_f32(x), invalid_mask)); // negative arg will be NAN
return x;
}

@@ -159,7 +147,7 @@ v4sf exp_ps_f32(v4sf x) {
x = vmaxq_f32(x, vdupq_n_f32(c_exp_lo));

/* express exp(x) as exp(g + n*log(2)) */
fx = vmlaq_f32(vdupq_n_f32(0.5f), x, vdupq_n_f32(c_cephes_LOG2EF));
fx = fma_ps_f32(vdupq_n_f32(0.5f), x, vdupq_n_f32(c_cephes_LOG2EF));

/* perform a floorf */
tmp = vcvtq_f32_s32(vcvtq_s32_f32(fx));
@@ -175,34 +163,20 @@ v4sf exp_ps_f32(v4sf x) {
x = vsubq_f32(x, tmp);
x = vsubq_f32(x, z);

static const float cephes_exp_p[6] = {c_cephes_exp_p0, c_cephes_exp_p1,
c_cephes_exp_p2, c_cephes_exp_p3,
c_cephes_exp_p4, c_cephes_exp_p5};
v4sf y = vld1q_dup_f32(cephes_exp_p + 0);
v4sf c1 = vld1q_dup_f32(cephes_exp_p + 1);
v4sf c2 = vld1q_dup_f32(cephes_exp_p + 2);
v4sf c3 = vld1q_dup_f32(cephes_exp_p + 3);
v4sf c4 = vld1q_dup_f32(cephes_exp_p + 4);
v4sf c5 = vld1q_dup_f32(cephes_exp_p + 5);

y = vmulq_f32(y, x);
z = vmulq_f32(x, x);
y = vaddq_f32(y, c1);
y = vmulq_f32(y, x);
y = vaddq_f32(y, c2);
y = vmulq_f32(y, x);
y = vaddq_f32(y, c3);
y = vmulq_f32(y, x);
y = vaddq_f32(y, c4);
y = vmulq_f32(y, x);
y = vaddq_f32(y, c5);

y = vmulq_f32(y, z);
y = vaddq_f32(y, x);
v4sf y = vdupq_n_f32(c_cephes_exp_p0);
y = fma_ps_f32(vdupq_n_f32(c_cephes_exp_p1), y, x);
y = fma_ps_f32(vdupq_n_f32(c_cephes_exp_p2), y, x);
y = fma_ps_f32(vdupq_n_f32(c_cephes_exp_p3), y, x);
y = fma_ps_f32(vdupq_n_f32(c_cephes_exp_p4), y, x);
y = fma_ps_f32(vdupq_n_f32(c_cephes_exp_p5), y, x);

y = fma_ps_f32(x, y, z);
y = vaddq_f32(y, one);

/* build 2^n */
int32x4_t mm;
v4si mm;
mm = vcvtq_s32_f32(fx);
mm = vaddq_s32(mm, vdupq_n_s32(0x7f));
mm = vshlq_n_s32(mm, 23);
@@ -249,8 +223,9 @@ float16x8_t exp_ps_f16(float16x8_t x) {
almost no extra price so both sin_ps_f32 and cos_ps_f32 make use of
sincos_ps_f32..
*/
void sincos_ps_f32(v4sf x, v4sf* ysin, v4sf* ycos) { // any x
v4sf xmm1, xmm2, xmm3, y;
void sincos_ps_f32(v4sf x, v4sf* ysin, v4sf* ycos) {
// any x
v4sf y;

v4su emm2;

@@ -269,44 +244,36 @@ void sincos_ps_f32(v4sf x, v4sf* ysin, v4sf* ycos) { // any x
y = vcvtq_f32_u32(emm2);

/* get the polynom selection mask
there is one polynom for 0 <= x <= Pi/4
and another one for Pi/4<x<=Pi/2
Both branches will be computed.
*/
* there is one polynom for 0 <= x <= Pi/4
* and another one for Pi/4<x<=Pi/2
*
* Both branches will be computed.
*/
v4su poly_mask = vtstq_u32(emm2, vdupq_n_u32(2));

/* The magic pass: "Extended precision modular arithmetic"
x = ((x - y * DP1) - y * DP2) - y * DP3; */
xmm1 = vmulq_n_f32(y, c_minus_cephes_DP1);
xmm2 = vmulq_n_f32(y, c_minus_cephes_DP2);
xmm3 = vmulq_n_f32(y, c_minus_cephes_DP3);
x = vaddq_f32(x, xmm1);
x = vaddq_f32(x, xmm2);
x = vaddq_f32(x, xmm3);
* x = ((x - y * DP1) - y * DP2) - y * DP3; */
x = fma_ps_f32(x, y, vdupq_n_f32(c_minus_cephes_DP1));
x = fma_ps_f32(x, y, vdupq_n_f32(c_minus_cephes_DP2));
x = fma_ps_f32(x, y, vdupq_n_f32(c_minus_cephes_DP3));

sign_mask_sin = veorq_u32(sign_mask_sin, vtstq_u32(emm2, vdupq_n_u32(4)));
sign_mask_cos = vtstq_u32(vsubq_u32(emm2, vdupq_n_u32(2)), vdupq_n_u32(4));

/* Evaluate the first polynom (0 <= x <= Pi/4) in y1,
and the second polynom (Pi/4 <= x <= 0) in y2 */
* and the second polynom (Pi/4 <= x <= 0) in y2 */
v4sf z = vmulq_f32(x, x);
v4sf y1, y2;

y1 = vmulq_n_f32(z, c_coscof_p0);
y2 = vmulq_n_f32(z, c_sincof_p0);
y1 = vaddq_f32(y1, vdupq_n_f32(c_coscof_p1));
y2 = vaddq_f32(y2, vdupq_n_f32(c_sincof_p1));
y1 = vmulq_f32(y1, z);
y2 = vmulq_f32(y2, z);
y1 = vaddq_f32(y1, vdupq_n_f32(c_coscof_p2));
y2 = vaddq_f32(y2, vdupq_n_f32(c_sincof_p2));
y1 = fma_ps_f32(vdupq_n_f32(c_coscof_p1), z, vdupq_n_f32(c_coscof_p0));
y2 = fma_ps_f32(vdupq_n_f32(c_sincof_p1), z, vdupq_n_f32(c_sincof_p0));
y1 = fma_ps_f32(vdupq_n_f32(c_coscof_p2), y1, z);
y2 = fma_ps_f32(vdupq_n_f32(c_sincof_p2), y2, z);
y1 = vmulq_f32(y1, z);
y2 = vmulq_f32(y2, z);
y1 = vmulq_f32(y1, z);
y2 = vmulq_f32(y2, x);
y1 = vsubq_f32(y1, vmulq_f32(z, vdupq_n_f32(0.5f)));
y2 = vaddq_f32(y2, x);
y1 = vmlsq_f32(y1, z, vdupq_n_f32(0.5f));
y2 = fma_ps_f32(x, y2, x);
y1 = vaddq_f32(y1, vdupq_n_f32(1));

/* select the correct result from the two polynoms */
@@ -407,20 +374,20 @@ v4sf sigmoid_ps_f32(v4sf src) {
auto val = vmaxq_f32(vdupq_n_f32(sigmoid_constants.lower_range), src);
val = vminq_f32(vdupq_n_f32(sigmoid_constants.upper_range), val);
auto squared = vmulq_f32(val, val);
auto p = vmlaq_f32(
auto p = fma_ps_f32(
vdupq_n_f32(sigmoid_constants.alpha_7), squared,
vdupq_n_f32(sigmoid_constants.alpha_9));
p = vmlaq_f32(vdupq_n_f32(sigmoid_constants.alpha_5), p, squared);
p = vmlaq_f32(vdupq_n_f32(sigmoid_constants.alpha_3), p, squared);
p = vmlaq_f32(vdupq_n_f32(sigmoid_constants.alpha_1), p, squared);
p = fma_ps_f32(vdupq_n_f32(sigmoid_constants.alpha_5), p, squared);
p = fma_ps_f32(vdupq_n_f32(sigmoid_constants.alpha_3), p, squared);
p = fma_ps_f32(vdupq_n_f32(sigmoid_constants.alpha_1), p, squared);
p = vmulq_f32(p, val);
auto q = vmlaq_f32(
auto q = fma_ps_f32(
vdupq_n_f32(sigmoid_constants.beta_8), squared,
vdupq_n_f32(sigmoid_constants.beta_10));
q = vmlaq_f32(vdupq_n_f32(sigmoid_constants.beta_6), q, squared);
q = vmlaq_f32(vdupq_n_f32(sigmoid_constants.beta_4), q, squared);
q = vmlaq_f32(vdupq_n_f32(sigmoid_constants.beta_2), q, squared);
q = vmlaq_f32(vdupq_n_f32(sigmoid_constants.beta_0), q, squared);
q = fma_ps_f32(vdupq_n_f32(sigmoid_constants.beta_6), q, squared);
q = fma_ps_f32(vdupq_n_f32(sigmoid_constants.beta_4), q, squared);
q = fma_ps_f32(vdupq_n_f32(sigmoid_constants.beta_2), q, squared);
q = fma_ps_f32(vdupq_n_f32(sigmoid_constants.beta_0), q, squared);
return vaddq_f32(div_ps_f32(p, q), vdupq_n_f32(sigmoid_constants.one_half));
}



+ 8
- 2
dnn/src/arm_common/elemwise/neon_mathfun.h View File

@@ -54,7 +54,7 @@ v4sf cos_ps_f32(v4sf x);

v4sf tan_ps_f32(v4sf x);

static inline v4sf div_ps_f32(v4sf x, v4sf y) {
static inline v4sf div_ps_f32(v4sf& x, v4sf& y) {
#if MEGDNN_AARCH64
return vdivq_f32(x, y);
#else
@@ -65,6 +65,12 @@ static inline v4sf div_ps_f32(v4sf x, v4sf y) {
#endif
}

#if defined(__ARM_FEATURE_FMA)
#define fma_ps_f32(c, b, a) vfmaq_f32((c), (a), (b))
#else
#define fma_ps_f32(c, b, a) vmlaq_f32((c), (a), (b))
#endif

v4sf sigmoid_ps_f32(v4sf x);

#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
@@ -73,7 +79,7 @@ v4sf sigmoid_ps_f32(v4sf x);
*/
float16x8_t exp_ps_f16(float16x8_t x);

static inline float16x8_t div_ps_f16(float16x8_t x, float16x8_t y) {
static inline float16x8_t div_ps_f16(float16x8_t& x, float16x8_t& y) {
#if MEGDNN_AARCH64
return vdivq_f16(x, y);
#else


Loading…
Cancel
Save