Browse Source

refactor(dnn/rocm): remove some useless includes

GitOrigin-RevId: 3d2c315a36
release-1.6
Megvii Engine Team 3 years ago
parent
commit
83cf4ee64e
5 changed files with 21 additions and 32 deletions
  1. +1
    -1
      dnn/src/rocm/argsort/argsort.cpp.hip
  2. +1
    -0
      dnn/src/rocm/argsort/argsort.h.hip
  3. +0
    -2
      dnn/src/rocm/argsort/backward.cpp.hip
  4. +13
    -17
      dnn/src/rocm/argsort/bitonic_sort.cpp.hip
  5. +6
    -12
      dnn/src/rocm/topk/topk_radix.cpp.hip

+ 1
- 1
dnn/src/rocm/argsort/argsort.cpp.hip View File

@@ -174,7 +174,7 @@ template void argsort::forward<dtype>(const dtype*, dtype*, int*, void*, \
ARGSORT_FOREACH_CTYPE(INST_FORWARD) ARGSORT_FOREACH_CTYPE(INST_FORWARD)
INST_CUB_SORT(uint32_t) INST_CUB_SORT(uint32_t)
// INST_CUB_SORT(uint64_t)
INST_CUB_SORT(uint64_t)
#undef INST_CUB_SORT #undef INST_CUB_SORT
#undef INST_FORWARD #undef INST_FORWARD
} }


+ 1
- 0
dnn/src/rocm/argsort/argsort.h.hip View File

@@ -40,6 +40,7 @@ void forward(const dtype* sptr, dtype* dptr, int* iptr, void* workspace,
const int* iptr_src = NULL); const int* iptr_src = NULL);


//! iterate over all supported data types //! iterate over all supported data types
// device_radix_sort does not support dt_float16 dtype(half_float::half in rocm)
#define ARGSORT_FOREACH_CTYPE(cb) \ #define ARGSORT_FOREACH_CTYPE(cb) \
cb(float) cb(int32_t) // DNN_INC_FLOAT16(cb(dt_float16)) cb(float) cb(int32_t) // DNN_INC_FLOAT16(cb(dt_float16))




+ 0
- 2
dnn/src/rocm/argsort/backward.cpp.hip View File

@@ -14,8 +14,6 @@
#include "./argsort.h.hip" #include "./argsort.h.hip"
#include "./backward.h.hip" #include "./backward.h.hip"


// #include "src/rocm/utils.h"

using namespace megdnn; using namespace megdnn;
using namespace rocm; using namespace rocm;
using namespace argsort; using namespace argsort;


+ 13
- 17
dnn/src/rocm/argsort/bitonic_sort.cpp.hip View File

@@ -11,13 +11,9 @@
#include "hcc_detail/hcc_defs_prologue.h" #include "hcc_detail/hcc_defs_prologue.h"


#include "./bitonic_sort.h.hip" #include "./bitonic_sort.h.hip"
// #include "src/cuda/query_blocksize.cuh"
// #include "megdnn/dtype.h"
#include "megdnn/dtype.h"


// #if __CUDACC_VER_MAJOR__ < 9
// #pragma message "warp sync disabled due to insufficient cuda version"
#define __syncwarp __syncthreads #define __syncwarp __syncthreads
// #endif


#include <algorithm> #include <algorithm>
#include <cmath> #include <cmath>
@@ -84,17 +80,17 @@ struct NumTrait<int32_t> {
static __device__ __forceinline__ int32_t min() { return INT_MIN; } static __device__ __forceinline__ int32_t min() { return INT_MIN; }
}; };


// #if !MEGDNN_DISABLE_FLOAT16
// template <>
// struct NumTrait<dt_float16> {
// static __device__ __forceinline__ dt_float16 max() {
// return std::numeric_limits<dt_float16>::max();
// }
// static __device__ __forceinline__ dt_float16 min() {
// return std::numeric_limits<dt_float16>::lowest();
// }
// };
// #endif
#if !MEGDNN_DISABLE_FLOAT16
template <>
struct NumTrait<dt_float16> {
static __device__ __forceinline__ dt_float16 max() {
return std::numeric_limits<dt_float16>::max();
}
static __device__ __forceinline__ dt_float16 min() {
return std::numeric_limits<dt_float16>::lowest();
}
};
#endif


struct LessThan { struct LessThan {
template <typename Key, typename Value> template <typename Key, typename Value>
@@ -310,7 +306,7 @@ namespace rocm {


INST(float, int); INST(float, int);
INST(int32_t, int); INST(int32_t, int);
// DNN_INC_FLOAT16(INST(dt_float16, int));
DNN_INC_FLOAT16(INST(dt_float16, int));
#undef INST #undef INST


} // namespace megdnn } // namespace megdnn


+ 6
- 12
dnn/src/rocm/topk/topk_radix.cpp.hip View File

@@ -18,13 +18,7 @@
#include <algorithm> #include <algorithm>
#include <cmath> #include <cmath>


#if __CUDACC_VER_MAJOR__ < 9
#pragma message "topk is a little slower on cuda earlier than 9.0"
// on cuda 9.0 and later, due to thread-divergent branches we should use
// __syncwarp; and I am too lazy to implement a correct legacy version, so just
// use __syncthreads instead for older cuda
#define __syncwarp __syncthreads #define __syncwarp __syncthreads
#endif


using namespace megdnn; using namespace megdnn;
using namespace rocm; using namespace rocm;
@@ -256,12 +250,12 @@ static __global__ void update_prefix_and_k(const uint32_t* bucket_cnt,
} }
} }


//if ((cumsum_bucket_cnt[NR_BUCKET] < kv) |
// (cumsum_bucket_cnt[i] != cumsum_bucket_cnt[i - 1] + sum)) {
// // impossible
// int* bad = 0x0;
// *bad = 23;
//}
if ((cumsum_bucket_cnt[NR_BUCKET] < kv) |
(cumsum_bucket_cnt[i] != cumsum_bucket_cnt[i - 1] + sum)) {
// impossible
int* bad = 0x0;
*bad = 23;
}
} }


static uint32_t get_grid_dim_x(uint32_t length) { static uint32_t get_grid_dim_x(uint32_t length) {


Loading…
Cancel
Save