Browse Source

refactor(dnn): move arch-dependant code to arch.h

GitOrigin-RevId: 52350144b1
tags/v1.6.0-rc1
Megvii Engine Team 3 years ago
parent
commit
56c1b626bf
3 changed files with 25 additions and 29 deletions
  1. +10
    -0
      dnn/include/megdnn/arch.h
  2. +1
    -9
      dnn/src/common/resize.cuh
  3. +14
    -20
      dnn/src/common/rounding_converter.cuh

+ 10
- 0
dnn/include/megdnn/arch.h View File

@@ -140,6 +140,16 @@
#define MEGDNN_DEVICE #define MEGDNN_DEVICE
#endif #endif


#if MEGDNN_CC_CUDA
#define MEGDNN_FORCE_INLINE __forceinline__
#else
#if __GNUC__ || __has_attribute(always_inline)
#define MEGDNN_FORCE_INLINE inline __attribute__((always_inline))
#else
#define MEGDNN_FORCE_INLINE inline
#endif
#endif

#if defined(_MSC_VER) || defined(WIN32) #if defined(_MSC_VER) || defined(WIN32)
#define ATTR_ALIGNED(v) __declspec(align(v)) #define ATTR_ALIGNED(v) __declspec(align(v))
#else #else


+ 1
- 9
dnn/src/common/resize.cuh View File

@@ -13,18 +13,10 @@


#include "megdnn/arch.h" #include "megdnn/arch.h"


#if MEGDNN_CC_HOST && !defined(__host__)
#if __GNUC__ || __has_attribute(always_inline)
#define __forceinline__ inline __attribute__((always_inline))
#else
#define __forceinline__ inline
#endif
#endif

namespace megdnn { namespace megdnn {
namespace resize { namespace resize {


MEGDNN_HOST MEGDNN_DEVICE __forceinline__ void interpolate_cubic(
MEGDNN_HOST MEGDNN_DEVICE MEGDNN_FORCE_INLINE void interpolate_cubic(
float x, float* coeffs) { float x, float* coeffs) {
const float A = -0.75f; const float A = -0.75f;




+ 14
- 20
dnn/src/common/rounding_converter.cuh View File

@@ -12,17 +12,6 @@
#pragma once #pragma once
#include "megdnn/dtype.h" #include "megdnn/dtype.h"


#if MEGDNN_CC_HOST && !defined(__host__)
#define MEGDNN_HOST_DEVICE_SELF_DEFINE
#define __host__
#define __device__
#if __GNUC__ || __has_attribute(always_inline)
#define __forceinline__ inline __attribute__((always_inline))
#else
#define __forceinline__ inline
#endif
#endif

namespace megdnn { namespace megdnn {
namespace rounding { namespace rounding {


@@ -31,7 +20,8 @@ struct RoundingConverter;


template <> template <>
struct RoundingConverter<float> { struct RoundingConverter<float> {
__host__ __device__ __forceinline__ float operator()(float x) const {
MEGDNN_HOST MEGDNN_DEVICE MEGDNN_FORCE_INLINE float operator()(
float x) const {
return x; return x;
} }
}; };
@@ -40,7 +30,7 @@ struct RoundingConverter<float> {


template <> template <>
struct RoundingConverter<half_float::half> { struct RoundingConverter<half_float::half> {
__host__ __device__ __forceinline__ half_float::half operator()(
MEGDNN_HOST MEGDNN_DEVICE MEGDNN_FORCE_INLINE half_float::half operator()(
float x) const { float x) const {
return static_cast<half_float::half>(x); return static_cast<half_float::half>(x);
} }
@@ -48,8 +38,8 @@ struct RoundingConverter<half_float::half> {


template <> template <>
struct RoundingConverter<half_bfloat16::bfloat16> { struct RoundingConverter<half_bfloat16::bfloat16> {
__host__ __device__ __forceinline__ half_bfloat16::bfloat16 operator()(
float x) const {
MEGDNN_HOST MEGDNN_DEVICE MEGDNN_FORCE_INLINE half_bfloat16::bfloat16
operator()(float x) const {
return static_cast<half_bfloat16::bfloat16>(x); return static_cast<half_bfloat16::bfloat16>(x);
} }
}; };
@@ -58,7 +48,8 @@ struct RoundingConverter<half_bfloat16::bfloat16> {


template <> template <>
struct RoundingConverter<int8_t> { struct RoundingConverter<int8_t> {
__host__ __device__ __forceinline__ int8_t operator()(float x) const {
MEGDNN_HOST MEGDNN_DEVICE MEGDNN_FORCE_INLINE int8_t
operator()(float x) const {
#if MEGDNN_CC_HOST #if MEGDNN_CC_HOST
using std::round; using std::round;
#endif #endif
@@ -68,11 +59,12 @@ struct RoundingConverter<int8_t> {


template <> template <>
struct RoundingConverter<uint8_t> { struct RoundingConverter<uint8_t> {
__host__ __device__ __forceinline__ uint8_t operator()(float x) const {
MEGDNN_HOST MEGDNN_DEVICE MEGDNN_FORCE_INLINE uint8_t
operator()(float x) const {
#if MEGDNN_CC_HOST #if MEGDNN_CC_HOST
using std::round;
using std::max; using std::max;
using std::min; using std::min;
using std::round;
#endif #endif
x = min(255.0f, max(0.0f, x)); //! FIXME!!! check other places x = min(255.0f, max(0.0f, x)); //! FIXME!!! check other places
return static_cast<uint8_t>(round(x)); return static_cast<uint8_t>(round(x));
@@ -81,7 +73,8 @@ struct RoundingConverter<uint8_t> {


template <> template <>
struct RoundingConverter<dt_qint4> { struct RoundingConverter<dt_qint4> {
__host__ __device__ __forceinline__ dt_qint4 operator()(float x) const {
MEGDNN_HOST MEGDNN_DEVICE MEGDNN_FORCE_INLINE dt_qint4
operator()(float x) const {
#if MEGDNN_CC_HOST #if MEGDNN_CC_HOST
using std::round; using std::round;
#endif #endif
@@ -91,7 +84,8 @@ struct RoundingConverter<dt_qint4> {


template <> template <>
struct RoundingConverter<dt_quint4> { struct RoundingConverter<dt_quint4> {
__host__ __device__ __forceinline__ dt_quint4 operator()(float x) const {
MEGDNN_HOST MEGDNN_DEVICE MEGDNN_FORCE_INLINE dt_quint4
operator()(float x) const {
#if MEGDNN_CC_HOST #if MEGDNN_CC_HOST
using std::round; using std::round;
#endif #endif


Loading…
Cancel
Save