|
|
@@ -11,13 +11,9 @@ |
|
|
|
#include "hcc_detail/hcc_defs_prologue.h" |
|
|
|
|
|
|
|
#include "./bitonic_sort.h.hip" |
|
|
|
// #include "src/cuda/query_blocksize.cuh" |
|
|
|
// #include "megdnn/dtype.h" |
|
|
|
#include "megdnn/dtype.h" |
|
|
|
|
|
|
|
// #if __CUDACC_VER_MAJOR__ < 9 |
|
|
|
// #pragma message "warp sync disabled due to insufficient cuda version" |
|
|
|
#define __syncwarp __syncthreads |
|
|
|
// #endif |
|
|
|
|
|
|
|
#include <algorithm> |
|
|
|
#include <cmath> |
|
|
@@ -84,17 +80,17 @@ struct NumTrait<int32_t> { |
|
|
|
static __device__ __forceinline__ int32_t min() { return INT_MIN; } |
|
|
|
}; |
|
|
|
|
|
|
|
// #if !MEGDNN_DISABLE_FLOAT16 |
|
|
|
// template <> |
|
|
|
// struct NumTrait<dt_float16> { |
|
|
|
// static __device__ __forceinline__ dt_float16 max() { |
|
|
|
// return std::numeric_limits<dt_float16>::max(); |
|
|
|
// } |
|
|
|
// static __device__ __forceinline__ dt_float16 min() { |
|
|
|
// return std::numeric_limits<dt_float16>::lowest(); |
|
|
|
// } |
|
|
|
// }; |
|
|
|
// #endif |
|
|
|
#if !MEGDNN_DISABLE_FLOAT16 |
|
|
|
template <> |
|
|
|
struct NumTrait<dt_float16> { |
|
|
|
static __device__ __forceinline__ dt_float16 max() { |
|
|
|
return std::numeric_limits<dt_float16>::max(); |
|
|
|
} |
|
|
|
static __device__ __forceinline__ dt_float16 min() { |
|
|
|
return std::numeric_limits<dt_float16>::lowest(); |
|
|
|
} |
|
|
|
}; |
|
|
|
#endif |
|
|
|
|
|
|
|
struct LessThan { |
|
|
|
template <typename Key, typename Value> |
|
|
@@ -310,7 +306,7 @@ namespace rocm { |
|
|
|
|
|
|
|
INST(float, int); |
|
|
|
INST(int32_t, int); |
|
|
|
// DNN_INC_FLOAT16(INST(dt_float16, int)); |
|
|
|
DNN_INC_FLOAT16(INST(dt_float16, int)); |
|
|
|
#undef INST |
|
|
|
|
|
|
|
} // namespace megdnn |
|
|
|