Browse Source

refactor(dnn): move arch-dependant code to arch.h

GitOrigin-RevId: 52350144b1
tags/v1.6.0-rc1
Megvii Engine Team 3 years ago
parent
commit
56c1b626bf
3 changed files with 25 additions and 29 deletions
  1. +10
    -0
      dnn/include/megdnn/arch.h
  2. +1
    -9
      dnn/src/common/resize.cuh
  3. +14
    -20
      dnn/src/common/rounding_converter.cuh

+ 10
- 0
dnn/include/megdnn/arch.h View File

@@ -140,6 +140,16 @@
#define MEGDNN_DEVICE
#endif

#if MEGDNN_CC_CUDA
#define MEGDNN_FORCE_INLINE __forceinline__
#else
#if __GNUC__ || __has_attribute(always_inline)
#define MEGDNN_FORCE_INLINE inline __attribute__((always_inline))
#else
#define MEGDNN_FORCE_INLINE inline
#endif
#endif

#if defined(_MSC_VER) || defined(WIN32)
#define ATTR_ALIGNED(v) __declspec(align(v))
#else


+ 1
- 9
dnn/src/common/resize.cuh View File

@@ -13,18 +13,10 @@

#include "megdnn/arch.h"

#if MEGDNN_CC_HOST && !defined(__host__)
#if __GNUC__ || __has_attribute(always_inline)
#define __forceinline__ inline __attribute__((always_inline))
#else
#define __forceinline__ inline
#endif
#endif

namespace megdnn {
namespace resize {

MEGDNN_HOST MEGDNN_DEVICE __forceinline__ void interpolate_cubic(
MEGDNN_HOST MEGDNN_DEVICE MEGDNN_FORCE_INLINE void interpolate_cubic(
float x, float* coeffs) {
const float A = -0.75f;



+ 14
- 20
dnn/src/common/rounding_converter.cuh View File

@@ -12,17 +12,6 @@
#pragma once
#include "megdnn/dtype.h"

#if MEGDNN_CC_HOST && !defined(__host__)
#define MEGDNN_HOST_DEVICE_SELF_DEFINE
#define __host__
#define __device__
#if __GNUC__ || __has_attribute(always_inline)
#define __forceinline__ inline __attribute__((always_inline))
#else
#define __forceinline__ inline
#endif
#endif

namespace megdnn {
namespace rounding {

@@ -31,7 +20,8 @@ struct RoundingConverter;

template <>
struct RoundingConverter<float> {
__host__ __device__ __forceinline__ float operator()(float x) const {
MEGDNN_HOST MEGDNN_DEVICE MEGDNN_FORCE_INLINE float operator()(
float x) const {
return x;
}
};
@@ -40,7 +30,7 @@ struct RoundingConverter<float> {

template <>
struct RoundingConverter<half_float::half> {
__host__ __device__ __forceinline__ half_float::half operator()(
MEGDNN_HOST MEGDNN_DEVICE MEGDNN_FORCE_INLINE half_float::half operator()(
float x) const {
return static_cast<half_float::half>(x);
}
@@ -48,8 +38,8 @@ struct RoundingConverter<half_float::half> {

template <>
struct RoundingConverter<half_bfloat16::bfloat16> {
__host__ __device__ __forceinline__ half_bfloat16::bfloat16 operator()(
float x) const {
MEGDNN_HOST MEGDNN_DEVICE MEGDNN_FORCE_INLINE half_bfloat16::bfloat16
operator()(float x) const {
return static_cast<half_bfloat16::bfloat16>(x);
}
};
@@ -58,7 +48,8 @@ struct RoundingConverter<half_bfloat16::bfloat16> {

template <>
struct RoundingConverter<int8_t> {
__host__ __device__ __forceinline__ int8_t operator()(float x) const {
MEGDNN_HOST MEGDNN_DEVICE MEGDNN_FORCE_INLINE int8_t
operator()(float x) const {
#if MEGDNN_CC_HOST
using std::round;
#endif
@@ -68,11 +59,12 @@ struct RoundingConverter<int8_t> {

template <>
struct RoundingConverter<uint8_t> {
__host__ __device__ __forceinline__ uint8_t operator()(float x) const {
MEGDNN_HOST MEGDNN_DEVICE MEGDNN_FORCE_INLINE uint8_t
operator()(float x) const {
#if MEGDNN_CC_HOST
using std::round;
using std::max;
using std::min;
using std::round;
#endif
x = min(255.0f, max(0.0f, x)); //! FIXME!!! check other places
return static_cast<uint8_t>(round(x));
@@ -81,7 +73,8 @@ struct RoundingConverter<uint8_t> {

template <>
struct RoundingConverter<dt_qint4> {
__host__ __device__ __forceinline__ dt_qint4 operator()(float x) const {
MEGDNN_HOST MEGDNN_DEVICE MEGDNN_FORCE_INLINE dt_qint4
operator()(float x) const {
#if MEGDNN_CC_HOST
using std::round;
#endif
@@ -91,7 +84,8 @@ struct RoundingConverter<dt_qint4> {

template <>
struct RoundingConverter<dt_quint4> {
__host__ __device__ __forceinline__ dt_quint4 operator()(float x) const {
MEGDNN_HOST MEGDNN_DEVICE MEGDNN_FORCE_INLINE dt_quint4
operator()(float x) const {
#if MEGDNN_CC_HOST
using std::round;
#endif


Loading…
Cancel
Save