Browse Source

fix(dnn): specialize pow to make it consistent

GitOrigin-RevId: cff3bbbadd
release-1.7
Megvii Engine Team 3 years ago
parent
commit
c50858ee13
1 changed files with 26 additions and 0 deletions
  1. +26
    -0
      dnn/src/common/elemwise/kern_defs.cuh

+ 26
- 0
dnn/src/common/elemwise/kern_defs.cuh View File

@@ -26,6 +26,8 @@
#include <algorithm>
using std::max;
using std::min;

#define rsqrtf(x) (1.f / sqrt(x))
#endif

#ifndef MEGDNN_ELEMWISE_MODE_ENABLE
@@ -93,6 +95,30 @@ __device__ __host__ inline float gelu_grad(float x, float dy) {
return dy * (normcdf_v + x * phi);
}

__device__ __host__ inline bool feq(float a, float b) {
return fabsf(a - b) < 1e-6;
}

__device__ __host__ inline float dispatch_powf(float x, float y) {
#define CALL_IF(_v, _stmt) \
do { \
if (feq(y, _v)) { \
return _stmt; \
} \
} while (0)

CALL_IF(2.f, x * x);
CALL_IF(0.5f, sqrtf(x));
CALL_IF(-0.5f, rsqrtf(x));
CALL_IF(0.f, 1.f);
CALL_IF(1.f, x);
CALL_IF(3.f, x * x * x);
CALL_IF(-1.f, 1.f / x);
CALL_IF(-2.f, 1.f / (x * x));
#undef CALL_IF
return powf(x, y);
}

#include "src/common/elemwise/each_mode.inl"

template <megcorePlatform_t plat, uint32_t mode, typename dtype>


Loading…
Cancel
Save