|
@@ -26,6 +26,8 @@ |
|
|
#include <algorithm> |
|
|
#include <algorithm> |
|
|
using std::max; |
|
|
using std::max; |
|
|
using std::min; |
|
|
using std::min; |
|
|
|
|
|
|
|
|
|
|
|
#define rsqrtf(x) (1.f / sqrt(x)) |
|
|
#endif |
|
|
#endif |
|
|
|
|
|
|
|
|
#ifndef MEGDNN_ELEMWISE_MODE_ENABLE |
|
|
#ifndef MEGDNN_ELEMWISE_MODE_ENABLE |
|
@@ -93,6 +95,30 @@ __device__ __host__ inline float gelu_grad(float x, float dy) { |
|
|
return dy * (normcdf_v + x * phi); |
|
|
return dy * (normcdf_v + x * phi); |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
__device__ __host__ inline bool feq(float a, float b) { |
|
|
|
|
|
return fabsf(a - b) < 1e-6; |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
__device__ __host__ inline float dispatch_powf(float x, float y) { |
|
|
|
|
|
#define CALL_IF(_v, _stmt) \ |
|
|
|
|
|
do { \ |
|
|
|
|
|
if (feq(y, _v)) { \ |
|
|
|
|
|
return _stmt; \ |
|
|
|
|
|
} \ |
|
|
|
|
|
} while (0) |
|
|
|
|
|
|
|
|
|
|
|
CALL_IF(2.f, x * x); |
|
|
|
|
|
CALL_IF(0.5f, sqrtf(x)); |
|
|
|
|
|
CALL_IF(-0.5f, rsqrtf(x)); |
|
|
|
|
|
CALL_IF(0.f, 1.f); |
|
|
|
|
|
CALL_IF(1.f, x); |
|
|
|
|
|
CALL_IF(3.f, x * x * x); |
|
|
|
|
|
CALL_IF(-1.f, 1.f / x); |
|
|
|
|
|
CALL_IF(-2.f, 1.f / (x * x)); |
|
|
|
|
|
#undef CALL_IF |
|
|
|
|
|
return powf(x, y); |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
#include "src/common/elemwise/each_mode.inl" |
|
|
#include "src/common/elemwise/each_mode.inl" |
|
|
|
|
|
|
|
|
template <megcorePlatform_t plat, uint32_t mode, typename dtype> |
|
|
template <megcorePlatform_t plat, uint32_t mode, typename dtype> |
|
|