|
|
@@ -12,17 +12,6 @@ |
|
|
|
#pragma once |
|
|
|
#include "megdnn/dtype.h" |
|
|
|
|
|
|
|
#if MEGDNN_CC_HOST && !defined(__host__) |
|
|
|
#define MEGDNN_HOST_DEVICE_SELF_DEFINE |
|
|
|
#define __host__ |
|
|
|
#define __device__ |
|
|
|
#if __GNUC__ || __has_attribute(always_inline) |
|
|
|
#define __forceinline__ inline __attribute__((always_inline)) |
|
|
|
#else |
|
|
|
#define __forceinline__ inline |
|
|
|
#endif |
|
|
|
#endif |
|
|
|
|
|
|
|
namespace megdnn { |
|
|
|
namespace rounding { |
|
|
|
|
|
|
@@ -31,7 +20,8 @@ struct RoundingConverter; |
|
|
|
|
|
|
|
template <> |
|
|
|
struct RoundingConverter<float> { |
|
|
|
__host__ __device__ __forceinline__ float operator()(float x) const { |
|
|
|
MEGDNN_HOST MEGDNN_DEVICE MEGDNN_FORCE_INLINE float operator()( |
|
|
|
float x) const { |
|
|
|
return x; |
|
|
|
} |
|
|
|
}; |
|
|
@@ -40,7 +30,7 @@ struct RoundingConverter<float> { |
|
|
|
|
|
|
|
template <> |
|
|
|
struct RoundingConverter<half_float::half> { |
|
|
|
__host__ __device__ __forceinline__ half_float::half operator()( |
|
|
|
MEGDNN_HOST MEGDNN_DEVICE MEGDNN_FORCE_INLINE half_float::half operator()( |
|
|
|
float x) const { |
|
|
|
return static_cast<half_float::half>(x); |
|
|
|
} |
|
|
@@ -48,8 +38,8 @@ struct RoundingConverter<half_float::half> { |
|
|
|
|
|
|
|
template <> |
|
|
|
struct RoundingConverter<half_bfloat16::bfloat16> { |
|
|
|
__host__ __device__ __forceinline__ half_bfloat16::bfloat16 operator()( |
|
|
|
float x) const { |
|
|
|
MEGDNN_HOST MEGDNN_DEVICE MEGDNN_FORCE_INLINE half_bfloat16::bfloat16 |
|
|
|
operator()(float x) const { |
|
|
|
return static_cast<half_bfloat16::bfloat16>(x); |
|
|
|
} |
|
|
|
}; |
|
|
@@ -58,7 +48,8 @@ struct RoundingConverter<half_bfloat16::bfloat16> { |
|
|
|
|
|
|
|
template <> |
|
|
|
struct RoundingConverter<int8_t> { |
|
|
|
__host__ __device__ __forceinline__ int8_t operator()(float x) const { |
|
|
|
MEGDNN_HOST MEGDNN_DEVICE MEGDNN_FORCE_INLINE int8_t |
|
|
|
operator()(float x) const { |
|
|
|
#if MEGDNN_CC_HOST |
|
|
|
using std::round; |
|
|
|
#endif |
|
|
@@ -68,11 +59,12 @@ struct RoundingConverter<int8_t> { |
|
|
|
|
|
|
|
template <> |
|
|
|
struct RoundingConverter<uint8_t> { |
|
|
|
__host__ __device__ __forceinline__ uint8_t operator()(float x) const { |
|
|
|
MEGDNN_HOST MEGDNN_DEVICE MEGDNN_FORCE_INLINE uint8_t |
|
|
|
operator()(float x) const { |
|
|
|
#if MEGDNN_CC_HOST |
|
|
|
using std::round; |
|
|
|
using std::max; |
|
|
|
using std::min; |
|
|
|
using std::round; |
|
|
|
#endif |
|
|
|
x = min(255.0f, max(0.0f, x)); //! FIXME!!! check other places |
|
|
|
return static_cast<uint8_t>(round(x)); |
|
|
@@ -81,7 +73,8 @@ struct RoundingConverter<uint8_t> { |
|
|
|
|
|
|
|
template <> |
|
|
|
struct RoundingConverter<dt_qint4> { |
|
|
|
__host__ __device__ __forceinline__ dt_qint4 operator()(float x) const { |
|
|
|
MEGDNN_HOST MEGDNN_DEVICE MEGDNN_FORCE_INLINE dt_qint4 |
|
|
|
operator()(float x) const { |
|
|
|
#if MEGDNN_CC_HOST |
|
|
|
using std::round; |
|
|
|
#endif |
|
|
@@ -91,7 +84,8 @@ struct RoundingConverter<dt_qint4> { |
|
|
|
|
|
|
|
template <> |
|
|
|
struct RoundingConverter<dt_quint4> { |
|
|
|
__host__ __device__ __forceinline__ dt_quint4 operator()(float x) const { |
|
|
|
MEGDNN_HOST MEGDNN_DEVICE MEGDNN_FORCE_INLINE dt_quint4 |
|
|
|
operator()(float x) const { |
|
|
|
#if MEGDNN_CC_HOST |
|
|
|
using std::round; |
|
|
|
#endif |
|
|
|